├── requirements.txt ├── models ├── __pycache__ │ ├── GAT.cpython-38.pyc │ ├── GIN.cpython-38.pyc │ ├── SGCN.cpython-38.pyc │ ├── SGMP.cpython-38.pyc │ ├── Dimenet.cpython-38.pyc │ ├── PPFNet.cpython-38.pyc │ ├── Schnet.cpython-38.pyc │ ├── PointNet.cpython-38.pyc │ ├── GatedGraphConv.cpython-38.pyc │ └── dimenet_utils.cpython-38.pyc ├── SGCN.py ├── GatedGraphConv.py ├── GIN.py ├── PointNet.py ├── GAT.py ├── PPFNet.py ├── dimenet_utils.py ├── Schnet.py ├── SGMP.py └── Dimenet.py ├── utils ├── __pycache__ │ ├── utils.cpython-38.pyc │ └── moleculenet.cpython-38.pyc ├── brain_load_data.py ├── utils.py └── moleculenet.py ├── run_SGMP_BACE.sh ├── run_PointNet_BACE.sh ├── run_PointNet_QM9.sh ├── run_SGMP_QM9.sh ├── run_SGMP_st_BACE.py ├── run_SGMP_st_QM9.sh ├── README.md ├── data └── build_synthetic_data.py ├── main_base.py └── main_base_st.py /requirements.txt: -------------------------------------------------------------------------------- 1 | pytorch 1.7.1 2 | pytorch-geometric 1.6.3 3 | rdkit 4 | numpy 5 | scipy 6 | scikit 7 | tqdm -------------------------------------------------------------------------------- /models/__pycache__/GAT.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rollingstonezz/SGMP_code/HEAD/models/__pycache__/GAT.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/GIN.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rollingstonezz/SGMP_code/HEAD/models/__pycache__/GIN.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/SGCN.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rollingstonezz/SGMP_code/HEAD/models/__pycache__/SGCN.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/SGMP.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rollingstonezz/SGMP_code/HEAD/models/__pycache__/SGMP.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rollingstonezz/SGMP_code/HEAD/utils/__pycache__/utils.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/Dimenet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rollingstonezz/SGMP_code/HEAD/models/__pycache__/Dimenet.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/PPFNet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rollingstonezz/SGMP_code/HEAD/models/__pycache__/PPFNet.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/Schnet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rollingstonezz/SGMP_code/HEAD/models/__pycache__/Schnet.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/PointNet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rollingstonezz/SGMP_code/HEAD/models/__pycache__/PointNet.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/moleculenet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rollingstonezz/SGMP_code/HEAD/utils/__pycache__/moleculenet.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/GatedGraphConv.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rollingstonezz/SGMP_code/HEAD/models/__pycache__/GatedGraphConv.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/dimenet_utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rollingstonezz/SGMP_code/HEAD/models/__pycache__/dimenet_utils.cpython-38.pyc -------------------------------------------------------------------------------- /run_SGMP_BACE.sh: -------------------------------------------------------------------------------- 1 | python3 main_base.py --save_dir ./results \ 2 | --data_dir ./data \ 3 | --model SGMP \ 4 | --dataset BACE \ 5 | --split 811 \ 6 | --device gpu \ 7 | --random_seed 1 \ 8 | --batch_size 64 \ 9 | --epoch 500 \ 10 | --lr 1e-3 \ 11 | --test_per_round 5 \ 12 | --num_layers 3 -------------------------------------------------------------------------------- /run_PointNet_BACE.sh: -------------------------------------------------------------------------------- 1 | python3 main_base.py --save_dir ./results \ 2 | --data_dir ./data \ 3 | --model PointNet \ 4 | --dataset BACE \ 5 | --split 811 \ 6 | --device gpu \ 7 | --random_seed 1 \ 8 | --batch_size 64 \ 9 | --epoch 500 \ 10 | --lr 1e-3 \ 11 | --test_per_round 5 \ 12 | --num_layers 3 -------------------------------------------------------------------------------- /run_PointNet_QM9.sh: -------------------------------------------------------------------------------- 1 | python3 main_base.py --save_dir ./results \ 2 | --data_dir ./data \ 3 | --model PointNet \ 4 | --dataset QM9 \ 5 | --split 811 \ 6 | --device gpu \ 7 | --random_seed 1 \ 8 | --batch_size 64 \ 9 | --epoch 500 \ 10 | --lr 1e-3 \ 11 | --test_per_round 5 \ 12 | --label 0 \ 13 | --num_layers 3 -------------------------------------------------------------------------------- /run_SGMP_QM9.sh: -------------------------------------------------------------------------------- 1 | python3 main_base.py --save_dir ./results \ 2 | --data_dir ./data \ 3 | --model SGMP \ 4 | --dataset QM9 \ 5 | --split 811 \ 6 | --device gpu \ 7 | --random_seed 1 \ 8 | --batch_size 64 \ 9 | --epoch 500 \ 10 | --lr 1e-3 \ 11 | --test_per_round 5 \ 12 | --label 0 \ 13 | --num_layers 3 -------------------------------------------------------------------------------- /run_SGMP_st_BACE.py: -------------------------------------------------------------------------------- 1 | python3 main_base_st.py --save_dir ./results \ 2 | --data_dir ./data \ 3 | --model SGMP \ 4 | --dataset BACE \ 5 | --split 811 \ 6 | --device gpu \ 7 | --random_seed 1 \ 8 | --batch_size 64 \ 9 | --epoch 500 \ 10 | --lr 1e-3 \ 11 | --test_per_round 5 \ 12 | --spanning_tree True \ 13 | --num_layers 3 14 | -------------------------------------------------------------------------------- /run_SGMP_st_QM9.sh: -------------------------------------------------------------------------------- 1 | python3 main_base_st.py --save_dir ./results \ 2 | --data_dir ./data \ 3 | --model SGMP \ 4 | --dataset QM9 \ 5 | --split 811 \ 6 | --device gpu \ 7 | --random_seed 1 \ 8 | --batch_size 64 \ 9 | --epoch 500 \ 10 | --lr 1e-3 \ 11 | --test_per_round 5 \ 12 | --label 0 \ 13 | --spanning_tree True \ 14 | --num_layers 3 -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Representation Learning on Spatial Networks 2 | 3 | This repository is the official implementation of [Representation Learning on Spatial Networks](). 4 | 5 | 6 | # Training 7 | 8 | ## Examples of running the models: 9 | 10 | **1. training the SGMP model on BACE dataset: 11 | - bash run_SGMP_BACE.sh 12 | 13 | **2. training the SGMP model with sampling spanning trees on BACE dataset: 14 | - bash run_SGMP_st_BACE.sh 15 | 16 | **3. training the PointNet benchmark model on BACE dataset: 17 | - bash run_PointNet_BACE.sh 18 | 19 | **4. training the SGMP model on QM9 dataset for target 0 ($\mu$): 20 | - bash run_SGMP_QM9.sh 21 | 22 | **5. training the SGMP model with sampling spanning trees on QM9 dataset for target 0 ($\mu$): 23 | - bash run_SGMP_st_QM9.sh 24 | 25 | **6. training the PointNet benchmark model on QM9 dataset for target 0 ($\mu$): 26 | - bash run_PointNet_QM9.sh 27 | 28 | **7. generate synthetic dataset 29 | - cd data 30 | - python build_synthetic_data.py 31 | 32 | 33 | ## Evaluation 34 | 35 | The evaluation will be given in the ./results 36 | 37 | 38 | 39 | 40 | -------------------------------------------------------------------------------- /utils/brain_load_data.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import os 4 | import networkx as nx 5 | import matplotlib.pyplot as plt 6 | from torch_geometric.data import Data, DataLoader 7 | import torch 8 | import pandas as pd 9 | 10 | def load_brain_data(data_dir='./data', structure='sc', target='ReadEng_Unadj', threshold=1e5, random_seed=12345): 11 | np.random.seed(random_seed) 12 | coor_PATH = os.path.join(data_dir, 'Brain', 'ROI_coordinates') 13 | graph_PATH = os.path.join(data_dir, 'Brain', 'SC_FC_dataset_0905') 14 | df = pd.read_csv(os.path.join(graph_PATH, 'targets.csv')) 15 | 16 | with open(os.path.join(coor_PATH, 'coordinates_mean.npy'), 'rb') as f: 17 | coordinates = np.load(f) 18 | 19 | task_list = ['EMOTION','GAMBLING','LANGUAGE','MOTOR','RELATIONAL','SOCIAL','WM'] 20 | 21 | age_dict = {'22-25':0, '26-30':1, '31-35':2, '36+':3} 22 | 23 | dataset = [] 24 | subjects_list = [] 25 | for task in task_list: 26 | inputFile = 'SC_'+task+'_correlationFC.npz' 27 | data = np.load(os.path.join(graph_PATH, inputFile)) 28 | if structure == 'fc': 29 | struc = data['fc'] if task != 'RESTINGSTATE' else data['rawFC'] 30 | elif structure == 'sc': 31 | struc = data['sc'] if task != 'RESTINGSTATE' else data['rawSC'] 32 | subjects = data['subjects'] 33 | for i in np.random.permutation(len(subjects)): 34 | subject = subjects[i] 35 | if not subject in subjects_list: 36 | y_temp = df.loc[df['Subject']==subject, target].values[0] 37 | y0 = age_dict[y_temp] 38 | 39 | item = Data( 40 | x=torch.arange(68, dtype=torch.long), 41 | pos=torch.tensor(coordinates, dtype=torch.float), 42 | edge_index=torch.tensor(np.concatenate(np.nonzero((struc[i]>threshold)|(struc[i]<-threshold))).reshape(2,-1), dtype=torch.long), 43 | edge_attr=torch.tensor(struc[i][(struc[i]>threshold)|(struc[i]<-threshold)].reshape(-1), dtype=torch.float), 44 | y=torch.tensor([[y0]], dtype=torch.float), 45 | name_struc=structure, 46 | ) 47 | dataset.append(item) 48 | subjects_list.append(subject) 49 | 50 | return dataset -------------------------------------------------------------------------------- /models/SGCN.py: -------------------------------------------------------------------------------- 1 | # benchmark sgcn 2 | import torch 3 | from torch import Tensor 4 | import torch.nn.functional as F 5 | from torch.nn import Embedding, Sequential, Linear, ModuleList, ReLU, Parameter 6 | from typing import Callable, Union 7 | from torch_geometric.typing import OptPairTensor, Adj, OptTensor, Size 8 | from torch_geometric.nn import MessagePassing 9 | from torch_scatter import scatter 10 | from torch_geometric.utils import softmax 11 | 12 | class SConv(MessagePassing): 13 | def __init__(self, hidden_channels, num_gaussians): 14 | super(SConv, self).__init__(aggr='mean') 15 | self.mlp1 = Sequential( 16 | Linear(hidden_channels*2, hidden_channels), 17 | torch.nn.ReLU(), 18 | Linear(hidden_channels, hidden_channels), 19 | ) 20 | self.mlp2 = Sequential( 21 | Linear(hidden_channels*2, hidden_channels), 22 | torch.nn.ReLU(), 23 | Linear(hidden_channels, hidden_channels), 24 | ) 25 | self.mlp3 = Sequential( 26 | Linear(3, hidden_channels), 27 | torch.nn.ReLU(), 28 | Linear(hidden_channels, hidden_channels), 29 | ) 30 | 31 | 32 | def forward(self, x, pos, edge_index): 33 | h = self.propagate(edge_index, x=pos, h=x) 34 | x = torch.cat([h, x], dim=1) 35 | x = self.mlp1(x) 36 | return x 37 | 38 | def message(self, x_i, x_j, h_j): 39 | dist = (x_j-x_i) 40 | spatial = self.mlp3(dist) 41 | temp = torch.cat([h_j, spatial], dim=1) 42 | return self.mlp2(temp) 43 | 44 | class SGCN(torch.nn.Module): 45 | 46 | def __init__(self,input_channels_node=1, hidden_channels=128, output_channels=1, readout='add', num_layers=3): 47 | super(SGCN, self).__init__() 48 | 49 | assert readout in ['add', 'sum', 'mean'] 50 | 51 | self.readout = readout 52 | self.node_lin = Sequential( 53 | Linear(input_channels_node, hidden_channels), 54 | ReLU(), 55 | Linear(hidden_channels, hidden_channels) 56 | ) 57 | self.num_layers = num_layers 58 | self.interactions = ModuleList() 59 | for _ in range(num_layers): 60 | block = SConv(hidden_channels, num_gaussians=hidden_channels) 61 | self.interactions.append(block) 62 | 63 | self.lin1 = Linear(hidden_channels, hidden_channels // 2) 64 | self.lin2 = Linear(hidden_channels // 2, output_channels) 65 | self.reset_parameters() 66 | 67 | def reset_parameters(self): 68 | torch.nn.init.xavier_uniform_(self.node_lin[0].weight) 69 | self.node_lin[0].bias.data.fill_(0) 70 | torch.nn.init.xavier_uniform_(self.node_lin[2].weight) 71 | self.node_lin[2].bias.data.fill_(0) 72 | 73 | def forward(self, x, pos, edge_index, batch): 74 | 75 | x = self.node_lin(x) 76 | for block in self.interactions: 77 | x = block(x, pos, edge_index) 78 | x = x.relu() 79 | 80 | x = self.lin1(x) 81 | x = x.relu() 82 | x = self.lin2(x) 83 | out = scatter(x, batch, dim=0, reduce=self.readout) 84 | 85 | return out 86 | 87 | class GaussianSmearing(torch.nn.Module): 88 | def __init__(self, start=0.0, stop=5.0, num_gaussians=50): 89 | super(GaussianSmearing, self).__init__() 90 | offset = torch.linspace(start, stop, num_gaussians) 91 | self.coeff = -0.5 / (offset[1] - offset[0]).item()**2 92 | self.register_buffer('offset', offset) 93 | 94 | def forward(self, dist): 95 | dist = dist.view(-1, 1) - self.offset.view(1, -1) 96 | return torch.exp(self.coeff * torch.pow(dist, 2)) -------------------------------------------------------------------------------- /models/GatedGraphConv.py: -------------------------------------------------------------------------------- 1 | ######################################################## 2 | ### modified from the version by pytorch-geometric ### 3 | ######################################################## 4 | import torch 5 | from torch import Tensor 6 | from torch.nn import Parameter as Param 7 | from torch.nn import Embedding, Sequential, Linear, ModuleList, ReLU, Parameter 8 | from torch_geometric.typing import Adj, OptTensor 9 | from torch_sparse import SparseTensor, matmul 10 | from torch_geometric.nn.conv import MessagePassing 11 | 12 | from torch_scatter import scatter 13 | 14 | class GatedGraphConv(MessagePassing): 15 | def __init__(self, output_channels: int, num_layers: int, aggr: str = 'add', **kwargs): 16 | super(GatedGraphConv, self).__init__(aggr=aggr, **kwargs) 17 | 18 | self.output_channels = output_channels 19 | self.num_layers = num_layers 20 | 21 | self.weight = Param(Tensor(num_layers, output_channels, output_channels)) 22 | self.rnn = torch.nn.GRUCell(output_channels, output_channels) 23 | 24 | self.edge_weight = Linear(output_channels, 1) 25 | 26 | self.reset_parameters() 27 | 28 | def reset_parameters(self): 29 | self.rnn.reset_parameters() 30 | self.edge_weight.reset_parameters() 31 | torch.nn.init.xavier_uniform_(self.weight) 32 | torch.nn.init.xavier_uniform_(self.edge_weight.weight) 33 | self.edge_weight.bias.data.fill_(0) 34 | 35 | def forward(self,x, pos, edge_index): 36 | """""" 37 | for i in range(self.num_layers): 38 | m = torch.matmul(x, self.weight[i]) 39 | # propagate_type: (x: Tensor, edge_weight: OptTensor) 40 | m = self.propagate(edge_index, x=m, edge_weight=None, size=None) 41 | x = self.rnn(m, x) 42 | 43 | return x 44 | 45 | def message(self, x_j: Tensor, edge_weight: OptTensor): 46 | return x_j if edge_weight is None else self.edge_weight(edge_weight) * x_j 47 | 48 | def message_and_aggregate(self, adj_t: SparseTensor, x: Tensor) -> Tensor: 49 | return matmul(adj_t, x, reduce=self.aggr) 50 | 51 | def __repr__(self): 52 | return '{}({}, num_layers={})'.format(self.__class__.__name__, 53 | self.out_channels, 54 | self.num_layers) 55 | 56 | class GatedNet(torch.nn.Module): 57 | def __init__(self, input_channels_node, hidden_channels, output_channels, readout='add', num_layers=3): 58 | super(GatedNet, self).__init__() 59 | self.readout = readout 60 | self.node_lin = Sequential( 61 | Linear(input_channels_node+3, hidden_channels), 62 | ReLU(), 63 | Linear(hidden_channels, hidden_channels) 64 | ) 65 | self.num_layers = num_layers 66 | self.convs = ModuleList() 67 | for i in range(self.num_layers): 68 | conv = GatedGraphConv(hidden_channels, num_layers=num_layers) 69 | self.convs.append(conv) 70 | 71 | self.lin1 = Linear(hidden_channels, hidden_channels//2) 72 | self.lin2 = Linear(hidden_channels//2, output_channels) 73 | self.reset_parameters() 74 | 75 | def reset_parameters(self): 76 | torch.nn.init.xavier_uniform_(self.node_lin[0].weight) 77 | self.node_lin[0].bias.data.fill_(0) 78 | torch.nn.init.xavier_uniform_(self.node_lin[2].weight) 79 | self.node_lin[2].bias.data.fill_(0) 80 | 81 | 82 | def forward(self, x, pos, edge_index, batch): 83 | x = torch.cat([x, pos], dim=1) 84 | x = self.node_lin(x) 85 | for i in range(self.num_layers): 86 | x = self.convs[i](x, pos, edge_index) 87 | x = x.relu() 88 | 89 | x = self.lin1(x) 90 | x = x.relu() 91 | x = self.lin2(x) 92 | x = scatter(x, batch, dim=0, reduce=self.readout) 93 | 94 | return x -------------------------------------------------------------------------------- /models/GIN.py: -------------------------------------------------------------------------------- 1 | ######################################################## 2 | ### modified from the version by pytorch-geometric ### 3 | ######################################################## 4 | import torch 5 | from torch import Tensor 6 | from torch.nn import Embedding, Sequential, Linear, ModuleList, ReLU 7 | from typing import Callable, Union 8 | from torch_geometric.typing import OptPairTensor, Adj, OptTensor, Size 9 | from torch_geometric.nn import MessagePassing 10 | from torch_scatter import scatter 11 | 12 | class GINConv(MessagePassing): 13 | def __init__(self, nn: Callable, eps: float = 0., train_eps: bool = False, 14 | **kwargs): 15 | kwargs.setdefault('aggr', 'add') 16 | super(GINConv, self).__init__(**kwargs) 17 | self.nn = nn 18 | self.initial_eps = eps 19 | if train_eps: 20 | self.eps = torch.nn.Parameter(torch.Tensor([eps])) 21 | else: 22 | self.register_buffer('eps', torch.Tensor([eps])) 23 | self.reset_parameters() 24 | 25 | def reset_parameters(self): 26 | self.eps.data.fill_(self.initial_eps) 27 | 28 | def forward(self, x: Union[Tensor, OptPairTensor], pos: Tensor, edge_index: Adj) -> Tensor: 29 | """""" 30 | if isinstance(x, Tensor): 31 | x: OptPairTensor = (x, x) 32 | 33 | # propagate_type: (x: OptPairTensor) 34 | out = self.propagate(edge_index, x=x) 35 | 36 | x_r = x[1] 37 | if x_r is not None: 38 | out += (1 + self.eps) * x_r 39 | 40 | return self.nn(out) 41 | 42 | def message(self, x_j: Tensor) -> Tensor: 43 | return x_j 44 | 45 | def __repr__(self): 46 | return '{}(nn={})'.format(self.__class__.__name__, self.nn) 47 | 48 | class GINNet(torch.nn.Module): 49 | def __init__(self, input_channels_node, hidden_channels, output_channels, readout='add', eps=0., num_layers=3): 50 | super(GINNet, self).__init__() 51 | self.readout = readout 52 | self.node_lin = Sequential( 53 | Linear(input_channels_node+3, hidden_channels), 54 | ReLU(), 55 | Linear(hidden_channels, hidden_channels) 56 | ) 57 | self.num_layers = num_layers 58 | 59 | self.mlp = ModuleList() 60 | for _ in range(self.num_layers): 61 | block = Sequential( 62 | Linear(hidden_channels, hidden_channels), 63 | ReLU(), 64 | Linear(hidden_channels, hidden_channels) 65 | ) 66 | self.mlp.append(block) 67 | 68 | self.convs = ModuleList() 69 | for i in range(self.num_layers): 70 | conv = GINConv(nn=self.mlp[i], eps=eps) 71 | self.convs.append(conv) 72 | 73 | self.lin1 = Linear(hidden_channels, hidden_channels//2) 74 | self.lin2 = Linear(hidden_channels//2, output_channels) 75 | self.reset_parameters() 76 | 77 | def reset_parameters(self): 78 | for nn in self.mlp: 79 | torch.nn.init.xavier_uniform_(nn[0].weight) 80 | nn[0].bias.data.fill_(0) 81 | torch.nn.init.xavier_uniform_(nn[2].weight) 82 | nn[2].bias.data.fill_(0) 83 | 84 | torch.nn.init.xavier_uniform_(self.node_lin[0].weight) 85 | self.node_lin[0].bias.data.fill_(0) 86 | torch.nn.init.xavier_uniform_(self.node_lin[2].weight) 87 | self.node_lin[2].bias.data.fill_(0) 88 | 89 | def forward(self, x, pos, edge_index, batch): 90 | x = torch.cat([x, pos], dim=1) 91 | x = self.node_lin(x) 92 | for i in range(self.num_layers): 93 | x = self.convs[i](x, pos, edge_index) 94 | x = x.relu() 95 | 96 | x = self.lin1(x) 97 | x = x.relu() 98 | x = self.lin2(x) 99 | x = scatter(x, batch, dim=0, reduce=self.readout) 100 | 101 | return x -------------------------------------------------------------------------------- /models/PointNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | from torch.nn import Parameter as Param 4 | from torch.nn import Embedding, Sequential, Linear, ModuleList, ReLU, Parameter 5 | from torch_geometric.typing import Adj, OptTensor 6 | from torch_sparse import SparseTensor, matmul 7 | from torch_geometric.nn.conv import MessagePassing 8 | 9 | from torch_scatter import scatter 10 | 11 | class PointConv(MessagePassing): 12 | def __init__(self, local_nn, global_nn, aggr='add'): 13 | 14 | super(PointConv, self).__init__(aggr=aggr) 15 | self.local_nn = local_nn 16 | self.global_nn = global_nn 17 | 18 | def forward(self, x, pos, edge_index): 19 | """""" 20 | out = self.propagate(edge_index, x=x, pos=pos) 21 | return self.global_nn(out) 22 | 23 | def message(self, x_i, x_j, pos_i, pos_j): 24 | temp = torch.cat([x_i, x_j, pos_j], dim=-1) 25 | return self.local_nn(temp) 26 | 27 | def __repr__(self): 28 | return '{}({}, {}, dim={})'.format(self.__class__.__name__, 29 | self.input_channels, self.output_channels, 30 | self.dim) 31 | 32 | class PointNet(torch.nn.Module): 33 | def __init__(self, input_channels_node, hidden_channels, output_channels, readout='add', num_layers=3): 34 | super(PointNet, self).__init__() 35 | self.readout = readout 36 | self.node_lin = Sequential( 37 | Linear(input_channels_node, hidden_channels), 38 | ReLU(), 39 | Linear(hidden_channels, hidden_channels) 40 | ) 41 | self.num_layers = num_layers 42 | self.local_nn = ModuleList() 43 | self.global_nn = ModuleList() 44 | for _ in range(self.num_layers): 45 | block = Sequential( 46 | Linear(hidden_channels*2+3, hidden_channels), 47 | ReLU(), 48 | Linear(hidden_channels, hidden_channels) 49 | ) 50 | self.local_nn.append(block) 51 | for _ in range(self.num_layers): 52 | block = Sequential( 53 | Linear(hidden_channels, hidden_channels), 54 | ReLU(), 55 | Linear(hidden_channels, hidden_channels) 56 | ) 57 | self.global_nn.append(block) 58 | 59 | self.convs = ModuleList() 60 | for i in range(self.num_layers): 61 | conv = PointConv(self.local_nn[i], self.global_nn[i]) 62 | self.convs.append(conv) 63 | 64 | self.embedding = Embedding(100, hidden_channels) 65 | self.lin1 = Linear(hidden_channels, hidden_channels//2) 66 | self.lin2 = Linear(hidden_channels//2, output_channels) 67 | self.reset_parameters() 68 | 69 | def reset_parameters(self): 70 | for nn in self.local_nn: 71 | torch.nn.init.xavier_uniform_(nn[0].weight) 72 | nn[0].bias.data.fill_(0) 73 | torch.nn.init.xavier_uniform_(nn[2].weight) 74 | nn[2].bias.data.fill_(0) 75 | for nn in self.global_nn: 76 | torch.nn.init.xavier_uniform_(nn[0].weight) 77 | nn[0].bias.data.fill_(0) 78 | torch.nn.init.xavier_uniform_(nn[2].weight) 79 | nn[2].bias.data.fill_(0) 80 | 81 | torch.nn.init.xavier_uniform_(self.node_lin[0].weight) 82 | self.node_lin[0].bias.data.fill_(0) 83 | torch.nn.init.xavier_uniform_(self.node_lin[2].weight) 84 | self.node_lin[2].bias.data.fill_(0) 85 | 86 | def forward(self, x, pos, edge_index, batch): 87 | if x.dim() == 1: 88 | x = self.embedding(x) 89 | else: 90 | x = self.node_lin(x) 91 | for i in range(self.num_layers): 92 | x = self.convs[i](x, pos, edge_index) 93 | x = x.relu() 94 | 95 | x = self.lin1(x) 96 | x = x.relu() 97 | x = self.lin2(x) 98 | x = scatter(x, batch, dim=0, reduce=self.readout) 99 | 100 | return x -------------------------------------------------------------------------------- /models/GAT.py: -------------------------------------------------------------------------------- 1 | ######################################################## 2 | ### modified from the version by pytorch-geometric ### 3 | ######################################################## 4 | import torch 5 | from torch import Tensor 6 | import torch.nn.functional as F 7 | from torch.nn import Embedding, Sequential, Linear, ModuleList, ReLU, Parameter 8 | from typing import Callable, Union 9 | from torch_geometric.typing import OptPairTensor, Adj, OptTensor, Size 10 | from torch_geometric.nn import MessagePassing 11 | from torch_scatter import scatter 12 | from torch_geometric.utils import softmax 13 | 14 | class GATConv(MessagePassing): 15 | def __init__(self, input_channels_node: int, output_channels: int, heads: int = 1, negative_slope: float = 0.1, **kwargs): 16 | kwargs.setdefault('aggr', 'add') 17 | super(GATConv, self).__init__(node_dim=0, **kwargs) 18 | 19 | self.input_channels_node = input_channels_node 20 | self.output_channels = output_channels 21 | self.heads = heads 22 | self.negative_slope = negative_slope 23 | 24 | self.lin_l = Linear(input_channels_node, heads * output_channels, bias=False) 25 | self.lin_r = self.lin_l 26 | 27 | self.att_l = Parameter(torch.Tensor(1, heads, output_channels)) 28 | self.att_r = Parameter(torch.Tensor(1, heads, output_channels)) 29 | self.reset_parameters() 30 | 31 | def reset_parameters(self): 32 | torch.nn.init.xavier_uniform_(self.lin_l.weight) 33 | torch.nn.init.xavier_uniform_(self.att_l) 34 | torch.nn.init.xavier_uniform_(self.att_r) 35 | 36 | def forward(self, x, pos, edge_index): 37 | H, C = self.heads, self.output_channels 38 | 39 | assert x.dim() == 2, 'Static graphs not supported in `GATConv`.' 40 | x_l = x_r = self.lin_l(x).view(-1, H, C) 41 | alpha_l = (x_l * self.att_l).sum(dim=-1) 42 | alpha_r = (x_r * self.att_r).sum(dim=-1) 43 | 44 | out = self.propagate(edge_index, x=(x_l, x_r), alpha=(alpha_l, alpha_r)) 45 | out = out.mean(dim=1) 46 | return out 47 | 48 | def message(self, x_j: Tensor, alpha_j: Tensor, alpha_i: OptTensor, index: Tensor) -> Tensor: 49 | alpha = alpha_j if alpha_i is None else alpha_j + alpha_i 50 | alpha = F.leaky_relu(alpha, self.negative_slope) 51 | alpha = softmax(alpha, index) 52 | return x_j * alpha.unsqueeze(-1) 53 | 54 | def __repr__(self): 55 | return '{}({}, {}, heads={})'.format(self.__class__.__name__, 56 | self.input_channels_node, self.output_channels, self.heads) 57 | 58 | class GATNet(torch.nn.Module): 59 | def __init__(self, input_channels_node, hidden_channels, output_channels, readout='add', num_layers=3): 60 | super(GATNet, self).__init__() 61 | self.readout = readout 62 | self.node_lin = Sequential( 63 | Linear(input_channels_node+3, hidden_channels), 64 | ReLU(), 65 | Linear(hidden_channels, hidden_channels) 66 | ) 67 | self.num_layers = num_layers 68 | self.convs = ModuleList() 69 | for i in range(self.num_layers): 70 | conv = GATConv(hidden_channels, hidden_channels, hidden_channels) 71 | self.convs.append(conv) 72 | 73 | self.lin1 = Linear(hidden_channels, hidden_channels//2) 74 | self.lin2 = Linear(hidden_channels//2, output_channels) 75 | self.reset_parameters() 76 | 77 | def reset_parameters(self): 78 | torch.nn.init.xavier_uniform_(self.node_lin[0].weight) 79 | self.node_lin[0].bias.data.fill_(0) 80 | torch.nn.init.xavier_uniform_(self.node_lin[2].weight) 81 | self.node_lin[2].bias.data.fill_(0) 82 | 83 | def forward(self, x, pos, edge_index, batch): 84 | x = torch.cat([x, pos], dim=1) 85 | x = self.node_lin(x) 86 | for i in range(self.num_layers): 87 | x = self.convs[i](x, pos, edge_index) 88 | x = x.relu() 89 | 90 | x = self.lin1(x) 91 | x = x.relu() 92 | x = self.lin2(x) 93 | x = scatter(x, batch, dim=0, reduce=self.readout) 94 | 95 | return x -------------------------------------------------------------------------------- /data/build_synthetic_data.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | from tqdm import tqdm 4 | from math import pi as PI 5 | import time 6 | 7 | import ase 8 | import torch 9 | import pickle as pkl 10 | from torch import Tensor 11 | from torch_geometric.data import Data 12 | from torch_geometric.utils import remove_isolated_nodes, to_networkx 13 | 14 | import networkx as nx 15 | from networkx.algorithms.cluster import average_clustering 16 | from networkx.algorithms.distance_measures import diameter as G_diameter 17 | from networkx.algorithms.distance_measures import radius as G_radius 18 | from scipy.spatial import distance_matrix 19 | from scipy.sparse import csr_matrix 20 | 21 | from collections import Counter 22 | import matplotlib.pyplot as plt 23 | 24 | def S_eccentricity(pos): 25 | dis_matrix = distance_matrix(pos, pos) 26 | eccentricity = dis_matrix.max(axis=1) 27 | return eccentricity 28 | def S_diameter(pos): 29 | eccentricity = S_eccentricity(pos) 30 | return np.max(eccentricity) 31 | def S_radius(pos): 32 | eccentricity = S_eccentricity(pos) 33 | return np.min(eccentricity) 34 | 35 | def build_synthetic_data(n, m, dim, D, p, L): 36 | x = np.zeros(n, dtype=int) 37 | pos = np.random.random((n, dim)) * D # Random Given the position 38 | k_max=1 # initialize k_max 39 | k_list = [1, 1] 40 | edge_index = [[0],[1]] 41 | for i in range(2,n): 42 | # P_spatial 43 | distance = np.power(np.sum((pos[:i] - pos[i])**p, axis=-1), 1/p) 44 | P_spatial = np.exp(-distance/L) 45 | # P_graph 46 | P_graph = (np.array(k_list)) / (k_max) 47 | # P total 48 | P = P_spatial * P_graph 49 | 50 | # random sample 51 | flag = np.random.random(size=(i,)) < P 52 | #print(P) 53 | neighbors = np.nonzero(flag)[0] 54 | 55 | 56 | for j in neighbors: 57 | edge_index[0].append(j) 58 | edge_index[1].append(i) 59 | k_list[j] += 1 60 | k_list.append(len(neighbors)) 61 | k_max = max(k_list) 62 | 63 | undirected_edge_index = [edge_index[0] + edge_index[1], edge_index[1] + edge_index[0]] 64 | x = torch.zeros((len(pos), 1), dtype=torch.float) 65 | pos=torch.tensor(pos, dtype=torch.float) 66 | edge_index=torch.tensor(undirected_edge_index, dtype=torch.long) 67 | edge_index, _, mask = remove_isolated_nodes(edge_index, num_nodes=n) 68 | x, pos = x[mask], pos[mask] 69 | edge_attr = torch.zeros(len(edge_index[0]), dtype=torch.float) 70 | row, col = edge_index.numpy() 71 | distances = torch.norm(pos[row]-pos[col],dim=-1).numpy() 72 | edges = [(row[i], col[i], distances[i]) for i in range(len(row))] 73 | # build networkx graph 74 | temp = Data(x=x, edge_index=edge_index) 75 | G = to_networkx(temp) 76 | for row, col, weight in edges: 77 | G[row][col]['weight'] = weight 78 | 79 | y=torch.tensor([[average_clustering(G), S_diameter(pos), S_radius(pos), L]], dtype=torch.float) 80 | 81 | data = Data( 82 | x=x, 83 | pos=pos, 84 | edge_index=edge_index, 85 | edge_attr=edge_attr, 86 | y=y, 87 | ) 88 | return data, np.array(k_list)[mask.numpy()] 89 | 90 | N = 20 # number of generated graphs 91 | n=20 92 | m=3 93 | dim = 3 # spatial dimension 94 | D = 5.0 # the space limit for all nodes (rectangular space) 95 | p = 2 96 | 97 | synthetic_dataset = [] 98 | counter_all = Counter() 99 | 100 | for num in tqdm(range(N)): 101 | for n in [15, 20, 25, 30]: 102 | for m in [2]: 103 | for D in np.arange(1.0, 10.1, 1.0): 104 | for L in [D*0.75, D*1.0, D*1.25, D*1.5]: 105 | data, k_list = build_synthetic_data(n, m, dim, D, p, L) 106 | counter = Counter(k_list) 107 | counter_all += counter 108 | synthetic_dataset.append(data) 109 | 110 | print('Number of graphs', len(synthetic_dataset)) 111 | 112 | if not os.path.exists('./synthetic'): 113 | os.makedirs('./synthetic') 114 | with open('./synthetic/synthetic.pkl','wb') as file: 115 | pkl.dump(synthetic_dataset, file) -------------------------------------------------------------------------------- /models/PPFNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | from torch.nn import Parameter as Param 4 | from torch.nn import Embedding, Sequential, Linear, ModuleList, ReLU, Parameter 5 | from torch_geometric.typing import Adj, OptTensor 6 | from torch_sparse import SparseTensor, matmul 7 | from torch_geometric.nn.conv import MessagePassing 8 | from torch_geometric.utils import remove_self_loops, add_self_loops 9 | 10 | from torch_scatter import scatter 11 | 12 | def get_angle(v1: Tensor, v2: Tensor) -> Tensor: 13 | return torch.atan2( 14 | torch.cross(v1, v2, dim=1).norm(p=2, dim=1), (v1 * v2).sum(dim=1)) 15 | 16 | 17 | def point_pair_features(pos_i: Tensor, pos_j: Tensor, normal_i: Tensor, 18 | normal_j: Tensor) -> Tensor: 19 | pseudo = pos_j - pos_i 20 | return torch.stack([ 21 | pseudo.norm(p=2, dim=1), 22 | get_angle(normal_i, pseudo), 23 | get_angle(normal_j, pseudo), 24 | get_angle(normal_i, normal_j) 25 | ], dim=1) 26 | 27 | class PPFConv(MessagePassing): 28 | def __init__(self, local_nn, global_nn, **kwargs): 29 | kwargs.setdefault('aggr', 'add') 30 | super(PPFConv, self).__init__(**kwargs) 31 | 32 | self.local_nn = local_nn 33 | self.global_nn = global_nn 34 | 35 | def forward(self, x, pos, edge_index): # yapf: disable 36 | """""" 37 | epsilon = 1e-12 # for numerical stable 38 | normal = pos / (pos.norm(dim=-1).reshape(-1,1)+epsilon) 39 | out = self.propagate(edge_index, x=x, pos=pos, normal=normal) 40 | if self.global_nn is not None: 41 | out = self.global_nn(out) 42 | 43 | return out 44 | 45 | 46 | def message(self, x_j, pos_i, pos_j, 47 | normal_i, normal_j) -> Tensor: 48 | msg = point_pair_features(pos_i, pos_j, normal_i, normal_j) 49 | if x_j is not None: 50 | msg = torch.cat([x_j, msg], dim=1) 51 | if self.local_nn is not None: 52 | msg = self.local_nn(msg) 53 | return msg 54 | 55 | def __repr__(self): 56 | return '{}(local_nn={}, global_nn={})'.format(self.__class__.__name__, 57 | self.local_nn, 58 | self.global_nn) 59 | 60 | class PPFNet(torch.nn.Module): 61 | def __init__(self, input_channels_node, hidden_channels, output_channels, readout='add', num_layers=3): 62 | super(PPFNet, self).__init__() 63 | self.readout = readout 64 | self.node_lin = Sequential( 65 | Linear(input_channels_node, hidden_channels), 66 | ReLU(), 67 | Linear(hidden_channels, hidden_channels) 68 | ) 69 | self.embedding = Embedding(100, hidden_channels) 70 | self.num_layers = num_layers 71 | self.local_nn = ModuleList() 72 | self.global_nn = ModuleList() 73 | for _ in range(self.num_layers): 74 | block = Sequential( 75 | Linear(hidden_channels+4, hidden_channels), 76 | ReLU(), 77 | Linear(hidden_channels, hidden_channels) 78 | ) 79 | self.local_nn.append(block) 80 | for _ in range(self.num_layers): 81 | block = Sequential( 82 | Linear(hidden_channels, hidden_channels), 83 | ReLU(), 84 | Linear(hidden_channels, hidden_channels) 85 | ) 86 | self.global_nn.append(block) 87 | 88 | self.convs = ModuleList() 89 | for i in range(self.num_layers): 90 | conv = PPFConv(self.local_nn[i], self.global_nn[i]) 91 | self.convs.append(conv) 92 | 93 | self.lin1 = Linear(hidden_channels, hidden_channels//2) 94 | self.lin2 = Linear(hidden_channels//2, output_channels) 95 | self.reset_parameters() 96 | 97 | def reset_parameters(self): 98 | for nn in self.local_nn: 99 | torch.nn.init.xavier_uniform_(nn[0].weight) 100 | nn[0].bias.data.fill_(0) 101 | torch.nn.init.xavier_uniform_(nn[2].weight) 102 | nn[2].bias.data.fill_(0) 103 | for nn in self.global_nn: 104 | torch.nn.init.xavier_uniform_(nn[0].weight) 105 | nn[0].bias.data.fill_(0) 106 | torch.nn.init.xavier_uniform_(nn[2].weight) 107 | nn[2].bias.data.fill_(0) 108 | 109 | torch.nn.init.xavier_uniform_(self.node_lin[0].weight) 110 | self.node_lin[0].bias.data.fill_(0) 111 | torch.nn.init.xavier_uniform_(self.node_lin[2].weight) 112 | self.node_lin[2].bias.data.fill_(0) 113 | 114 | def forward(self, x, pos, edge_index, batch): 115 | if x.dim() == 1: 116 | x = self.embedding(x) 117 | else: 118 | x = self.node_lin(x) 119 | for i in range(self.num_layers): 120 | x = self.convs[i](x, pos, edge_index) 121 | x = x.relu() 122 | 123 | x = self.lin1(x) 124 | x = x.relu() 125 | x = self.lin2(x) 126 | x = scatter(x, batch, dim=0, reduce=self.readout) 127 | 128 | return x -------------------------------------------------------------------------------- /models/dimenet_utils.py: -------------------------------------------------------------------------------- 1 | ####################################################### 2 | ### This is from the version by pytorch-geometric ### 3 | ####################################################### 4 | # Shameless copy and paste from: https://github.com/klicperajo/dimenet 5 | 6 | import numpy as np 7 | from scipy.optimize import brentq 8 | from scipy import special as sp 9 | 10 | try: 11 | import sympy as sym 12 | except ImportError: 13 | sym = None 14 | 15 | 16 | def Jn(r, n): 17 | return np.sqrt(np.pi / (2 * r)) * sp.jv(n + 0.5, r) 18 | 19 | 20 | def Jn_zeros(n, k): 21 | zerosj = np.zeros((n, k), dtype='float32') 22 | zerosj[0] = np.arange(1, k + 1) * np.pi 23 | points = np.arange(1, k + n) * np.pi 24 | racines = np.zeros(k + n - 1, dtype='float32') 25 | for i in range(1, n): 26 | for j in range(k + n - 1 - i): 27 | foo = brentq(Jn, points[j], points[j + 1], (i, )) 28 | racines[j] = foo 29 | points = racines 30 | zerosj[i][:k] = racines[:k] 31 | 32 | return zerosj 33 | 34 | 35 | def spherical_bessel_formulas(n): 36 | x = sym.symbols('x') 37 | 38 | f = [sym.sin(x) / x] 39 | a = sym.sin(x) / x 40 | for i in range(1, n): 41 | b = sym.diff(a, x) / x 42 | f += [sym.simplify(b * (-x)**i)] 43 | a = sym.simplify(b) 44 | return f 45 | 46 | 47 | def bessel_basis(n, k): 48 | zeros = Jn_zeros(n, k) 49 | normalizer = [] 50 | for order in range(n): 51 | normalizer_tmp = [] 52 | for i in range(k): 53 | normalizer_tmp += [0.5 * Jn(zeros[order, i], order + 1)**2] 54 | normalizer_tmp = 1 / np.array(normalizer_tmp)**0.5 55 | normalizer += [normalizer_tmp] 56 | 57 | f = spherical_bessel_formulas(n) 58 | x = sym.symbols('x') 59 | bess_basis = [] 60 | for order in range(n): 61 | bess_basis_tmp = [] 62 | for i in range(k): 63 | bess_basis_tmp += [ 64 | sym.simplify(normalizer[order][i] * 65 | f[order].subs(x, zeros[order, i] * x)) 66 | ] 67 | bess_basis += [bess_basis_tmp] 68 | return bess_basis 69 | 70 | 71 | def sph_harm_prefactor(k, m): 72 | return ((2 * k + 1) * np.math.factorial(k - abs(m)) / 73 | (4 * np.pi * np.math.factorial(k + abs(m))))**0.5 74 | 75 | 76 | def associated_legendre_polynomials(k, zero_m_only=True): 77 | z = sym.symbols('z') 78 | P_l_m = [[0] * (j + 1) for j in range(k)] 79 | 80 | P_l_m[0][0] = 1 81 | if k > 0: 82 | P_l_m[1][0] = z 83 | 84 | for j in range(2, k): 85 | P_l_m[j][0] = sym.simplify(((2 * j - 1) * z * P_l_m[j - 1][0] - 86 | (j - 1) * P_l_m[j - 2][0]) / j) 87 | if not zero_m_only: 88 | for i in range(1, k): 89 | P_l_m[i][i] = sym.simplify((1 - 2 * i) * P_l_m[i - 1][i - 1]) 90 | if i + 1 < k: 91 | P_l_m[i + 1][i] = sym.simplify( 92 | (2 * i + 1) * z * P_l_m[i][i]) 93 | for j in range(i + 2, k): 94 | P_l_m[j][i] = sym.simplify( 95 | ((2 * j - 1) * z * P_l_m[j - 1][i] - 96 | (i + j - 1) * P_l_m[j - 2][i]) / (j - i)) 97 | 98 | return P_l_m 99 | 100 | 101 | def real_sph_harm(k, zero_m_only=True, spherical_coordinates=True): 102 | if not zero_m_only: 103 | S_m = [0] 104 | C_m = [1] 105 | for i in range(1, k): 106 | x = sym.symbols('x') 107 | y = sym.symbols('y') 108 | S_m += [x * S_m[i - 1] + y * C_m[i - 1]] 109 | C_m += [x * C_m[i - 1] - y * S_m[i - 1]] 110 | 111 | P_l_m = associated_legendre_polynomials(k, zero_m_only) 112 | if spherical_coordinates: 113 | theta = sym.symbols('theta') 114 | z = sym.symbols('z') 115 | for i in range(len(P_l_m)): 116 | for j in range(len(P_l_m[i])): 117 | if type(P_l_m[i][j]) != int: 118 | P_l_m[i][j] = P_l_m[i][j].subs(z, sym.cos(theta)) 119 | if not zero_m_only: 120 | phi = sym.symbols('phi') 121 | for i in range(len(S_m)): 122 | S_m[i] = S_m[i].subs(x, 123 | sym.sin(theta) * sym.cos(phi)).subs( 124 | y, 125 | sym.sin(theta) * sym.sin(phi)) 126 | for i in range(len(C_m)): 127 | C_m[i] = C_m[i].subs(x, 128 | sym.sin(theta) * sym.cos(phi)).subs( 129 | y, 130 | sym.sin(theta) * sym.sin(phi)) 131 | 132 | Y_func_l_m = [['0'] * (2 * j + 1) for j in range(k)] 133 | for i in range(k): 134 | Y_func_l_m[i][0] = sym.simplify(sph_harm_prefactor(i, 0) * P_l_m[i][0]) 135 | 136 | if not zero_m_only: 137 | for i in range(1, k): 138 | for j in range(1, i + 1): 139 | Y_func_l_m[i][j] = sym.simplify( 140 | 2**0.5 * sph_harm_prefactor(i, j) * C_m[j] * P_l_m[i][j]) 141 | for i in range(1, k): 142 | for j in range(1, i + 1): 143 | Y_func_l_m[i][-j] = sym.simplify( 144 | 2**0.5 * sph_harm_prefactor(i, -j) * S_m[j] * P_l_m[i][j]) 145 | 146 | return Y_func_l_m -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import networkx as nx 4 | from networkx.utils import UnionFind 5 | 6 | from typing import Optional 7 | import torch 8 | from torch import Tensor 9 | 10 | from torch_sparse import SparseTensor 11 | from scipy.sparse import csr_matrix 12 | from scipy.sparse.csgraph import minimum_spanning_tree 13 | from typing import Optional 14 | 15 | def random_spanning_tree(edge_index): 16 | r=torch.randperm(len(edge_index[0]), device=edge_index.device) 17 | row, col = edge_index[:,r].numpy() 18 | subtrees = UnionFind() 19 | spanning_edges = [] 20 | i = 0 21 | 22 | while i < len(row): 23 | if subtrees[row[i]] != subtrees[col[i]]: 24 | subtrees.union(row[i], col[i]) 25 | spanning_edges.append([row[i], col[i]]) 26 | i += 1 27 | return spanning_edges 28 | 29 | def scipy_spanning_tree(edge_index, num_nodes, num_edges): 30 | row, col = edge_index.numpy() 31 | cgraph = csr_matrix((np.random.random(num_edges) + 1, (row, col)), shape=(num_nodes, num_nodes)) 32 | Tcsr = minimum_spanning_tree(cgraph) 33 | tree_row, tree_col = Tcsr.nonzero() 34 | spanning_edges = np.concatenate([[tree_row], [tree_col]]).T 35 | return spanning_edges 36 | 37 | def build_spanning_tree_edge(edge_index, algo='union', num_nodes=None, num_edges=None): 38 | # spanning_edges 39 | if algo=='union': 40 | spanning_edges = random_spanning_tree(edge_index) 41 | elif algo=='scipy': 42 | spanning_edges = scipy_spanning_tree(edge_index, num_nodes, num_edges) 43 | 44 | spanning_edges = torch.tensor(spanning_edges, dtype=torch.long, device=edge_index.device).T 45 | spanning_edges_undirected = torch.stack( 46 | [ 47 | torch.cat([spanning_edges[0], spanning_edges[1]]), 48 | torch.cat([spanning_edges[1], spanning_edges[0]]), 49 | ] 50 | ) 51 | return spanning_edges_undirected 52 | 53 | 54 | def add_self_loops(edge_index, edge_attr: Optional[torch.Tensor] = None, 55 | fill_value: float = 1., num_nodes: Optional[int] = None): 56 | N = num_nodes if num_nodes is not None else len(torch.unique(edge_index)) 57 | 58 | loop_index = torch.arange(0, N, dtype=torch.long, device=edge_index.device) 59 | loop_index = loop_index.unsqueeze(0).repeat(2, 1) 60 | 61 | if edge_attr is not None: 62 | assert edge_attr.size(0) == edge_index.size(1) 63 | loop_attr = edge_attr.new_full((N, edge_attr.size(1)), fill_value) 64 | edge_attr = torch.cat([edge_attr, loop_attr], dim=0) 65 | 66 | edge_index = torch.cat([edge_index, loop_index], dim=1) 67 | 68 | return edge_index, edge_attr 69 | 70 | def triplets(edge_index, num_nodes): 71 | row, col = edge_index # i->j 72 | 73 | value = torch.arange(row.size(0), device=row.device) 74 | adj_t = SparseTensor(row=row, col=col, value=value, 75 | sparse_sizes=(num_nodes, num_nodes)) 76 | adj_t_row = adj_t[col] 77 | num_triplets = adj_t_row.set_value(None).sum(dim=1).to(torch.long) 78 | 79 | # Node indices (k->j->i) for triplets. 80 | idx_i = row.repeat_interleave(num_triplets) 81 | idx_j = col.repeat_interleave(num_triplets) 82 | edx_1st = value.repeat_interleave(num_triplets) 83 | idx_k = adj_t_row.storage.col() 84 | edx_2nd = adj_t_row.storage.value() 85 | mask1 = (idx_i == idx_k) & (idx_j != idx_i) # Remove go back triplets. 86 | mask2 = (idx_i == idx_j) & (idx_j != idx_k) # Remove repeat self loop triplets. 87 | mask3 = (idx_j == idx_k) & (idx_i != idx_k) # Remove self-loop neighbors 88 | mask = ~(mask1 | mask2 | mask3) # 0 -> 0 -> 0 or # 0 -> 1 -> 2 89 | idx_i, idx_j, idx_k, edx_1st, edx_2nd = idx_i[mask], idx_j[mask], idx_k[mask], edx_1st[mask], edx_2nd[mask] 90 | 91 | # count real number of triplets 92 | num_triplets_real = torch.cumsum(num_triplets, dim=0) - torch.cumsum(~mask, dim=0)[torch.cumsum(num_triplets, dim=0)-1] 93 | 94 | return torch.stack([idx_i, idx_j, idx_k]), num_triplets_real.to(torch.long), edx_1st, edx_2nd 95 | 96 | def fourthplets(edge_index, triplets, num_nodes, edge_attr_index_1st, edge_attr_index_2nd): 97 | row, col = edge_index # i->j 98 | i, j, k = triplets # i->j->k 99 | 100 | value = torch.arange(row.size(0), device=row.device) 101 | adj_t = SparseTensor(row=row, col=col, value=value, 102 | sparse_sizes=(num_nodes, num_nodes)) 103 | adj_t_row = adj_t[k] 104 | num_fourthlets = adj_t_row.set_value(None).sum(dim=1).to(torch.long) 105 | 106 | # Node indices (i->j->k->p) for fourthlets. 107 | idx_i = i.repeat_interleave(num_fourthlets) 108 | idx_j = j.repeat_interleave(num_fourthlets) 109 | idx_k = k.repeat_interleave(num_fourthlets) 110 | edge_attr_index_1st = edge_attr_index_1st.repeat_interleave(num_fourthlets) 111 | edge_attr_index_2nd = edge_attr_index_2nd.repeat_interleave(num_fourthlets) 112 | idx_p = adj_t_row.storage.col() 113 | edge_attr_index_3rd = adj_t_row.storage.value() 114 | mask1 = (idx_i != idx_p) & (idx_j != idx_k) & (idx_j != idx_p) & (idx_k != idx_p) # 0 -> 1 -> 2 -> 3 115 | mask2 = (idx_i == idx_p) & (idx_j == idx_p) & (idx_k == idx_p) # 0 -> 0 -> 0 -> 0 116 | mask = mask1 | mask2 117 | 118 | idx_i, idx_j, idx_k, idx_p = idx_i[mask], idx_j[mask], idx_k[mask], idx_p[mask] 119 | edge_attr_index_1st, edge_attr_index_2nd, edge_attr_index_3rd = edge_attr_index_1st[mask], edge_attr_index_2nd[mask], edge_attr_index_3rd[mask] 120 | 121 | # count real number of fourthlets 122 | num_fourthlets_real = torch.cumsum(num_fourthlets, dim=0) - torch.cumsum(~mask, dim=0)[torch.cumsum(num_fourthlets, dim=0)-1] 123 | 124 | return torch.stack([idx_i, idx_j, idx_k, idx_p]), num_fourthlets_real.to(torch.long), edge_attr_index_1st, edge_attr_index_2nd, edge_attr_index_3rd 125 | 126 | def find_higher_order_neighbors(edge_index, num_of_nodes, order=1): 127 | edge_index_1st = edge_index 128 | 129 | if order==1: 130 | return edge_index_1st 131 | elif order==2: 132 | # find 2nd & 3rd neighbors 133 | edge_index_2nd, num_2nd_neighbors, edge_attr_index_1st, edge_attr_index_2nd = triplets(edge_index_1st, num_of_nodes) 134 | return edge_index_1st, edge_index_2nd, num_2nd_neighbors, edge_attr_index_1st, edge_attr_index_2nd 135 | elif order==3: 136 | # find 2nd & 3rd neighbors 137 | edge_index_2nd, num_2nd_neighbors, edge_attr_index_1st, edge_attr_index_2nd = triplets(edge_index_1st, num_of_nodes) 138 | edge_index_3rd, num_3rd_neighbors, edge_attr_index_1st, edge_attr_index_2nd, edge_attr_index_3rd = fourthplets(edge_index_1st, edge_index_2nd, num_of_nodes, edge_attr_index_1st, edge_attr_index_2nd) 139 | return edge_index_1st, edge_index_2nd, edge_index_3rd, num_2nd_neighbors, num_3rd_neighbors, edge_attr_index_1st, edge_attr_index_2nd, edge_attr_index_3rd 140 | else: 141 | raise NotImplementedError('We currently only support up to 3rd neighbors') -------------------------------------------------------------------------------- /models/Schnet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import warnings 3 | import os.path as osp 4 | from math import pi as PI 5 | 6 | import ase 7 | import torch 8 | import torch.nn.functional as F 9 | 10 | from torch.nn import Embedding, Sequential, Linear, ModuleList, ReLU, Parameter 11 | from torch.nn import Embedding, Sequential, Linear, ModuleList 12 | import numpy as np 13 | 14 | from torch_scatter import scatter 15 | from torch_geometric.data.makedirs import makedirs 16 | from torch_geometric.data import download_url, extract_zip 17 | from torch_geometric.nn import radius_graph, MessagePassing 18 | 19 | 20 | class Schnet(torch.nn.Module): 21 | 22 | def __init__(self, input_channels_node=1, hidden_channels=128, 23 | output_channels=1, num_filters=128, 24 | num_interactions=3, num_gaussians=50, cutoff=10.0, 25 | readout='add'): 26 | super(Schnet, self).__init__() 27 | 28 | assert readout in ['add', 'sum', 'mean'] 29 | 30 | self.hidden_channels = hidden_channels 31 | self.num_filters = num_filters 32 | self.num_interactions = num_interactions 33 | self.num_gaussians = num_gaussians 34 | self.cutoff = cutoff 35 | self.readout = readout 36 | self.readout = self.readout 37 | 38 | self.embedding = Embedding(100, hidden_channels) 39 | self.distance_expansion = GaussianSmearing(0.0, cutoff, num_gaussians) 40 | self.node_lin = Sequential( 41 | Linear(input_channels_node, hidden_channels), 42 | ReLU(), 43 | Linear(hidden_channels, hidden_channels) 44 | ) 45 | 46 | self.interactions = ModuleList() 47 | for _ in range(num_interactions): 48 | block = InteractionBlock(hidden_channels, num_gaussians, 49 | num_filters, cutoff) 50 | self.interactions.append(block) 51 | 52 | self.lin1 = Linear(hidden_channels, hidden_channels // 2) 53 | self.act = ShiftedSoftplus() 54 | self.lin2 = Linear(hidden_channels // 2, output_channels) 55 | 56 | self.reset_parameters() 57 | 58 | def reset_parameters(self): 59 | self.embedding.reset_parameters() 60 | for interaction in self.interactions: 61 | interaction.reset_parameters() 62 | torch.nn.init.xavier_uniform_(self.lin1.weight) 63 | self.lin1.bias.data.fill_(0) 64 | torch.nn.init.xavier_uniform_(self.lin2.weight) 65 | self.lin2.bias.data.fill_(0) 66 | torch.nn.init.xavier_uniform_(self.node_lin[0].weight) 67 | self.node_lin[0].bias.data.fill_(0) 68 | torch.nn.init.xavier_uniform_(self.node_lin[2].weight) 69 | self.node_lin[2].bias.data.fill_(0) 70 | 71 | def forward(self, x, pos, edge_index, batch=None): 72 | batch = torch.zeros_like(x) if batch is None else batch 73 | if x.dim() == 1: 74 | h = self.embedding(x) 75 | else: 76 | h = self.node_lin(x) 77 | # modified to use graph connections 78 | # edge_index = radius_graph(pos, r=self.cutoff, batch=batch) 79 | row, col = edge_index 80 | edge_weight = (pos[row] - pos[col]).norm(dim=-1) 81 | edge_attr = self.distance_expansion(edge_weight) 82 | 83 | for interaction in self.interactions: 84 | h = h + interaction(h, edge_index, edge_weight, edge_attr) 85 | 86 | h = self.lin1(h) 87 | h = self.act(h) 88 | h = self.lin2(h) 89 | 90 | out = scatter(h, batch, dim=0, reduce=self.readout) 91 | 92 | return out 93 | 94 | 95 | def __repr__(self): 96 | return (f'{self.__class__.__name__}(' 97 | f'hidden_channels={self.hidden_channels}, ' 98 | f'num_filters={self.num_filters}, ' 99 | f'num_interactions={self.num_interactions}, ' 100 | f'num_gaussians={self.num_gaussians}, ' 101 | f'cutoff={self.cutoff})') 102 | 103 | class InteractionBlock(torch.nn.Module): 104 | def __init__(self, hidden_channels, num_gaussians, num_filters, cutoff): 105 | super(InteractionBlock, self).__init__() 106 | self.mlp = Sequential( 107 | Linear(num_gaussians, num_filters), 108 | ShiftedSoftplus(), 109 | Linear(num_filters, num_filters), 110 | ) 111 | self.conv = CFConv(hidden_channels, hidden_channels, num_filters, 112 | self.mlp, cutoff) 113 | self.act = ShiftedSoftplus() 114 | self.lin = Linear(hidden_channels, hidden_channels) 115 | 116 | self.reset_parameters() 117 | 118 | def reset_parameters(self): 119 | torch.nn.init.xavier_uniform_(self.mlp[0].weight) 120 | self.mlp[0].bias.data.fill_(0) 121 | torch.nn.init.xavier_uniform_(self.mlp[2].weight) 122 | self.mlp[0].bias.data.fill_(0) 123 | self.conv.reset_parameters() 124 | torch.nn.init.xavier_uniform_(self.lin.weight) 125 | self.lin.bias.data.fill_(0) 126 | 127 | def forward(self, x, edge_index, edge_weight, edge_attr): 128 | x = self.conv(x, edge_index, edge_weight, edge_attr) 129 | x = self.act(x) 130 | x = self.lin(x) 131 | return x 132 | 133 | 134 | class CFConv(MessagePassing): 135 | def __init__(self, in_channels, out_channels, num_filters, nn, cutoff): 136 | super(CFConv, self).__init__(aggr='add') 137 | self.lin1 = Linear(in_channels, num_filters, bias=False) 138 | self.lin2 = Linear(num_filters, out_channels) 139 | self.nn = nn 140 | self.cutoff = cutoff 141 | 142 | self.reset_parameters() 143 | 144 | def reset_parameters(self): 145 | torch.nn.init.xavier_uniform_(self.lin1.weight) 146 | torch.nn.init.xavier_uniform_(self.lin2.weight) 147 | self.lin2.bias.data.fill_(0) 148 | 149 | def forward(self, x, edge_index, edge_weight, edge_attr): 150 | C = 0.5 * (torch.cos(edge_weight * PI / self.cutoff) + 1.0) 151 | W = self.nn(edge_attr) * C.view(-1, 1) 152 | 153 | x = self.lin1(x) 154 | x = self.propagate(edge_index, x=x, W=W) 155 | x = self.lin2(x) 156 | return x 157 | 158 | def message(self, x_j, W): 159 | return x_j * W 160 | 161 | 162 | class GaussianSmearing(torch.nn.Module): 163 | def __init__(self, start=0.0, stop=5.0, num_gaussians=50): 164 | super(GaussianSmearing, self).__init__() 165 | offset = torch.linspace(start, stop, num_gaussians) 166 | self.coeff = -0.5 / (offset[1] - offset[0]).item()**2 167 | self.register_buffer('offset', offset) 168 | 169 | def forward(self, dist): 170 | dist = dist.view(-1, 1) - self.offset.view(1, -1) 171 | return torch.exp(self.coeff * torch.pow(dist, 2)) 172 | 173 | 174 | class ShiftedSoftplus(torch.nn.Module): 175 | def __init__(self): 176 | super(ShiftedSoftplus, self).__init__() 177 | self.shift = torch.log(torch.tensor(2.0)).item() 178 | 179 | def forward(self, x): 180 | return F.softplus(x) - self.shift -------------------------------------------------------------------------------- /models/SGMP.py: -------------------------------------------------------------------------------- 1 | from math import pi as PI 2 | import torch 3 | import torch.nn.functional as F 4 | from torch.nn import Embedding, Sequential, Linear, ModuleList 5 | from torch import Tensor 6 | from torch_scatter import scatter 7 | 8 | def get_angle(v1: Tensor, v2: Tensor) -> Tensor: 9 | return torch.atan2( 10 | torch.cross(v1, v2, dim=1).norm(p=2, dim=1), (v1 * v2).sum(dim=1)) 11 | 12 | class GaussianSmearing(torch.nn.Module): 13 | def __init__(self, start=0.0, stop=5.0, num_gaussians=50): 14 | super(GaussianSmearing, self).__init__() 15 | offset = torch.linspace(start, stop, num_gaussians) 16 | self.coeff = -0.5 / (offset[1] - offset[0]).item()**2 17 | self.register_buffer('offset', offset) 18 | 19 | def forward(self, dist): 20 | dist = dist.view(-1, 1) - self.offset.view(1, -1) 21 | return torch.exp(self.coeff * torch.pow(dist, 2)) 22 | 23 | 24 | class SGMP(torch.nn.Module): 25 | 26 | def __init__(self, input_channels_node=None, hidden_channels=128, 27 | output_channels=1, num_interactions=3, 28 | num_gaussians=(50,6,12), cutoff=10.0, 29 | readout='add'): 30 | super(SGMP, self).__init__() 31 | 32 | assert readout in ['add', 'sum', 'mean'] 33 | 34 | self.input_channels_node = input_channels_node 35 | self.hidden_channels = hidden_channels 36 | self.num_interactions = num_interactions 37 | self.num_gaussians = num_gaussians 38 | self.readout = readout 39 | # the gaussian expansion here is used to help quicker converge 40 | self.distance_expansion = GaussianSmearing(0.0, cutoff, num_gaussians[0]) 41 | self.theta_expansion = GaussianSmearing(0.0, PI, num_gaussians[1]) 42 | self.phi_expansion = GaussianSmearing(0.0, 2*PI, num_gaussians[2]) 43 | self.embedding = Sequential( 44 | Linear(input_channels_node, hidden_channels), 45 | torch.nn.ReLU(), 46 | Linear(hidden_channels, hidden_channels) 47 | ) 48 | 49 | self.interactions = ModuleList() 50 | for _ in range(num_interactions): 51 | block = SPNN(hidden_channels, num_gaussians, self.distance_expansion, self.theta_expansion, self.phi_expansion, input_channels=hidden_channels) 52 | self.interactions.append(block) 53 | 54 | self.lin1 = Linear(hidden_channels, hidden_channels // 2) 55 | self.act = torch.nn.ReLU() 56 | self.lin2 = Linear(hidden_channels // 2, output_channels) 57 | self.reset_parameters() 58 | 59 | def reset_parameters(self): 60 | for block in self.interactions: 61 | block.reset_parameters() 62 | 63 | torch.nn.init.xavier_uniform_(self.lin1.weight) 64 | self.lin1.bias.data.fill_(0) 65 | torch.nn.init.xavier_uniform_(self.lin2.weight) 66 | self.lin2.bias.data.fill_(0) 67 | 68 | def forward(self, x, pos, batch, edge_index_3rd=None): 69 | x = self.embedding(x) 70 | 71 | distances = {} 72 | thetas = {} 73 | phis = {} 74 | i, j, k, p = edge_index_3rd 75 | i_to_j_dis = (pos[j] - pos[i]).norm(p=2, dim=1) 76 | k_to_j_dis = (pos[k] - pos[j]).norm(p=2, dim=1) 77 | p_to_j_dis = (pos[p] - pos[j]).norm(p=2, dim=1) 78 | distances[1] = i_to_j_dis 79 | distances[2] = k_to_j_dis 80 | distances[3] = p_to_j_dis 81 | theta_ijk = get_angle(pos[j] - pos[i], pos[k] - pos[j]) 82 | theta_ijp = get_angle(pos[j] - pos[i], pos[p] - pos[j]) 83 | thetas[1] = theta_ijk 84 | thetas[2] = theta_ijp 85 | 86 | v1 = torch.cross(pos[j] - pos[i], pos[k] - pos[j], dim=1) 87 | v2 = torch.cross(pos[j] - pos[i], pos[p] - pos[j], dim=1) 88 | phi_ijkp = get_angle(v1, v2) 89 | vn = torch.cross(v1, v2, dim=1) 90 | flag = torch.sign((vn * (pos[j]- pos[i])).sum(dim=1)) 91 | phis[1] = phi_ijkp * flag 92 | 93 | for interaction in self.interactions: 94 | x = x + interaction( 95 | x, 96 | distances, 97 | thetas, 98 | phis, 99 | edge_index_3rd, 100 | ) 101 | 102 | 103 | if batch is None: 104 | batch = torch.zeros(len(x), dtype=torch.long, device=x.device) 105 | 106 | x = scatter(x, batch, dim=0, reduce=self.readout) 107 | x = self.lin1(x) 108 | x = self.act(x) 109 | x = self.lin2(x) 110 | 111 | return x 112 | 113 | def __repr__(self): 114 | return (f'{self.__class__.__name__}(' 115 | f'hidden_channels={self.hidden_channels}, ' 116 | f'num_layers={self.num_interactions})') 117 | 118 | 119 | class SPNN(torch.nn.Module): 120 | def __init__(self, hidden_channels, num_gaussians, distance_expansion, theta_expansion, phi_expansion, input_channels=None, readout='add'): 121 | super(SPNN, self).__init__() 122 | 123 | self.readout = readout 124 | self.distance_expansion = distance_expansion 125 | self.theta_expansion = theta_expansion 126 | self.phi_expansion = phi_expansion 127 | 128 | self.mlps_dist = ModuleList() 129 | mlp_1st = Sequential( 130 | Linear(num_gaussians[0], hidden_channels), 131 | torch.nn.ReLU(), 132 | Linear(hidden_channels, hidden_channels), 133 | ) 134 | mlp_2nd = Sequential( 135 | Linear(num_gaussians[0]+num_gaussians[1], hidden_channels), 136 | torch.nn.ReLU(), 137 | Linear(hidden_channels, hidden_channels), 138 | ) 139 | mlp_3rd = Sequential( 140 | Linear(num_gaussians[0]+num_gaussians[1]+num_gaussians[2], hidden_channels), 141 | torch.nn.ReLU(), 142 | Linear(hidden_channels, hidden_channels), 143 | ) 144 | self.mlps_dist.append(mlp_1st) 145 | self.mlps_dist.append(mlp_2nd) 146 | self.mlps_dist.append(mlp_3rd) 147 | 148 | self.combine = Sequential( 149 | Linear(hidden_channels*7, hidden_channels*4), 150 | torch.nn.ReLU(), 151 | Linear(hidden_channels*4, hidden_channels*2), 152 | torch.nn.ReLU(), 153 | Linear(hidden_channels*2, hidden_channels), 154 | ) 155 | 156 | self.reset_parameters() 157 | 158 | def reset_parameters(self): 159 | for i in range(3): 160 | torch.nn.init.xavier_uniform_(self.mlps_dist[i][0].weight) 161 | self.mlps_dist[i][0].bias.data.fill_(0) 162 | torch.nn.init.xavier_uniform_(self.combine[0].weight) 163 | self.combine[0].bias.data.fill_(0) 164 | torch.nn.init.xavier_uniform_(self.combine[2].weight) 165 | self.combine[2].bias.data.fill_(0) 166 | torch.nn.init.xavier_uniform_(self.combine[4].weight) 167 | self.combine[4].bias.data.fill_(0) 168 | 169 | def forward(self, 170 | x, 171 | distances, 172 | thetas, 173 | phis, 174 | edge_index_3rd=None, 175 | ): 176 | i, j, k, p = edge_index_3rd 177 | 178 | node_attr_0st = x[i] 179 | node_attr_1st = x[j] 180 | node_attr_2nd = x[k] 181 | node_attr_3rd = x[p] 182 | geo_encoding_1st = self.mlps_dist[0]( 183 | self.distance_expansion(distances[1]) 184 | ) 185 | geo_encoding_2nd = self.mlps_dist[1]( 186 | torch.cat([self.distance_expansion(distances[2]), self.theta_expansion(thetas[1])], dim=-1) 187 | ) 188 | geo_encoding_3rd = self.mlps_dist[2]( 189 | torch.cat([self.distance_expansion(distances[3]), self.theta_expansion(thetas[2]), self.phi_expansion(phis[1])], dim=-1) 190 | ) 191 | 192 | concatenated_vector = torch.cat( 193 | [ 194 | node_attr_0st, 195 | node_attr_1st, 196 | node_attr_2nd, 197 | node_attr_3rd, 198 | geo_encoding_1st, 199 | geo_encoding_2nd, 200 | geo_encoding_3rd, 201 | ], 202 | dim=-1 203 | ) 204 | x = self.combine(concatenated_vector) 205 | 206 | # aggregate 207 | x = scatter(x, i, dim=0, reduce=self.readout) 208 | 209 | return x -------------------------------------------------------------------------------- /utils/moleculenet.py: -------------------------------------------------------------------------------- 1 | ######################################################## 2 | ### modified from the version by pytorch-geometric ### 3 | ######################################################## 4 | import os 5 | import os.path as osp 6 | import re 7 | from tqdm import tqdm 8 | import torch 9 | 10 | from torch_geometric.data import (InMemoryDataset, Data, download_url, 11 | extract_gz) 12 | try: 13 | from rdkit import Chem 14 | from rdkit.Chem import AllChem 15 | except: 16 | Chem = None 17 | AllChem = None 18 | 19 | 20 | x_map = { 21 | 'atomic_num': 22 | list(range(0, 119)), 23 | 'chirality': [ 24 | 'CHI_UNSPECIFIED', 25 | 'CHI_TETRAHEDRAL_CW', 26 | 'CHI_TETRAHEDRAL_CCW', 27 | 'CHI_OTHER', 28 | ], 29 | 'degree': 30 | list(range(0, 11)), 31 | 'formal_charge': 32 | list(range(-5, 7)), 33 | 'num_hs': 34 | list(range(0, 9)), 35 | 'num_radical_electrons': 36 | list(range(0, 5)), 37 | 'hybridization': [ 38 | 'UNSPECIFIED', 39 | 'S', 40 | 'SP', 41 | 'SP2', 42 | 'SP3', 43 | 'SP3D', 44 | 'SP3D2', 45 | 'OTHER', 46 | ], 47 | 'is_aromatic': [False, True], 48 | 'is_in_ring': [False, True], 49 | } 50 | 51 | e_map = { 52 | 'bond_type': [ 53 | 'misc', 54 | 'SINGLE', 55 | 'DOUBLE', 56 | 'TRIPLE', 57 | 'AROMATIC', 58 | ], 59 | 'stereo': [ 60 | 'STEREONONE', 61 | 'STEREOZ', 62 | 'STEREOE', 63 | 'STEREOCIS', 64 | 'STEREOTRANS', 65 | 'STEREOANY', 66 | ], 67 | 'is_conjugated': [False, True], 68 | } 69 | 70 | 71 | class MoleculeNet(InMemoryDataset): 72 | r"""The `MoleculeNet `_ benchmark 73 | collection from the `"MoleculeNet: A Benchmark for Molecular Machine 74 | Learning" `_ paper, containing datasets 75 | from physical chemistry, biophysics and physiology. 76 | All datasets come with the additional node and edge features introduced by 77 | the `Open Graph Benchmark `_. 78 | 79 | Args: 80 | root (string): Root directory where the dataset should be saved. 81 | name (string): The name of the dataset (:obj:`"ESOL"`, 82 | :obj:`"FreeSolv"`, :obj:`"Lipo"`, :obj:`"PCBA"`, :obj:`"MUV"`, 83 | :obj:`"HIV"`, :obj:`"BACE"`, :obj:`"BBPB"`, :obj:`"Tox21"`, 84 | :obj:`"ToxCast"`, :obj:`"SIDER"`, :obj:`"ClinTox"`). 85 | transform (callable, optional): A function/transform that takes in an 86 | :obj:`torch_geometric.data.Data` object and returns a transformed 87 | version. The data object will be transformed before every access. 88 | (default: :obj:`None`) 89 | pre_transform (callable, optional): A function/transform that takes in 90 | an :obj:`torch_geometric.data.Data` object and returns a 91 | transformed version. The data object will be transformed before 92 | being saved to disk. (default: :obj:`None`) 93 | pre_filter (callable, optional): A function that takes in an 94 | :obj:`torch_geometric.data.Data` object and returns a boolean 95 | value, indicating whether the data object should be included in the 96 | final dataset. (default: :obj:`None`) 97 | """ 98 | 99 | url = 'https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/{}' 100 | 101 | # Format: name: [display_name, url_name, csv_name, smiles_idx, y_idx] 102 | names = { 103 | 'esol': ['ESOL', 'delaney-processed.csv', 'delaney-processed', -1, -2], 104 | 'freesolv': ['FreeSolv', 'SAMPL.csv', 'SAMPL', 1, 2], 105 | 'lipo': ['Lipophilicity', 'Lipophilicity.csv', 'Lipophilicity', 2, 1], 106 | 'pcba': ['PCBA', 'pcba.csv.gz', 'pcba', -1, 107 | slice(0, 128)], 108 | 'muv': ['MUV', 'muv.csv.gz', 'muv', -1, 109 | slice(0, 17)], 110 | 'hiv': ['HIV', 'HIV.csv', 'HIV', 0, -1], 111 | 'bace': ['BACE', 'bace.csv', 'bace', 0, 2], 112 | 'bbbp': ['BBPB', 'BBBP.csv', 'BBBP', -1, -2], 113 | 'tox21': ['Tox21', 'tox21.csv.gz', 'tox21', -1, 114 | slice(0, 12)], 115 | 'toxcast': 116 | ['ToxCast', 'toxcast_data.csv.gz', 'toxcast_data', 0, 117 | slice(1, 618)], 118 | 'sider': ['SIDER', 'sider.csv.gz', 'sider', 0, 119 | slice(1, 28)], 120 | 'clintox': ['ClinTox', 'clintox.csv.gz', 'clintox', 0, 121 | slice(1, 3)], 122 | } 123 | 124 | def __init__(self, root, name, transform=None, pre_transform=None, 125 | pre_filter=None): 126 | 127 | self.name = name.lower() 128 | assert self.name in self.names.keys() 129 | super(MoleculeNet, self).__init__(root, transform, pre_transform, 130 | pre_filter) 131 | self.data, self.slices = torch.load(self.processed_paths[0]) 132 | 133 | @property 134 | def raw_dir(self): 135 | return osp.join(self.root, self.name, 'raw') 136 | 137 | @property 138 | def processed_dir(self): 139 | return osp.join(self.root, self.name, 'processed') 140 | 141 | @property 142 | def raw_file_names(self): 143 | return f'{self.names[self.name][2]}.csv' 144 | 145 | @property 146 | def processed_file_names(self): 147 | return 'data.pt' 148 | 149 | def download(self): 150 | url = self.url.format(self.names[self.name][1]) 151 | path = download_url(url, self.raw_dir) 152 | if self.names[self.name][1][-2:] == 'gz': 153 | extract_gz(path, self.raw_dir) 154 | os.unlink(path) 155 | 156 | def process(self): 157 | with open(self.raw_paths[0], 'r') as f: 158 | dataset = f.read().split('\n')[1:-1] 159 | dataset = [x for x in dataset if len(x) > 0] # Filter empty lines. 160 | 161 | data_list = [] 162 | for line in tqdm(dataset): 163 | line = re.sub(r'\".*\"', '', line) # Replace ".*" strings. 164 | line = line.split(',') 165 | 166 | smiles = line[self.names[self.name][3]] 167 | ys = line[self.names[self.name][4]] 168 | ys = ys if isinstance(ys, list) else [ys] 169 | 170 | ys = [float(y) if len(y) > 0 else float('NaN') for y in ys] 171 | y = torch.tensor(ys, dtype=torch.float).view(1, -1) 172 | 173 | mol = Chem.MolFromSmiles(smiles) 174 | if mol is None: 175 | continue 176 | 177 | AllChem.EmbedMolecule(mol,randomSeed=0xf00d) 178 | N = mol.GetNumAtoms() 179 | pos = Chem.MolToMolBlock(mol).split('\n')[4:4 + N] 180 | pos = [[float(x) for x in line.split()[:3]] for line in pos] 181 | pos = torch.tensor(pos, dtype=torch.float) 182 | 183 | xs = [] 184 | for atom in mol.GetAtoms(): 185 | x = [] 186 | x.append(x_map['atomic_num'].index(atom.GetAtomicNum())) 187 | x.append(x_map['chirality'].index(str(atom.GetChiralTag()))) 188 | x.append(x_map['degree'].index(atom.GetTotalDegree())) 189 | x.append(x_map['formal_charge'].index(atom.GetFormalCharge())) 190 | x.append(x_map['num_hs'].index(atom.GetTotalNumHs())) 191 | x.append(x_map['num_radical_electrons'].index( 192 | atom.GetNumRadicalElectrons())) 193 | x.append(x_map['hybridization'].index( 194 | str(atom.GetHybridization()))) 195 | x.append(x_map['is_aromatic'].index(atom.GetIsAromatic())) 196 | x.append(x_map['is_in_ring'].index(atom.IsInRing())) 197 | xs.append(x) 198 | 199 | x = torch.tensor(xs, dtype=torch.long).view(-1, 9) 200 | 201 | edge_indices, edge_attrs = [], [] 202 | for bond in mol.GetBonds(): 203 | i = bond.GetBeginAtomIdx() 204 | j = bond.GetEndAtomIdx() 205 | 206 | e = [] 207 | e.append(e_map['bond_type'].index(str(bond.GetBondType()))) 208 | e.append(e_map['stereo'].index(str(bond.GetStereo()))) 209 | e.append(e_map['is_conjugated'].index(bond.GetIsConjugated())) 210 | 211 | edge_indices += [[i, j], [j, i]] 212 | edge_attrs += [e, e] 213 | 214 | edge_index = torch.tensor(edge_indices) 215 | edge_index = edge_index.t().to(torch.long).view(2, -1) 216 | edge_attr = torch.tensor(edge_attrs, dtype=torch.long).view(-1, 3) 217 | 218 | # Sort indices. 219 | if edge_index.numel() > 0: 220 | perm = (edge_index[0] * x.size(0) + edge_index[1]).argsort() 221 | edge_index, edge_attr = edge_index[:, perm], edge_attr[perm] 222 | 223 | data = Data(x=x, pos=pos, edge_index=edge_index, edge_attr=edge_attr, y=y, 224 | smiles=smiles) 225 | 226 | if self.pre_filter is not None and not self.pre_filter(data): 227 | continue 228 | 229 | if self.pre_transform is not None: 230 | data = self.pre_transform(data) 231 | 232 | data_list.append(data) 233 | 234 | torch.save(self.collate(data_list), self.processed_paths[0]) 235 | 236 | def __repr__(self): 237 | return '{}({})'.format(self.names[self.name][0], len(self)) -------------------------------------------------------------------------------- /models/Dimenet.py: -------------------------------------------------------------------------------- 1 | ######################################################## 2 | ### modified from the version by pytorch-geometric ### 3 | ######################################################## 4 | import os 5 | import os.path as osp 6 | from math import sqrt, pi as PI 7 | import math 8 | import numpy as np 9 | import torch 10 | from torch.nn import Linear, Embedding 11 | from torch_scatter import scatter 12 | from torch_sparse import SparseTensor 13 | from torch_geometric.nn import radius_graph 14 | from torch_geometric.data import download_url 15 | from torch_geometric.data.makedirs import makedirs 16 | 17 | from torch.nn import Embedding, Sequential, Linear, ModuleList, ReLU, Parameter 18 | from .dimenet_utils import bessel_basis, real_sph_harm 19 | 20 | try: 21 | import sympy as sym 22 | except ImportError: 23 | sym = None 24 | 25 | try: 26 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 27 | import tensorflow as tf 28 | except ImportError: 29 | tf = None 30 | 31 | qm9_target_dict = { 32 | 0: 'mu', 33 | 1: 'alpha', 34 | 2: 'homo', 35 | 3: 'lumo', 36 | 5: 'r2', 37 | 6: 'zpve', 38 | 7: 'U0', 39 | 8: 'U', 40 | 9: 'H', 41 | 10: 'G', 42 | 11: 'Cv', 43 | } 44 | def glorot_orthogonal(tensor, scale): 45 | if tensor is not None: 46 | torch.nn.init.orthogonal_(tensor.data) 47 | scale /= ((tensor.size(-2) + tensor.size(-1)) * tensor.var()) 48 | tensor.data *= scale.sqrt() 49 | def glorot(tensor): 50 | if tensor is not None: 51 | stdv = math.sqrt(6.0 / (tensor.size(-2) + tensor.size(-1))) 52 | tensor.data.uniform_(-stdv, stdv) 53 | def swish(x): 54 | return x * x.sigmoid() 55 | 56 | class Envelope(torch.nn.Module): 57 | def __init__(self, exponent): 58 | super(Envelope, self).__init__() 59 | self.p = exponent + 1 60 | self.a = -(self.p + 1) * (self.p + 2) / 2 61 | self.b = self.p * (self.p + 2) 62 | self.c = -self.p * (self.p + 1) / 2 63 | 64 | def forward(self, x): 65 | p, a, b, c = self.p, self.a, self.b, self.c 66 | x_pow_p0 = x.pow(p - 1) 67 | x_pow_p1 = x_pow_p0 * x 68 | x_pow_p2 = x_pow_p1 * x 69 | return 1. / x + a * x_pow_p0 + b * x_pow_p1 + c * x_pow_p2 70 | 71 | 72 | class BesselBasisLayer(torch.nn.Module): 73 | def __init__(self, num_radial, cutoff=5.0, envelope_exponent=5): 74 | super(BesselBasisLayer, self).__init__() 75 | self.cutoff = cutoff 76 | self.envelope = Envelope(envelope_exponent) 77 | 78 | self.freq = torch.nn.Parameter(torch.Tensor(num_radial)) 79 | 80 | self.reset_parameters() 81 | 82 | def reset_parameters(self): 83 | torch.arange(1, self.freq.numel() + 1, out=self.freq).mul_(PI) 84 | 85 | def forward(self, dist): 86 | dist = dist.unsqueeze(-1) / self.cutoff 87 | return self.envelope(dist) * (self.freq * dist).sin() 88 | 89 | 90 | class SphericalBasisLayer(torch.nn.Module): 91 | def __init__(self, num_spherical, num_radial, cutoff=5.0, 92 | envelope_exponent=5): 93 | super(SphericalBasisLayer, self).__init__() 94 | assert num_radial <= 64 95 | self.num_spherical = num_spherical 96 | self.num_radial = num_radial 97 | self.cutoff = cutoff 98 | self.envelope = Envelope(envelope_exponent) 99 | 100 | bessel_forms = bessel_basis(num_spherical, num_radial) 101 | sph_harm_forms = real_sph_harm(num_spherical) 102 | self.sph_funcs = [] 103 | self.bessel_funcs = [] 104 | 105 | x, theta = sym.symbols('x theta') 106 | modules = {'sin': torch.sin, 'cos': torch.cos} 107 | for i in range(num_spherical): 108 | if i == 0: 109 | sph1 = sym.lambdify([theta], sph_harm_forms[i][0], modules)(0) 110 | self.sph_funcs.append(lambda x: torch.zeros_like(x) + sph1) 111 | else: 112 | sph = sym.lambdify([theta], sph_harm_forms[i][0], modules) 113 | self.sph_funcs.append(sph) 114 | for j in range(num_radial): 115 | bessel = sym.lambdify([x], bessel_forms[i][j], modules) 116 | self.bessel_funcs.append(bessel) 117 | 118 | def forward(self, dist, angle, idx_kj): 119 | dist = dist / self.cutoff 120 | rbf = torch.stack([f(dist) for f in self.bessel_funcs], dim=1) 121 | rbf = self.envelope(dist).unsqueeze(-1) * rbf 122 | 123 | cbf = torch.stack([f(angle) for f in self.sph_funcs], dim=1) 124 | 125 | n, k = self.num_spherical, self.num_radial 126 | out = (rbf[idx_kj].view(-1, n, k) * cbf.view(-1, n, 1)).view(-1, n * k) 127 | return out 128 | 129 | 130 | class EmbeddingBlock(torch.nn.Module): 131 | def __init__(self, num_radial, hidden_channels, input_channels_node=1, act=swish): 132 | super(EmbeddingBlock, self).__init__() 133 | self.act = act 134 | 135 | self.node_lin = Sequential( 136 | Linear(input_channels_node, hidden_channels), 137 | ReLU(), 138 | Linear(hidden_channels, hidden_channels) 139 | ) 140 | self.emb = Embedding(95, hidden_channels) 141 | self.lin_rbf = Linear(num_radial, hidden_channels) 142 | self.lin = Linear(3 * hidden_channels, hidden_channels) 143 | 144 | self.reset_parameters() 145 | 146 | def reset_parameters(self): 147 | self.emb.weight.data.uniform_(-sqrt(3), sqrt(3)) 148 | self.lin_rbf.reset_parameters() 149 | self.lin.reset_parameters() 150 | torch.nn.init.xavier_uniform_(self.node_lin[0].weight) 151 | self.node_lin[0].bias.data.fill_(0) 152 | torch.nn.init.xavier_uniform_(self.node_lin[2].weight) 153 | self.node_lin[2].bias.data.fill_(0) 154 | 155 | def forward(self, x, rbf, i, j, flag): 156 | if flag: 157 | x = self.emb(x) 158 | else: 159 | x = self.node_lin(x) 160 | rbf = self.act(self.lin_rbf(rbf)) 161 | return self.act(self.lin(torch.cat([x[i], x[j], rbf], dim=-1))) 162 | 163 | 164 | class ResidualLayer(torch.nn.Module): 165 | def __init__(self, hidden_channels, act=swish): 166 | super(ResidualLayer, self).__init__() 167 | self.act = act 168 | self.lin1 = Linear(hidden_channels, hidden_channels) 169 | self.lin2 = Linear(hidden_channels, hidden_channels) 170 | 171 | self.reset_parameters() 172 | 173 | def reset_parameters(self): 174 | glorot_orthogonal(self.lin1.weight, scale=2.0) 175 | self.lin1.bias.data.fill_(0) 176 | glorot_orthogonal(self.lin2.weight, scale=2.0) 177 | self.lin2.bias.data.fill_(0) 178 | 179 | def forward(self, x): 180 | return x + self.act(self.lin2(self.act(self.lin1(x)))) 181 | 182 | 183 | class InteractionBlock(torch.nn.Module): 184 | def __init__(self, hidden_channels, num_bilinear, num_spherical, 185 | num_radial, num_before_skip, num_after_skip, act=swish): 186 | super(InteractionBlock, self).__init__() 187 | self.act = act 188 | 189 | self.lin_rbf = Linear(num_radial, hidden_channels, bias=False) 190 | self.lin_sbf = Linear(num_spherical * num_radial, num_bilinear, 191 | bias=False) 192 | 193 | # Dense transformations of input messages. 194 | self.lin_kj = Linear(hidden_channels, hidden_channels) 195 | self.lin_ji = Linear(hidden_channels, hidden_channels) 196 | 197 | self.W = torch.nn.Parameter( 198 | torch.Tensor(hidden_channels, num_bilinear, hidden_channels)) 199 | 200 | self.layers_before_skip = torch.nn.ModuleList([ 201 | ResidualLayer(hidden_channels, act) for _ in range(num_before_skip) 202 | ]) 203 | self.lin = Linear(hidden_channels, hidden_channels) 204 | self.layers_after_skip = torch.nn.ModuleList([ 205 | ResidualLayer(hidden_channels, act) for _ in range(num_after_skip) 206 | ]) 207 | 208 | self.reset_parameters() 209 | 210 | def reset_parameters(self): 211 | glorot_orthogonal(self.lin_rbf.weight, scale=2.0) 212 | glorot_orthogonal(self.lin_sbf.weight, scale=2.0) 213 | glorot_orthogonal(self.lin_kj.weight, scale=2.0) 214 | self.lin_kj.bias.data.fill_(0) 215 | glorot_orthogonal(self.lin_ji.weight, scale=2.0) 216 | self.lin_ji.bias.data.fill_(0) 217 | self.W.data.normal_(mean=0, std=2 / self.W.size(0)) 218 | for res_layer in self.layers_before_skip: 219 | res_layer.reset_parameters() 220 | glorot_orthogonal(self.lin.weight, scale=2.0) 221 | self.lin.bias.data.fill_(0) 222 | for res_layer in self.layers_after_skip: 223 | res_layer.reset_parameters() 224 | 225 | def forward(self, x, rbf, sbf, idx_kj, idx_ji): 226 | rbf = self.lin_rbf(rbf) 227 | sbf = self.lin_sbf(sbf) 228 | 229 | x_ji = self.act(self.lin_ji(x)) 230 | x_kj = self.act(self.lin_kj(x)) 231 | x_kj = x_kj * rbf 232 | x_kj = torch.einsum('wj,wl,ijl->wi', sbf, x_kj[idx_kj], self.W) 233 | x_kj = scatter(x_kj, idx_ji, dim=0, dim_size=x.size(0)) 234 | 235 | h = x_ji + x_kj 236 | for layer in self.layers_before_skip: 237 | h = layer(h) 238 | h = self.act(self.lin(h)) + x 239 | for layer in self.layers_after_skip: 240 | h = layer(h) 241 | 242 | return h 243 | 244 | 245 | class OutputBlock(torch.nn.Module): 246 | def __init__(self, num_radial, hidden_channels, out_channels, num_layers, 247 | act=swish): 248 | super(OutputBlock, self).__init__() 249 | self.act = act 250 | 251 | self.lin_rbf = Linear(num_radial, hidden_channels, bias=False) 252 | self.lins = torch.nn.ModuleList() 253 | for _ in range(num_layers): 254 | self.lins.append(Linear(hidden_channels, hidden_channels)) 255 | self.lin = Linear(hidden_channels, out_channels, bias=False) 256 | 257 | self.reset_parameters() 258 | 259 | def reset_parameters(self): 260 | glorot_orthogonal(self.lin_rbf.weight, scale=2.0) 261 | for lin in self.lins: 262 | glorot_orthogonal(lin.weight, scale=2.0) 263 | lin.bias.data.fill_(0) 264 | self.lin.weight.data.fill_(0) 265 | 266 | def forward(self, x, rbf, i, num_nodes=None): 267 | x = self.lin_rbf(rbf) * x 268 | x = scatter(x, i, dim=0, dim_size=num_nodes) 269 | for lin in self.lins: 270 | x = self.act(lin(x)) 271 | return self.lin(x) 272 | 273 | 274 | class Dimenet(torch.nn.Module): 275 | 276 | url = ('https://github.com/klicperajo/dimenet/raw/master/pretrained/' 277 | 'dimenet') 278 | 279 | def __init__(self, input_channels_node=1, hidden_channels=128, output_channels=1, num_blocks=3, num_bilinear=8, 280 | num_spherical=7, num_radial=6, cutoff=5.0, envelope_exponent=5, 281 | num_before_skip=1, num_after_skip=2, num_output_layers=3, 282 | act=swish): 283 | super(Dimenet, self).__init__() 284 | 285 | self.cutoff = cutoff 286 | 287 | if sym is None: 288 | raise ImportError('Package `sympy` could not be found.') 289 | 290 | self.num_blocks = num_blocks 291 | 292 | self.rbf = BesselBasisLayer(num_radial, cutoff, envelope_exponent) 293 | self.sbf = SphericalBasisLayer(num_spherical, num_radial, cutoff, 294 | envelope_exponent) 295 | 296 | self.emb = EmbeddingBlock(num_radial, hidden_channels, input_channels_node=input_channels_node, act=act) 297 | 298 | self.output_blocks = torch.nn.ModuleList([ 299 | OutputBlock(num_radial, hidden_channels, output_channels, 300 | num_output_layers, act) for _ in range(num_blocks + 1) 301 | ]) 302 | 303 | self.interaction_blocks = torch.nn.ModuleList([ 304 | InteractionBlock(hidden_channels, num_bilinear, num_spherical, 305 | num_radial, num_before_skip, num_after_skip, act) 306 | for _ in range(num_blocks) 307 | ]) 308 | 309 | self.reset_parameters() 310 | 311 | def reset_parameters(self): 312 | self.rbf.reset_parameters() 313 | self.emb.reset_parameters() 314 | for out in self.output_blocks: 315 | out.reset_parameters() 316 | for interaction in self.interaction_blocks: 317 | interaction.reset_parameters() 318 | 319 | 320 | def triplets(self, edge_index, num_nodes): 321 | row, col = edge_index # j->i 322 | 323 | value = torch.arange(row.size(0), device=row.device) 324 | adj_t = SparseTensor(row=col, col=row, value=value, 325 | sparse_sizes=(num_nodes, num_nodes)) 326 | adj_t_row = adj_t[row] 327 | num_triplets = adj_t_row.set_value(None).sum(dim=1).to(torch.long) 328 | 329 | # Node indices (k->j->i) for triplets. 330 | idx_i = col.repeat_interleave(num_triplets) 331 | idx_j = row.repeat_interleave(num_triplets) 332 | idx_k = adj_t_row.storage.col() 333 | mask = idx_i != idx_k # Remove i == k triplets. 334 | idx_i, idx_j, idx_k = idx_i[mask], idx_j[mask], idx_k[mask] 335 | 336 | # Edge indices (k-j, j->i) for triplets. 337 | idx_kj = adj_t_row.storage.value()[mask] 338 | idx_ji = adj_t_row.storage.row()[mask] 339 | 340 | return col, row, idx_i, idx_j, idx_k, idx_kj, idx_ji 341 | 342 | 343 | def forward(self, x, pos, edge_index, batch=None): 344 | """""" 345 | # modified to use graph connections 346 | # edge_index = radius_graph(pos, r=self.cutoff, batch=batch) 347 | 348 | i, j, idx_i, idx_j, idx_k, idx_kj, idx_ji = self.triplets( 349 | edge_index, num_nodes=x.size(0)) 350 | 351 | # Calculate distances. 352 | dist = (pos[i] - pos[j]).pow(2).sum(dim=-1).sqrt() 353 | 354 | # Calculate angles. 355 | pos_i = pos[idx_i] 356 | pos_ji, pos_ki = pos[idx_j] - pos_i, pos[idx_k] - pos_i 357 | a = (pos_ji * pos_ki).sum(dim=-1) 358 | b = torch.cross(pos_ji, pos_ki).norm(dim=-1) 359 | angle = torch.atan2(b, a) 360 | 361 | rbf = self.rbf(dist) 362 | sbf = self.sbf(dist, angle, idx_kj) 363 | 364 | # Embedding block. 365 | if x.dim() == 1: 366 | x = self.emb(x, rbf, i, j, flag=True) 367 | else: 368 | x = self.emb(x, rbf, i, j, flag=False) 369 | 370 | 371 | P = self.output_blocks[0](x, rbf, i, num_nodes=pos.size(0)) 372 | 373 | # Interaction blocks. 374 | for interaction_block, output_block in zip(self.interaction_blocks, 375 | self.output_blocks[1:]): 376 | x = interaction_block(x, rbf, sbf, idx_kj, idx_ji) 377 | P += output_block(x, rbf, i) 378 | 379 | return P.sum(dim=0) if batch is None else scatter(P, batch, dim=0) 380 | -------------------------------------------------------------------------------- /main_base.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | import os 4 | from tqdm import tqdm 5 | import time 6 | import pickle as pkl 7 | import json 8 | 9 | import torch 10 | import copy 11 | from torch_geometric.data import Data, DataLoader 12 | from torch_scatter import scatter 13 | 14 | from utils.utils import build_spanning_tree_edge, find_higher_order_neighbors, add_self_loops 15 | from sklearn.metrics import r2_score 16 | from sklearn.metrics import roc_auc_score 17 | from sklearn.metrics import classification_report 18 | 19 | def get_args(): 20 | parser = argparse.ArgumentParser() 21 | parser.add_argument('--data_dir', type=str, default='./data') 22 | parser.add_argument('--save_dir', type=str, default='./results') 23 | parser.add_argument('--model', type=str, default='SGMP') 24 | parser.add_argument('--dataset', type=str, default='BACE') 25 | parser.add_argument('--split', type=str, default='811') 26 | parser.add_argument('--device', type=str, default='gpu') 27 | parser.add_argument('--readout', type=str, default='add') 28 | parser.add_argument('--spanning_tree', type=str, default='False') 29 | parser.add_argument('--structure', type=str, default='sc') 30 | 31 | parser.add_argument('--random_seed', type=int, default=12345) 32 | parser.add_argument('--random_seed_2', type=int, default=12345) 33 | parser.add_argument('--label', type=int, default=12) 34 | parser.add_argument('--batch_size', type=int, default=64) 35 | parser.add_argument('--num_layers', type=int, default=3) 36 | parser.add_argument('--epoch', type=int, default=500) 37 | parser.add_argument('--lr', type=float, default=1e-3) 38 | parser.add_argument('--test_per_round', type=int, default=5) 39 | parser.add_argument('--threshold', type=float, default=0.1) 40 | parser.add_argument('--cutoff', type=float, default=10.0) 41 | parser.add_argument('--weight_decay', type=float, default=5e-4) 42 | args = parser.parse_args() 43 | 44 | return args 45 | 46 | def load_data(args): 47 | if args.dataset == 'synthetic': 48 | with open(os.path.join(args.data_dir, 'synthetic', 'synthetic.pkl'), 'rb') as file: 49 | dataset = pkl.load(file) 50 | dataset = dataset[torch.randperm(len(dataset))] 51 | train_valid_split = int( int(args.split[0]) / 10 * len(dataset) ) 52 | valid_test_split = int( int(args.split[1]) / 10 * len(dataset) ) 53 | 54 | elif args.dataset == 'QM9': 55 | from torch_geometric.datasets import QM9 56 | dataset = QM9(root=os.path.join(args.data_dir, 'QM9')) 57 | random_state = np.random.RandomState(seed=42) 58 | perm = torch.from_numpy(random_state.permutation(np.arange(130831))) 59 | dataset = dataset[perm] 60 | train_valid_split, valid_test_split = 110000, 10000 61 | 62 | elif args.dataset == 'brain': 63 | from utils.brain_load_data import load_brain_data 64 | dataset = load_brain_data(data_dir=args.data_dir, structure='sc', threshold=5e5, random_seed=args.random_seed) 65 | train_valid_split = int( int(args.split[0]) / 10 * len(dataset) ) 66 | valid_test_split = int( int(args.split[1]) / 10 * len(dataset) ) 67 | 68 | elif args.dataset in ["ESOL", "Lipo", "BACE", "BBBP"]: 69 | from utils.moleculenet import MoleculeNet 70 | dataset = MoleculeNet(root=os.path.join(args.data_dir, 'MoleculeNet', args.dataset),name=args.dataset) 71 | dataset = dataset[torch.randperm(len(dataset))] 72 | train_valid_split = int( int(args.split[0]) / 10 * len(dataset) ) 73 | valid_test_split = int( int(args.split[1]) / 10 * len(dataset) ) 74 | else: 75 | raise Exception('Dataset not recognized.') 76 | 77 | 78 | train_dataset = dataset[:train_valid_split] 79 | valid_dataset = dataset[train_valid_split:train_valid_split+valid_test_split] 80 | test_dataset = dataset[train_valid_split+valid_test_split:] 81 | 82 | print('======================') 83 | print(f'Number of training graphs: {len(train_dataset)}') 84 | print(f'Number of valid graphs: {len(valid_dataset)}') 85 | print(f'Number of test graphs: {len(test_dataset)}') 86 | train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True) 87 | valid_loader = DataLoader(valid_dataset, batch_size=args.batch_size, shuffle=False) 88 | test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False) 89 | 90 | return train_loader, valid_loader, test_loader 91 | 92 | def main(data, args): 93 | # logging 94 | if args.spanning_tree == 'True': 95 | LOG_DIR = os.path.join(args.save_dir, args.dataset, args.model+'_st') 96 | else: 97 | LOG_DIR = os.path.join(args.save_dir, args.dataset, args.model) 98 | if args.dataset == 'QM9': 99 | LOG_DIR = os.path.join(LOG_DIR, str(args.label)) 100 | 101 | if not os.path.exists(LOG_DIR): 102 | os.makedirs(LOG_DIR) 103 | log_file = os.path.join(LOG_DIR, 'log.csv') 104 | result_file = os.path.join(LOG_DIR, 'results.txt') 105 | 106 | # device 107 | if args.device == 'cpu': 108 | device = torch.device("cpu") 109 | elif args.device == 'gpu': 110 | device = torch.device("cuda:0") 111 | else: 112 | raise Exception('Please assign the type of device: cpu or gpu.') 113 | 114 | if args.dataset == 'brain': 115 | input_channels_node, hidden_channels, readout = 1, 64, args.readout 116 | elif args.dataset == 'QM9': 117 | input_channels_node, hidden_channels, readout = 11, 64, args.readout 118 | else: 119 | input_channels_node, hidden_channels, readout = 9, 64, args.readout 120 | 121 | if args.dataset in [ 'BACE', 'BBBP']: 122 | task = 'classification' 123 | elif args.dataset in ['QM9', 'ESOL', 'Lipo']: 124 | task = 'regression' 125 | elif args.dataset in ['brain']: 126 | task = 'regression' 127 | 128 | if task == 'regression': 129 | output_channels = 1 130 | else: 131 | output_channels = 2 132 | 133 | # select model and its parameter 134 | print(args.model) 135 | if args.model == 'GIN': 136 | from models.GIN import GINNet 137 | net = GINNet(input_channels_node, hidden_channels, output_channels, readout=readout, eps=0., num_layers=args.num_layers) 138 | elif args.model == 'GAT': 139 | from models.GAT import GATNet 140 | net = GATNet(input_channels_node, hidden_channels, output_channels, readout=readout, num_layers=args.num_layers) 141 | elif args.model == 'GatedGraphConv': 142 | from models.GatedGraphConv import GatedNet 143 | net = GatedNet(input_channels_node, hidden_channels, output_channels, readout=readout, num_layers=args.num_layers) 144 | elif args.model == 'PointNet': 145 | from models.PointNet import PointNet 146 | net = PointNet(input_channels_node, hidden_channels, output_channels, readout=readout, num_layers=args.num_layers) 147 | elif args.model == 'PPFNet': 148 | from models.PPFNet import PPFNet 149 | net = PPFNet(input_channels_node, hidden_channels, output_channels, readout=readout, num_layers=args.num_layers) 150 | elif args.model == 'SGCN': 151 | from models.SGCN import SGCN 152 | net = SGCN(input_channels_node, hidden_channels, output_channels, readout=readout, num_layers=args.num_layers) 153 | elif args.model == 'Schnet': 154 | from models.Schnet import Schnet 155 | net = Schnet(input_channels_node=input_channels_node, 156 | hidden_channels=hidden_channels, output_channels=output_channels, num_interactions=args.num_layers, 157 | num_gaussians=hidden_channels, cutoff=args.cutoff, readout=readout) 158 | elif args.model == 'Dimenet': 159 | from models.Dimenet import Dimenet 160 | net = Dimenet(input_channels_node=input_channels_node, 161 | hidden_channels=hidden_channels, output_channels=output_channels, num_blocks=args.num_layers, 162 | cutoff=args.cutoff) 163 | elif args.model == 'SGMP': 164 | from models.SGMP import SGMP 165 | net = SGMP(input_channels_node=input_channels_node, 166 | hidden_channels=hidden_channels, output_channels=output_channels, 167 | num_interactions=args.num_layers, cutoff=args.cutoff, 168 | readout=readout) 169 | 170 | model = net.to(device) 171 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) 172 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.9, patience=10, min_lr=5e-5) 173 | 174 | if task == 'regression': 175 | criterion = torch.nn.MSELoss() 176 | from sklearn.metrics import mean_absolute_error 177 | measure = mean_absolute_error 178 | elif task == 'classification': 179 | criterion = torch.nn.CrossEntropyLoss() 180 | from sklearn.metrics import accuracy_score 181 | measure = accuracy_score 182 | 183 | # get train/valid/test data 184 | train_loader, valid_loader, test_loader = data 185 | 186 | def train(loader, model, args): 187 | model.train() 188 | for data in (loader): # Iterate in batches over the training dataset. 189 | x, pos, edge_index, batch = data.x.float(), data.pos, data.edge_index, data.batch 190 | if args.dataset == 'brain': 191 | x = x.long() 192 | if args.dataset == 'QM9': 193 | y = data.y[:, args.label] 194 | else: 195 | y = data.y.long() if task == 'classification' else data.y 196 | num_nodes = data.num_nodes 197 | num_edges = data.num_edges 198 | if args.spanning_tree == 'True': 199 | edge_index = build_spanning_tree_edge(edge_index.cpu(), algo='scipy', num_nodes=num_nodes, num_edges=num_edges) 200 | x, pos, edge_index, batch, y = x.to(device), pos.to(device), edge_index.to(device), batch.to(device), y.to(device) 201 | if args.model == 'SGMP': 202 | edge_index, _ = add_self_loops(edge_index, num_nodes=num_nodes) # add self loop to avoid crash on specific data point with longest path < 3 203 | _, _, edge_index_3rd, _, _, _, _, _ = find_higher_order_neighbors(edge_index, num_nodes, order=3) 204 | out = model(x, pos, batch, edge_index_3rd) 205 | else: 206 | out = model(x, pos, edge_index, batch) 207 | 208 | if task == 'classification': 209 | loss = criterion(out, y.reshape(-1)) # Compute the loss. 210 | else: 211 | loss = criterion(out.reshape(-1, 1), y.reshape(-1, 1)) # Compute the loss. 212 | loss.backward() # Derive gradients. 213 | optimizer.step() # Update parameters based on gradients. 214 | optimizer.zero_grad() # Clear gradients. 215 | 216 | def test(loader, model, args): 217 | model.eval() 218 | y_hat, y_true = [], [] 219 | loss_total, total_graph = 0, 0 220 | for data in loader: # Iterate in batches over the training/test dataset. 221 | x, pos, edge_index, batch = data.x.float(), data.pos, data.edge_index, data.batch 222 | if args.dataset == 'brain': 223 | x = x.long() 224 | if args.dataset == 'QM9': 225 | y = data.y[:, args.label] 226 | else: 227 | y = data.y.long() if task == 'classification' else data.y 228 | num_nodes = data.num_nodes 229 | num_edges = data.num_edges 230 | if args.spanning_tree == 'True': 231 | edge_index = build_spanning_tree_edge(edge_index.cpu(), algo='scipy', num_nodes=num_nodes, num_edges=num_edges) 232 | x, pos, edge_index, batch, y = x.to(device), pos.to(device), edge_index.to(device), batch.to(device), y.to(device) 233 | if args.model == 'SGMP': 234 | edge_index, _ = add_self_loops(edge_index, num_nodes=num_nodes, fill_value=-1.) 235 | _, _, edge_index_3rd, _, _, _, _, _ = find_higher_order_neighbors(edge_index, num_nodes, order=3) 236 | out = model(x, pos, batch, edge_index_3rd) 237 | else: 238 | out = model(x, pos, edge_index, batch) 239 | 240 | if task == 'classification': 241 | loss = criterion(out, y.reshape(-1)) # Compute the loss. 242 | else: 243 | loss = criterion(out.reshape(-1, 1), y.reshape(-1, 1)) # Compute the loss. 244 | loss_total += loss.detach().cpu() * data.num_graphs 245 | total_graph += data.num_graphs 246 | if task == 'classification': 247 | pred = out.argmax(dim=1) # Use the class with highest probability. 248 | y_hat += list(pred.cpu().detach().numpy().reshape(-1)) 249 | else: 250 | y_hat += list(out.cpu().detach().numpy().reshape(-1)) 251 | y_true += list(y.cpu().detach().numpy().reshape(-1)) 252 | 253 | return loss_total/total_graph, y_hat, y_true 254 | 255 | with open(log_file, 'a') as f: 256 | print(f"Epoch, Valid loss, Valid score, --- %s seconds ---", file=f) 257 | 258 | start_time = time.time() 259 | best_valid_score = 1e10 if task == 'regression' else 0 260 | best_model = None 261 | for epoch in (range(1, args.epoch)): 262 | # training 263 | train(train_loader, model, args) 264 | 265 | if epoch % args.test_per_round == 0: 266 | valid_loss, yhat_valid, ytrue_valid = test(valid_loader, model, args) 267 | valid_score = measure(ytrue_valid, yhat_valid) 268 | 269 | if epoch >= 100: 270 | lr = scheduler.optimizer.param_groups[0]['lr'] 271 | scheduler.step(valid_loss) 272 | 273 | with open(log_file, 'a') as f: 274 | print(f"{epoch:03d}, {valid_loss:.4f}, {valid_score:.4f} ,{(time.time() - start_time):.4f}", file=f) 275 | 276 | if task == 'regression': 277 | if valid_score < best_valid_score: 278 | best_valid_score = valid_score 279 | best_model = copy.deepcopy(model) 280 | else: 281 | if valid_score > best_valid_score: 282 | best_valid_score = valid_score 283 | best_model = copy.deepcopy(model) 284 | 285 | train_loss, yhat_train, ytrue_train = test(train_loader, model, args) 286 | train_score = measure(ytrue_train, yhat_train) 287 | valid_loss, yhat_valid, ytrue_valid = test(valid_loader, model, args) 288 | valid_score = measure(ytrue_valid, yhat_valid) 289 | test_loss, yhat_test, ytrue_test = test(test_loader, model, args) 290 | test_score = measure(ytrue_test, yhat_test) 291 | with open(result_file, 'a') as f: 292 | if task == 'regression': 293 | print(f"Final, Train RMSE: {np.sqrt(train_loss):.4f}, Train MAE: {train_score:.4f}, Valid RMSE: {np.sqrt(valid_loss):.4f}, Valid MAE: {valid_score:.4f}, Test RMSE: {np.sqrt(test_loss):.4f}, Test MAE: {test_score:.4f}", file=f) 294 | elif task == 'classification': 295 | print(f"Final, Train loss: {train_loss:.4f}, Train acc: {train_score:.4f}, Valid loss: {valid_loss:.4f}, Valid acc: {valid_score:.4f}, Test loss: {test_loss:.4f}, Test acc: {test_score:.4f}", file=f) 296 | 297 | 298 | train_loss, yhat_train, ytrue_train = test(train_loader, best_model, args) 299 | train_score = measure(ytrue_train, yhat_train) 300 | valid_loss, yhat_valid, ytrue_valid = test(valid_loader, best_model, args) 301 | valid_score = measure(ytrue_valid, yhat_valid) 302 | test_loss, yhat_test, ytrue_test = test(test_loader, best_model, args) 303 | test_score = measure(ytrue_test, yhat_test) 304 | with open(result_file, 'a') as f: 305 | if task == 'regression': 306 | print(f"Best Model, Train RMSE: {np.sqrt(train_loss):.4f}, Train MAE: {train_score:.4f}, Valid RMSE: {np.sqrt(valid_loss):.4f}, Valid MAE: {valid_score:.4f}, Test RMSE: {np.sqrt(test_loss):.4f}, Test MAE: {test_score:.4f}") 307 | print(f"Best Model, Train RMSE: {np.sqrt(train_loss):.4f}, Train MAE: {train_score:.4f}, Valid RMSE: {np.sqrt(valid_loss):.4f}, Valid MAE: {valid_score:.4f}, Test RMSE: {np.sqrt(test_loss):.4f}, Test MAE: {test_score:.4f}", file=f) 308 | elif task == 'classification': 309 | print(f"Best Model, Train loss: {train_loss:.4f}, Train acc: {train_score:.4f}, Valid loss: {valid_loss:.4f}, Valid acc: {valid_score:.4f}, Test loss: {test_loss:.4f}, Test acc: {test_score:.4f}") 310 | print(f"Best Model, Train loss: {train_loss:.4f}, Train acc: {train_score:.4f}, Valid loss: {valid_loss:.4f}, Valid acc: {valid_score:.4f}, Test loss: {test_loss:.4f}, Test acc: {test_score:.4f}", file=f) 311 | 312 | 313 | if __name__ == '__main__': 314 | args = get_args() 315 | 316 | if not os.path.exists(args.save_dir): 317 | os.makedirs(args.save_dir) 318 | 319 | torch.manual_seed(args.random_seed) 320 | data = load_data(args) 321 | main(data, args) 322 | 323 | 324 | -------------------------------------------------------------------------------- /main_base_st.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | import os 4 | from tqdm import tqdm 5 | import time 6 | import pickle as pkl 7 | import json 8 | 9 | import torch 10 | import copy 11 | from torch_geometric.data import Data, DataLoader 12 | from torch_scatter import scatter 13 | 14 | from utils.utils import build_spanning_tree_edge, find_higher_order_neighbors, add_self_loops 15 | from sklearn.metrics import r2_score 16 | from sklearn.metrics import roc_auc_score 17 | from sklearn.metrics import classification_report 18 | 19 | def get_args(): 20 | parser = argparse.ArgumentParser() 21 | parser.add_argument('--data_dir', type=str, default='./data') 22 | parser.add_argument('--save_dir', type=str, default='./results') 23 | parser.add_argument('--model', type=str, default='SGMP') 24 | parser.add_argument('--dataset', type=str, default='BACE') 25 | parser.add_argument('--split', type=str, default='811') 26 | parser.add_argument('--device', type=str, default='gpu') 27 | parser.add_argument('--readout', type=str, default='add') 28 | parser.add_argument('--spanning_tree', type=str, default='False') 29 | parser.add_argument('--structure', type=str, default='sc') 30 | 31 | parser.add_argument('--random_seed', type=int, default=12345) 32 | parser.add_argument('--random_seed_2', type=int, default=12345) 33 | parser.add_argument('--label', type=int, default=12) 34 | parser.add_argument('--batch_size', type=int, default=64) 35 | parser.add_argument('--num_layers', type=int, default=3) 36 | parser.add_argument('--epoch', type=int, default=500) 37 | parser.add_argument('--lr', type=float, default=1e-3) 38 | parser.add_argument('--test_per_round', type=int, default=5) 39 | parser.add_argument('--threshold', type=float, default=0.1) 40 | parser.add_argument('--cutoff', type=float, default=10.0) 41 | parser.add_argument('--weight_decay', type=float, default=5e-4) 42 | args = parser.parse_args() 43 | 44 | return args 45 | 46 | def load_data(args): 47 | if args.dataset == 'synthetic': 48 | with open(os.path.join(args.data_dir, 'synthetic.pkl'), 'rb') as file: 49 | dataset = pkl.load(file) 50 | dataset = dataset[torch.randperm(len(dataset))] 51 | train_valid_split = int( int(args.split[0]) / 10 * len(dataset) ) 52 | valid_test_split = int( int(args.split[1]) / 10 * len(dataset) ) 53 | 54 | elif args.dataset == 'QM9': 55 | from torch_geometric.datasets import QM9 56 | dataset = QM9(root=os.path.join(args.data_dir, 'QM9')) 57 | random_state = np.random.RandomState(seed=42) 58 | perm = torch.from_numpy(random_state.permutation(np.arange(130831))) 59 | dataset = dataset[perm] 60 | train_valid_split, valid_test_split = 110000, 10000 61 | 62 | elif args.dataset == 'brain': 63 | from utils.brain_load_data import load_brain_data 64 | dataset = load_brain_data(data_dir=args.data_dir, structure='sc', threshold=5e5, random_seed=args.random_seed) 65 | train_valid_split = int( int(args.split[0]) / 10 * len(dataset) ) 66 | valid_test_split = int( int(args.split[1]) / 10 * len(dataset) ) 67 | 68 | elif args.dataset in ["ESOL", "Lipo", "BACE", "BBBP"]: 69 | from utils.moleculenet import MoleculeNet 70 | dataset = MoleculeNet(root=os.path.join(args.data_dir, 'MoleculeNet', args.dataset),name=args.dataset) 71 | dataset = dataset[torch.randperm(len(dataset))] 72 | train_valid_split = int( int(args.split[0]) / 10 * len(dataset) ) 73 | valid_test_split = int( int(args.split[1]) / 10 * len(dataset) ) 74 | else: 75 | raise Exception('Dataset not recognized.') 76 | 77 | 78 | train_dataset = dataset[:train_valid_split] 79 | valid_dataset = dataset[train_valid_split:train_valid_split+valid_test_split] 80 | test_dataset = dataset[train_valid_split+valid_test_split:] 81 | 82 | print('======================') 83 | print(f'Number of training graphs: {len(train_dataset)}') 84 | print(f'Number of valid graphs: {len(valid_dataset)}') 85 | print(f'Number of test graphs: {len(test_dataset)}') 86 | train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True) 87 | valid_loader = DataLoader(valid_dataset, batch_size=args.batch_size, shuffle=False) 88 | test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False) 89 | 90 | return train_loader, valid_loader, test_loader 91 | 92 | def main(data, args): 93 | # logging 94 | if args.spanning_tree == 'True': 95 | LOG_DIR = os.path.join(args.save_dir, args.dataset, args.model+'_st') 96 | else: 97 | LOG_DIR = os.path.join(args.save_dir, args.dataset, args.model) 98 | if args.dataset == 'QM9': 99 | LOG_DIR = os.path.join(LOG_DIR, str(args.label)) 100 | 101 | if not os.path.exists(LOG_DIR): 102 | os.makedirs(LOG_DIR) 103 | log_file = os.path.join(LOG_DIR, 'log.csv') 104 | result_file = os.path.join(LOG_DIR, 'results.txt') 105 | 106 | # device 107 | if args.device == 'cpu': 108 | device = torch.device("cpu") 109 | elif args.device == 'gpu': 110 | device = torch.device("cuda:0") 111 | else: 112 | raise Exception('Please assign the type of device: cpu or gpu.') 113 | 114 | if args.dataset == 'brain': 115 | input_channels_node, hidden_channels, readout = 1, 64, args.readout 116 | elif args.dataset == 'QM9': 117 | input_channels_node, hidden_channels, readout = 11, 64, args.readout 118 | else: 119 | input_channels_node, hidden_channels, readout = 9, 64, args.readout 120 | 121 | if args.dataset in [ 'BACE', 'BBBP']: 122 | task = 'classification' 123 | elif args.dataset in ['QM9', 'ESOL', 'Lipo']: 124 | task = 'regression' 125 | elif args.dataset in ['brain']: 126 | task = 'regression' 127 | 128 | if task == 'regression': 129 | output_channels = 1 130 | else: 131 | output_channels = 2 132 | 133 | # select model and its parameter 134 | print(args.model) 135 | if args.model == 'GIN': 136 | from models.GIN import GINNet 137 | net = GINNet(input_channels_node, hidden_channels, output_channels, readout=readout, eps=0., num_layers=args.num_layers) 138 | elif args.model == 'GAT': 139 | from models.GAT import GATNet 140 | net = GATNet(input_channels_node, hidden_channels, output_channels, readout=readout, num_layers=args.num_layers) 141 | elif args.model == 'GatedGraphConv': 142 | from models.GatedGraphConv import GatedNet 143 | net = GatedNet(input_channels_node, hidden_channels, output_channels, readout=readout, num_layers=args.num_layers) 144 | elif args.model == 'PointNet': 145 | from models.PointNet import PointNet 146 | net = PointNet(input_channels_node, hidden_channels, output_channels, readout=readout, num_layers=args.num_layers) 147 | elif args.model == 'PPFNet': 148 | from models.PPFNet import PPFNet 149 | net = PPFNet(input_channels_node, hidden_channels, output_channels, readout=readout, num_layers=args.num_layers) 150 | elif args.model == 'SGCN': 151 | from models.SGCN import SGCN 152 | net = SGCN(input_channels_node, hidden_channels, output_channels, readout=readout, num_layers=args.num_layers) 153 | elif args.model == 'Schnet': 154 | from models.Schnet import Schnet 155 | net = Schnet(input_channels_node=input_channels_node, 156 | hidden_channels=hidden_channels, output_channels=output_channels, num_interactions=args.num_layers, 157 | num_gaussians=hidden_channels, cutoff=args.cutoff, readout=readout) 158 | elif args.model == 'Dimenet': 159 | from models.Dimenet import Dimenet 160 | net = Dimenet(input_channels_node=input_channels_node, 161 | hidden_channels=hidden_channels, output_channels=output_channels, num_blocks=args.num_layers, 162 | cutoff=args.cutoff) 163 | elif args.model == 'SGMP': 164 | from models.SGMP import SGMP 165 | net = SGMP(input_channels_node=input_channels_node, 166 | hidden_channels=hidden_channels, output_channels=output_channels, 167 | num_interactions=args.num_layers, cutoff=args.cutoff, 168 | readout=readout) 169 | 170 | model = net.to(device) 171 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) 172 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.9, patience=5, min_lr=5e-5) 173 | 174 | if task == 'regression': 175 | criterion = torch.nn.MSELoss() 176 | from sklearn.metrics import mean_absolute_error 177 | measure = mean_absolute_error 178 | elif task == 'classification': 179 | criterion = torch.nn.CrossEntropyLoss() 180 | from sklearn.metrics import accuracy_score 181 | measure = accuracy_score 182 | 183 | # get train/valid/test data 184 | train_loader, valid_loader, test_loader = data 185 | 186 | def train(loader, model, args): 187 | model.train() 188 | for data in (loader): # Iterate in batches over the training dataset. 189 | x, pos, edge_index, batch = data.x.float(), data.pos, data.edge_index, data.batch 190 | if args.dataset == 'brain': 191 | x = x.long() 192 | if args.dataset == 'QM9': 193 | y = data.y[:, args.label] 194 | else: 195 | y = data.y.long() if task == 'classification' else data.y 196 | num_nodes = data.num_nodes 197 | num_edges = data.num_edges 198 | if args.spanning_tree == 'True': 199 | edge_index = build_spanning_tree_edge(edge_index.cpu(), algo='scipy', num_nodes=num_nodes, num_edges=num_edges) 200 | x, pos, edge_index, batch, y = x.to(device), pos.to(device), edge_index.to(device), batch.to(device), y.to(device) 201 | if args.model == 'SGMP': 202 | edge_index, _ = add_self_loops(edge_index, num_nodes=num_nodes) # add self loop to avoid crash on specific data point with longest path < 3 203 | _, _, edge_index_3rd, _, _, _, _, _ = find_higher_order_neighbors(edge_index, num_nodes, order=3) 204 | out = model(x, pos, batch, edge_index_3rd) 205 | else: 206 | out = model(x, pos, edge_index, batch) 207 | 208 | if task == 'classification': 209 | loss = criterion(out, y.reshape(-1)) # Compute the loss. 210 | else: 211 | loss = criterion(out.reshape(-1, 1), y.reshape(-1, 1)) # Compute the loss. 212 | loss.backward() # Derive gradients. 213 | optimizer.step() # Update parameters based on gradients. 214 | optimizer.zero_grad() # Clear gradients. 215 | 216 | def test(loader, model, args): 217 | model.eval() 218 | y_hat, y_true = [], [] 219 | loss_total, total_graph = 0, 0 220 | for data in loader: # Iterate in batches over the training/test dataset. 221 | x, pos, edge_index, batch = data.x.float(), data.pos, data.edge_index, data.batch 222 | if args.dataset == 'brain': 223 | x = x.long() 224 | if args.dataset == 'QM9': 225 | y = data.y[:, args.label] 226 | else: 227 | y = data.y.long() if task == 'classification' else data.y 228 | num_nodes = data.num_nodes 229 | num_edges = data.num_edges 230 | if args.spanning_tree == 'True': 231 | out_list = [] 232 | num_samples = 25 233 | for _ in range(num_samples): # mote-carlo to reduce the variance on test 234 | sp_edge_index = build_spanning_tree_edge(edge_index.cpu(), algo='scipy', num_nodes=num_nodes, num_edges=num_edges) 235 | x, pos, sp_edge_index, batch, y = x.to(device), pos.to(device), sp_edge_index.to(device), batch.to(device), y.to(device) 236 | if args.model == 'SGMP': 237 | sp_edge_index, _ = add_self_loops(sp_edge_index, num_nodes=num_nodes, fill_value=-1.) 238 | _, _, edge_index_3rd, _, _, _, _, _ = find_higher_order_neighbors(sp_edge_index, num_nodes, order=3) 239 | out = model(x, pos, batch, edge_index_3rd) 240 | else: 241 | out = model(x, pos, edge_index, batch) 242 | out_list.append(out.detach().cpu()) 243 | 244 | if task == 'classification': 245 | out = torch.cat(out_list,dim=0).reshape(num_samples,-1,2).mean(dim=0).to(device) 246 | loss = criterion(out, y.reshape(-1)) # Compute the loss. 247 | else: 248 | out = torch.cat(out_list,dim=0).reshape(num_samples,-1,1).mean(dim=0).to(device) 249 | loss = criterion(out.reshape(-1, 1), y.reshape(-1, 1)) # Compute the loss. 250 | loss_total += loss.detach().cpu() * data.num_graphs 251 | total_graph += data.num_graphs 252 | if task == 'classification': 253 | pred = out.argmax(dim=1) # Use the class with highest probability. 254 | y_hat += list(pred.cpu().detach().numpy().reshape(-1)) 255 | else: 256 | y_hat += list(out.cpu().detach().numpy().reshape(-1)) 257 | y_true += list(y.cpu().detach().numpy().reshape(-1)) 258 | 259 | return loss_total/total_graph, y_hat, y_true 260 | 261 | with open(log_file, 'a') as f: 262 | print(f"Epoch, Valid loss, Valid score, --- %s seconds ---", file=f) 263 | 264 | start_time = time.time() 265 | best_valid_score = 1e10 if task == 'regression' else 0 266 | best_model = None 267 | for epoch in (range(1, args.epoch)): 268 | # training 269 | train(train_loader, model, args) 270 | 271 | if epoch % args.test_per_round == 0: 272 | valid_loss, yhat_valid, ytrue_valid = test(valid_loader, model, args) 273 | valid_score = measure(ytrue_valid, yhat_valid) 274 | 275 | if epoch >= 100: 276 | lr = scheduler.optimizer.param_groups[0]['lr'] 277 | scheduler.step(valid_loss) 278 | 279 | with open(log_file, 'a') as f: 280 | print(f"{epoch:03d}, {valid_loss:.4f}, {valid_score:.4f} ,{(time.time() - start_time):.4f}", file=f) 281 | 282 | if task == 'regression': 283 | if valid_score < best_valid_score: 284 | best_valid_score = valid_score 285 | best_model = copy.deepcopy(model) 286 | else: 287 | if valid_score > best_valid_score: 288 | best_valid_score = valid_score 289 | best_model = copy.deepcopy(model) 290 | 291 | train_loss, yhat_train, ytrue_train = test(train_loader, model, args) 292 | train_score = measure(ytrue_train, yhat_train) 293 | valid_loss, yhat_valid, ytrue_valid = test(valid_loader, model, args) 294 | valid_score = measure(ytrue_valid, yhat_valid) 295 | test_loss, yhat_test, ytrue_test = test(test_loader, model, args) 296 | test_score = measure(ytrue_test, yhat_test) 297 | with open(result_file, 'a') as f: 298 | if task == 'regression': 299 | print(f"Final, Train RMSE: {np.sqrt(train_loss):.4f}, Train MAE: {train_score:.4f}, Valid RMSE: {np.sqrt(valid_loss):.4f}, Valid MAE: {valid_score:.4f}, Test RMSE: {np.sqrt(test_loss):.4f}, Test MAE: {test_score:.4f}", file=f) 300 | elif task == 'classification': 301 | print(f"Final, Train loss: {train_loss:.4f}, Train acc: {train_score:.4f}, Valid loss: {valid_loss:.4f}, Valid acc: {valid_score:.4f}, Test loss: {test_loss:.4f}, Test acc: {test_score:.4f}", file=f) 302 | 303 | 304 | train_loss, yhat_train, ytrue_train = test(train_loader, best_model, args) 305 | train_score = measure(ytrue_train, yhat_train) 306 | valid_loss, yhat_valid, ytrue_valid = test(valid_loader, best_model, args) 307 | valid_score = measure(ytrue_valid, yhat_valid) 308 | test_loss, yhat_test, ytrue_test = test(test_loader, best_model, args) 309 | test_score = measure(ytrue_test, yhat_test) 310 | with open(result_file, 'a') as f: 311 | if task == 'regression': 312 | print(f"Best Model, Train RMSE: {np.sqrt(train_loss):.4f}, Train MAE: {train_score:.4f}, Valid RMSE: {np.sqrt(valid_loss):.4f}, Valid MAE: {valid_score:.4f}, Test RMSE: {np.sqrt(test_loss):.4f}, Test MAE: {test_score:.4f}") 313 | print(f"Best Model, Train RMSE: {np.sqrt(train_loss):.4f}, Train MAE: {train_score:.4f}, Valid RMSE: {np.sqrt(valid_loss):.4f}, Valid MAE: {valid_score:.4f}, Test RMSE: {np.sqrt(test_loss):.4f}, Test MAE: {test_score:.4f}", file=f) 314 | elif task == 'classification': 315 | print(f"Best Model, Train loss: {train_loss:.4f}, Train acc: {train_score:.4f}, Valid loss: {valid_loss:.4f}, Valid acc: {valid_score:.4f}, Test loss: {test_loss:.4f}, Test acc: {test_score:.4f}") 316 | print(f"Best Model, Train loss: {train_loss:.4f}, Train acc: {train_score:.4f}, Valid loss: {valid_loss:.4f}, Valid acc: {valid_score:.4f}, Test loss: {test_loss:.4f}, Test acc: {test_score:.4f}", file=f) 317 | 318 | 319 | if __name__ == '__main__': 320 | args = get_args() 321 | 322 | if not os.path.exists(args.save_dir): 323 | os.makedirs(args.save_dir) 324 | 325 | torch.manual_seed(args.random_seed) 326 | data = load_data(args) 327 | main(data, args) 328 | 329 | 330 | --------------------------------------------------------------------------------