├── deepgraph ├── data │ ├── substructure_dataset_utils │ │ ├── __init__.py │ │ ├── LICENSE │ │ ├── utils_encoding.py │ │ ├── neighbor_extractors.py │ │ ├── utils.py │ │ ├── utils_graph_processing.py │ │ └── substructure_transform.py │ ├── ogb_datasets │ │ ├── __init__.py │ │ └── ogb_dataset_lookup_table.py │ ├── pyg_datasets │ │ ├── __init__.py │ │ └── pyg_dataset_lookup_table.py │ ├── __init__.py │ ├── algos.pyx │ ├── subsampling │ │ ├── sampler.py │ │ └── sampling.py │ ├── dataset.py │ ├── collator.py │ ├── wrapper.py │ └── substructure_dataset.py ├── tasks │ ├── __init__.py │ └── is2re.py ├── models │ └── __init__.py ├── modules │ ├── __init__.py │ ├── deepgraph_graph_encoder_layer.py │ ├── deepgraph_layers.py │ └── deepgraph_graph_encoder.py ├── criterions │ ├── __init__.py │ ├── contrastive_loss.py │ ├── l1_loss.py │ ├── multiclass_cross_entropy.py │ ├── multilabel_multiclass_cross_entropy.py │ ├── binary_logloss.py │ ├── node_multiclass_cross_entropy.py │ └── mae_deltapos.py ├── __init__.py ├── pretrain │ └── __init__.py └── evaluate │ ├── cache_data.py │ └── evaluate.py ├── overview.png ├── train.py ├── LICENSE └── README.md /deepgraph/data/substructure_dataset_utils/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhao-ht/DeepGraph/HEAD/overview.png -------------------------------------------------------------------------------- /deepgraph/tasks/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | -------------------------------------------------------------------------------- /deepgraph/models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | from .deepgraph import DeepGraphModel 5 | -------------------------------------------------------------------------------- /deepgraph/data/ogb_datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | from .ogb_dataset_lookup_table import OGBDatasetLookupTable 5 | -------------------------------------------------------------------------------- /deepgraph/data/pyg_datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | from .pyg_dataset_lookup_table import PYGDatasetLookupTable 5 | 6 | -------------------------------------------------------------------------------- /deepgraph/data/__init__.py: -------------------------------------------------------------------------------- 1 | DATASET_REGISTRY = {} 2 | 3 | def register_dataset(name: str): 4 | def register_dataset_func(func): 5 | DATASET_REGISTRY[name] = func() 6 | return register_dataset_func 7 | 8 | -------------------------------------------------------------------------------- /deepgraph/modules/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | from .multihead_attention import MultiheadAttention 5 | from .deepgraph_layers import GraphNodeFeature, GraphAttnBias 6 | from .deepgraph_graph_encoder_layer import DeepGraphGraphEncoderLayer 7 | from .deepgraph_graph_encoder import DeepGraphGraphEncoder, init_graphormer_params 8 | -------------------------------------------------------------------------------- /deepgraph/criterions/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | from pathlib import Path 5 | import importlib 6 | 7 | # automatically import any Python files in the criterions/ directory 8 | for file in sorted(Path(__file__).parent.glob("*.py")): 9 | if not file.name.startswith("_"): 10 | importlib.import_module("deepgraph.criterions." + file.name[:-3]) 11 | -------------------------------------------------------------------------------- /deepgraph/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | import importlib 5 | 6 | 7 | try: 8 | import torch 9 | 10 | torch.multiprocessing.set_start_method("fork", force=True) 11 | except: 12 | import sys 13 | 14 | print( 15 | "Your OS does not support multiprocessing based on fork, please use num_workers=0", 16 | file=sys.stderr, 17 | flush=True, 18 | ) 19 | 20 | import deepgraph.criterions 21 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from rdkit import Chem 3 | import graph_tool 4 | import torch_geometric 5 | import joblib 6 | from ogb.graphproppred import PygGraphPropPredDataset 7 | import warnings 8 | 9 | warnings.filterwarnings('ignore') 10 | # import fairseq 11 | # from fairseq.distributed import utils as distributed_utils 12 | # from fairseq.logging import meters, metrics, progress_bar # noqa 13 | 14 | # from fairseq.logging import metrics 15 | 16 | # sys.modules["fairseq.distributed_utils"] = distributed_utils 17 | # sys.modules["fairseq.meters"] = meters 18 | # sys.modules["fairseq.metrics"] = metrics 19 | # sys.modules["fairseq.progress_bar"] = progress_bar 20 | from fairseq_cli.train import cli_main 21 | import logging 22 | logging.getLogger().setLevel(logging.INFO) 23 | 24 | 25 | 26 | if __name__ == "__main__": 27 | cli_main() -------------------------------------------------------------------------------- /deepgraph/data/substructure_dataset_utils/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Giorgos Bouritsas 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Microsoft Corporation. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE 22 | -------------------------------------------------------------------------------- /deepgraph/pretrain/__init__.py: -------------------------------------------------------------------------------- 1 | from torch.hub import load_state_dict_from_url 2 | import torch.distributed as dist 3 | 4 | PRETRAINED_MODEL_URLS = { 5 | "pcqm4mv1_graphormer_base":"https://szheng.blob.core.windows.net/graphormer/modelzoo/pcqm4mv1/checkpoint_best_pcqm4mv1.pt", 6 | "pcqm4mv2_graphormer_base":"https://szheng.blob.core.windows.net/graphormer/modelzoo/pcqm4mv2/checkpoint_best_pcqm4mv2.pt", 7 | "oc20is2re_graphormer3d_base":"https://szheng.blob.core.windows.net/graphormer/modelzoo/oc20is2re/checkpoint_last_oc20_is2re.pt", 8 | "pcqm4mv1_graphormer_base_for_molhiv":"https://szheng.blob.core.windows.net/graphormer/modelzoo/pcqm4mv1/checkpoint_base_preln_pcqm4mv1_for_hiv.pt", 9 | } 10 | 11 | def load_pretrained_model(pretrained_model_name): 12 | if pretrained_model_name not in PRETRAINED_MODEL_URLS: 13 | raise ValueError("Unknown pretrained model name %s", pretrained_model_name) 14 | if not dist.is_initialized(): 15 | return load_state_dict_from_url(PRETRAINED_MODEL_URLS[pretrained_model_name], progress=True)["model"] 16 | else: 17 | pretrained_model = load_state_dict_from_url(PRETRAINED_MODEL_URLS[pretrained_model_name], progress=True, file_name=f"{pretrained_model_name}_{dist.get_rank()}")["model"] 18 | dist.barrier() 19 | return pretrained_model 20 | -------------------------------------------------------------------------------- /deepgraph/data/substructure_dataset_utils/utils_encoding.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import sys 3 | from sklearn.preprocessing import KBinsDiscretizer, OneHotEncoder, MinMaxScaler, StandardScaler 4 | 5 | import numpy as np 6 | from tqdm import tqdm 7 | 8 | 9 | def find_max_id(dataset_list): 10 | atom_id_max = 0 11 | edge_id_max = 0 12 | for dataset in dataset_list: 13 | for id,graph in tqdm(enumerate(dataset)): 14 | try: 15 | atom_id_max = max(atom_id_max, graph.x.max()) 16 | edge_id_max = max(edge_id_max, graph.edge_attr.max()) 17 | except: 18 | pass 19 | 20 | return int(atom_id_max),int(edge_id_max) 21 | 22 | 23 | class one_hot_unique: 24 | 25 | def __init__(self, tensor_list, **kwargs): 26 | tensor_list = torch.cat(tensor_list, 0) 27 | self.d = list() 28 | self.corrs = dict() 29 | for col in range(tensor_list.shape[1]): 30 | uniques, corrs = np.unique(tensor_list[:, col], return_inverse=True, axis=0) 31 | self.d.append(len(uniques)) 32 | self.corrs[col] = corrs 33 | return 34 | 35 | def fit(self, tensor_list): 36 | pointer = 0 37 | encoded_tensors = list() 38 | for tensor in tensor_list: 39 | n = tensor.shape[0] 40 | for col in range(tensor.shape[1]): 41 | translated = torch.LongTensor(self.corrs[col][pointer:pointer+n]).unsqueeze(1) 42 | encoded = torch.cat((encoded, translated), 1) if col > 0 else translated 43 | encoded_tensors.append(encoded) 44 | pointer += n 45 | return encoded_tensors 46 | 47 | 48 | class one_hot_max: 49 | 50 | def __init__(self, tensor_list, **kwargs): 51 | tensor_list = torch.cat(tensor_list,0) 52 | self.d = [int(tensor_list[:,i].max()+1) for i in range(tensor_list.shape[1])] 53 | 54 | def fit(self, tensor_list): 55 | return tensor_list 56 | 57 | 58 | -------------------------------------------------------------------------------- /deepgraph/criterions/contrastive_loss.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | from fairseq.dataclass.configs import FairseqDataclass 5 | 6 | import torch 7 | import torch.nn as nn 8 | from fairseq.logging import metrics 9 | # from fairseq import metrics 10 | from fairseq.criterions import FairseqCriterion, register_criterion 11 | import numpy as np 12 | 13 | @register_criterion("contrustive_loss", dataclass=FairseqDataclass) 14 | class GraphPretrainContrustiveLoss(FairseqCriterion): 15 | """ 16 | Implementation for the L1 loss (MAE loss) used in graphormer model training. 17 | """ 18 | 19 | def forward(self, model, sample, reduce=True): 20 | """Compute the loss for the given sample. 21 | 22 | Returns a tuple with three elements: 23 | 1) the loss 24 | 2) the sample size, which is used as the denominator for the gradient 25 | 3) logging outputs to display while training 26 | """ 27 | sample_size = sample["nsamples"] 28 | 29 | with torch.no_grad(): 30 | natoms = sample['net_input']['batched_data']['data']['x'].shape[1] 31 | 32 | x1 = model(sample['net_input']['batched_data']['data1']) 33 | x2 = model(sample['net_input']['batched_data']['data2']) 34 | 35 | T = 0.1 36 | batch_size, _ = x1.size() 37 | x1_abs = x1.norm(dim=1) 38 | x2_abs = x2.norm(dim=1) 39 | 40 | sim_matrix = torch.einsum('ik,jk->ij', x1, x2) / torch.einsum('i,j->ij', x1_abs, x2_abs).type(torch.float64) 41 | sim_matrix = torch.exp(sim_matrix / T) 42 | pos_sim = sim_matrix[np.arange(batch_size), np.arange(batch_size)] 43 | loss = pos_sim / (sim_matrix.sum(dim=1) - pos_sim) 44 | loss = - torch.log(loss).mean().type(torch.float16) 45 | 46 | logging_output = { 47 | "loss": loss.data, 48 | "sample_size": x1.size(0), 49 | "nsentences": sample_size, 50 | "ntokens": natoms, 51 | } 52 | return loss, sample_size, logging_output 53 | 54 | @staticmethod 55 | def reduce_metrics(logging_outputs) -> None: 56 | """Aggregate logging outputs from data parallel training.""" 57 | loss_sum = sum(log.get("loss", 0) for log in logging_outputs) 58 | sample_size = sum(log.get("sample_size", 0) for log in logging_outputs) 59 | 60 | metrics.log_scalar("loss", loss_sum / sample_size, sample_size, round=6) 61 | 62 | @staticmethod 63 | def logging_outputs_can_be_summed() -> bool: 64 | """ 65 | Whether the logging outputs returned by `forward` can be summed 66 | across workers prior to calling `reduce_metrics`. Setting this 67 | to True will improves distributed training speed. 68 | """ 69 | return True 70 | -------------------------------------------------------------------------------- /deepgraph/data/algos.pyx: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | import cython 5 | from cython.parallel cimport prange, parallel 6 | cimport numpy 7 | import numpy 8 | 9 | def floyd_warshall(adjacency_matrix): 10 | 11 | (nrows, ncols) = adjacency_matrix.shape 12 | assert nrows == ncols 13 | cdef unsigned int n = nrows 14 | 15 | adj_mat_copy = adjacency_matrix.astype(long, order='C', casting='safe', copy=True) 16 | assert adj_mat_copy.flags['C_CONTIGUOUS'] 17 | cdef numpy.ndarray[long, ndim=2, mode='c'] M = adj_mat_copy 18 | cdef numpy.ndarray[long, ndim=2, mode='c'] path = -1 * numpy.ones([n, n], dtype=numpy.int64) 19 | 20 | cdef unsigned int i, j, k 21 | cdef long M_ij, M_ik, cost_ikkj 22 | cdef long* M_ptr = &M[0,0] 23 | cdef long* M_i_ptr 24 | cdef long* M_k_ptr 25 | 26 | # set unreachable nodes distance to 510 27 | for i in range(n): 28 | for j in range(n): 29 | if i == j: 30 | M[i][j] = 0 31 | elif M[i][j] == 0: 32 | M[i][j] = 510 33 | 34 | # floyed algo 35 | for k in range(n): 36 | M_k_ptr = M_ptr + n*k 37 | for i in range(n): 38 | M_i_ptr = M_ptr + n*i 39 | M_ik = M_i_ptr[k] 40 | for j in range(n): 41 | cost_ikkj = M_ik + M_k_ptr[j] 42 | M_ij = M_i_ptr[j] 43 | if M_ij > cost_ikkj: 44 | M_i_ptr[j] = cost_ikkj 45 | path[i][j] = k 46 | 47 | # set unreachable path to 510 48 | for i in range(n): 49 | for j in range(n): 50 | if M[i][j] >= 510: 51 | path[i][j] = 510 52 | M[i][j] = 510 53 | 54 | return M, path 55 | 56 | 57 | def get_all_edges(path, i, j): 58 | cdef int k = path[i][j] 59 | if k == -1: 60 | return [] 61 | else: 62 | return get_all_edges(path, i, k) + [k] + get_all_edges(path, k, j) 63 | 64 | 65 | def gen_edge_input(max_dist, path, edge_feat): 66 | 67 | (nrows, ncols) = path.shape 68 | assert nrows == ncols 69 | cdef unsigned int n = nrows 70 | cdef unsigned int max_dist_copy = max_dist 71 | 72 | path_copy = path.astype(long, order='C', casting='safe', copy=True) 73 | edge_feat_copy = edge_feat.astype(long, order='C', casting='safe', copy=True) 74 | assert path_copy.flags['C_CONTIGUOUS'] 75 | assert edge_feat_copy.flags['C_CONTIGUOUS'] 76 | 77 | cdef numpy.ndarray[long, ndim=4, mode='c'] edge_fea_all = -1 * numpy.ones([n, n, max_dist_copy, edge_feat.shape[-1]], dtype=numpy.int64) 78 | cdef unsigned int i, j, k, num_path, cur 79 | 80 | for i in range(n): 81 | for j in range(n): 82 | if i == j: 83 | continue 84 | if path_copy[i][j] == 510: 85 | continue 86 | path = [i] + get_all_edges(path_copy, i, j) + [j] 87 | num_path = len(path) - 1 88 | for k in range(num_path): 89 | edge_fea_all[i, j, k, :] = edge_feat_copy[path[k], path[k+1], :] 90 | 91 | return edge_fea_all 92 | -------------------------------------------------------------------------------- /deepgraph/data/substructure_dataset_utils/neighbor_extractors.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_sparse import SparseTensor # for propagation 3 | 4 | def k_hop_subgraph(edge_index, num_nodes, num_hops): 5 | # return k-hop subgraphs for all nodes in the graph 6 | row, col = edge_index 7 | sparse_adj = SparseTensor(row=row, col=col, sparse_sizes=(num_nodes, num_nodes)) 8 | hop_masks = [torch.eye(num_nodes, dtype=torch.bool, device=edge_index.device)] # each one contains <= i hop masks 9 | hop_indicator = row.new_full((num_nodes, num_nodes), -1) 10 | hop_indicator[hop_masks[0]] = 0 11 | for i in range(num_hops): 12 | next_mask = sparse_adj.matmul(hop_masks[i].float()) > 0 13 | hop_masks.append(next_mask) 14 | hop_indicator[(hop_indicator==-1) & next_mask] = i+1 15 | hop_indicator = hop_indicator.T # N x N 16 | node_mask = (hop_indicator >= 0) # N x N dense mask matrix 17 | return node_mask, hop_indicator 18 | 19 | 20 | from torch_cluster import random_walk 21 | def random_walk_subgraph(edge_index, num_nodes, walk_length, p=1, q=1, repeat=1, cal_hops=True, max_hops=10): 22 | """ 23 | p (float, optional): Likelihood of immediately revisiting a node in the 24 | walk. (default: :obj:`1`) Setting it to a high value (> max(q, 1)) ensures 25 | that we are less likely to sample an already visited node in the following two steps. 26 | q (float, optional): Control parameter to interpolate between 27 | breadth-first strategy and depth-first strategy (default: :obj:`1`) 28 | if q > 1, the random walk is biased towards nodes close to node t. 29 | if q < 1, the walk is more inclined to visit nodes which are further away from the node t. 30 | p, q ∈ {0.25, 0.50, 1, 2, 4}. 31 | Typical values: 32 | Fix p and tune q 33 | 34 | repeat: restart the random walk many times and combine together for the result 35 | 36 | """ 37 | row, col = edge_index 38 | start = torch.arange(num_nodes, device=edge_index.device) 39 | walks = [random_walk(row, col, 40 | start=start, 41 | walk_length=walk_length, 42 | p=p, q=q, 43 | num_nodes=num_nodes) for _ in range(repeat)] 44 | walk = torch.cat(walks, dim=-1) 45 | node_mask = row.new_empty((num_nodes, num_nodes), dtype=torch.bool) 46 | # print(walk.shape) 47 | node_mask.fill_(False) 48 | node_mask[start.repeat_interleave((walk_length+1)*repeat), walk.reshape(-1)] = True 49 | if cal_hops: # this is fast enough 50 | sparse_adj = SparseTensor(row=row, col=col, sparse_sizes=(num_nodes, num_nodes)) 51 | hop_masks = [torch.eye(num_nodes, dtype=torch.bool, device=edge_index.device)] 52 | hop_indicator = row.new_full((num_nodes, num_nodes), -1) 53 | hop_indicator[hop_masks[0]] = 0 54 | for i in range(max_hops): 55 | next_mask = sparse_adj.matmul(hop_masks[i].float())>0 56 | hop_masks.append(next_mask) 57 | hop_indicator[(hop_indicator==-1) & next_mask] = i+1 58 | if hop_indicator[node_mask].min() != -1: 59 | break 60 | return node_mask, hop_indicator 61 | return node_mask, None 62 | 63 | 64 | -------------------------------------------------------------------------------- /deepgraph/criterions/l1_loss.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | from fairseq.dataclass.configs import FairseqDataclass 5 | 6 | import torch 7 | import torch.nn as nn 8 | from fairseq.logging import metrics 9 | # from fairseq import metrics 10 | from fairseq.criterions import FairseqCriterion, register_criterion 11 | 12 | 13 | @register_criterion("l1_loss", dataclass=FairseqDataclass) 14 | class GraphPredictionL1Loss(FairseqCriterion): 15 | """ 16 | Implementation for the L1 loss (MAE loss) used in graphormer model training. 17 | """ 18 | 19 | def forward(self, model, sample, reduce=True): 20 | """Compute the loss for the given sample. 21 | 22 | Returns a tuple with three elements: 23 | 1) the loss 24 | 2) the sample size, which is used as the denominator for the gradient 25 | 3) logging outputs to display while training 26 | """ 27 | sample_size = sample["nsamples"] 28 | 29 | with torch.no_grad(): 30 | natoms = sample["net_input"]["batched_data"]["x"].shape[1] 31 | 32 | logits = model(**sample["net_input"]) 33 | logits = logits[:, 0, :] 34 | targets = model.get_targets(sample, [logits]) 35 | 36 | loss = nn.L1Loss(reduction="sum")(logits, targets[: logits.size(0)]) 37 | 38 | logging_output = { 39 | "loss": loss.data, 40 | "sample_size": logits.size(0), 41 | "nsentences": sample_size, 42 | "ntokens": natoms, 43 | } 44 | return loss, sample_size, logging_output 45 | 46 | @staticmethod 47 | def reduce_metrics(logging_outputs) -> None: 48 | """Aggregate logging outputs from data parallel training.""" 49 | loss_sum = sum(log.get("loss", 0) for log in logging_outputs) 50 | sample_size = sum(log.get("sample_size", 0) for log in logging_outputs) 51 | 52 | metrics.log_scalar("loss", loss_sum / sample_size, sample_size, round=6) 53 | 54 | @staticmethod 55 | def logging_outputs_can_be_summed() -> bool: 56 | """ 57 | Whether the logging outputs returned by `forward` can be summed 58 | across workers prior to calling `reduce_metrics`. Setting this 59 | to True will improves distributed training speed. 60 | """ 61 | return True 62 | 63 | 64 | @register_criterion("l1_loss_with_flag", dataclass=FairseqDataclass) 65 | class GraphPredictionL1LossWithFlag(GraphPredictionL1Loss): 66 | """ 67 | Implementation for the binary log loss used in graphormer model training. 68 | """ 69 | 70 | def perturb_forward(self, model, sample, perturb, reduce=True): 71 | """Compute the loss for the given sample. 72 | 73 | Returns a tuple with three elements: 74 | 1) the loss 75 | 2) the sample size, which is used as the denominator for the gradient 76 | 3) logging outputs to display while training 77 | """ 78 | sample_size = sample["nsamples"] 79 | 80 | batch_data = sample["net_input"]["batched_data"]["x"] 81 | with torch.no_grad(): 82 | natoms = batch_data.shape[1] 83 | logits = model(**sample["net_input"], perturb=perturb)[:, 0, :] 84 | targets = model.get_targets(sample, [logits]) 85 | loss = nn.L1Loss(reduction="sum")(logits, targets[: logits.size(0)]) 86 | 87 | logging_output = { 88 | "loss": loss.data, 89 | "sample_size": logits.size(0), 90 | "nsentences": sample_size, 91 | "ntokens": natoms, 92 | } 93 | return loss, sample_size, logging_output 94 | -------------------------------------------------------------------------------- /deepgraph/evaluate/cache_data.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | from fairseq import checkpoint_utils, utils, options, tasks 4 | from fairseq.logging import progress_bar 5 | from fairseq.dataclass.utils import convert_namespace_to_omegaconf 6 | 7 | import os 8 | 9 | import sys 10 | from os import path 11 | 12 | from multiprocessing import Pool 13 | 14 | sys.path.append( path.dirname( path.dirname( path.abspath(__file__) ) ) ) 15 | 16 | import logging 17 | import lmdb 18 | 19 | logging.root.setLevel(logging.INFO) 20 | logging.basicConfig(level=logging.INFO) 21 | 22 | def f(x): 23 | return x * x 24 | 25 | if __name__ == '__main__': 26 | with Pool(5) as p: 27 | print(p.map(f, [1, 2, 3])) 28 | 29 | 30 | 31 | parser = options.get_training_parser() 32 | 33 | res_lis = [] 34 | args = options.parse_args_and_arch(parser, modify_parser=None) 35 | logger = logging.getLogger(__name__) 36 | 37 | checkpoint_path = None 38 | logger.info(f"evaluating checkpoint file {checkpoint_path}") 39 | 40 | 41 | 42 | 43 | cfg = convert_namespace_to_omegaconf(args) 44 | np.random.seed(cfg.common.seed) 45 | utils.set_torch_seed(cfg.common.seed) 46 | 47 | # initialize task 48 | task = tasks.setup_task(cfg.task) 49 | model = task.build_model(cfg.model) 50 | 51 | 52 | split='inner' 53 | task.load_dataset(split) 54 | batch_iterator = task.get_batch_iterator( 55 | dataset=task.dataset(split), 56 | max_tokens=cfg.dataset.max_tokens_valid, 57 | max_sentences=cfg.dataset.batch_size_valid, 58 | max_positions=utils.resolve_max_positions( 59 | task.max_positions(), 60 | model.max_positions(), 61 | ), 62 | ignore_invalid_inputs=cfg.dataset.skip_invalid_size_inputs_valid_test, 63 | required_batch_size_multiple=cfg.dataset.required_batch_size_multiple, 64 | seed=cfg.common.seed, 65 | num_workers=cfg.dataset.num_workers, 66 | epoch=0, 67 | data_buffer_size=cfg.dataset.data_buffer_size, 68 | disable_iterator_cache=False, 69 | ) 70 | itr = batch_iterator.next_epoch_itr( 71 | shuffle=False, set_dataset_epoch=False 72 | ) 73 | progress = progress_bar.progress_bar( 74 | itr, 75 | log_format=cfg.common.log_format, 76 | log_interval=cfg.common.log_interval, 77 | default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple") 78 | ) 79 | 80 | 81 | path_lmdb = os.path.join(task.dm.dataset.args['data_dir'],task.dm.dataset.args['lmdb_root_dir'],task.dm.dataset.args['transform_cache_path']+ task.dm.dataset.args['transform_dir']) 82 | print('saving to ', path_lmdb) 83 | 84 | if not os.path.exists(path_lmdb): 85 | os.makedirs(path_lmdb) 86 | lmdb_env = lmdb.open(path_lmdb, map_size=1e12) 87 | txn = lmdb_env.begin(write=True) 88 | 89 | 90 | 91 | for data in progress: 92 | 93 | idx=data["net_input"]["batched_data"]['idx'] 94 | subgraph_tensor_res=data["net_input"]["batched_data"]['subgraph_tensor_res'] 95 | sorted_adj_res=data["net_input"]["batched_data"]['sorted_adj_res'] 96 | 97 | for i in range(len(idx)): 98 | id_cur=idx[i] 99 | subgraph_tensor = subgraph_tensor_res[i] 100 | sorted_adj = sorted_adj_res[i] 101 | assert len(subgraph_tensor)==args.transform_cache_number 102 | if len(subgraph_tensor)>1: 103 | assert subgraph_tensor[0] != subgraph_tensor[1] 104 | for sample_id in range(args.transform_cache_number): 105 | result = [subgraph_tensor[sample_id], sorted_adj[sample_id]] 106 | txn.put(key=(str(id_cur) + '_' + str(sample_id)).encode(), value=str(result).encode()) 107 | 108 | 109 | print('********************************commit substructure sampling caching **************************') 110 | txn.commit() 111 | lmdb_env.close() 112 | 113 | 114 | 115 | 116 | -------------------------------------------------------------------------------- /deepgraph/data/subsampling/sampler.py: -------------------------------------------------------------------------------- 1 | from torch_geometric.data import Data 2 | from deepgraph.data.subsampling.sampling import * 3 | 4 | import re 5 | 6 | 7 | 8 | 9 | class Subgraphs_Sampler(object): 10 | def __init__(self, 11 | sampling_mode=None, 12 | minimum_redundancy=2, 13 | shortest_path_mode_stride=2, 14 | random_mode_sampling_rate=0.3, 15 | random_init=False, 16 | only_unused_nodes=True): 17 | super().__init__() 18 | 19 | self.random_init = random_init 20 | self.sampling_mode = sampling_mode 21 | self.minimum_redundancy = minimum_redundancy 22 | self.shortest_path_mode_stride = shortest_path_mode_stride 23 | self.random_mode_sampling_rate = random_mode_sampling_rate 24 | self.only_unused_nodes=only_unused_nodes 25 | 26 | 27 | def __call__(self, edge_index,subgraphs_nodes_tensor,num_nodes,must_select_list): 28 | 29 | if subgraphs_nodes_tensor.shape[1]>0: 30 | selected_subgraphs, node_selected_times = subsampling_subgraphs_general(edge_index, 31 | subgraphs_nodes_tensor, 32 | num_nodes, 33 | sampling_mode=self.sampling_mode, 34 | random_init=self.random_init, 35 | minimum_redundancy=self.minimum_redundancy, 36 | shortest_path_mode_stride=self.shortest_path_mode_stride, 37 | random_mode_sampling_rate=self.random_mode_sampling_rate, 38 | must_select_list=must_select_list,only_unused_nodes=self.only_unused_nodes) 39 | else: 40 | selected_subgraphs=[] 41 | node_selected_times=None 42 | 43 | return selected_subgraphs,node_selected_times 44 | 45 | 46 | 47 | 48 | def subsampling_subgraphs_general(edge_index, subgraphs_nodes, num_nodes=None, sampling_mode='shortest_path', random_init=False, minimum_redundancy=0, 49 | shortest_path_mode_stride=2, random_mode_sampling_rate=0.5,must_select_list=None,only_unused_nodes=None): 50 | 51 | assert sampling_mode in ['shortest_path', 'random', 'min_set_cover','min_set_cover_random'] 52 | if sampling_mode == 'random': 53 | selected_subgraphs, node_selected_times = random_sampling_general(subgraphs_nodes, num_nodes=num_nodes, rate=random_mode_sampling_rate, minimum_redundancy=minimum_redundancy,must_select_list=must_select_list) 54 | if sampling_mode == 'shortest_path': 55 | selected_subgraphs, node_selected_times = shortest_path_sampling_general(edge_index, subgraphs_nodes, num_nodes=num_nodes, minimum_redundancy=minimum_redundancy, 56 | stride=max(1, shortest_path_mode_stride), random_init=random_init) 57 | if sampling_mode in ['min_set_cover']: 58 | 59 | selected_subgraphs, node_selected_times = min_set_cover_sampling_general(subgraphs_nodes, 60 | minimum_redundancy=minimum_redundancy, random_init=random_init,num_nodes=num_nodes,must_select_list=must_select_list,only_unused_nodes=only_unused_nodes) 61 | 62 | if sampling_mode in ['min_set_cover_random']: 63 | 64 | selected_subgraphs, node_selected_times = min_set_cover_random_sampling_general(subgraphs_nodes, 65 | minimum_redundancy=minimum_redundancy, random_init=random_init,num_nodes=num_nodes,must_select_list=must_select_list,only_unused_nodes=only_unused_nodes) 66 | else: 67 | raise ValueError('not supported sampling mode') 68 | 69 | return selected_subgraphs, node_selected_times 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | -------------------------------------------------------------------------------- /deepgraph/criterions/multiclass_cross_entropy.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | from fairseq.dataclass.configs import FairseqDataclass 5 | 6 | import torch 7 | from torch.nn import functional 8 | from fairseq.logging import metrics 9 | # from fairseq import metrics 10 | from fairseq.criterions import FairseqCriterion, register_criterion 11 | 12 | 13 | @register_criterion("multiclass_cross_entropy", dataclass=FairseqDataclass) 14 | class GraphPredictionMulticlassCrossEntropy(FairseqCriterion): 15 | """ 16 | Implementation for the multi-class log loss used in graphormer model training. 17 | """ 18 | 19 | def forward(self, model, sample, reduce=True): 20 | """Compute the loss for the given sample. 21 | 22 | Returns a tuple with three elements: 23 | 1) the loss 24 | 2) the sample size, which is used as the denominator for the gradient 25 | 3) logging outputs to display while training 26 | """ 27 | sample_size = sample["nsamples"] 28 | 29 | with torch.no_grad(): 30 | natoms = sample["net_input"]["batched_data"]["x"].shape[1] 31 | 32 | logits = model(**sample["net_input"]) 33 | logits = logits[:, 0, :] 34 | targets = model.get_targets(sample, [logits])[: logits.size(0)] 35 | ncorrect = (torch.argmax(logits, dim=-1).reshape(-1) == targets.reshape(-1)).sum() 36 | 37 | loss = functional.cross_entropy( 38 | logits, targets.reshape(-1), reduction="sum" 39 | ) 40 | 41 | logging_output = { 42 | "loss": loss.data, 43 | "sample_size": sample_size, 44 | "nsentences": sample_size, 45 | "ntokens": natoms, 46 | "ncorrect": ncorrect, 47 | } 48 | return loss, sample_size, logging_output 49 | 50 | @staticmethod 51 | def reduce_metrics(logging_outputs) -> None: 52 | """Aggregate logging outputs from data parallel training.""" 53 | loss_sum = sum(log.get("loss", 0) for log in logging_outputs) 54 | sample_size = sum(log.get("sample_size", 0) for log in logging_outputs) 55 | 56 | metrics.log_scalar("loss", loss_sum / sample_size, sample_size, round=3) 57 | if len(logging_outputs) > 0 and "ncorrect" in logging_outputs[0]: 58 | ncorrect = sum(log.get("ncorrect", 0) for log in logging_outputs) 59 | metrics.log_scalar( 60 | "accuracy", 100.0 * ncorrect / sample_size, sample_size, round=1 61 | ) 62 | 63 | @staticmethod 64 | def logging_outputs_can_be_summed() -> bool: 65 | """ 66 | Whether the logging outputs returned by `forward` can be summed 67 | across workers prior to calling `reduce_metrics`. Setting this 68 | to True will improves distributed training speed. 69 | """ 70 | return True 71 | 72 | 73 | @register_criterion("multiclass_cross_entropy_with_flag", dataclass=FairseqDataclass) 74 | class GraphPredictionMulticlassCrossEntropyWithFlag(GraphPredictionMulticlassCrossEntropy): 75 | """ 76 | Implementation for the multi-class log loss used in graphormer model training. 77 | """ 78 | 79 | def forward(self, model, sample, reduce=True): 80 | """Compute the loss for the given sample. 81 | 82 | Returns a tuple with three elements: 83 | 1) the loss 84 | 2) the sample size, which is used as the denominator for the gradient 85 | 3) logging outputs to display while training 86 | """ 87 | sample_size = sample["nsamples"] 88 | perturb = sample.get("perturb", None) 89 | 90 | with torch.no_grad(): 91 | natoms = sample["net_input"]["batched_data"]["x"].shape[1] 92 | 93 | logits = model(**sample["net_input"], perturb=perturb) 94 | logits = logits[:, 0, :] 95 | targets = model.get_targets(sample, [logits])[: logits.size(0)] 96 | ncorrect = (torch.argmax(logits, dim=-1).reshape(-1) == targets.reshape(-1)).sum() 97 | 98 | loss = functional.cross_entropy( 99 | logits, targets.reshape(-1), reduction="sum" 100 | ) 101 | 102 | logging_output = { 103 | "loss": loss.data, 104 | "sample_size": sample_size, 105 | "nsentences": sample_size, 106 | "ntokens": natoms, 107 | "ncorrect": ncorrect, 108 | } 109 | return loss, sample_size, logging_output 110 | -------------------------------------------------------------------------------- /deepgraph/criterions/multilabel_multiclass_cross_entropy.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | from fairseq.dataclass.configs import FairseqDataclass 5 | 6 | import torch 7 | from torch.nn import functional 8 | from fairseq.logging import metrics 9 | # from fairseq import metrics 10 | from fairseq.criterions import FairseqCriterion, register_criterion 11 | 12 | 13 | @register_criterion("multilabel_multiclass_cross_entropy", dataclass=FairseqDataclass) 14 | class GraphPredictionMultiLabelMulticlassCrossEntropy(FairseqCriterion): 15 | """ 16 | Implementation for the multi-class log loss used in graphormer model training. 17 | """ 18 | 19 | def forward(self, model, sample, reduce=True): 20 | """Compute the loss for the given sample. 21 | 22 | Returns a tuple with three elements: 23 | 1) the loss 24 | 2) the sample size, which is used as the denominator for the gradient 25 | 3) logging outputs to display while training 26 | """ 27 | 28 | logits = model(**sample["net_input"]) 29 | 30 | targets = model.get_targets(sample, [logits])[: logits.size(0)] 31 | 32 | 33 | logits=logits.reshape([logits.shape[0]*logits.shape[1],logits.shape[2]]) 34 | targets=targets 35 | ncorrect = (torch.argmax(logits, dim=-1).reshape(-1) == targets.reshape(-1)).sum() 36 | 37 | loss = functional.cross_entropy( 38 | logits, targets.reshape(-1), reduction="sum" 39 | ) 40 | 41 | sample_size = sample["nsamples"] 42 | 43 | with torch.no_grad(): 44 | natoms = targets.shape[0]*targets.shape[1] 45 | 46 | 47 | logging_output = { 48 | "loss": float(loss.data), 49 | "sample_size": natoms, 50 | "nsentences": sample_size, 51 | "ntokens": natoms, 52 | "ncorrect": ncorrect, 53 | } 54 | return loss, natoms, logging_output 55 | 56 | @staticmethod 57 | def reduce_metrics(logging_outputs) -> None: 58 | """Aggregate logging outputs from data parallel training.""" 59 | loss_sum = sum(log.get("loss", 0) for log in logging_outputs) 60 | sample_size = sum(log.get("ntokens", 0) for log in logging_outputs) 61 | 62 | metrics.log_scalar("loss", loss_sum / sample_size, sample_size, round=3) 63 | if len(logging_outputs) > 0 and "ncorrect" in logging_outputs[0]: 64 | ncorrect = sum(log.get("ncorrect", 0) for log in logging_outputs) 65 | metrics.log_scalar( 66 | "accuracy", 100.0 * ncorrect / sample_size, sample_size, round=1 67 | ) 68 | 69 | @staticmethod 70 | def logging_outputs_can_be_summed() -> bool: 71 | """ 72 | Whether the logging outputs returned by `forward` can be summed 73 | across workers prior to calling `reduce_metrics`. Setting this 74 | to True will improves distributed training speed. 75 | """ 76 | return True 77 | 78 | 79 | @register_criterion("multilabel_multiclass_cross_entropy_with_flag", dataclass=FairseqDataclass) 80 | class GraphPredictionMultiLabelMulticlassCrossEntropyWithFlag(GraphPredictionMultiLabelMulticlassCrossEntropy): 81 | """ 82 | Implementation for the multi-class log loss used in graphormer model training. 83 | """ 84 | 85 | def forward(self, model, sample, reduce=True): 86 | """Compute the loss for the given sample. 87 | 88 | Returns a tuple with three elements: 89 | 1) the loss 90 | 2) the sample size, which is used as the denominator for the gradient 91 | 3) logging outputs to display while training 92 | """ 93 | sample_size = sample["nsamples"] 94 | perturb = sample.get("perturb", None) 95 | 96 | with torch.no_grad(): 97 | natoms = sample["net_input"]["batched_data"]["x"].shape[1] 98 | 99 | logits = model(**sample["net_input"], perturb=perturb) 100 | logits = logits[:, 0, :,:] 101 | targets = model.get_targets(sample, [logits])[: logits.size(0)] 102 | ncorrect = (torch.argmax(logits, dim=-1).reshape(-1) == targets.reshape(-1)).sum() 103 | 104 | loss = functional.cross_entropy( 105 | logits, targets.reshape(-1), reduction="sum" 106 | ) 107 | 108 | logging_output = { 109 | "loss": loss.data, 110 | "sample_size": sample_size, 111 | "nsentences": sample_size, 112 | "ntokens": natoms, 113 | "ncorrect": ncorrect, 114 | } 115 | return loss, sample_size, logging_output 116 | -------------------------------------------------------------------------------- /deepgraph/criterions/binary_logloss.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | from fairseq.dataclass.configs import FairseqDataclass 5 | 6 | import torch 7 | from torch.nn import functional 8 | from fairseq.logging import metrics 9 | from fairseq import utils 10 | # from fairseq import metrics, utils 11 | from fairseq.criterions import FairseqCriterion, register_criterion 12 | 13 | 14 | @register_criterion("binary_logloss", dataclass=FairseqDataclass) 15 | class GraphPredictionBinaryLogLoss(FairseqCriterion): 16 | """ 17 | Implementation for the binary log loss used in graphormer model training. 18 | """ 19 | 20 | def forward(self, model, sample, reduce=True): 21 | """Compute the loss for the given sample. 22 | 23 | Returns a tuple with three elements: 24 | 1) the loss 25 | 2) the sample size, which is used as the denominator for the gradient 26 | 3) logging outputs to display while training 27 | """ 28 | sample_size = sample["nsamples"] 29 | 30 | with torch.no_grad(): 31 | natoms = sample["net_input"]["batched_data"]["x"].shape[1] 32 | 33 | logits = model(**sample["net_input"]) 34 | logits = logits[:, 0, :] 35 | targets = model.get_targets(sample, [logits]) 36 | preds = torch.where(torch.sigmoid(logits) < 0.5, 0, 1) 37 | 38 | logits_flatten = logits.reshape(-1) 39 | targets_flatten = targets[: logits.size(0)].reshape(-1) 40 | mask = ~torch.isnan(targets_flatten) 41 | loss = functional.binary_cross_entropy_with_logits( 42 | logits_flatten[mask].float(), targets_flatten[mask].float(), reduction="sum" 43 | ) 44 | 45 | logging_output = { 46 | "loss": loss.data, 47 | "sample_size": torch.sum(mask.type(torch.int64)), 48 | "nsentences": sample_size, 49 | "ntokens": natoms, 50 | "ncorrect": (preds == targets[:preds.size(0)]).sum(), 51 | } 52 | return loss, sample_size, logging_output 53 | 54 | @staticmethod 55 | def reduce_metrics(logging_outputs) -> None: 56 | """Aggregate logging outputs from data parallel training.""" 57 | loss_sum = sum(log.get("loss", 0) for log in logging_outputs) 58 | sample_size = sum(log.get("sample_size", 0) for log in logging_outputs) 59 | 60 | metrics.log_scalar("loss", loss_sum / sample_size, sample_size, round=3) 61 | if len(logging_outputs) > 0 and "ncorrect" in logging_outputs[0]: 62 | ncorrect = sum(log.get("ncorrect", 0) for log in logging_outputs) 63 | metrics.log_scalar( 64 | "accuracy", 100.0 * ncorrect / sample_size, sample_size, round=1 65 | ) 66 | 67 | @staticmethod 68 | def logging_outputs_can_be_summed() -> bool: 69 | """ 70 | Whether the logging outputs returned by `forward` can be summed 71 | across workers prior to calling `reduce_metrics`. Setting this 72 | to True will improves distributed training speed. 73 | """ 74 | return True 75 | 76 | 77 | @register_criterion("binary_logloss_with_flag", dataclass=FairseqDataclass) 78 | class GraphPredictionBinaryLogLossWithFlag(GraphPredictionBinaryLogLoss): 79 | """ 80 | Implementation for the binary log loss used in graphormer model training. 81 | """ 82 | 83 | def forward(self, model, sample, reduce=True): 84 | """Compute the loss for the given sample. 85 | 86 | Returns a tuple with three elements: 87 | 1) the loss 88 | 2) the sample size, which is used as the denominator for the gradient 89 | 3) logging outputs to display while training 90 | """ 91 | sample_size = sample["nsamples"] 92 | perturb = sample.get("perturb", None) 93 | 94 | batch_data = sample["net_input"]["batched_data"]["x"] 95 | with torch.no_grad(): 96 | natoms = batch_data.shape[1] 97 | logits = model(**sample["net_input"], perturb=perturb)[:, 0, :] 98 | targets = model.get_targets(sample, [logits]) 99 | preds = torch.where(torch.sigmoid(logits) < 0.5, 0, 1) 100 | 101 | logits_flatten = logits.reshape(-1) 102 | targets_flatten = targets[: logits.size(0)].reshape(-1) 103 | mask = ~torch.isnan(targets_flatten) 104 | loss = functional.binary_cross_entropy_with_logits( 105 | logits_flatten[mask].float(), targets_flatten[mask].float(), reduction="sum" 106 | ) 107 | 108 | logging_output = { 109 | "loss": loss.data, 110 | "sample_size": logits.size(0), 111 | "nsentences": sample_size, 112 | "ntokens": natoms, 113 | "ncorrect": (preds == targets[:preds.size(0)]).sum(), 114 | } 115 | return loss, sample_size, logging_output 116 | -------------------------------------------------------------------------------- /deepgraph/criterions/node_multiclass_cross_entropy.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | from fairseq.dataclass.configs import FairseqDataclass 5 | 6 | import torch 7 | from torch.nn import functional 8 | from fairseq.logging import metrics 9 | # from fairseq import metrics 10 | from fairseq.criterions import FairseqCriterion, register_criterion 11 | 12 | 13 | @register_criterion("node_multiclass_cross_entropy", dataclass=FairseqDataclass) 14 | class GraphNodePredictionMulticlassCrossEntropy(FairseqCriterion): 15 | """ 16 | Implementation for the multi-class log loss used in graphormer model training. 17 | """ 18 | 19 | def forward(self, model, sample, reduce=True): 20 | """Compute the loss for the given sample. 21 | 22 | Returns a tuple with three elements: 23 | 1) the loss 24 | 2) the sample size, which is used as the denominator for the gradient 25 | 3) logging outputs to display while training 26 | """ 27 | 28 | logits = model(**sample["net_input"]) 29 | # logits = logits[:, 0, :] 30 | logits = logits[:, 1:, :] 31 | targets = model.get_targets(sample, [logits])[: logits.size(0)] 32 | 33 | if 'sub_adj_mask' in sample['net_input']['batched_data']: 34 | node_ind=(sample['net_input']['batched_data']['sub_adj_mask'][:,:,0]==1) 35 | node_ind_y=node_ind[:,0:targets.shape[1]] 36 | else: 37 | node_ind=~(sample['net_input']['batched_data']['attn_bias'][:,1,:].isinf())[:,1:] 38 | node_ind_y=node_ind 39 | logits_node=logits[node_ind] 40 | 41 | targets_node=targets[node_ind_y]-1 42 | ncorrect = (torch.argmax(logits_node, dim=-1).reshape(-1) == targets_node.reshape(-1)).sum() 43 | 44 | loss = functional.cross_entropy( 45 | logits_node, targets_node.reshape(-1), reduction="sum" 46 | ) 47 | 48 | 49 | sample_size = sample["nsamples"] 50 | 51 | with torch.no_grad(): 52 | natoms = int(node_ind_y.sum()) 53 | 54 | 55 | logging_output = { 56 | "loss": float(loss.data), 57 | "sample_size": natoms, 58 | "nsentences": sample_size, 59 | "ntokens": natoms, 60 | "ncorrect": ncorrect, 61 | } 62 | return loss, natoms, logging_output 63 | 64 | @staticmethod 65 | def reduce_metrics(logging_outputs) -> None: 66 | """Aggregate logging outputs from data parallel training.""" 67 | loss_sum = sum(log.get("loss", 0) for log in logging_outputs) 68 | sample_size = sum(log.get("ntokens", 0) for log in logging_outputs) 69 | 70 | metrics.log_scalar("loss", loss_sum / sample_size, sample_size, round=3) 71 | if len(logging_outputs) > 0 and "ncorrect" in logging_outputs[0]: 72 | ncorrect = sum(log.get("ncorrect", 0) for log in logging_outputs) 73 | metrics.log_scalar( 74 | "accuracy", 100.0 * ncorrect / sample_size, sample_size, round=1 75 | ) 76 | 77 | @staticmethod 78 | def logging_outputs_can_be_summed() -> bool: 79 | """ 80 | Whether the logging outputs returned by `forward` can be summed 81 | across workers prior to calling `reduce_metrics`. Setting this 82 | to True will improves distributed training speed. 83 | """ 84 | return True 85 | 86 | 87 | @register_criterion("node_multiclass_cross_entropy_with_flag", dataclass=FairseqDataclass) 88 | class GraphNodePredictionMulticlassCrossEntropyWithFlag(GraphNodePredictionMulticlassCrossEntropy): 89 | """ 90 | Implementation for the multi-class log loss used in graphormer model training. 91 | """ 92 | 93 | def forward(self, model, sample, reduce=True): 94 | """Compute the loss for the given sample. 95 | 96 | Returns a tuple with three elements: 97 | 1) the loss 98 | 2) the sample size, which is used as the denominator for the gradient 99 | 3) logging outputs to display while training 100 | """ 101 | sample_size = sample["nsamples"] 102 | perturb = sample.get("perturb", None) 103 | 104 | with torch.no_grad(): 105 | natoms = sample["net_input"]["batched_data"]["x"].shape[1] 106 | 107 | logits = model(**sample["net_input"], perturb=perturb) 108 | logits = logits[:, 0, :] 109 | targets = model.get_targets(sample, [logits])[: logits.size(0)] 110 | ncorrect = (torch.argmax(logits, dim=-1).reshape(-1) == targets.reshape(-1)).sum() 111 | 112 | loss = functional.cross_entropy( 113 | logits, targets.reshape(-1), reduction="sum" 114 | ) 115 | 116 | logging_output = { 117 | "loss": loss.data, 118 | "sample_size": sample_size, 119 | "nsentences": sample_size, 120 | "ntokens": natoms, 121 | "ncorrect": ncorrect, 122 | } 123 | return loss, sample_size, logging_output 124 | -------------------------------------------------------------------------------- /deepgraph/criterions/mae_deltapos.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | from typing import Callable, Mapping, Sequence, Tuple 5 | from numpy import mod 6 | import torch 7 | from torch import Tensor 8 | import torch.nn.functional as F 9 | 10 | 11 | from fairseq.logging import metrics 12 | # from fairseq import metrics 13 | from fairseq.criterions import FairseqCriterion, register_criterion 14 | 15 | 16 | @register_criterion("mae_deltapos") 17 | class IS2RECriterion(FairseqCriterion): 18 | e_thresh = 0.02 19 | e_mean = -1.4729953244844094 20 | e_std = 2.2707848125378405 21 | d_mean = [0.1353900283575058, 0.06877671927213669, 0.08111362904310226] 22 | d_std = [1.7862379550933838, 1.78688645362854, 0.8023099899291992] 23 | 24 | def __init__(self, task, cfg): 25 | super().__init__(task) 26 | self.node_loss_weight = cfg.node_loss_weight 27 | self.min_node_loss_weight = cfg.min_node_loss_weight 28 | self.max_update = cfg.max_update 29 | self.node_loss_weight_range = max( 30 | 0, self.node_loss_weight - self.min_node_loss_weight 31 | ) 32 | 33 | def forward( 34 | self, 35 | model: Callable[..., Tuple[Tensor, Tensor, Tensor]], 36 | sample: Mapping[str, Mapping[str, Tensor]], 37 | reduce=True, 38 | ): 39 | update_num = model.num_updates 40 | assert update_num >= 0 41 | node_loss_weight = ( 42 | self.node_loss_weight 43 | - self.node_loss_weight_range * update_num / self.max_update 44 | ) 45 | 46 | valid_nodes = sample["net_input"]["atoms"].ne(0).sum() 47 | output, node_output, node_target_mask = model( 48 | **sample["net_input"], 49 | ) 50 | 51 | relaxed_energy = sample["targets"]["relaxed_energy"] 52 | relaxed_energy = relaxed_energy.float() 53 | relaxed_energy = (relaxed_energy - self.e_mean) / self.e_std 54 | sample_size = relaxed_energy.numel() 55 | loss = F.l1_loss(output.float().view(-1), relaxed_energy, reduction="none") 56 | with torch.no_grad(): 57 | energy_within_threshold = (loss.detach() * self.e_std < self.e_thresh).sum() 58 | loss = loss.sum() 59 | 60 | deltapos = sample["targets"]["deltapos"].float() 61 | deltapos = (deltapos - deltapos.new_tensor(self.d_mean)) / deltapos.new_tensor( 62 | self.d_std 63 | ) 64 | deltapos *= node_target_mask 65 | node_output *= node_target_mask 66 | target_cnt = node_target_mask.sum(dim=[1, 2]) 67 | node_loss = ( 68 | F.l1_loss(node_output.float(), deltapos, reduction="none") 69 | .mean(dim=-1) 70 | .sum(dim=-1) 71 | / target_cnt 72 | ).sum() 73 | 74 | logging_output = { 75 | "loss": loss.detach(), 76 | "energy_within_threshold": energy_within_threshold, 77 | "node_loss": node_loss.detach(), 78 | "sample_size": sample_size, 79 | "nsentences": sample_size, 80 | "num_nodes": valid_nodes.detach(), 81 | "node_loss_weight": node_loss_weight * sample_size, 82 | } 83 | return loss + node_loss_weight * node_loss, sample_size, logging_output 84 | 85 | @staticmethod 86 | def reduce_metrics(logging_outputs: Sequence[Mapping]) -> None: 87 | loss_sum = sum(log.get("loss", 0) for log in logging_outputs) 88 | energy_within_threshold_sum = sum( 89 | log.get("energy_within_threshold", 0) for log in logging_outputs 90 | ) 91 | node_loss_sum = sum(log.get("node_loss", 0) for log in logging_outputs) 92 | sample_size = sum(log.get("sample_size", 0) for log in logging_outputs) 93 | 94 | mean_loss = (loss_sum / sample_size) * IS2RECriterion.e_std 95 | energy_within_threshold = energy_within_threshold_sum / sample_size 96 | mean_node_loss = (node_loss_sum / sample_size) * sum(IS2RECriterion.d_std) / 3.0 97 | mean_n_nodes = ( 98 | sum([log.get("num_nodes", 0) for log in logging_outputs]) / sample_size 99 | ) 100 | node_loss_weight = ( 101 | sum([log.get("node_loss_weight", 0) for log in logging_outputs]) 102 | / sample_size 103 | ) 104 | 105 | metrics.log_scalar("loss", mean_loss, sample_size, round=6) 106 | metrics.log_scalar("ewth", energy_within_threshold, sample_size, round=6) 107 | metrics.log_scalar("node_loss", mean_node_loss, sample_size, round=6) 108 | metrics.log_scalar("nodes_per_graph", mean_n_nodes, sample_size, round=6) 109 | metrics.log_scalar("node_loss_weight", node_loss_weight, sample_size, round=6) 110 | 111 | @staticmethod 112 | def logging_outputs_can_be_summed() -> bool: 113 | """ 114 | Whether the logging outputs returned by `forward` can be summed 115 | across workers prior to calling `reduce_metrics`. Setting this 116 | to True will improves distributed training speed. 117 | """ 118 | return True 119 | -------------------------------------------------------------------------------- /deepgraph/evaluate/evaluate.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import numpy as np 4 | from fairseq import checkpoint_utils, utils, options, tasks 5 | from fairseq.logging import progress_bar 6 | from fairseq.dataclass.utils import convert_namespace_to_omegaconf 7 | import ogb 8 | import sys 9 | import os 10 | from pathlib import Path 11 | from sklearn.metrics import roc_auc_score 12 | 13 | import sys 14 | from os import path 15 | 16 | sys.path.append( path.dirname( path.dirname( path.abspath(__file__) ) ) ) 17 | from pretrain import load_pretrained_model 18 | 19 | import logging 20 | 21 | def eval(args, use_pretrained, checkpoint_path=None, logger=None): 22 | cfg = convert_namespace_to_omegaconf(args) 23 | np.random.seed(cfg.common.seed) 24 | utils.set_torch_seed(cfg.common.seed) 25 | 26 | # initialize task 27 | task = tasks.setup_task(cfg.task) 28 | model = task.build_model(cfg.model) 29 | 30 | # load checkpoint 31 | if use_pretrained: 32 | model_state = load_pretrained_model(cfg.task.pretrained_model_name) 33 | else: 34 | model_state = torch.load(checkpoint_path)["model"] 35 | model.load_state_dict( 36 | model_state, strict=True, model_cfg=cfg.model 37 | ) 38 | del model_state 39 | 40 | model.to(torch.cuda.current_device()) 41 | # load dataset 42 | split = args.split 43 | task.load_dataset(split) 44 | batch_iterator = task.get_batch_iterator( 45 | dataset=task.dataset(split), 46 | max_tokens=cfg.dataset.max_tokens_valid, 47 | max_sentences=cfg.dataset.batch_size_valid, 48 | max_positions=utils.resolve_max_positions( 49 | task.max_positions(), 50 | model.max_positions(), 51 | ), 52 | ignore_invalid_inputs=cfg.dataset.skip_invalid_size_inputs_valid_test, 53 | required_batch_size_multiple=cfg.dataset.required_batch_size_multiple, 54 | seed=cfg.common.seed, 55 | num_workers=cfg.dataset.num_workers, 56 | epoch=0, 57 | data_buffer_size=cfg.dataset.data_buffer_size, 58 | disable_iterator_cache=False, 59 | ) 60 | itr = batch_iterator.next_epoch_itr( 61 | shuffle=False, set_dataset_epoch=False 62 | ) 63 | progress = progress_bar.progress_bar( 64 | itr, 65 | log_format=cfg.common.log_format, 66 | log_interval=cfg.common.log_interval, 67 | default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple") 68 | ) 69 | 70 | # infer 71 | y_pred = [] 72 | y_true = [] 73 | with torch.no_grad(): 74 | model.eval() 75 | for i, sample in enumerate(progress): 76 | sample = utils.move_to_cuda(sample) 77 | if args.criterion=='multilabel_multiclass_cross_entropy': 78 | y = model(**sample["net_input"]).squeeze(0).argmax(1) 79 | y_pred.extend(y.detach().cpu()) 80 | y_true.extend(sample["target"].detach().cpu()) 81 | else: 82 | y = model(**sample["net_input"])[:, 0, :].reshape(-1) 83 | y_pred.extend(y.detach().cpu()) 84 | y_true.extend(sample["target"].detach().cpu().reshape(-1)[:y.shape[0]]) 85 | torch.cuda.empty_cache() 86 | 87 | # save predictions 88 | y_pred = torch.Tensor(y_pred) 89 | y_true = torch.Tensor(y_true) 90 | 91 | # evaluate pretrained models 92 | if use_pretrained: 93 | if cfg.task.pretrained_model_name == "pcqm4mv1_graphormer_base": 94 | evaluator = ogb.lsc.PCQM4MEvaluator() 95 | input_dict = {'y_pred': y_pred, 'y_true': y_true} 96 | result_dict = evaluator.eval(input_dict) 97 | logger.info(f'PCQM4Mv1Evaluator: {result_dict}') 98 | elif cfg.task.pretrained_model_name == "pcqm4mv2_graphormer_base": 99 | evaluator = ogb.lsc.PCQM4Mv2Evaluator() 100 | input_dict = {'y_pred': y_pred, 'y_true': y_true} 101 | result_dict = evaluator.eval(input_dict) 102 | logger.info(f'PCQM4Mv2Evaluator: {result_dict}') 103 | else: 104 | if args.metric == "auc": 105 | auc = roc_auc_score(y_true, y_pred) 106 | logger.info(f"auc: {auc}") 107 | return auc 108 | elif args.metric == "mae": 109 | mae = np.mean(np.abs(y_true - y_pred).numpy()) 110 | logger.info(f"mae: {mae}") 111 | return mae 112 | else: 113 | raise ValueError(f"Unsupported metric {args.metric}") 114 | 115 | def main(): 116 | parser = options.get_training_parser() 117 | parser.add_argument( 118 | "--split", 119 | type=str, 120 | ) 121 | parser.add_argument( 122 | "--metric", 123 | type=str, 124 | ) 125 | res_lis = [] 126 | args = options.parse_args_and_arch(parser, modify_parser=None) 127 | logger = logging.getLogger(__name__) 128 | if args.pretrained_model_name != "none": 129 | eval(args, True, logger=logger) 130 | elif hasattr(args, "save_dir"): 131 | 132 | # for checkpoint_fname in ['checkpoint_best.pt','checkpoint_last.pt']: 133 | for checkpoint_fname in os.listdir(args.save_dir): 134 | checkpoint_path = Path(args.save_dir) / checkpoint_fname 135 | logger.info(f"evaluating checkpoint file {checkpoint_path}") 136 | res=eval(args, False, checkpoint_path, logger) 137 | print('-----------------------------------------------------------------result: ',res) 138 | res_lis.append(res) 139 | print(np.max(res_lis),np.argmax(res_lis),np.min(res_lis),np.argmin(res_lis)) 140 | 141 | 142 | if __name__ == '__main__': 143 | main() 144 | -------------------------------------------------------------------------------- /deepgraph/data/dataset.py: -------------------------------------------------------------------------------- 1 | from functools import lru_cache 2 | 3 | import ogb 4 | import numpy as np 5 | import torch 6 | from torch.nn import functional as F 7 | from fairseq.data import data_utils, FairseqDataset, BaseWrapperDataset 8 | import pkg_resources 9 | from .wrapper import MyPygGraphPropPredDataset 10 | from .collator import collator,collatorcontrust,collator_adj,collator_node_label,collator_adj_cache 11 | 12 | from typing import Optional, Union 13 | from torch_geometric.data import Data as PYGDataset 14 | from dgl.data import DGLDataset 15 | from .pyg_datasets import PYGDatasetLookupTable 16 | from .substructure_dataset import SubstructureDataset 17 | from .ogb_datasets import OGBDatasetLookupTable 18 | 19 | 20 | 21 | class BatchedDataDataset(FairseqDataset): 22 | def __init__( 23 | self, dataset, max_node=128, multi_hop_max_dist=5, spatial_pos_max=1024 24 | ): 25 | super().__init__() 26 | self.dataset = dataset 27 | self.max_node = max_node 28 | self.multi_hop_max_dist = multi_hop_max_dist 29 | self.spatial_pos_max = spatial_pos_max 30 | 31 | def __getitem__(self, index): 32 | item = self.dataset[int(index)] 33 | return item 34 | 35 | def __len__(self): 36 | return len(self.dataset) 37 | 38 | def collater(self, samples): 39 | return collator( 40 | samples, 41 | max_node=self.max_node, 42 | multi_hop_max_dist=self.multi_hop_max_dist, 43 | spatial_pos_max=self.spatial_pos_max, 44 | ) 45 | 46 | class BatchedDataDataset_Substructure(FairseqDataset): 47 | def __init__( 48 | self, dataset, max_node=128, multi_hop_max_dist=5, spatial_pos_max=1024 49 | ): 50 | super().__init__() 51 | self.dataset = dataset 52 | self.max_node = max_node 53 | self.multi_hop_max_dist = multi_hop_max_dist 54 | self.spatial_pos_max = spatial_pos_max 55 | 56 | def __getitem__(self, index): 57 | item = self.dataset[int(index)] 58 | return item 59 | 60 | def __len__(self): 61 | return len(self.dataset) 62 | 63 | def collater(self, samples): 64 | return collator_adj( 65 | samples, 66 | max_node=self.max_node, 67 | multi_hop_max_dist=self.multi_hop_max_dist, 68 | spatial_pos_max=self.spatial_pos_max, 69 | ) 70 | 71 | 72 | class BatchedDataDataset_Substructure_Cache(BatchedDataDataset_Substructure): 73 | def collater(self, samples): 74 | return collator_adj_cache( 75 | samples, 76 | max_node=self.max_node, 77 | multi_hop_max_dist=self.multi_hop_max_dist, 78 | spatial_pos_max=self.spatial_pos_max, 79 | ) 80 | 81 | 82 | class BatchedDataContrustDataset(FairseqDataset): 83 | def __init__( 84 | self, dataset, max_node=128, multi_hop_max_dist=5, spatial_pos_max=1024 85 | ): 86 | super().__init__() 87 | self.dataset = dataset 88 | self.max_node = max_node 89 | self.multi_hop_max_dist = multi_hop_max_dist 90 | self.spatial_pos_max = spatial_pos_max 91 | 92 | def __getitem__(self, index): 93 | item = self.dataset[int(index)] 94 | return item 95 | 96 | def __len__(self): 97 | return len(self.dataset) 98 | 99 | def collater(self, samples): 100 | return collatorcontrust( 101 | samples, 102 | max_node=self.max_node, 103 | multi_hop_max_dist=self.multi_hop_max_dist, 104 | spatial_pos_max=self.spatial_pos_max, 105 | ) 106 | 107 | 108 | class TargetDataset(FairseqDataset): 109 | def __init__(self, dataset,node_level_task=False): 110 | super().__init__() 111 | self.dataset = dataset 112 | self.node_level_task = node_level_task 113 | 114 | @lru_cache(maxsize=16) 115 | def __getitem__(self, index): 116 | return self.dataset[index].y 117 | 118 | def __len__(self): 119 | return len(self.dataset) 120 | 121 | def collater(self, samples): 122 | if self.node_level_task: 123 | return collator_node_label(samples) 124 | else: 125 | return torch.stack(samples, dim=0) 126 | 127 | 128 | class GraphormerDataset: 129 | def __init__( 130 | self, 131 | dataset: Optional[Union[PYGDataset, DGLDataset]] = None, 132 | dataset_spec: Optional[str] = None, 133 | dataset_source: Optional[str] = None, 134 | seed: int = 0, 135 | train_idx = None, 136 | valid_idx = None, 137 | test_idx = None, 138 | args=None, 139 | **kwargs 140 | ): 141 | super().__init__() 142 | if dataset is not None: 143 | if dataset_source == "pyg": 144 | self.dataset = SubstructureDataset(dataset, train_idx=train_idx, valid_idx=valid_idx, test_idx=test_idx) 145 | else: 146 | raise ValueError("customized dataset can only have source pyg") 147 | 148 | elif dataset_source == "pyg": 149 | self.dataset = PYGDatasetLookupTable.GetPYGDataset(dataset_spec, seed=seed,args=args,**kwargs) 150 | elif dataset_source == "ogb": 151 | self.dataset = OGBDatasetLookupTable.GetOGBDataset(dataset_spec, seed=seed,args=args,**kwargs) 152 | else: 153 | raise ValueError("dataset_source not implied") 154 | self.setup() 155 | 156 | def setup(self): 157 | self.train_idx = self.dataset.train_idx 158 | self.valid_idx = self.dataset.valid_idx 159 | self.test_idx = self.dataset.test_idx 160 | 161 | self.dataset_train = self.dataset.train_data 162 | self.dataset_val = self.dataset.valid_data 163 | self.dataset_test = self.dataset.test_data 164 | 165 | 166 | 167 | class EpochShuffleDataset(BaseWrapperDataset): 168 | def __init__(self, dataset, num_samples, seed): 169 | super().__init__(dataset) 170 | self.num_samples = num_samples 171 | self.seed = seed 172 | self.set_epoch(1) 173 | 174 | def set_epoch(self, epoch): 175 | with data_utils.numpy_seed(self.seed + epoch - 1): 176 | self.sort_order = np.random.permutation(self.num_samples) 177 | 178 | def ordered_indices(self): 179 | return self.sort_order 180 | 181 | @property 182 | def can_reuse_epoch_itr_across_epochs(self): 183 | return False 184 | 185 | 186 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Are More Layers Beneficial to Graph Transformers? 2 | 3 | 4 | ## Introduction 5 | 6 | This is the code of our work [Are More Layers Beneficial to Graph Transformers?](https://openreview.net/pdf?id=uagC-X9XMi8) published on ICLR 2023. 7 |
8 | 9 |
10 | 11 | 12 | 13 | 14 | ## Installation 15 | To run DeepGraph, please clone the repository to your local machine and install the required dependencies using the script provided. 16 | #### Note 17 | Please note that we use CUDA 10.2 and python 3.7. If you are using a different version of CUDA or python, please adjust the package version as necessary. 18 | #### Environment 19 | ``` 20 | conda create -n DeepGraph python=3.7 21 | 22 | source activate DeepGraph 23 | 24 | pip3 install torch==1.9.0+cu102 torchaudio torchvision -f https://download.pytorch.org/whl/cu102/torch_stable.html --user 25 | 26 | wget https://data.pyg.org/whl/torch-1.9.0%2Bcu102/torch_sparse-0.6.12-cp37-cp37m-linux_x86_64.whl 27 | wget https://data.pyg.org/whl/torch-1.9.0%2Bcu102/torch_scatter-2.0.9-cp37-cp37m-linux_x86_64.whl 28 | wget https://data.pyg.org/whl/torch-1.9.0%2Bcu102/torch_spline_conv-1.2.1-cp37-cp37m-linux_x86_64.whl 29 | wget https://data.pyg.org/whl/torch-1.9.0%2Bcu102/torch_cluster-1.5.9-cp37-cp37m-linux_x86_64.whl 30 | pip install torch_sparse-0.6.12-cp37-cp37m-linux_x86_64.whl 31 | pip install torch_scatter-2.0.9-cp37-cp37m-linux_x86_64.whl 32 | pip install torch_spline_conv-1.2.1-cp37-cp37m-linux_x86_64.whl 33 | pip install torch_cluster-1.5.9-cp37-cp37m-linux_x86_64.whl 34 | 35 | pip install torch-geometric==1.7.2 36 | 37 | conda install -c conda-forge/label/cf202003 graph-tool 38 | 39 | pip install lmdb 40 | pip install tensorboardX==2.4.1 41 | pip install ogb==1.3.2 42 | pip install rdkit-pypi==2021.9.3 43 | pip install dgl==0.7.2 -f https://data.dgl.ai/wheels/repo.html 44 | pip install tqdm 45 | pip install wandb 46 | pip install networkx 47 | pip install setuptools==59.5.0 48 | pip install multiprocess 49 | 50 | git clone -b 0.12.2-release https://github.com/facebookresearch/fairseq 51 | cd fairseq 52 | pip install ./ 53 | python setup.py build_ext --inplace 54 | cd .. 55 | 56 | ``` 57 | 58 | 59 | ## Run the Application 60 | 61 | We provide the training scripts for ZINC, CLUSTER, PATTERN and PCQM4M-LSC. 62 | 63 | 64 | #### Script for ZINC 65 | ``` 66 | CUDA_VISIBLE_DEVICES=0 python train.py --user-dir ./deepgraph --save-dir ckpts/zinc --ddp-backend=legacy_ddp --dataset-name zinc --dataset-source pyg --data-dir dataset/ --task graph_prediction_substructure --id-type cycle_graph+path_graph+star_graph+k_neighborhood --ks [8,4,6,2] --sampling-redundancy 6 --valid-on-test --criterion l1_loss --arch graphormer_slim --deepnorm --num-classes 1 --num-workers 16 --attention-dropout 0.1 --act-dropout 0.1 --dropout 0.0 --optimizer adam --adam-betas '(0.9,0.999)' --adam-eps 1e-8 --clip-norm 5.0 --weight-decay 0.01 --lr-scheduler polynomial_decay --power 1 --warmup-updates 640000 --total-num-update 2560000 --lr 2e-4 --end-learning-rate 1e-6 --batch-size 16 --fp16 --data-buffer-size 20 --encoder-layers 48 --encoder-embed-dim 80 --encoder-ffn-embed-dim 80 --encoder-attention-heads 8 --max-epoch 10000 --keep-best-checkpoints 2 --keep-last-epochs 3 67 | 68 | ``` 69 | 70 | #### Script for CLUSTER 71 | 72 | ``` 73 | CUDA_VISIBLE_DEVICES=0 python train.py --user-dir ./deepgraph --save-dir ckpts/CLUSTER --ddp-backend=legacy_ddp --dataset-name CLUSTER --dataset-source pyg --node-level-task --data-dir dataset/ --task graph_prediction_substructure --id-type random_walk --ks [10] --sampling-redundancy 2 --valid-on-test --criterion node_multiclass_cross_entropy --arch graphormer_slim --deepnorm --num-classes 6 --num-workers 16 --attention-dropout 0.4 --act-dropout 0.4 --dropout 0.4 --optimizer adam --adam-betas '(0.9,0.999)' --adam-eps 1e-8 --clip-norm 5.0 --weight-decay 0.01 --lr-scheduler polynomial_decay --power 1 --warmup-updates 300 --total-num-update 40000 --lr 5e-4 --end-learning-rate 5e-4 --batch-size 64 --fp16 --data-buffer-size 20 --encoder-layers 48 --encoder-embed-dim 48 --encoder-ffn-embed-dim 96 --encoder-attention-heads 8 --max-nodes 1024 --max-epoch 10000 --keep-best-checkpoints 2 --keep-last-epochs 3 --best-checkpoint-metric accuracy --maximize-best-checkpoint-metric 74 | 75 | ``` 76 | 77 | #### Script for PATTERN 78 | 79 | ``` 80 | CUDA_VISIBLE_DEVICES=0 python train.py --user-dir ./deepgraph --save-dir ckpts/PATTERN --ddp-backend=legacy_ddp --dataset-name PATTERN --dataset-source pyg --node-level-task --data-dir dataset/ --task graph_prediction_substructure --id-type random_walk --ks [10] --sampling-redundancy 2 --valid-on-test --criterion node_multiclass_cross_entropy --arch graphormer_slim --deepnorm --num-classes 2 --num-workers 16 --attention-dropout 0.1 --act-dropout 0.1 --dropout 0.0 --optimizer adam --adam-betas '(0.9,0.999)' --adam-eps 1e-8 --clip-norm 5.0 --weight-decay 0.01 --lr-scheduler polynomial_decay --power 1 --warmup-updates 6000 --total-num-update 40000 --lr 2e-4 --end-learning-rate 1e-6 --batch-size 64 --fp16 --data-buffer-size 20 --encoder-layers 48 --encoder-embed-dim 80 --encoder-ffn-embed-dim 80 --encoder-attention-heads 8 --max-nodes 512 --max-epoch 10000 --keep-best-checkpoints 2 --keep-last-epochs 3 --best-checkpoint-metric accuracy --maximize-best-checkpoint-metric 81 | 82 | ``` 83 | 84 | #### Script for PCQM4M 85 | ``` 86 | CUDA_VISIBLE_DEVICES=0 python train.py --user-dir ./deepgraph --save-dir ckpts/pcqm4m --ddp-backend=legacy_ddp --dataset-name pcqm4m --dataset-source ogb --data-dir dataset/ --task graph_prediction_substructure --id-type cycle_graph+path_graph+star_graph+k_neighborhood --ks [8,4,6,2] --sampling-redundancy 6 --criterion l1_loss --arch graphormer_base --deepnorm --num-classes 1 --num-workers 16 --attention-dropout 0.1 --act-dropout 0.1 --dropout 0.0 --optimizer adam --adam-betas '(0.9, 0.999)' --adam-eps 1e-8 --clip-norm 5.0 --weight-decay 0.01 --lr-scheduler polynomial_decay --power 1 --warmup-updates 50000 --total-num-update 500000 --lr 1e-4 --end-learning-rate 1e-5 --batch-size 100 --fp16 --data-buffer-size 20 --encoder-layers 48 --encoder-embed-dim 768 --encoder-ffn-embed-dim 768 --encoder-attention-heads 32 --max-epoch 10000 --keep-best-checkpoints 2 --keep-last-epochs 3 87 | 88 | ``` 89 | 90 | ## Guidence for Further Development 91 | 92 | Our code consists of several folds, and it is compatible with other Fairseq-based frameworks. You can easily integrate DeepGraph with other Fairseq frameworks by merging the folders. Additionally, you can apply further development to our code by adding corresponding modules. 93 | ``` 94 | deepgraph 95 | ├─criterions 96 | ├─data 97 | │ ├─ogb_datasets 98 | │ ├─pyg_datasets 99 | │ ├─subsampling 100 | │ └─substructure_dataset_utils 101 | ├─evaluate 102 | ├─models 103 | ├─modules 104 | ├─pretrain 105 | └─tasks 106 | ``` 107 | 108 | 109 | ## Citation 110 | 111 | Please kindly cite this paper if our work is useful: 112 | ``` 113 | @inproceedings{zhaomore, 114 | title={Are More Layers Beneficial to Graph Transformers?}, 115 | author={Zhao, Haiteng and Ma, Shuming and Zhang, Dongdong and Deng, Zhi-Hong and Wei, Furu}, 116 | booktitle={International Conference on Learning Representations}, 117 | year={2022} 118 | } 119 | ``` 120 | 121 | -------------------------------------------------------------------------------- /deepgraph/modules/deepgraph_graph_encoder_layer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | # Copyright (c) Facebook, Inc. and its affiliates. 5 | # 6 | # This source code is licensed under the MIT license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | 9 | from typing import Callable, Optional 10 | 11 | import torch 12 | import torch.nn as nn 13 | from fairseq import utils 14 | from fairseq.modules import LayerNorm 15 | from fairseq.modules.fairseq_dropout import FairseqDropout 16 | from fairseq.modules.quant_noise import quant_noise 17 | 18 | from .multihead_attention import MultiheadAttention,MultiheadAttentionRe 19 | import math 20 | 21 | class DeepGraphGraphEncoderLayer(nn.Module): 22 | def __init__( 23 | self, 24 | embedding_dim: int = 768, 25 | ffn_embedding_dim: int = 3072, 26 | num_attention_heads: int = 8, 27 | dropout: float = 0.1, 28 | attention_dropout: float = 0.1, 29 | activation_dropout: float = 0.1, 30 | activation_fn: str = "relu", 31 | export: bool = False, 32 | q_noise: float = 0.0, 33 | qn_block_size: int = 8, 34 | init_fn: Callable = None, 35 | pre_layernorm: bool = False, 36 | deepnorm: bool = False, 37 | encoder_layers: int = 12, 38 | reattention:bool=False 39 | ) -> None: 40 | super().__init__() 41 | 42 | if init_fn is not None: 43 | init_fn() 44 | 45 | # Initialize parameters 46 | self.embedding_dim = embedding_dim 47 | self.num_attention_heads = num_attention_heads 48 | self.attention_dropout = attention_dropout 49 | self.q_noise = q_noise 50 | self.qn_block_size = qn_block_size 51 | self.pre_layernorm = pre_layernorm 52 | 53 | self.dropout_module = FairseqDropout( 54 | dropout, module_name=self.__class__.__name__ 55 | ) 56 | self.activation_dropout_module = FairseqDropout( 57 | activation_dropout, module_name=self.__class__.__name__ 58 | ) 59 | 60 | # Initialize blocks 61 | self.activation_fn = utils.get_activation_fn(activation_fn) 62 | self.self_attn = self.build_self_attention( 63 | self.embedding_dim, 64 | num_attention_heads, 65 | dropout=attention_dropout, 66 | self_attention=True, 67 | q_noise=q_noise, 68 | qn_block_size=qn_block_size, 69 | reattention=reattention 70 | ) 71 | 72 | # layer norm associated with the self attention layer 73 | self.self_attn_layer_norm = LayerNorm(self.embedding_dim, export=export) 74 | 75 | self.fc1 = self.build_fc1( 76 | self.embedding_dim, 77 | ffn_embedding_dim, 78 | q_noise=q_noise, 79 | qn_block_size=qn_block_size, 80 | ) 81 | self.fc2 = self.build_fc2( 82 | ffn_embedding_dim, 83 | self.embedding_dim, 84 | q_noise=q_noise, 85 | qn_block_size=qn_block_size, 86 | ) 87 | 88 | # layer norm associated with the position wise feed-forward NN 89 | self.final_layer_norm = LayerNorm(self.embedding_dim, export=export) 90 | 91 | 92 | if deepnorm: 93 | # self.shortcut_scale = math.pow(math.pow(cfg.encoder_layers, 4) * cfg.decoder_layers, 0.0625) * 0.81 94 | # if utils.safe_getattr(cfg, "deepnorm_encoder_only", False): 95 | self.shortcut_scale = math.pow(2.0 * encoder_layers, 0.25) 96 | # print('deepnorm shortcut_scale ',self.shortcut_scale) 97 | else: 98 | self.shortcut_scale = 1.0 99 | 100 | if deepnorm: 101 | # self.fixup_scale = math.pow(math.pow(cfg.encoder_layers, 4) * cfg.decoder_layers, 0.0625) / 1.15 102 | # if utils.safe_getattr(cfg, "deepnorm_encoder_only", False): 103 | self.fixup_scale = math.pow(8.0 * encoder_layers, 0.25) 104 | # print('deepnorm fixup_scale ', self.fixup_scale) 105 | self.deepnorm_init() 106 | 107 | def deepnorm_init(self): 108 | def rescale(param): 109 | param.div_(self.fixup_scale) 110 | 111 | rescale(self.self_attn.v_proj.weight.data) 112 | rescale(self.self_attn.v_proj.bias.data) 113 | rescale(self.self_attn.out_proj.weight.data) 114 | rescale(self.self_attn.out_proj.bias.data) 115 | 116 | rescale(self.fc1.weight.data) 117 | rescale(self.fc2.weight.data) 118 | rescale(self.fc1.bias.data) 119 | rescale(self.fc2.bias.data) 120 | # print('deepnorm init') 121 | 122 | def build_fc1(self, input_dim, output_dim, q_noise, qn_block_size): 123 | return quant_noise(nn.Linear(input_dim, output_dim), q_noise, qn_block_size) 124 | 125 | def build_fc2(self, input_dim, output_dim, q_noise, qn_block_size): 126 | return quant_noise(nn.Linear(input_dim, output_dim), q_noise, qn_block_size) 127 | 128 | def build_self_attention( 129 | self, 130 | embed_dim, 131 | num_attention_heads, 132 | dropout, 133 | self_attention, 134 | q_noise, 135 | qn_block_size, 136 | reattention=False 137 | ): 138 | if not reattention: 139 | return MultiheadAttention( 140 | embed_dim, 141 | num_attention_heads, 142 | dropout=dropout, 143 | self_attention=True, 144 | q_noise=q_noise, 145 | qn_block_size=qn_block_size, 146 | ) 147 | else: 148 | return MultiheadAttentionRe( 149 | embed_dim, 150 | num_attention_heads, 151 | dropout=dropout, 152 | self_attention=True, 153 | q_noise=q_noise, 154 | qn_block_size=qn_block_size, 155 | ) 156 | 157 | 158 | def forward( 159 | self, 160 | x: torch.Tensor, 161 | self_attn_bias: Optional[torch.Tensor] = None, 162 | self_attn_mask: Optional[torch.Tensor] = None, 163 | self_attn_padding_mask: Optional[torch.Tensor] = None, 164 | ): 165 | """ 166 | LayerNorm is applied either before or after the self-attention/ffn 167 | modules similar to the original Transformer implementation. 168 | """ 169 | # x: T x B x C 170 | residual = x 171 | if self.pre_layernorm: 172 | x = self.self_attn_layer_norm(x) 173 | x, attn = self.self_attn( 174 | query=x, 175 | key=x, 176 | value=x, 177 | attn_bias=self_attn_bias, 178 | key_padding_mask=self_attn_padding_mask, 179 | need_weights=True, 180 | attn_mask=self_attn_mask, 181 | ) 182 | x = self.dropout_module(x) 183 | x = residual* self.shortcut_scale + x 184 | if not self.pre_layernorm: 185 | x = self.self_attn_layer_norm(x) 186 | 187 | residual = x 188 | if self.pre_layernorm: 189 | x = self.final_layer_norm(x) 190 | x = self.activation_fn(self.fc1(x)) 191 | x = self.activation_dropout_module(x) 192 | x = self.fc2(x) 193 | x = self.dropout_module(x) 194 | x = residual * self.shortcut_scale + x 195 | if not self.pre_layernorm: 196 | x = self.final_layer_norm(x) 197 | return x, attn 198 | -------------------------------------------------------------------------------- /deepgraph/modules/deepgraph_layers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | # Copyright (c) Facebook, Inc. and its affiliates. 5 | # 6 | # This source code is licensed under the MIT license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | 9 | import math 10 | 11 | import torch 12 | import torch.nn as nn 13 | from torchvision.models import resnet18 14 | 15 | def init_params(module, n_layers): 16 | if isinstance(module, nn.Linear): 17 | module.weight.data.normal_(mean=0.0, std=0.02 / math.sqrt(n_layers)) 18 | if module.bias is not None: 19 | module.bias.data.zero_() 20 | if isinstance(module, nn.Embedding): 21 | module.weight.data.normal_(mean=0.0, std=0.02) 22 | 23 | 24 | class GraphNodeFeature(nn.Module): 25 | """ 26 | Compute node features for each node in the graph. 27 | """ 28 | 29 | def __init__( 30 | self, num_heads, num_atoms, num_in_degree, num_out_degree, hidden_dim, n_layers 31 | ): 32 | super(GraphNodeFeature, self).__init__() 33 | self.num_heads = num_heads 34 | self.num_atoms = num_atoms 35 | 36 | # 1 for graph token 37 | self.atom_encoder = nn.Embedding(num_atoms + 1, hidden_dim, padding_idx=0) 38 | self.in_degree_encoder = nn.Embedding(num_in_degree, hidden_dim, padding_idx=0) 39 | self.out_degree_encoder = nn.Embedding( 40 | num_out_degree, hidden_dim, padding_idx=0 41 | ) 42 | 43 | self.graph_token = nn.Embedding(1, hidden_dim) 44 | 45 | self.apply(lambda module: init_params(module, n_layers=n_layers)) 46 | 47 | def forward(self, batched_data): 48 | x, in_degree, out_degree = ( 49 | batched_data["x"], 50 | batched_data["in_degree"], 51 | batched_data["out_degree"], 52 | ) 53 | n_graph, n_node = x.size()[:2] 54 | 55 | # node feauture + graph token 56 | node_feature = self.atom_encoder(x).sum(dim=-2) # [n_graph, n_node, n_hidden] 57 | 58 | # if self.flag and perturb is not None: 59 | # node_feature += perturb 60 | 61 | node_feature = ( 62 | node_feature 63 | + self.in_degree_encoder(in_degree) 64 | + self.out_degree_encoder(out_degree) 65 | ) 66 | 67 | graph_token_feature = self.graph_token.weight.unsqueeze(0).repeat(n_graph, 1, 1) 68 | 69 | graph_node_feature = torch.cat([graph_token_feature, node_feature], dim=1) 70 | 71 | return graph_node_feature 72 | 73 | 74 | class GraphAttnBias(nn.Module): 75 | """ 76 | Compute attention bias for each head. 77 | """ 78 | 79 | def __init__( 80 | self, 81 | num_heads, 82 | num_atoms, 83 | num_edges, 84 | num_spatial, 85 | num_edge_dis, 86 | hidden_dim, 87 | edge_type, 88 | multi_hop_max_dist, 89 | n_layers, 90 | ): 91 | super(GraphAttnBias, self).__init__() 92 | self.num_heads = num_heads 93 | self.multi_hop_max_dist = multi_hop_max_dist 94 | 95 | self.edge_encoder = nn.Embedding(num_edges + 1, num_heads, padding_idx=0) 96 | self.edge_type = edge_type 97 | if self.edge_type == "multi_hop": 98 | self.edge_dis_encoder = nn.Embedding( 99 | num_edge_dis * num_heads * num_heads, 1 100 | ) 101 | self.spatial_pos_encoder = nn.Embedding(num_spatial, num_heads, padding_idx=0) 102 | 103 | self.graph_token_virtual_distance = nn.Embedding(1, num_heads) 104 | 105 | self.apply(lambda module: init_params(module, n_layers=n_layers)) 106 | 107 | def forward(self, batched_data): 108 | attn_bias, spatial_pos, x = ( 109 | batched_data["attn_bias"], 110 | batched_data["spatial_pos"], 111 | batched_data["x"], 112 | ) 113 | # in_degree, out_degree = batched_data.in_degree, batched_data.in_degree 114 | edge_input, attn_edge_type = ( 115 | batched_data["edge_input"], 116 | batched_data["attn_edge_type"], 117 | ) 118 | 119 | n_graph, n_node = x.size()[:2] 120 | graph_attn_bias = attn_bias.clone() 121 | graph_attn_bias = graph_attn_bias.unsqueeze(1).repeat( 122 | 1, self.num_heads, 1, 1 123 | ) # [n_graph, n_head, n_node+1, n_node+1] 124 | 125 | # spatial pos 126 | # [n_graph, n_node, n_node, n_head] -> [n_graph, n_head, n_node, n_node] 127 | spatial_pos_bias = self.spatial_pos_encoder(spatial_pos).permute(0, 3, 1, 2) 128 | graph_attn_bias[:, :, 1:, 1:] = graph_attn_bias[:, :, 1:, 1:] + spatial_pos_bias 129 | 130 | # reset spatial pos here 131 | t = self.graph_token_virtual_distance.weight.view(1, self.num_heads, 1) 132 | graph_attn_bias[:, :, 1:, 0] = graph_attn_bias[:, :, 1:, 0] + t 133 | graph_attn_bias[:, :, 0, :] = graph_attn_bias[:, :, 0, :] + t 134 | 135 | # edge feature 136 | if self.edge_type == "multi_hop": 137 | spatial_pos_ = spatial_pos.clone() 138 | spatial_pos_[spatial_pos_ == 0] = 1 # set pad to 1 139 | # set 1 to 1, x > 1 to x - 1 140 | spatial_pos_ = torch.where(spatial_pos_ > 1, spatial_pos_ - 1, spatial_pos_) 141 | if self.multi_hop_max_dist > 0: 142 | spatial_pos_ = spatial_pos_.clamp(0, self.multi_hop_max_dist) 143 | edge_input = edge_input[:, :, :, : self.multi_hop_max_dist, :] 144 | # [n_graph, n_node, n_node, max_dist, n_head] 145 | edge_input = self.edge_encoder(edge_input).mean(-2) 146 | max_dist = edge_input.size(-2) 147 | edge_input_flat = edge_input.permute(3, 0, 1, 2, 4).reshape( 148 | max_dist, -1, self.num_heads 149 | ) 150 | edge_input_flat = torch.bmm( 151 | edge_input_flat, 152 | self.edge_dis_encoder.weight.reshape( 153 | -1, self.num_heads, self.num_heads 154 | )[:max_dist, :, :], 155 | ) 156 | edge_input = edge_input_flat.reshape( 157 | max_dist, n_graph, n_node, n_node, self.num_heads 158 | ).permute(1, 2, 3, 0, 4) 159 | edge_input = ( 160 | edge_input.sum(-2) / (spatial_pos_.float().unsqueeze(-1)) 161 | ).permute(0, 3, 1, 2) 162 | else: 163 | # [n_graph, n_node, n_node, n_head] -> [n_graph, n_head, n_node, n_node] 164 | edge_input = self.edge_encoder(attn_edge_type).mean(-2).permute(0, 3, 1, 2) 165 | 166 | graph_attn_bias[:, :, 1:, 1:] = graph_attn_bias[:, :, 1:, 1:] + edge_input 167 | graph_attn_bias = graph_attn_bias + attn_bias.unsqueeze(1) # reset 168 | 169 | return graph_attn_bias 170 | 171 | 172 | class SubStructure_Adj_Encoder(nn.Module): 173 | def __init__(self,subgraph_max_size,output_size): 174 | super(SubStructure_Adj_Encoder,self).__init__() 175 | self.subgraph_max_size=subgraph_max_size 176 | self.output_size=output_size 177 | 178 | self.encoder=nn.Sequential(nn.Linear(subgraph_max_size*subgraph_max_size,output_size), 179 | nn.ReLU(inplace=True), 180 | nn.Linear(output_size, output_size)) 181 | 182 | 183 | def forward(self,sorted_adj): 184 | input=sorted_adj.reshape([-1,self.subgraph_max_size*self.subgraph_max_size]) 185 | output=self.encoder(input) 186 | 187 | result=output.reshape([sorted_adj.shape[0],sorted_adj.shape[1],-1]) 188 | return result -------------------------------------------------------------------------------- /deepgraph/data/pyg_datasets/pyg_dataset_lookup_table.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | from typing import Optional 5 | import graph_tool 6 | from torch_geometric.datasets import * 7 | from torch_geometric.data import Dataset 8 | from ..substructure_dataset import SubstructureDataset 9 | import torch.distributed as dist 10 | import torch 11 | 12 | class MyQM7b(QM7b): 13 | def download(self): 14 | if not dist.is_initialized() or dist.get_rank() == 0: 15 | super(MyQM7b, self).download() 16 | if dist.is_initialized(): 17 | dist.barrier() 18 | 19 | def process(self): 20 | if not dist.is_initialized() or dist.get_rank() == 0: 21 | super(MyQM7b, self).process() 22 | if dist.is_initialized(): 23 | dist.barrier() 24 | 25 | 26 | class MyQM9(QM9): 27 | def download(self): 28 | if not dist.is_initialized() or dist.get_rank() == 0: 29 | super(MyQM9, self).download() 30 | if dist.is_initialized(): 31 | dist.barrier() 32 | 33 | def process(self): 34 | if not dist.is_initialized() or dist.get_rank() == 0: 35 | super(MyQM9, self).process() 36 | if dist.is_initialized(): 37 | dist.barrier() 38 | 39 | class MyZINC(ZINC): 40 | def download(self): 41 | if not dist.is_initialized() or dist.get_rank() == 0: 42 | super(MyZINC, self).download() 43 | if dist.is_initialized(): 44 | dist.barrier() 45 | 46 | def process(self): 47 | if not dist.is_initialized() or dist.get_rank() == 0: 48 | super(MyZINC, self).process() 49 | if dist.is_initialized(): 50 | dist.barrier() 51 | 52 | class MYZINC_UNI(torch.utils.data.Dataset): 53 | def __init__(self,data_list): 54 | self.data_list=data_list 55 | 56 | def __len__(self): 57 | return len(self.data_list) 58 | 59 | def __getitem__(self, idx): 60 | data=self.data_list[idx] 61 | return data 62 | 63 | 64 | class MyMoleculeNet(MoleculeNet): 65 | def download(self): 66 | if not dist.is_initialized() or dist.get_rank() == 0: 67 | super(MyMoleculeNet, self).download() 68 | if dist.is_initialized(): 69 | dist.barrier() 70 | 71 | def process(self): 72 | if not dist.is_initialized() or dist.get_rank() == 0: 73 | super(MyMoleculeNet, self).process() 74 | if dist.is_initialized(): 75 | dist.barrier() 76 | 77 | 78 | 79 | 80 | class PYGDatasetLookupTable: 81 | @staticmethod 82 | def GetPYGDataset(dataset_spec: str, seed: int,args=None,**kwargs) -> Optional[Dataset]: 83 | split_result = dataset_spec.split(":") 84 | if len(split_result) == 2: 85 | name, params = split_result[0], split_result[1] 86 | params = params.split(",") 87 | elif len(split_result) == 1: 88 | name = dataset_spec 89 | params = [] 90 | inner_dataset = None 91 | num_class = 1 92 | 93 | train_set = None 94 | valid_set = None 95 | test_set = None 96 | 97 | root = "dataset" 98 | if name == "qm7b": 99 | inner_dataset = MyQM7b(root=kwargs['data_dir']) 100 | elif name == "qm9": 101 | inner_dataset = MyQM9(root=kwargs['data_dir']) 102 | elif name == "zinc_full": 103 | inner_dataset = MyZINC(root=kwargs['data_dir'],subset=False) 104 | train_set = MyZINC(root=kwargs['data_dir'],subset=False, split="train") 105 | valid_set = MyZINC(root=kwargs['data_dir'],subset=False, split="val") 106 | test_set = MyZINC(root=kwargs['data_dir'],subset=False, split="test") 107 | elif name == "zinc": 108 | inner_dataset = MyZINC(root=kwargs['data_dir'],subset=True) 109 | train_set = MyZINC(root=kwargs['data_dir'],subset=True, split="train") 110 | valid_set = MyZINC(root=kwargs['data_dir'],subset=True, split="val") 111 | test_set = MyZINC(root=kwargs['data_dir'],subset=True, split="test") 112 | elif name == "zinc_val_on_test": 113 | inner_dataset = MyZINC(root=kwargs['data_dir'],subset=True) 114 | train_set = MyZINC(root=kwargs['data_dir'],subset=True, split="train") 115 | valid_set = MyZINC(root=kwargs['data_dir'],subset=True, split="test") 116 | test_set = MyZINC(root=kwargs['data_dir'],subset=True, split="test") 117 | elif name == 'zinc_uniform': 118 | inner_dataset = MyZINC(root=kwargs['data_dir'],subset=True) 119 | train_set = MyZINC(root=kwargs['data_dir'],subset=True, split="train") 120 | valid_set = MyZINC(root=kwargs['data_dir'],subset=True, split="val") 121 | test_set = MyZINC(root=kwargs['data_dir'],subset=True, split="test") 122 | total_data=[] 123 | for i in range(len(train_set)): 124 | total_data.append(train_set[i]) 125 | for i in range(len(valid_set)): 126 | total_data.append(valid_set[i]) 127 | for i in range(len(test_set)): 128 | total_data.append(test_set[i]) 129 | train_data,valid_data,test_data=torch.utils.data.random_split(total_data, [len(train_set), len(valid_set),len(test_set)]) 130 | train_set=MYZINC_UNI(train_data) 131 | valid_set = MYZINC_UNI(valid_data) 132 | test_set = MYZINC_UNI(test_data) 133 | inner_dataset=train_set 134 | elif name == 'zinc_uniform_val_on_test': 135 | inner_dataset = MyZINC(root=kwargs['data_dir'],subset=True) 136 | train_set = MyZINC(root=kwargs['data_dir'],subset=True, split="train") 137 | valid_set = MyZINC(root=kwargs['data_dir'],subset=True, split="val") 138 | test_set = MyZINC(root=kwargs['data_dir'],subset=True, split="test") 139 | total_data=[] 140 | for i in range(len(train_set)): 141 | total_data.append(train_set[i]) 142 | for i in range(len(valid_set)): 143 | total_data.append(valid_set[i]) 144 | for i in range(len(test_set)): 145 | total_data.append(test_set[i]) 146 | train_data,valid_data,test_data=torch.utils.data.random_split(total_data, [len(train_set), len(valid_set),len(test_set)]) 147 | train_set=MYZINC_UNI(train_data) 148 | valid_set = MYZINC_UNI(test_data) 149 | test_set = MYZINC_UNI(test_data) 150 | inner_dataset=train_set 151 | elif name == "moleculenet": 152 | nm = None 153 | for param in params: 154 | name, value = param.split("=") 155 | if name == "name": 156 | nm = value 157 | inner_dataset = MyMoleculeNet(root=kwargs['data_dir'], name=nm) 158 | 159 | elif name in ["CLUSTER","PATTERN"]: 160 | train_set = GNNBenchmarkDataset(root=kwargs['data_dir'], name=name, split='train') 161 | valid_set = GNNBenchmarkDataset(root=kwargs['data_dir'], name=name, split='val') 162 | test_set = GNNBenchmarkDataset(root=kwargs['data_dir'], name=name, split='test') 163 | 164 | else: 165 | raise ValueError(f"Unknown dataset name {name} for pyg source.") 166 | 167 | if args['valid_on_test']: 168 | valid_set = test_set 169 | 170 | if train_set is not None: 171 | result= SubstructureDataset( 172 | None, 173 | seed, 174 | None, 175 | None, 176 | None, 177 | train_set, 178 | valid_set, 179 | test_set, 180 | args['not_re_define'], 181 | args=args 182 | ) 183 | else: 184 | result= ( 185 | None 186 | if inner_dataset is None 187 | else SubstructureDataset(inner_dataset, seed) 188 | ) 189 | 190 | return result 191 | -------------------------------------------------------------------------------- /deepgraph/data/collator.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | import torch 5 | 6 | 7 | def pad_1d_unsqueeze(x, padlen): 8 | x = x + 1 # pad id = 0 9 | xlen = x.size(0) 10 | if xlen < padlen: 11 | new_x = x.new_zeros([padlen], dtype=x.dtype) 12 | new_x[:xlen] = x 13 | x = new_x 14 | return x.unsqueeze(0) 15 | 16 | 17 | def pad_2d_unsqueeze(x, padlen): 18 | x = x + 1 # pad id = 0 19 | xlen, xdim = x.size() 20 | if xlen < padlen: 21 | new_x = x.new_zeros([padlen, xdim], dtype=x.dtype) 22 | new_x[:xlen, :] = x 23 | x = new_x 24 | return x.unsqueeze(0) 25 | 26 | def pad_sub_adjs_unsqueeze(x, padlen): 27 | xlen, xdim1,xdim2 = x.size() 28 | if xlen < padlen: 29 | new_x = x.new_zeros([padlen, xdim1,xdim2], dtype=x.dtype) 30 | new_x[:xlen, :,:] = x 31 | x = new_x 32 | return x.unsqueeze(0) 33 | 34 | def pad_attn_bias_unsqueeze(x, padlen): 35 | xlen = x.size(0) 36 | if xlen < padlen: 37 | new_x = x.new_zeros([padlen, padlen], dtype=x.dtype).fill_(float("-inf")) 38 | new_x[:xlen, :xlen] = x 39 | new_x[xlen:, :xlen] = 0 40 | x = new_x 41 | return x.unsqueeze(0) 42 | 43 | 44 | def pad_edge_type_unsqueeze(x, padlen): 45 | xlen = x.size(0) 46 | if xlen < padlen: 47 | new_x = x.new_zeros([padlen, padlen, x.size(-1)], dtype=x.dtype) 48 | new_x[:xlen, :xlen, :] = x 49 | x = new_x 50 | return x.unsqueeze(0) 51 | 52 | 53 | def pad_spatial_pos_unsqueeze(x, padlen): 54 | x = x + 1 55 | xlen = x.size(0) 56 | if xlen < padlen: 57 | new_x = x.new_zeros([padlen, padlen], dtype=x.dtype) 58 | new_x[:xlen, :xlen] = x 59 | x = new_x 60 | return x.unsqueeze(0) 61 | 62 | 63 | def pad_3d_unsqueeze(x, padlen1, padlen2, padlen3): 64 | x = x + 1 65 | xlen1, xlen2, xlen3, xlen4 = x.size() 66 | if xlen1 < padlen1 or xlen2 < padlen2 or xlen3 < padlen3: 67 | new_x = x.new_zeros([padlen1, padlen2, padlen3, xlen4], dtype=x.dtype) 68 | new_x[:xlen1, :xlen2, :xlen3, :] = x 69 | x = new_x 70 | return x.unsqueeze(0) 71 | 72 | def pad_3d_sequences(xs, padlen1, padlen2, padlen3): 73 | result=torch.zeros([len(xs),padlen1,padlen2,padlen3,xs[0].shape[-1]], dtype=xs[0].dtype) 74 | for i,x in enumerate(xs): 75 | result[i,:x.shape[0], :x.shape[1], :x.shape[2], :]=x+1 76 | return result 77 | 78 | 79 | 80 | def collator(items, max_node=512, multi_hop_max_dist=20, spatial_pos_max=20): 81 | items = [item for item in items if item is not None and item.x.size(0) <= max_node] 82 | items = [ 83 | ( 84 | item.idx, 85 | item.attn_bias, 86 | item.attn_edge_type, 87 | item.spatial_pos, 88 | item.in_degree, 89 | item.out_degree, 90 | item.x, 91 | item.edge_input[:, :, :multi_hop_max_dist, :], 92 | item.y, 93 | ) 94 | for item in items 95 | ] 96 | ( 97 | idxs, 98 | attn_biases, 99 | attn_edge_types, 100 | spatial_poses, 101 | in_degrees, 102 | out_degrees, 103 | xs, 104 | edge_inputs, 105 | ys, 106 | ) = zip(*items) 107 | 108 | for idx, _ in enumerate(attn_biases): 109 | attn_biases[idx][1:, 1:][spatial_poses[idx] >= spatial_pos_max] = float("-inf") 110 | max_node_num = max(i.size(0) for i in xs) 111 | max_dist = max(i.size(-2) for i in edge_inputs) 112 | y = torch.cat(ys) 113 | x = torch.cat([pad_2d_unsqueeze(i, max_node_num) for i in xs]) 114 | edge_input = torch.cat( 115 | [pad_3d_unsqueeze(i, max_node_num, max_node_num, max_dist) for i in edge_inputs] 116 | ) 117 | attn_bias = torch.cat( 118 | [pad_attn_bias_unsqueeze(i, max_node_num + 1) for i in attn_biases] 119 | ) 120 | attn_edge_type = torch.cat( 121 | [pad_edge_type_unsqueeze(i, max_node_num) for i in attn_edge_types] 122 | ) 123 | spatial_pos = torch.cat( 124 | [pad_spatial_pos_unsqueeze(i, max_node_num) for i in spatial_poses] 125 | ) 126 | in_degree = torch.cat([pad_1d_unsqueeze(i, max_node_num) for i in in_degrees]) 127 | 128 | return dict( 129 | idx=torch.LongTensor(idxs), 130 | attn_bias=attn_bias, 131 | attn_edge_type=attn_edge_type, 132 | spatial_pos=spatial_pos, 133 | in_degree=in_degree, 134 | out_degree=in_degree, # for undirected graph 135 | x=x, 136 | edge_input=edge_input, 137 | y=y, 138 | ) 139 | 140 | def collator_adj(items, max_node=512, multi_hop_max_dist=20, spatial_pos_max=20): 141 | items = [item for item in items if item is not None and item.x.size(0) <= max_node] 142 | items = [ 143 | ( 144 | item.idx, 145 | item.attn_bias, 146 | item.attn_edge_type, 147 | item.spatial_pos, 148 | item.in_degree, 149 | item.out_degree, 150 | item.x, 151 | item.edge_input[:, :, :multi_hop_max_dist, :], 152 | item.y, 153 | item.sorted_adj, 154 | item.sub_adj_mask 155 | ) 156 | for item in items 157 | ] 158 | ( 159 | idxs, 160 | attn_biases, 161 | attn_edge_types, 162 | spatial_poses, 163 | in_degrees, 164 | out_degrees, 165 | xs, 166 | edge_inputs, 167 | ys, 168 | sorted_adjs, 169 | sub_adj_masks 170 | ) = zip(*items) 171 | 172 | for idx, _ in enumerate(attn_biases): 173 | attn_biases[idx][1:, 1:][spatial_poses[idx] >= spatial_pos_max] = float("-inf") 174 | max_node_num = max(i.size(0) for i in xs) 175 | max_dist = max(i.size(-2) for i in edge_inputs) 176 | if sorted_adjs[0] is not None: 177 | max_subs = max(i.size(0) for i in sorted_adjs) 178 | y = torch.cat(ys) 179 | x = torch.cat([pad_2d_unsqueeze(i, max_node_num) for i in xs]) 180 | # edge_input = torch.cat( 181 | # [pad_3d_unsqueeze(i, max_node_num, max_node_num, max_dist) for i in edge_inputs] 182 | # ) 183 | edge_input = pad_3d_sequences(edge_inputs, max_node_num, max_node_num, max_dist) 184 | attn_bias = torch.cat( 185 | [pad_attn_bias_unsqueeze(i, max_node_num + 1) for i in attn_biases] 186 | ) 187 | attn_edge_type = torch.cat( 188 | [pad_edge_type_unsqueeze(i, max_node_num) for i in attn_edge_types] 189 | ) 190 | spatial_pos = torch.cat( 191 | [pad_spatial_pos_unsqueeze(i, max_node_num) for i in spatial_poses] 192 | ) 193 | in_degree = torch.cat([pad_1d_unsqueeze(i, max_node_num) for i in in_degrees]) 194 | 195 | 196 | if sorted_adjs[0] is None: 197 | sorted_adj=None 198 | else: 199 | sorted_adj=torch.cat([pad_sub_adjs_unsqueeze(i, max_subs) for i in sorted_adjs]) 200 | if sub_adj_masks[0] is None: 201 | sub_adj_mask=None 202 | else: 203 | sub_adj_mask = torch.cat([pad_2d_unsqueeze(i, max_node_num) for i in sub_adj_masks]) 204 | 205 | return dict( 206 | idx=torch.LongTensor(idxs), 207 | attn_bias=attn_bias, 208 | attn_edge_type=attn_edge_type, 209 | spatial_pos=spatial_pos, 210 | in_degree=in_degree, 211 | out_degree=in_degree, # for undirected graph 212 | x=x, 213 | edge_input=edge_input, 214 | y=y, 215 | sorted_adj=sorted_adj, 216 | sub_adj_mask=sub_adj_mask 217 | ) 218 | 219 | def collator_adj_cache(items, max_node=512, multi_hop_max_dist=20, spatial_pos_max=20): 220 | items = [ 221 | ( 222 | item.idx, 223 | item.subgraph_tensor_res, 224 | item.sorted_adj_res 225 | ) 226 | for item in items 227 | ] 228 | 229 | ( 230 | idx, 231 | subgraph_tensor_res, 232 | sorted_adj_res, 233 | ) = zip(*items) 234 | 235 | return dict( 236 | idx=idx, 237 | subgraph_tensor_res=subgraph_tensor_res, 238 | sorted_adj_res=sorted_adj_res, 239 | ) 240 | 241 | 242 | 243 | def collator_node_label(labels): 244 | max_node_num = max(i.size(0) for i in labels) 245 | label_padded=torch.cat([pad_1d_unsqueeze(i,max_node_num) for i in labels]) 246 | 247 | return label_padded 248 | 249 | 250 | 251 | def collatorcontrust(datas, max_node=512, multi_hop_max_dist=20, spatial_pos_max=20): 252 | datas_data=[a_data.data for a_data in datas] 253 | datas_data1 = [a_data.data1 for a_data in datas] 254 | datas_data2 = [a_data.data2 for a_data in datas] 255 | result=[] 256 | for items in [datas_data,datas_data1,datas_data2]: 257 | result.append(collator(items, max_node=max_node, multi_hop_max_dist=multi_hop_max_dist, spatial_pos_max=spatial_pos_max)) 258 | res_dic={} 259 | res_dic['data']=result[0] 260 | res_dic['data1'] = result[1] 261 | res_dic['data2'] = result[2] 262 | return res_dic 263 | 264 | 265 | -------------------------------------------------------------------------------- /deepgraph/tasks/is2re.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | from pathlib import Path 5 | from typing import Sequence, Union 6 | 7 | import pickle 8 | from functools import lru_cache 9 | 10 | import lmdb 11 | 12 | import numpy as np 13 | import torch 14 | from torch import Tensor 15 | from fairseq.data import ( 16 | FairseqDataset, 17 | BaseWrapperDataset, 18 | NestedDictionaryDataset, 19 | data_utils, 20 | ) 21 | from fairseq.tasks import FairseqTask, register_task 22 | 23 | from ..data.dataset import EpochShuffleDataset 24 | 25 | class LMDBDataset: 26 | def __init__(self, db_path): 27 | super().__init__() 28 | assert Path(db_path).exists(), f"{db_path}: No such file or directory" 29 | self.env = lmdb.Environment( 30 | db_path, 31 | map_size=(1024 ** 3) * 256, 32 | subdir=False, 33 | readonly=True, 34 | readahead=True, 35 | meminit=False, 36 | ) 37 | self.len: int = self.env.stat()["entries"] 38 | 39 | def __len__(self): 40 | return self.len 41 | 42 | @lru_cache(maxsize=16) 43 | # def __getitem__(self, idx: int) -> dict[str, Union[Tensor, float]]: 44 | def __getitem__(self, idx: int) : 45 | if idx < 0 or idx >= self.len: 46 | raise IndexError 47 | data = pickle.loads(self.env.begin().get(f"{idx}".encode())) 48 | return dict( 49 | pos=torch.as_tensor(data["pos"]).float(), 50 | pos_relaxed=torch.as_tensor(data["pos_relaxed"]).float(), 51 | cell=torch.as_tensor(data["cell"]).float().view(3, 3), 52 | atoms=torch.as_tensor(data["atomic_numbers"]).long(), 53 | tags=torch.as_tensor(data["tags"]).long(), 54 | relaxed_energy=data["y_relaxed"], # python float 55 | ) 56 | 57 | 58 | class PBCDataset: 59 | def __init__(self, dataset: LMDBDataset): 60 | self.dataset = dataset 61 | self.cell_offsets = torch.tensor( 62 | [ 63 | [-1, -1, 0], 64 | [-1, 0, 0], 65 | [-1, 1, 0], 66 | [0, -1, 0], 67 | [0, 1, 0], 68 | [1, -1, 0], 69 | [1, 0, 0], 70 | [1, 1, 0], 71 | ], 72 | ).float() 73 | self.n_cells = self.cell_offsets.size(0) 74 | self.cutoff = 8 75 | self.filter_by_tag = True 76 | 77 | def __len__(self): 78 | return len(self.dataset) 79 | 80 | @lru_cache(maxsize=16) 81 | def __getitem__(self, idx): 82 | data = self.dataset[idx] 83 | 84 | pos = data["pos"] 85 | pos_relaxed = data["pos_relaxed"] 86 | cell = data["cell"] 87 | atoms = data["atoms"] 88 | tags = data["tags"] 89 | 90 | offsets = torch.matmul(self.cell_offsets, cell).view(self.n_cells, 1, 3) 91 | expand_pos = (pos.unsqueeze(0).expand(self.n_cells, -1, -1) + offsets).view( 92 | -1, 3 93 | ) 94 | expand_pos_relaxed = ( 95 | pos.unsqueeze(0).expand(self.n_cells, -1, -1) + offsets 96 | ).view(-1, 3) 97 | src_pos = pos[tags > 1] if self.filter_by_tag else pos 98 | 99 | dist: Tensor = (src_pos.unsqueeze(1) - expand_pos.unsqueeze(0)).norm(dim=-1) 100 | used_mask = (dist < self.cutoff).any(dim=0) & tags.ne(2).repeat( 101 | self.n_cells 102 | ) # not copy ads 103 | used_expand_pos = expand_pos[used_mask] 104 | used_expand_pos_relaxed = expand_pos_relaxed[used_mask] 105 | 106 | used_expand_tags = tags.repeat(self.n_cells)[ 107 | used_mask 108 | ] # original implementation use zeros, need to test 109 | return dict( 110 | pos=torch.cat([pos, used_expand_pos], dim=0), 111 | atoms=torch.cat([atoms, atoms.repeat(self.n_cells)[used_mask]]), 112 | tags=torch.cat([tags, used_expand_tags]), 113 | real_mask=torch.cat( 114 | [ 115 | torch.ones_like(tags, dtype=torch.bool), 116 | torch.zeros_like(used_expand_tags, dtype=torch.bool), 117 | ] 118 | ), 119 | deltapos=torch.cat( 120 | [pos_relaxed - pos, used_expand_pos_relaxed - used_expand_pos], dim=0 121 | ), 122 | relaxed_energy=data["relaxed_energy"], 123 | ) 124 | 125 | 126 | def pad_1d(samples: Sequence[Tensor], fill=0, multiplier=8): 127 | max_len = max(x.size(0) for x in samples) 128 | max_len = (max_len + multiplier - 1) // multiplier * multiplier 129 | n_samples = len(samples) 130 | out = torch.full( 131 | (n_samples, max_len, *samples[0].shape[1:]), fill, dtype=samples[0].dtype 132 | ) 133 | for i in range(n_samples): 134 | x_len = samples[i].size(0) 135 | out[i][:x_len] = samples[i] 136 | return out 137 | 138 | 139 | class AtomDataset(FairseqDataset): 140 | def __init__(self, dataset, keyword): 141 | super().__init__() 142 | self.dataset = dataset 143 | self.keyword = keyword 144 | self.atom_list = [ 145 | 1, 146 | 5, 147 | 6, 148 | 7, 149 | 8, 150 | 11, 151 | 13, 152 | 14, 153 | 15, 154 | 16, 155 | 17, 156 | 19, 157 | 20, 158 | 21, 159 | 22, 160 | 23, 161 | 24, 162 | 25, 163 | 26, 164 | 27, 165 | 28, 166 | 29, 167 | 30, 168 | 31, 169 | 32, 170 | 33, 171 | 34, 172 | 37, 173 | 38, 174 | 39, 175 | 40, 176 | 41, 177 | 42, 178 | 43, 179 | 44, 180 | 45, 181 | 46, 182 | 47, 183 | 48, 184 | 49, 185 | 50, 186 | 51, 187 | 52, 188 | 55, 189 | 72, 190 | 73, 191 | 74, 192 | 75, 193 | 76, 194 | 77, 195 | 78, 196 | 79, 197 | 80, 198 | 81, 199 | 82, 200 | 83, 201 | ] 202 | # fill others as unk 203 | unk_idx = len(self.atom_list) + 1 204 | self.atom_mapper = torch.full((128,), unk_idx) 205 | for idx, atom in enumerate(self.atom_list): 206 | self.atom_mapper[atom] = idx + 1 # reserve 0 for paddin 207 | 208 | @lru_cache(maxsize=16) 209 | def __getitem__(self, index): 210 | atoms: Tensor = self.dataset[index][self.keyword] 211 | return self.atom_mapper[atoms] 212 | 213 | def __len__(self): 214 | return len(self.dataset) 215 | 216 | def collater(self, samples): 217 | return pad_1d(samples) 218 | 219 | 220 | class KeywordDataset(FairseqDataset): 221 | def __init__(self, dataset, keyword, is_scalar=False, pad_fill=0): 222 | super().__init__() 223 | self.dataset = dataset 224 | self.keyword = keyword 225 | self.is_scalar = is_scalar 226 | self.pad_fill = pad_fill 227 | 228 | @lru_cache(maxsize=16) 229 | def __getitem__(self, index): 230 | return self.dataset[index][self.keyword] 231 | 232 | def __len__(self): 233 | return len(self.dataset) 234 | 235 | def collater(self, samples): 236 | if self.is_scalar: 237 | return torch.tensor(samples) 238 | return pad_1d(samples, fill=self.pad_fill) 239 | 240 | 241 | @register_task("is2re") 242 | class IS2RETask(FairseqTask): 243 | @classmethod 244 | def add_args(cls, parser): 245 | parser.add_argument("data", metavar="FILE", help="directory for data") 246 | 247 | @property 248 | def target_dictionary(self): 249 | return None 250 | 251 | def load_dataset(self, split, combine=False, **kwargs): 252 | assert split in [ 253 | "train", 254 | "val_id", 255 | "val_ood_ads", 256 | "val_ood_cat", 257 | "val_ood_both", 258 | "test_id", 259 | "test_ood_ads", 260 | "test_ood_cat", 261 | "test_ood_both", 262 | ], "invalid split: {}!".format(split) 263 | print(" > Loading {} ...".format(split)) 264 | 265 | db_path = str(Path(self.cfg.data) / split / "data.lmdb") 266 | lmdb_dataset = LMDBDataset(db_path) 267 | pbc_dataset = PBCDataset(lmdb_dataset) 268 | 269 | atoms = AtomDataset(pbc_dataset, "atoms") 270 | tags = KeywordDataset(pbc_dataset, "tags") 271 | real_mask = KeywordDataset(pbc_dataset, "real_mask") 272 | 273 | pos = KeywordDataset(pbc_dataset, "pos") 274 | 275 | relaxed_energy = KeywordDataset(pbc_dataset, "relaxed_energy", is_scalar=True) 276 | deltapos = KeywordDataset(pbc_dataset, "deltapos") 277 | 278 | dataset = NestedDictionaryDataset( 279 | { 280 | "net_input": { 281 | "pos": pos, 282 | "atoms": atoms, 283 | "tags": tags, 284 | "real_mask": real_mask, 285 | }, 286 | "targets": { 287 | "relaxed_energy": relaxed_energy, 288 | "deltapos": deltapos, 289 | }, 290 | }, 291 | sizes=[np.zeros(len(atoms))], 292 | ) 293 | 294 | if split == "train": 295 | dataset = EpochShuffleDataset( 296 | dataset, 297 | num_samples=len(atoms), 298 | seed=self.cfg.seed, 299 | ) 300 | 301 | print("| Loaded {} with {} samples".format(split, len(dataset))) 302 | self.datasets[split] = dataset 303 | -------------------------------------------------------------------------------- /deepgraph/data/wrapper.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | import torch 5 | import numpy as np 6 | import joblib 7 | from ogb.graphproppred import PygGraphPropPredDataset 8 | from ogb.lsc.pcqm4mv2_pyg import PygPCQM4Mv2Dataset 9 | from functools import lru_cache 10 | import pyximport 11 | import torch.distributed as dist 12 | 13 | pyximport.install(setup_args={"include_dirs": np.get_include()}) 14 | from . import algos 15 | import sys 16 | from tqdm import tqdm 17 | 18 | @torch.jit.script 19 | def convert_to_single_emb(x, offset: int = 10): 20 | feature_num = x.size(1) if len(x.size()) > 1 else 1 21 | feature_offset = 1 + torch.arange(0, feature_num * offset, offset, dtype=torch.long) 22 | x = x + feature_offset 23 | return x 24 | 25 | 26 | def encode_token_single_tensor(graph, atom_id_max, edge_id_max): 27 | cur_id = graph.x.shape[0] 28 | graph.cur_id = cur_id 29 | _, index = np.unique(graph.identifiers[0, :], return_index=True) 30 | index = np.sort(index) 31 | token_ids = graph.identifiers[2, index].unsqueeze(1) + atom_id_max + 1 32 | graph.x = torch.cat([graph.x, token_ids.repeat([1, graph.x.shape[1]])], 0).type(torch.int64) 33 | 34 | edges = torch.cat([graph.identifiers[0, :].unsqueeze(0) + cur_id, graph.identifiers[1, :].unsqueeze(0)], 0).long() 35 | graph.edge_index = torch.cat([graph.edge_index, edges, edges[[1, 0], :]], 1) 36 | 37 | if graph.edge_attr is not None: 38 | edge_attr = graph.identifiers[2, :].unsqueeze(1).repeat([1, graph.edge_attr.shape[1]]) + edge_id_max + 2 39 | graph.edge_attr = torch.cat([graph.edge_attr, edge_attr, edge_attr], 0).long() 40 | 41 | graph.sorted_adj = None 42 | 43 | graph.sub_adj_mask = torch.ones([graph.x.shape[0],1]).long() 44 | graph.sub_adj_mask[0:graph.cur_id]=0 45 | graph.sub_adj_mask=graph.sub_adj_mask.long() 46 | 47 | return graph 48 | 49 | 50 | def encode_token_single_tensor_with_adj(graph, local_attention_on_substructures): 51 | cur_id = graph.x.shape[0] 52 | graph.cur_id = cur_id 53 | 54 | _, index = np.unique(graph.identifiers[0, :], return_index=True) 55 | index = np.sort(index) 56 | token_ids = -1 * torch.ones_like(graph.identifiers[2, index].unsqueeze(1)) 57 | 58 | graph.sorted_adj = torch.cat([torch.zeros([graph.num_nodes, 59 | graph.sorted_adj.shape[1], 60 | graph.sorted_adj.shape[2]]), graph.sorted_adj], 0) 61 | 62 | graph.x = torch.cat([graph.x, token_ids.repeat([1, graph.x.shape[1]])], 0).type(torch.int64) 63 | graph.sub_adj_mask = (graph.x < 0).long() 64 | if graph.sub_adj_mask.shape[1] > 1: 65 | graph.sub_adj_mask = graph.sub_adj_mask[:, 0].unsqueeze(1) 66 | 67 | 68 | # expanded edges is used for attention mask, not for position embedding 69 | edges = torch.cat([graph.identifiers[0, :].unsqueeze(0) + cur_id, graph.identifiers[1, :].unsqueeze(0)], 0).long() 70 | graph.edge_index = torch.cat([graph.edge_index, edges, edges[[1, 0], :]], 1) 71 | 72 | if not local_attention_on_substructures: 73 | # expanded edge_attr is actually not used in local_attention_on_substructures 74 | edge_id = graph.edge_attr.shape[0] 75 | graph.edge_id = edge_id 76 | if graph.edge_attr is not None: 77 | edge_attr = (-1 * torch.ones_like(graph.identifiers[2].unsqueeze(1))) \ 78 | .repeat([1, graph.edge_attr.shape[1]]) \ 79 | .type(torch.int64) + edge_id_max + 2 80 | graph.edge_attr = torch.cat([graph.edge_attr, edge_attr, edge_attr], 0).long() 81 | 82 | return graph 83 | 84 | def graph_data_modification_single(data,substructures,sorted_adj=None): 85 | 86 | if data.x.dim()==1: 87 | data.x=data.x.unsqueeze(1) 88 | if data.y.dim() == 2: 89 | data.y = data.y.squeeze(1) 90 | if hasattr(data, 'edge_features'): 91 | data.edge_attr=data.edge_features 92 | del(data.edge_features) 93 | if hasattr(data,'degree'): 94 | del(data.degrees) 95 | if hasattr(data,'graph_size'): 96 | del (data.graph_size) 97 | if data.edge_attr is not None and data.edge_attr.dim()==1: 98 | data.edge_attr=data.edge_attr.unsqueeze(1) 99 | assert hasattr(data,'edge_attr') 100 | assert hasattr(data, 'edge_index') 101 | setattr(data,'identifiers', substructures) 102 | setattr(data,'sorted_adj',sorted_adj) 103 | 104 | return data 105 | 106 | 107 | 108 | def preprocess_item(item,local_attention_on_substructures=False,continuous_feature=False): 109 | max_dist_const=510 110 | substructure_dist_const=max_dist_const-1 111 | 112 | edge_attr, edge_index, x = item.edge_attr, item.edge_index, item.x 113 | 114 | N = x.size(0) 115 | x = convert_to_single_emb(x) 116 | 117 | # node adj matrix [N, N] bool 118 | adj_matrix = torch.zeros([N, N], dtype=torch.bool) 119 | adj_matrix[edge_index[0, :], edge_index[1, :]] = True 120 | if local_attention_on_substructures: 121 | 122 | adj = adj_matrix[0:item.cur_id, 0:item.cur_id] 123 | else: 124 | adj = adj_matrix[0:,0:] 125 | 126 | 127 | 128 | if edge_attr is None: 129 | attn_edge_type = -1*torch.ones([N, N, 1], dtype=torch.long) 130 | else: 131 | if len(edge_attr.size()) == 1: 132 | edge_attr = edge_attr[:, None] 133 | attn_edge_type = torch.zeros([N, N, edge_attr.size(-1)], dtype=torch.long) 134 | edge_attr=edge_attr.long() 135 | attn_edge_type[edge_index[0, :edge_attr.size(0)], edge_index[1, :edge_attr.size(0)]] = ( 136 | convert_to_single_emb(edge_attr) + 1 137 | ) 138 | 139 | 140 | #compute attention mask (attn_bias) 141 | if local_attention_on_substructures: 142 | adj_matrix = adj_matrix.float() 143 | adj_matrix[0:item.cur_id, 0:item.cur_id] = 1 144 | attn_bias=1-adj_matrix 145 | attn_bias[attn_bias>0]=float('-inf') 146 | attn_bias_res = torch.zeros([N + 1, N + 1], dtype=torch.float) 147 | attn_bias_res[1:,1:]=attn_bias 148 | attn_bias=attn_bias_res 149 | else: 150 | attn_bias = torch.zeros([N + 1, N + 1], dtype=torch.float) # with graph token 151 | 152 | 153 | 154 | shortest_path_result, path = algos.floyd_warshall(adj.numpy()) 155 | max_dist = np.amax(shortest_path_result) 156 | edge_input = algos.gen_edge_input(max_dist, path, attn_edge_type.numpy()) 157 | spatial_pos = torch.from_numpy((shortest_path_result)).long() 158 | 159 | #Expand the size of edge_input and spatial_pos 160 | if local_attention_on_substructures: 161 | edge_input_new=(-1*np.ones([N,N,edge_input.shape[2],edge_input.shape[3]])).astype(np.int64) 162 | edge_input_new[0:item.cur_id,0:item.cur_id,:,:]=edge_input 163 | edge_input=edge_input_new 164 | spatial_pos_new=(max_dist_const*torch.ones(N,N)).long() 165 | spatial_pos_new[0:item.cur_id,0:item.cur_id]=spatial_pos 166 | spatial_pos=spatial_pos_new 167 | 168 | 169 | # combine 170 | item.x = x.long() 171 | item.attn_bias = attn_bias 172 | item.attn_edge_type = attn_edge_type 173 | item.spatial_pos = spatial_pos 174 | item.in_degree = adj.long().sum(dim=1).view(-1) 175 | item.out_degree = item.in_degree # for undirected graph 176 | item.edge_input = torch.from_numpy(edge_input).long() 177 | 178 | 179 | return item 180 | 181 | 182 | 183 | def preprocess_item_local_attention(item,local_attention_on_substructures=False,continuous_feature=False): 184 | max_dist_const=510 185 | substructure_dist_const=max_dist_const-1 186 | 187 | edge_attr, edge_index, x = item.edge_attr, item.edge_index, item.x 188 | N = x.size(0) 189 | x = convert_to_single_emb(x) 190 | 191 | # node adj matrix [N, N] bool 192 | adj = torch.zeros([N, N], dtype=torch.bool) 193 | adj[edge_index[0, :], edge_index[1, :]] = True 194 | 195 | # edge feature here 196 | if edge_attr is None: 197 | attn_edge_type = -1*torch.ones([N, N, 1], dtype=torch.long) 198 | else: 199 | if len(edge_attr.size()) == 1: 200 | edge_attr = edge_attr[:, None] 201 | attn_edge_type = torch.zeros([N, N, edge_attr.size(-1)], dtype=torch.long) 202 | edge_attr=edge_attr.long() 203 | attn_edge_type[edge_index[0, :], edge_index[1, :]] = ( 204 | convert_to_single_emb(edge_attr) + 1 205 | ) 206 | 207 | 208 | if local_attention_on_substructures: 209 | adj_matrix = torch.sparse_coo_tensor(item.edge_index, torch.ones_like(item.edge_index[0]), 210 | [item.x.shape[0], item.x.shape[0]] 211 | ).to_dense().float() 212 | adj_matrix[0:item.cur_id, 0:item.cur_id] = 1 213 | attn_bias=1-adj_matrix 214 | attn_bias[attn_bias>0]=float('-inf') 215 | attn_bias_res = torch.zeros([N + 1, N + 1], dtype=torch.float) 216 | attn_bias_res[1:,1:]=attn_bias 217 | attn_bias=attn_bias_res 218 | else: 219 | attn_bias = torch.zeros([N + 1, N + 1], dtype=torch.float) # with graph token 220 | 221 | 222 | if local_attention_on_substructures: 223 | adj_new=torch.zeros_like(adj) 224 | adj_new[0:item.cur_id, 0:item.cur_id]=adj[0:item.cur_id, 0:item.cur_id] 225 | adj=adj_new 226 | shortest_path_result, path = algos.floyd_warshall(adj.numpy()) 227 | max_dist = np.amax(shortest_path_result) 228 | edge_input = algos.gen_edge_input(max_dist, path, attn_edge_type.numpy()) 229 | spatial_pos = torch.from_numpy((shortest_path_result)).long() 230 | 231 | if local_attention_on_substructures: 232 | mask=torch.tensor(shortest_path_result == max_dist_const) & adj_matrix.bool() 233 | shortest_path_result[mask]=substructure_dist_const 234 | 235 | # combine 236 | item.x = x.long() 237 | item.attn_bias = attn_bias 238 | item.attn_edge_type = attn_edge_type 239 | item.spatial_pos = spatial_pos 240 | item.in_degree = adj.long().sum(dim=1).view(-1) 241 | item.out_degree = item.in_degree # for undirected graph 242 | item.edge_input = torch.from_numpy(edge_input).long() 243 | 244 | 245 | return item 246 | 247 | 248 | class MyPygPCQM4MDataset(PygPCQM4Mv2Dataset): 249 | def download(self): 250 | super(MyPygPCQM4MDataset, self).download() 251 | 252 | def process(self): 253 | super(MyPygPCQM4MDataset, self).process() 254 | 255 | @lru_cache(maxsize=16) 256 | def __getitem__(self, idx): 257 | item = self.get(self.indices()[idx]) 258 | item.idx = idx 259 | return preprocess_item(item) 260 | 261 | 262 | class MyPygGraphPropPredDataset(PygGraphPropPredDataset): 263 | def download(self): 264 | if dist.get_rank() == 0: 265 | super(MyPygGraphPropPredDataset, self).download() 266 | dist.barrier() 267 | 268 | def process(self): 269 | if dist.get_rank() == 0: 270 | super(MyPygGraphPropPredDataset, self).process() 271 | dist.barrier() 272 | 273 | @lru_cache(maxsize=16) 274 | def __getitem__(self, idx): 275 | item = self.get(self.indices()[idx]) 276 | item.idx = idx 277 | item.y = item.y.reshape(-1) 278 | return preprocess_item(item) 279 | -------------------------------------------------------------------------------- /deepgraph/modules/deepgraph_graph_encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | # Copyright (c) Facebook, Inc. and its affiliates. 5 | # 6 | # This source code is licensed under the MIT license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | 9 | from typing import Optional, Tuple 10 | 11 | import torch 12 | import torch.nn as nn 13 | from fairseq.modules import FairseqDropout, LayerDropModuleList, LayerNorm 14 | from fairseq.modules.quant_noise import quant_noise as apply_quant_noise_ 15 | 16 | from .multihead_attention import MultiheadAttention 17 | from .deepgraph_layers import GraphNodeFeature, GraphAttnBias,SubStructure_Adj_Encoder 18 | from .deepgraph_graph_encoder_layer import DeepGraphGraphEncoderLayer 19 | 20 | 21 | def init_graphormer_params(module): 22 | """ 23 | Initialize the weights specific to the Graphormer Model. 24 | """ 25 | 26 | def normal_(data): 27 | # with FSDP, module params will be on CUDA, so we cast them back to CPU 28 | # so that the RNG is consistent with and without FSDP 29 | data.copy_(data.cpu().normal_(mean=0.0, std=0.02).to(data.device)) 30 | 31 | if isinstance(module, nn.Linear): 32 | normal_(module.weight.data) 33 | if module.bias is not None: 34 | module.bias.data.zero_() 35 | if isinstance(module, nn.Embedding): 36 | normal_(module.weight.data) 37 | if module.padding_idx is not None: 38 | module.weight.data[module.padding_idx].zero_() 39 | if isinstance(module, MultiheadAttention): 40 | normal_(module.q_proj.weight.data) 41 | normal_(module.k_proj.weight.data) 42 | normal_(module.v_proj.weight.data) 43 | 44 | 45 | class DeepGraphGraphEncoder(nn.Module): 46 | def __init__( 47 | self, 48 | num_atoms: int, 49 | num_in_degree: int, 50 | num_out_degree: int, 51 | num_edges: int, 52 | num_spatial: int, 53 | num_edge_dis: int, 54 | edge_type: str, 55 | multi_hop_max_dist: int, 56 | num_encoder_layers: int = 12, 57 | embedding_dim: int = 768, 58 | ffn_embedding_dim: int = 768, 59 | num_attention_heads: int = 32, 60 | dropout: float = 0.1, 61 | attention_dropout: float = 0.1, 62 | activation_dropout: float = 0.1, 63 | layerdrop: float = 0.0, 64 | encoder_normalize_before: bool = False, 65 | pre_layernorm: bool = False, 66 | apply_graphormer_init: bool = False, 67 | activation_fn: str = "gelu", 68 | embed_scale: float = None, 69 | freeze_embeddings: bool = False, 70 | n_trans_layers_to_freeze: int = 0, 71 | export: bool = False, 72 | traceable: bool = False, 73 | q_noise: float = 0.0, 74 | qn_block_size: int = 8, 75 | deepnorm:bool = False, 76 | encoder_layers: int = 12, 77 | encode_adj: bool = False, 78 | subgraph_max_size=None, 79 | reattention=False 80 | ) -> None: 81 | 82 | super().__init__() 83 | self.dropout_module = FairseqDropout( 84 | dropout, module_name=self.__class__.__name__ 85 | ) 86 | self.layerdrop = layerdrop 87 | self.embedding_dim = embedding_dim 88 | self.apply_graphormer_init = apply_graphormer_init 89 | self.traceable = traceable 90 | 91 | self.graph_node_feature = GraphNodeFeature( 92 | num_heads=num_attention_heads, 93 | num_atoms=num_atoms, 94 | num_in_degree=num_in_degree, 95 | num_out_degree=num_out_degree, 96 | hidden_dim=embedding_dim, 97 | n_layers=num_encoder_layers, 98 | ) 99 | 100 | self.graph_attn_bias = GraphAttnBias( 101 | num_heads=num_attention_heads, 102 | num_atoms=num_atoms, 103 | num_edges=num_edges, 104 | num_spatial=num_spatial, 105 | num_edge_dis=num_edge_dis, 106 | edge_type=edge_type, 107 | multi_hop_max_dist=multi_hop_max_dist, 108 | hidden_dim=embedding_dim, 109 | n_layers=num_encoder_layers, 110 | ) 111 | 112 | self.encode_adj=encode_adj 113 | if self.encode_adj: 114 | self.adj_encoder =SubStructure_Adj_Encoder(subgraph_max_size=subgraph_max_size, 115 | output_size=embedding_dim) 116 | 117 | 118 | self.embed_scale = embed_scale 119 | 120 | if q_noise > 0: 121 | self.quant_noise = apply_quant_noise_( 122 | nn.Linear(self.embedding_dim, self.embedding_dim, bias=False), 123 | q_noise, 124 | qn_block_size, 125 | ) 126 | else: 127 | self.quant_noise = None 128 | 129 | if encoder_normalize_before: 130 | self.emb_layer_norm = LayerNorm(self.embedding_dim, export=export) 131 | else: 132 | self.emb_layer_norm = None 133 | 134 | if pre_layernorm: 135 | self.final_layer_norm = LayerNorm(self.embedding_dim, export=export) 136 | 137 | if self.layerdrop > 0.0: 138 | self.layers = LayerDropModuleList(p=self.layerdrop) 139 | else: 140 | self.layers = nn.ModuleList([]) 141 | self.layers.extend( 142 | [ 143 | self.build_graphormer_graph_encoder_layer( 144 | embedding_dim=self.embedding_dim, 145 | ffn_embedding_dim=ffn_embedding_dim, 146 | num_attention_heads=num_attention_heads, 147 | dropout=self.dropout_module.p, 148 | attention_dropout=attention_dropout, 149 | activation_dropout=activation_dropout, 150 | activation_fn=activation_fn, 151 | export=export, 152 | q_noise=q_noise, 153 | qn_block_size=qn_block_size, 154 | pre_layernorm=pre_layernorm, 155 | deepnorm=deepnorm, 156 | encoder_layers=encoder_layers, 157 | reattention=reattention 158 | ) 159 | for _ in range(num_encoder_layers) 160 | ] 161 | ) 162 | 163 | # Apply initialization of model params after building the model 164 | if self.apply_graphormer_init: 165 | print('reiniting by graphormer_init') 166 | self.apply(init_graphormer_params) 167 | 168 | def freeze_module_params(m): 169 | if m is not None: 170 | for p in m.parameters(): 171 | p.requires_grad = False 172 | 173 | if freeze_embeddings: 174 | raise NotImplementedError("Freezing embeddings is not implemented yet.") 175 | 176 | for layer in range(n_trans_layers_to_freeze): 177 | freeze_module_params(self.layers[layer]) 178 | 179 | def build_graphormer_graph_encoder_layer( 180 | self, 181 | embedding_dim, 182 | ffn_embedding_dim, 183 | num_attention_heads, 184 | dropout, 185 | attention_dropout, 186 | activation_dropout, 187 | activation_fn, 188 | export, 189 | q_noise, 190 | qn_block_size, 191 | pre_layernorm, 192 | deepnorm, 193 | encoder_layers, 194 | reattention=False 195 | ): 196 | return DeepGraphGraphEncoderLayer( 197 | embedding_dim=embedding_dim, 198 | ffn_embedding_dim=ffn_embedding_dim, 199 | num_attention_heads=num_attention_heads, 200 | dropout=dropout, 201 | attention_dropout=attention_dropout, 202 | activation_dropout=activation_dropout, 203 | activation_fn=activation_fn, 204 | export=export, 205 | q_noise=q_noise, 206 | qn_block_size=qn_block_size, 207 | pre_layernorm=pre_layernorm, 208 | deepnorm=deepnorm, 209 | encoder_layers=encoder_layers, 210 | reattention=reattention 211 | ) 212 | 213 | def forward( 214 | self, 215 | batched_data, 216 | perturb=None, 217 | last_state_only: bool = False, 218 | token_embeddings: Optional[torch.Tensor] = None, 219 | attn_mask: Optional[torch.Tensor] = None, 220 | ) -> Tuple[torch.Tensor, torch.Tensor]: 221 | is_tpu = False 222 | # compute padding mask. This is needed for multi-head attention 223 | data_x = batched_data["x"] 224 | n_graph, n_node = data_x.size()[:2] 225 | padding_mask = (data_x[:, :, 0]).eq(0) # B x T x 1 226 | padding_mask_cls = torch.zeros( 227 | n_graph, 1, device=padding_mask.device, dtype=padding_mask.dtype 228 | ) 229 | padding_mask = torch.cat((padding_mask_cls, padding_mask), dim=1) 230 | # B x (T+1) x 1 231 | 232 | if token_embeddings is not None: 233 | x = token_embeddings 234 | else: 235 | x = self.graph_node_feature(batched_data) 236 | 237 | if self.encode_adj: 238 | sub_adj_mask=torch.cat([torch.zeros_like(batched_data['sub_adj_mask'][:,0,:].unsqueeze(1)), 239 | batched_data['sub_adj_mask']],1) 240 | index=(sub_adj_mask==2).type(x.dtype) 241 | start_index=index.sum(0).squeeze(1).nonzero().min() 242 | 243 | adj_emb=self.adj_encoder(batched_data['sorted_adj'][:,start_index:,:,:]) 244 | adj_emb=torch.cat([torch.zeros([adj_emb.shape[0],1+start_index, adj_emb.shape[2]]).to(adj_emb.device),adj_emb],1).type(x.dtype) 245 | 246 | x=x*(1-index)+adj_emb*index 247 | 248 | if perturb is not None: 249 | #ic(torch.mean(torch.abs(x[:, 1, :]))) 250 | #ic(torch.mean(torch.abs(perturb))) 251 | x[:, 1:, :] += perturb 252 | 253 | # x: B x T x C 254 | 255 | attn_bias = self.graph_attn_bias(batched_data) 256 | 257 | if self.embed_scale is not None: 258 | x = x * self.embed_scale 259 | 260 | if self.quant_noise is not None: 261 | x = self.quant_noise(x) 262 | 263 | if self.emb_layer_norm is not None: 264 | x = self.emb_layer_norm(x) 265 | 266 | x = self.dropout_module(x) 267 | 268 | # account for padding while computing the representation 269 | 270 | # B x T x C -> T x B x C 271 | x = x.transpose(0, 1) 272 | 273 | inner_states = [] 274 | if not last_state_only: 275 | inner_states.append(x) 276 | 277 | atts=[] 278 | 279 | for layer in self.layers: 280 | x, att = layer( 281 | x, 282 | self_attn_padding_mask=padding_mask, 283 | self_attn_mask=attn_mask, 284 | self_attn_bias=attn_bias, 285 | ) 286 | if not last_state_only: 287 | inner_states.append(x) 288 | atts.append(att) 289 | graph_rep = x[0, :, :] 290 | 291 | if last_state_only: 292 | inner_states = [x] 293 | 294 | if self.traceable: 295 | return torch.stack(inner_states), graph_rep,atts 296 | else: 297 | return inner_states, graph_rep,atts 298 | -------------------------------------------------------------------------------- /deepgraph/data/substructure_dataset_utils/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from .utils_graph_processing import subgraph_isomorphism_edge_counts, subgraph_isomorphism_vertex_counts, induced_edge_automorphism_orbits, edge_automorphism_orbits, automorphism_orbits,subgraph_isomorphism_vertex_extraction,subgraph_counts2ids,subgraph_2token 3 | 4 | import torch 5 | import torch.nn as nn 6 | import numpy as np 7 | from torch_geometric.data import Data 8 | import networkx as nx 9 | from torch_geometric.data import Data 10 | import glob 11 | import re 12 | import types 13 | from ast import literal_eval 14 | from omegaconf import open_dict 15 | import lmdb 16 | from tqdm import tqdm 17 | import time 18 | 19 | def get_custom_edge_list(ks, substructure_type=None, filename=None): 20 | ''' 21 | Instantiates a list of `edge_list`s representing substructures 22 | of type `substructure_type` with sizes specified by `ks`. 23 | ''' 24 | if substructure_type is None and filename is None: 25 | raise ValueError('You must specify either a type or a filename where to read substructures from.') 26 | edge_lists = [] 27 | for k in ks: 28 | if substructure_type is not None: 29 | graphs_nx = getattr(nx, substructure_type)(k) 30 | else: 31 | graphs_nx = nx.read_graph6(os.path.join(filename, 'graph{}c.g6'.format(k))) 32 | if isinstance(graphs_nx, list) or isinstance(graphs_nx, types.GeneratorType): 33 | edge_lists += [list(graph_nx.edges) for graph_nx in graphs_nx] 34 | else: 35 | edge_lists.append(list(graphs_nx.edges)) 36 | return edge_lists 37 | 38 | 39 | def process_arguments_substructure(args): 40 | with open_dict(args): 41 | args['k']=literal_eval(args['ks']) 42 | args['subgraph_max_size']=0 43 | tem= (args['id_type']) 44 | del args['id_type'] 45 | args['id_type']=tem.split('+') 46 | tem= (args['must_select_sub']) 47 | del args['must_select_sub'] 48 | args['must_select_sub']=tem.split('+') 49 | 50 | 51 | with open_dict(args): 52 | del args['custom_edge_list'] 53 | args['custom_edge_list']=[] 54 | 55 | if args['extra_method']=='feature': 56 | extract_id_fn = subgraph_counts2ids 57 | else: 58 | extract_id_fn = subgraph_2token 59 | ###### choose the function that computes the automorphism group and the orbits ####### 60 | if args['edge_automorphism'] == 'induced': 61 | automorphism_fn = induced_edge_automorphism_orbits if args['id_scope'] == 'local' else automorphism_orbits 62 | elif args['edge_automorphism'] == 'line_graph': 63 | automorphism_fn = edge_automorphism_orbits if args['id_scope'] == 'local' else automorphism_orbits 64 | else: 65 | raise NotImplementedError 66 | 67 | ###### choose the function that computes the subgraph isomorphisms ####### 68 | if args['extra_method']=='feature': 69 | count_fn = subgraph_isomorphism_edge_counts if args['id_scope'] == 'local' else subgraph_isomorphism_vertex_counts 70 | else: 71 | count_fn = subgraph_isomorphism_vertex_extraction 72 | 73 | ###### choose the substructures: usually loaded from networkx, 74 | ###### except for 'all_simple_graphs' where they need to be precomputed, 75 | ###### or when a custom edge list is provided in the input by the user 76 | with open_dict(args): 77 | del (args['custom_edge_list']) 78 | args['custom_edge_list']=[] 79 | args['must_select_list']=[] 80 | args['neighbor_type']=[] 81 | args['neighbor_size'] = [] 82 | args['pre_defined_path']=args['dataset_name']+'_'+'substructure' 83 | args['node_neighbor_path']=args['dataset_name']+'_'+'neighbor' 84 | args['transform_cache_path']=args['dataset_name']+'_'+args['sampling_mode']+'_'+str(args['sampling_redundancy'])+'_'+str(args['transform_cache_number']) 85 | 86 | for number,id_type in enumerate(args['id_type']): 87 | 88 | if id_type in ['cycle_graph','complete_graph', 89 | 'binomial_tree', 90 | 'star_graph', 91 | 'nonisomorphic_trees']: 92 | # args['k'] = args['k'][0] 93 | k_max = args['k'][number] 94 | k_min = 2 if id_type == 'star_graph' else 3 95 | args['subgraph_max_size']=max(args['subgraph_max_size'],k_max+1 if id_type=='star_graph' else k_max) 96 | with open_dict(args): 97 | add_sub=get_custom_edge_list(list(range(k_min, k_max + 1)), id_type) 98 | if id_type in args['must_select_sub']: 99 | args['must_select_list']+=list(range(len(args['custom_edge_list']), 100 | len(args['custom_edge_list'])+len(add_sub) 101 | )) 102 | args['custom_edge_list'] +=add_sub 103 | args['pre_defined_path'] +='_'+id_type+'_'+str(args['k'][number]) 104 | args['transform_cache_path'] += '_'+id_type+'_'+str(args['k'][number]) 105 | 106 | 107 | elif id_type in ['path_graph']: 108 | k_min = args['k'][number] 109 | k_max = 8 110 | args['subgraph_max_size'] = max(args['subgraph_max_size'], k_max+1) 111 | with open_dict(args): 112 | add_sub=get_custom_edge_list(list(range(k_min, k_max + 1)), id_type) 113 | if id_type in args['must_select_sub']: 114 | args['must_select_list']+=list(range(len(args['custom_edge_list']), 115 | len(args['custom_edge_list'])+len(add_sub) 116 | )) 117 | args['custom_edge_list'] +=add_sub 118 | args['pre_defined_path'] += '_' + id_type + '_' + str(args['k'][number]) 119 | args['transform_cache_path'] += '_' + id_type + '_' + str(args['k'][number]) 120 | 121 | elif id_type in ['k_neighborhood', 'random_walk']: 122 | with open_dict(args): 123 | args['neighbor_type'].append(id_type) 124 | args['neighbor_size'].append(args['k'][number]) 125 | args['node_neighbor_path']+='_'+id_type+'_'+str(args['k'][number]) 126 | args['transform_cache_path'] += '_' + id_type + '_' + str(args['k'][number]) 127 | 128 | 129 | 130 | elif id_type in ['cycle_graph_chosen_k', 131 | 'path_graph_chosen_k', 132 | 'complete_graph_chosen_k', 133 | 'binomial_tree_chosen_k', 134 | 'star_graph_chosen_k', 135 | 'nonisomorphic_trees_chosen_k']: 136 | print('warning: not complete subgraph_max_size for ',id_type) 137 | with open_dict(args): 138 | args['custom_edge_list'] += get_custom_edge_list(args['k'], id_type.replace('_chosen_k', '')) 139 | 140 | elif id_type in ['all_simple_graphs']: 141 | # args['k'] = args['k'][0] 142 | k_max = args['k'][number] 143 | k_min = 3 144 | args['subgraph_max_size'] = max(args['subgraph_max_size'], k_max) 145 | filename = os.path.join(args['root_folder'], 'all_simple_graphs') 146 | with open_dict(args): 147 | args['custom_edge_list'] += get_custom_edge_list(list(range(k_min, k_max + 1)), filename=filename) 148 | 149 | elif id_type in ['all_simple_graphs_chosen_k']: 150 | print('warning: not complete subgraph_max_size for ', id_type) 151 | filename = os.path.join(args['root_folder'], 'all_simple_graphs') 152 | with open_dict(args): 153 | args['custom_edge_list'] += get_custom_edge_list(args['k'], filename=filename) 154 | 155 | elif id_type in ['diamond_graph']: 156 | print('warning: not complete subgraph_max_size for ', id_type) 157 | # args['k'] = None 158 | graph_nx = nx.diamond_graph() 159 | with open_dict(args): 160 | args['custom_edge_list'] += [list(graph_nx.edges)] 161 | 162 | elif id_type == 'custom': 163 | print('warning: not complete subgraph_max_size for ', id_type) 164 | assert args['custom_edge_list'] is not None, "Custom edge list must be provided." 165 | 166 | else: 167 | raise NotImplementedError("Identifiers {} are not currently supported.".format(id_type)) 168 | 169 | return args, extract_id_fn, count_fn, automorphism_fn 170 | 171 | 172 | 173 | 174 | 175 | 176 | def transfer_subgraph_to_batchtensor_complete(substructures): 177 | id_lis=[] 178 | note_lis=[] 179 | type_list = [] 180 | cur_id=0 181 | for id_type,subtype in enumerate(substructures): 182 | for data in subtype: 183 | id_lis=id_lis+[cur_id]*len(data) 184 | type_list=type_list+[int(id_type)]*len(data) 185 | note_lis=note_lis+list(data) 186 | cur_id += 1 187 | 188 | return note_lis,id_lis,type_list 189 | 190 | 191 | def _prepare_process(data, subgraph_dicts, subgraph_params,ex_fn, cnt_fn): 192 | if data.edge_index.shape[1] == 0 and cnt_fn.__name__ == 'subgraph_isomorphism_edge_counts': 193 | setattr(data, 'identifiers', torch.zeros((0, sum(orbit_partition_sizes))).long()) 194 | else: 195 | new_data = ex_fn(cnt_fn, data, subgraph_dicts, subgraph_params) 196 | 197 | return new_data 198 | 199 | 200 | def substructure_to_gt(subgraph_params,automorphism_fn): 201 | subgraph_dicts = [] 202 | if 'edge_list' not in subgraph_params: 203 | raise ValueError('Edge list not provided.') 204 | for edge_list in subgraph_params['edge_list']: 205 | subgraph, orbit_partition, orbit_membership, aut_count = \ 206 | automorphism_fn(edge_list=edge_list, 207 | only_graph=True, 208 | directed=subgraph_params['directed'], 209 | directed_orbits=subgraph_params['directed_orbits']) 210 | subgraph_dicts.append({'subgraph': subgraph}) 211 | return subgraph_dicts 212 | 213 | 214 | 215 | 216 | 217 | 218 | def load_dataset(data_file): 219 | ''' 220 | Loads dataset from `data_file`. 221 | ''' 222 | print("Loading dataset from {}".format(data_file)) 223 | dataset_obj = torch.load(data_file) 224 | graphs_ptg = dataset_obj[0] 225 | num_classes = dataset_obj[1] 226 | orbit_partition_sizes = dataset_obj[2] 227 | 228 | return graphs_ptg, num_classes, orbit_partition_sizes 229 | 230 | 231 | def try_downgrading(data_folder, id_type, induced, directed_orbits, k, k_min): 232 | ''' 233 | Extracts the substructures of size up to the `k`, if a collection of substructures 234 | with size larger than k has already been computed. 235 | ''' 236 | found_data_filename, k_found = find_id_filename(data_folder, id_type, induced, directed_orbits, k) 237 | if found_data_filename is not None: 238 | graphs_ptg, num_classes, orbit_partition_sizes = load_dataset(found_data_filename) 239 | print("Downgrading k from dataset {}...".format(found_data_filename)) 240 | graphs_ptg, orbit_partition_sizes = downgrade_k(graphs_ptg, k, orbit_partition_sizes, k_min) 241 | return True, graphs_ptg, num_classes, orbit_partition_sizes 242 | else: 243 | return False, None, None, None 244 | 245 | 246 | def find_id_filename(data_folder, id_type, induced, directed_orbits, k): 247 | ''' 248 | Looks for existing precomputed datasets in `data_folder` with counts for substructure 249 | `id_type` larger `k`. 250 | ''' 251 | if induced: 252 | if directed_orbits: 253 | pattern = os.path.join(data_folder, '{}_induced_directed_orbits_[0-9]*.pt'.format(id_type)) 254 | else: 255 | pattern = os.path.join(data_folder, '{}_induced_[0-9]*.pt'.format(id_type)) 256 | else: 257 | if directed_orbits: 258 | pattern = os.path.join(data_folder, '{}_directed_orbits_[0-9]*.pt'.format(id_type)) 259 | else: 260 | pattern = os.path.join(data_folder, '{}_[0-9]*.pt'.format(id_type)) 261 | filenames = glob.glob(pattern) 262 | for name in filenames: 263 | k_found = int(re.findall(r'\d+', name)[-1]) 264 | if k_found >= k: 265 | return name, k_found 266 | return None, None 267 | 268 | def downgrade_k(dataset, k, orbit_partition_sizes, k_min): 269 | ''' 270 | Donwgrades `dataset` by keeping only the orbits of the requested substructures. 271 | ''' 272 | feature_vector_size = sum(orbit_partition_sizes[0:k-k_min+1]) 273 | graphs_ptg = list() 274 | for data in dataset: 275 | new_data = Data() 276 | for attr in data.__iter__(): 277 | name, value = attr 278 | setattr(new_data, name, value) 279 | setattr(new_data, 'identifiers', data.identifiers[:, 0:feature_vector_size]) 280 | graphs_ptg.append(new_data) 281 | return graphs_ptg, orbit_partition_sizes[0:k-k_min+1] 282 | 283 | -------------------------------------------------------------------------------- /deepgraph/data/ogb_datasets/ogb_dataset_lookup_table.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | from typing import Optional 5 | from ogb.lsc.pcqm4mv2_pyg import PygPCQM4Mv2Dataset 6 | from ogb.lsc.pcqm4m_pyg import PygPCQM4MDataset 7 | from ogb.graphproppred import PygGraphPropPredDataset 8 | from torch_geometric.data import Dataset 9 | from ..substructure_dataset import SubstructureDataset 10 | import torch.distributed as dist 11 | import os 12 | import torch 13 | from torch_scatter import scatter 14 | import numpy as np 15 | from torchvision import transforms 16 | 17 | 18 | 19 | class MyPygPCQM4Mv2Dataset(PygPCQM4Mv2Dataset): 20 | def download(self): 21 | if not dist.is_initialized() or dist.get_rank() == 0: 22 | super(MyPygPCQM4Mv2Dataset, self).download() 23 | if dist.is_initialized(): 24 | dist.barrier() 25 | 26 | def process(self): 27 | if not dist.is_initialized() or dist.get_rank() == 0: 28 | super(MyPygPCQM4Mv2Dataset, self).process() 29 | if dist.is_initialized(): 30 | dist.barrier() 31 | 32 | 33 | class MyPygPCQM4MDataset(PygPCQM4MDataset): 34 | def download(self): 35 | if not dist.is_initialized() or dist.get_rank() == 0: 36 | super(MyPygPCQM4MDataset, self).download() 37 | if dist.is_initialized(): 38 | dist.barrier() 39 | 40 | def process(self): 41 | if not dist.is_initialized() or dist.get_rank() == 0: 42 | super(MyPygPCQM4MDataset, self).process() 43 | if dist.is_initialized(): 44 | dist.barrier() 45 | 46 | 47 | class MyPygGraphPropPredDataset(PygGraphPropPredDataset): 48 | def download(self): 49 | if not dist.is_initialized() or dist.get_rank() == 0: 50 | super(MyPygGraphPropPredDataset, self).download() 51 | if dist.is_initialized(): 52 | dist.barrier() 53 | 54 | def process(self): 55 | if not dist.is_initialized() or dist.get_rank() == 0: 56 | super(MyPygGraphPropPredDataset, self).process() 57 | if dist.is_initialized(): 58 | dist.barrier() 59 | 60 | 61 | 62 | def extract_node_feature(data, reduce='add'): 63 | if reduce in ['mean', 'max', 'add']: 64 | data.x = scatter(data.edge_attr, 65 | data.edge_index[0], 66 | dim=0, 67 | dim_size=data.num_nodes, 68 | reduce=reduce) 69 | else: 70 | raise Exception('Unknown Aggregation Type') 71 | return data 72 | 73 | 74 | def get_vocab_mapping(seq_list, num_vocab): 75 | ''' 76 | Input: 77 | seq_list: a list of sequences 78 | num_vocab: vocabulary size 79 | Output: 80 | vocab2idx: 81 | A dictionary that maps vocabulary into integer index. 82 | Additioanlly, we also index '__UNK__' and '__EOS__' 83 | '__UNK__' : out-of-vocabulary term 84 | '__EOS__' : end-of-sentence 85 | idx2vocab: 86 | A list that maps idx to actual vocabulary. 87 | ''' 88 | 89 | vocab_cnt = {} 90 | vocab_list = [] 91 | for seq in seq_list: 92 | for w in seq: 93 | if w in vocab_cnt: 94 | vocab_cnt[w] += 1 95 | else: 96 | vocab_cnt[w] = 1 97 | vocab_list.append(w) 98 | 99 | cnt_list = np.array([vocab_cnt[w] for w in vocab_list]) 100 | topvocab = np.argsort(-cnt_list, kind = 'stable')[:num_vocab] 101 | 102 | print('Coverage of top {} vocabulary:'.format(num_vocab)) 103 | print(float(np.sum(cnt_list[topvocab]))/np.sum(cnt_list)) 104 | 105 | vocab2idx = {vocab_list[vocab_idx]: idx for idx, vocab_idx in enumerate(topvocab)} 106 | idx2vocab = [vocab_list[vocab_idx] for vocab_idx in topvocab] 107 | 108 | # print(topvocab) 109 | # print([vocab_list[v] for v in topvocab[:10]]) 110 | # print([vocab_list[v] for v in topvocab[-10:]]) 111 | 112 | vocab2idx['__UNK__'] = num_vocab 113 | idx2vocab.append('__UNK__') 114 | 115 | vocab2idx['__EOS__'] = num_vocab + 1 116 | idx2vocab.append('__EOS__') 117 | 118 | # test the correspondence between vocab2idx and idx2vocab 119 | for idx, vocab in enumerate(idx2vocab): 120 | assert(idx == vocab2idx[vocab]) 121 | 122 | # test that the idx of '__EOS__' is len(idx2vocab) - 1. 123 | # This fact will be used in decode_arr_to_seq, when finding __EOS__ 124 | assert(vocab2idx['__EOS__'] == len(idx2vocab) - 1) 125 | 126 | return vocab2idx, idx2vocab 127 | 128 | 129 | def augment_edge(data): 130 | ''' 131 | Input: 132 | data: PyG data object 133 | Output: 134 | data (edges are augmented in the following ways): 135 | data.edge_index: Added next-token edge. The inverse edges were also added. 136 | data.edge_attr (torch.Long): 137 | data.edge_attr[:,0]: whether it is AST edge (0) for next-token edge (1) 138 | data.edge_attr[:,1]: whether it is original direction (0) or inverse direction (1) 139 | ''' 140 | 141 | ##### AST edge 142 | edge_index_ast = data.edge_index 143 | edge_attr_ast = torch.zeros((edge_index_ast.size(1), 2)) 144 | 145 | ##### Inverse AST edge 146 | edge_index_ast_inverse = torch.stack([edge_index_ast[1], edge_index_ast[0]], dim=0) 147 | edge_attr_ast_inverse = torch.cat( 148 | [torch.zeros(edge_index_ast_inverse.size(1), 1), torch.ones(edge_index_ast_inverse.size(1), 1)], dim=1) 149 | 150 | ##### Next-token edge 151 | 152 | ## Obtain attributed nodes and get their indices in dfs order 153 | # attributed_node_idx = torch.where(data.node_is_attributed.view(-1,) == 1)[0] 154 | # attributed_node_idx_in_dfs_order = attributed_node_idx[torch.argsort(data.node_dfs_order[attributed_node_idx].view(-1,))] 155 | 156 | ## Since the nodes are already sorted in dfs ordering in our case, we can just do the following. 157 | attributed_node_idx_in_dfs_order = torch.where(data.node_is_attributed.view(-1, ) == 1)[0] 158 | 159 | ## build next token edge 160 | # Given: attributed_node_idx_in_dfs_order 161 | # [1, 3, 4, 5, 8, 9, 12] 162 | # Output: 163 | # [[1, 3, 4, 5, 8, 9] 164 | # [3, 4, 5, 8, 9, 12] 165 | edge_index_nextoken = torch.stack([attributed_node_idx_in_dfs_order[:-1], attributed_node_idx_in_dfs_order[1:]], 166 | dim=0) 167 | edge_attr_nextoken = torch.cat( 168 | [torch.ones(edge_index_nextoken.size(1), 1), torch.zeros(edge_index_nextoken.size(1), 1)], dim=1) 169 | 170 | ##### Inverse next-token edge 171 | edge_index_nextoken_inverse = torch.stack([edge_index_nextoken[1], edge_index_nextoken[0]], dim=0) 172 | edge_attr_nextoken_inverse = torch.ones((edge_index_nextoken.size(1), 2)) 173 | 174 | data.edge_index = torch.cat( 175 | [edge_index_ast, edge_index_ast_inverse, edge_index_nextoken, edge_index_nextoken_inverse], dim=1) 176 | data.edge_attr = torch.cat([edge_attr_ast, edge_attr_ast_inverse, edge_attr_nextoken, edge_attr_nextoken_inverse], 177 | dim=0) 178 | 179 | return data 180 | 181 | def encode_seq_to_arr(seq, vocab2idx, max_seq_len): 182 | ''' 183 | Input: 184 | seq: A list of words 185 | output: add y_arr (torch.Tensor) 186 | ''' 187 | 188 | augmented_seq = seq[:max_seq_len] + ['__EOS__'] * max(0, max_seq_len - len(seq)) 189 | return torch.tensor([[vocab2idx[w] if w in vocab2idx else vocab2idx['__UNK__'] for w in augmented_seq]], dtype = torch.long) 190 | 191 | 192 | def encode_code2(data, vocab2idx, max_seq_len): 193 | ''' 194 | Input: 195 | data: PyG graph object 196 | output: add y_arr to data 197 | ''' 198 | 199 | # PyG >= 1.5.0 200 | data.x=torch.cat([data.x[:,0].unsqueeze(1),data.node_depth,data.x[:,1].unsqueeze(1)],1) 201 | seq = data.y 202 | 203 | # PyG = 1.4.3 204 | # seq = data.y[0] 205 | data.y_ori=seq 206 | data.y = encode_seq_to_arr(seq, vocab2idx, max_seq_len) 207 | 208 | return data 209 | 210 | 211 | 212 | class OGBDatasetLookupTable: 213 | @staticmethod 214 | def GetOGBDataset(dataset_name: str, seed: int,args=None,**kwargs) -> Optional[Dataset]: 215 | 216 | inner_dataset = None 217 | train_idx = None 218 | valid_idx = None 219 | test_idx = None 220 | if dataset_name == "ogbg-molhiv": 221 | folder_name = dataset_name.replace("-", "_") 222 | os.system(f"mkdir -p {os.path.join(kwargs['data_dir'],folder_name)}") 223 | os.system(f"touch {os.path.join(kwargs['data_dir'],folder_name,'RELEASE_v1.txt')}") 224 | inner_dataset = MyPygGraphPropPredDataset(dataset_name,root=kwargs['data_dir']) 225 | idx_split = inner_dataset.get_idx_split() 226 | train_idx = idx_split["train"] 227 | valid_idx = idx_split["valid"] 228 | test_idx = idx_split["test"] 229 | elif dataset_name == "ogbg-molpcba": 230 | folder_name = dataset_name.replace("-", "_") 231 | os.system(f"mkdir -p {os.path.join(kwargs['data_dir'],folder_name)}") 232 | os.system(f"touch {os.path.join(kwargs['data_dir'],folder_name,'RELEASE_v1.txt')}") 233 | inner_dataset = MyPygGraphPropPredDataset(dataset_name,root=kwargs['data_dir']) 234 | idx_split = inner_dataset.get_idx_split() 235 | train_idx = idx_split["train"] 236 | valid_idx = idx_split["valid"] 237 | test_idx = idx_split["test"] 238 | elif dataset_name == "pcqm4mv2": 239 | os.system(f"mkdir -p {os.path.join(kwargs['data_dir'],'pcqm4m-v2')}") 240 | os.system(f"touch {os.path.join(kwargs['data_dir'],'pcqm4m-v2','RELEASE_v1.txt')}") 241 | inner_dataset = MyPygPCQM4Mv2Dataset() 242 | idx_split = inner_dataset.get_idx_split() 243 | train_idx = idx_split["train"] 244 | valid_idx = idx_split["valid"] 245 | test_idx = idx_split["test-dev"] 246 | elif dataset_name == "pcqm4m": 247 | os.system(f"mkdir -p {os.path.join(kwargs['data_dir'],'pcqm4m_kddcup2021')}") 248 | os.system(f"touch {os.path.join(kwargs['data_dir'],'pcqm4m_kddcup2021','RELEASE_v1.txt')}") 249 | inner_dataset = MyPygPCQM4MDataset(root=kwargs['data_dir']) 250 | idx_split = inner_dataset.get_idx_split() 251 | train_idx = idx_split["train"] 252 | valid_idx = idx_split["valid"] 253 | test_idx = idx_split["test"] 254 | elif dataset_name == 'pcqm4m_contrust_pretraining': 255 | os.system(f"mkdir -p {os.path.join(kwargs['data_dir'],'pcqm4m_kddcup2021')}") 256 | os.system(f"touch {os.path.join(kwargs['data_dir'],'pcqm4m_kddcup2021','RELEASE_v1.txt')}") 257 | inner_dataset = MyPygPCQM4MDataset(root=kwargs['data_dir']) 258 | idx_split = inner_dataset.get_idx_split() 259 | train_idx = idx_split["train"] 260 | valid_idx = idx_split["valid"] 261 | test_idx = idx_split["test"] 262 | 263 | elif dataset_name in ['ogbg-ppa']: 264 | from functools import partial 265 | transform = partial(extract_node_feature, reduce='add') 266 | folder_name = dataset_name.replace("-", "_") 267 | os.system(f"mkdir -p {os.path.join(kwargs['data_dir'],folder_name)}") 268 | os.system(f"touch {os.path.join(kwargs['data_dir'],folder_name,'RELEASE_v1.txt')}") 269 | inner_dataset = MyPygGraphPropPredDataset(name=dataset_name, root=kwargs['data_dir'], transform=transform) 270 | idx_split = inner_dataset.get_idx_split() 271 | train_idx = idx_split["train"] 272 | valid_idx = idx_split["valid"] 273 | test_idx = idx_split["test"] 274 | 275 | elif dataset_name in ['ogbg-code2']: 276 | inner_dataset = PygGraphPropPredDataset(name=dataset_name, root=kwargs['data_dir']) 277 | idx_split = inner_dataset.get_idx_split() 278 | train_idx = idx_split["train"] 279 | valid_idx = idx_split["valid"] 280 | test_idx = idx_split["test"] 281 | split_idx = inner_dataset.get_idx_split() 282 | # 283 | # ### building vocabulary for sequence predition. Only use training data. 284 | vocab2idx, idx2vocab = get_vocab_mapping([inner_dataset.data.y[i] for i in split_idx['train']], 5000) 285 | # 286 | # ### set the transform function 287 | # # augment_edge: add next-token edge as well as inverse edges. add edge attributes. 288 | # # encode_y_to_arr: add y_arr to PyG data object, indicating the array representation of a sequence. 289 | inner_dataset.transform = transforms.Compose([ 290 | augment_edge, lambda data: encode_code2(data, vocab2idx, 5) 291 | ]) 292 | 293 | else: 294 | raise ValueError(f"Unknown dataset name {dataset_name} for ogb source.") 295 | 296 | 297 | result=None if inner_dataset is None else SubstructureDataset( 298 | inner_dataset, seed, train_idx, valid_idx, test_idx,args=args 299 | ) 300 | 301 | return result 302 | -------------------------------------------------------------------------------- /deepgraph/data/subsampling/sampling.py: -------------------------------------------------------------------------------- 1 | import networkx as nx 2 | import torch 3 | import numpy as np 4 | import time 5 | 6 | 7 | 8 | def random_sampling(subgraphs_nodes, rate=0.5, minimum_redundancy=0, num_nodes=None): 9 | if num_nodes is None: 10 | num_nodes = subgraphs_nodes[1].max() + 1 11 | while True: 12 | selected = np.random.choice(num_nodes, int(num_nodes*rate), replace=False) 13 | node_selected_times = torch.bincount(subgraphs_nodes[1][check_values_in_set(subgraphs_nodes[0], selected)], minlength=num_nodes) 14 | if node_selected_times.min() >= minimum_redundancy: 15 | # rate += 0.1 # enlarge the sampling rate 16 | break 17 | return selected, node_selected_times 18 | 19 | 20 | def random_sampling_general(subgraphs_nodes, rate=0.5, minimum_redundancy=2,num_nodes=None,must_select_list=None): 21 | 22 | subgraphs_nodes_mask=transfer_edge_index_to_mask(subgraphs_nodes,num_nodes) # number_subgraph* number_nodes True when node j in subgraph i 23 | 24 | 25 | 26 | node_selected_times = torch.zeros(num_nodes) 27 | selected_all = [] 28 | 29 | if must_select_list is not None: 30 | for selected in must_select_list: 31 | try: 32 | selected_all.append(selected) 33 | node_selected_times[subgraphs_nodes_mask[selected]] += 1 34 | except: 35 | print('selected ',selected) 36 | print('num_nodes ',num_nodes) 37 | print('subgraphs_nodes_mask',subgraphs_nodes_mask) 38 | raise ValueError("error ") 39 | 40 | num_subgraphs = int(subgraphs_nodes[0].max())+1 41 | init_rate = min(subgraphs_nodes.shape[0]/num_subgraphs,1) 42 | rate=min(init_rate*minimum_redundancy,1) 43 | 44 | for i in range(int(subgraphs_nodes[0].max())+1): 45 | selected_all = np.random.choice(num_subgraphs, int(num_subgraphs*rate), replace=False) 46 | node_selected_times = subgraphs_nodes_mask[selected_all].sum(0) 47 | 48 | if node_selected_times.min() >= minimum_redundancy or rate >=1: 49 | break 50 | 51 | rate=min(rate+init_rate,1) 52 | 53 | 54 | 55 | 56 | return list(selected_all), node_selected_times 57 | 58 | 59 | # Approach 1: based on shortets path distance 60 | def shortest_path_sampling(edge_index, subgraphs_nodes, stride=2, minimum_redundancy=0, random_init=False, num_nodes=None): 61 | if num_nodes is None: 62 | num_nodes = subgraphs_nodes[1].max() + 1 63 | G = nx.from_edgelist(edge_index.t().numpy()) 64 | G.add_nodes_from(range(num_nodes)) 65 | if random_init: 66 | farthest = np.random.choice(num_nodes) # here can also choose the one with highest degree 67 | else: 68 | subgraph_size = torch.bincount(subgraphs_nodes[0], minlength=num_nodes) 69 | farthest = subgraph_size.argmax().item() 70 | 71 | distance = np.ones(num_nodes)*1e10 72 | selected = [] 73 | node_selected_times = torch.zeros(num_nodes) 74 | 75 | for i in range(num_nodes): 76 | selected.append(farthest) 77 | node_selected_times[subgraphs_nodes[1][subgraphs_nodes[0] == farthest]] += 1 78 | length_shortest_dict = nx.single_source_shortest_path_length(G, farthest) 79 | length_shortest = np.ones(num_nodes)*1e10 80 | length_shortest[list(length_shortest_dict.keys())] = list(length_shortest_dict.values()) 81 | mask = length_shortest < distance 82 | distance[mask] = length_shortest[mask] 83 | 84 | if (distance.max() < stride) and (node_selected_times.min() >= minimum_redundancy): # stop criterion 85 | break 86 | farthest = np.argmax(distance) 87 | return selected, node_selected_times 88 | 89 | 90 | # Approach 1: based on shortets path distance 91 | def shortest_path_sampling_general(edge_index, subgraphs_nodes, stride=2, minimum_redundancy=0, random_init=False, 92 | num_nodes=None): 93 | if num_nodes is None: 94 | num_nodes = subgraphs_nodes[1].max() + 1 95 | G = nx.from_edgelist(edge_index.t().numpy()) 96 | G.add_nodes_from(range(num_nodes)) 97 | if random_init: 98 | farthest = np.random.choice(num_nodes) # here can also choose the one with highest degree 99 | else: 100 | subgraph_size = torch.bincount(subgraphs_nodes[0], minlength=num_nodes) 101 | farthest = subgraph_size.argmax().item() 102 | 103 | distance = np.ones(num_nodes) * 1e10 104 | selected = [] 105 | node_selected_times = torch.zeros(num_nodes) 106 | 107 | for i in range(num_nodes): 108 | selected.append(farthest) 109 | node_selected_times[subgraphs_nodes[1][subgraphs_nodes[0] == farthest]] += 1 110 | length_shortest_dict = nx.single_source_shortest_path_length(G, farthest) 111 | length_shortest = np.ones(num_nodes) * 1e10 112 | length_shortest[list(length_shortest_dict.keys())] = list(length_shortest_dict.values()) 113 | mask = length_shortest < distance 114 | distance[mask] = length_shortest[mask] 115 | 116 | if (distance.max() < stride) and (node_selected_times.min() >= minimum_redundancy): # stop criterion 117 | break 118 | farthest = np.argmax(distance) 119 | return selected, node_selected_times 120 | 121 | 122 | 123 | def check_values_in_set(x, set, approach=1): 124 | assert min(x.shape) > 0 125 | assert min(set.shape) > 0 126 | if approach == 0: 127 | mask = sum(x==i for i in set) 128 | else: 129 | mapper = torch.zeros(max(x.max()+1, set.max()+1), dtype=torch.bool) 130 | mapper[set] = True 131 | mask = mapper[x] 132 | return mask 133 | 134 | ############################################### use dense input ################################################## 135 | ### this part is hard to change 136 | 137 | def min_set_cover_sampling(edge_index, subgraphs_nodes_mask, random_init=False, minimum_redundancy=2): 138 | 139 | num_nodes = subgraphs_nodes_mask.size(0) 140 | if random_init: 141 | selected = np.random.choice(num_nodes) 142 | else: 143 | selected = subgraphs_nodes_mask.sum(-1).argmax().item() # choose the maximum subgraph size one to remove randomness 144 | 145 | node_selected_times = torch.zeros(num_nodes) 146 | selected_all = [] 147 | 148 | for i in range(num_nodes): 149 | # selected_subgraphs[selected] = True 150 | selected_all.append(selected) 151 | node_selected_times[subgraphs_nodes_mask[selected]] += 1 152 | if node_selected_times.min() >= minimum_redundancy: # stop criterion 153 | break 154 | # calculate how many unused nodes in each subgraph (greedy set cover) 155 | unused_nodes = ~ ((node_selected_times - node_selected_times.min()).bool()) 156 | num_unused_nodes = (subgraphs_nodes_mask & unused_nodes).sum(-1) 157 | scores = num_unused_nodes 158 | scores[selected_all] = 0 159 | selected = np.argmax(scores).item() 160 | 161 | return selected_all, node_selected_times 162 | 163 | 164 | def transfer_edge_index_to_mask(subgraphs_nodes,num_nodes): 165 | res=torch.zeros([int(subgraphs_nodes[0].max()+1),num_nodes]) 166 | res[subgraphs_nodes[0,:],subgraphs_nodes[1,:]]=1 167 | out=res>0 168 | return out 169 | 170 | def min_set_cover_sampling_general(subgraphs_nodes, random_init=False, minimum_redundancy=2,num_nodes=None,must_select_list=None,only_unused_nodes=True): 171 | 172 | 173 | subgraphs_nodes_mask=transfer_edge_index_to_mask(subgraphs_nodes,num_nodes) 174 | 175 | node_selected_times = torch.zeros(num_nodes) 176 | selected_all = [] 177 | 178 | if must_select_list is not None: 179 | for selected in must_select_list: 180 | try: 181 | selected_all.append(selected) 182 | node_selected_times[subgraphs_nodes_mask[selected]] += 1 183 | except: 184 | print('selected ',selected) 185 | print('num_nodes ',num_nodes) 186 | print('subgraphs_nodes_mask',subgraphs_nodes_mask) 187 | raise ValueError("error ") 188 | # num_nodes = subgraphs_nodes_mask.size(0) 189 | if random_init: 190 | selected = np.random.choice(int(subgraphs_nodes[0].max())+1) 191 | else: 192 | selected = subgraphs_nodes_mask.sum(-1).argmax().item() # choose the maximum subgraph size one to remove randomness 193 | 194 | for i in range(int(subgraphs_nodes[0].max())+1): 195 | # selected_subgraphs[selected] = True 196 | selected_all.append(selected) 197 | node_selected_times[subgraphs_nodes_mask[selected]] += 1 198 | if node_selected_times.min() >= minimum_redundancy: # stop criterion 199 | break 200 | # calculate how many unused nodes in each subgraph (greedy set cover) 201 | if only_unused_nodes: 202 | unused_nodes = ~ ((node_selected_times - node_selected_times.min()).bool()) 203 | else: 204 | unused_nodes = ((node_selected_times - minimum_redundancy)<0) 205 | num_unused_nodes = (subgraphs_nodes_mask & unused_nodes).sum(-1) 206 | scores = num_unused_nodes 207 | scores[selected_all] = 0 208 | if scores.sum()==0: 209 | break 210 | selected = np.argmax(scores).item() 211 | 212 | return selected_all, node_selected_times 213 | 214 | def min_set_cover_random_sampling_general(subgraphs_nodes, random_init=False, minimum_redundancy=2,num_nodes=None,must_select_list=None,must_select_list_ratio=0.2,only_unused_nodes=True,ratio_init_drop=0.5,ratio_each_iter=0.5,balance_different_substructure=True): 215 | 216 | 217 | # The algorithm implies: 218 | # At the beginning, random drop some substructures for randomness (controled by ratio_init_drop) 219 | # Randomly sample top K substructures (controled by ratio_each_iter) 220 | # The final sampled substructure number adapted to the total substructure number. Expectation is cover number plus one 221 | # Balanced substructure types (controled by balance_different_substructure) 222 | # times = [] 223 | # times.append(time.time()) 224 | 225 | num_subgraphs = int(subgraphs_nodes[0].max())+1 226 | number_each_iter = max(int(num_nodes*subgraphs_nodes.shape[0] / num_subgraphs), 1) 227 | 228 | subgraphs_nodes_mask=transfer_edge_index_to_mask(subgraphs_nodes,num_nodes) 229 | 230 | init_drop=np.random.choice(num_subgraphs, int(num_subgraphs*ratio_init_drop), replace=False) 231 | 232 | 233 | if balance_different_substructure: 234 | 235 | # It requires the geometric substructures before the neighbor substructures 236 | 237 | ids=subgraphs_nodes[0,subgraphs_nodes[2]==-1] 238 | if ids.shape[0]>0: 239 | number_predefined=int(ids[0]) 240 | number_neighbor = num_subgraphs-number_predefined 241 | if number_predefined>number_neighbor: 242 | predefined_drop=np.random.choice(number_predefined,(number_predefined-number_neighbor), replace=False) 243 | init_drop=np.concatenate([init_drop,predefined_drop],0) 244 | 245 | 246 | 247 | node_selected_times = torch.zeros(num_nodes) 248 | selected_all = [] 249 | 250 | if must_select_list is not None: 251 | try: 252 | number_must_select=max(int(num_nodes*must_select_list_ratio),1) 253 | if len(must_select_list)>=number_must_select: 254 | must_select_list=[must_select_list[i] for i in np.random.choice(len(must_select_list), number_must_select, replace=False)] 255 | selected_all += must_select_list 256 | node_selected_times += subgraphs_nodes_mask[must_select_list].sum(0) 257 | except: 258 | print('selected ',must_select_list) 259 | print('num_nodes ',num_nodes) 260 | print('subgraphs_nodes_mask',subgraphs_nodes_mask) 261 | raise ValueError("error ") 262 | # num_nodes = subgraphs_nodes_mask.size(0) 263 | if random_init: 264 | selected = [np.random.choice(int(subgraphs_nodes[0].max())+1)] 265 | else: 266 | selected = [subgraphs_nodes_mask.sum(-1).argmax().item()] # choose the maximum subgraph size one to remove randomness 267 | 268 | 269 | for i in range(int(subgraphs_nodes[0].max())+1): 270 | # selected_subgraphs[selected] = True 271 | selected_all+=selected 272 | node_selected_times += subgraphs_nodes_mask[selected].sum(0) 273 | if node_selected_times.min() >= minimum_redundancy: # stop criterion 274 | break 275 | # calculate how many unused nodes in each subgraph (greedy set cover) 276 | if only_unused_nodes: 277 | unused_nodes = ~ ((node_selected_times - node_selected_times.min()).bool()) 278 | else: 279 | unused_nodes = ((node_selected_times - minimum_redundancy)<0) 280 | num_unused_nodes = (subgraphs_nodes_mask & unused_nodes).sum(-1) 281 | scores = num_unused_nodes 282 | scores[selected_all] = 0 283 | 284 | scores[init_drop] = 0 285 | 286 | if scores.sum()==0 : 287 | break 288 | number_lefted=int((scores>0).sum()) 289 | _,index=torch.topk(scores, min(int(number_each_iter/ratio_each_iter),number_lefted)) 290 | selected=index[np.random.choice(len(index), min(number_each_iter,number_lefted), replace=False)].tolist() 291 | 292 | 293 | return selected_all, node_selected_times 294 | 295 | -------------------------------------------------------------------------------- /deepgraph/data/substructure_dataset_utils/utils_graph_processing.py: -------------------------------------------------------------------------------- 1 | 2 | from torch_geometric.utils import to_undirected,remove_self_loops 3 | import networkx as nx 4 | 5 | import graph_tool as gt 6 | import graph_tool.topology as gt_topology 7 | 8 | import torch 9 | import numpy as np 10 | import time 11 | 12 | def automorphism_orbits(edge_list, print_msgs=False, only_graph=False,**kwargs): 13 | 14 | ##### vertex automorphism orbits ##### 15 | 16 | directed=kwargs['directed'] if 'directed' in kwargs else False 17 | 18 | graph = gt.Graph(directed=directed) 19 | graph.add_edge_list(edge_list) 20 | 21 | # gt.stats.remove_self_loops(graph) 22 | # gt.stats.remove_parallel_edges(graph) 23 | 24 | if only_graph: 25 | 26 | return graph, None, None, None 27 | 28 | 29 | # compute the vertex automorphism group 30 | aut_group = gt_topology.subgraph_isomorphism(graph, graph, induced=False, subgraph=True, generator=False) 31 | 32 | 33 | 34 | orbit_membership = {} 35 | for v in graph.get_vertices(): 36 | orbit_membership[v] = v 37 | 38 | 39 | 40 | # whenever two nodes can be mapped via some automorphism, they are assigned the same orbit 41 | for aut in aut_group: 42 | for original, vertex in enumerate(aut): 43 | role = min(original, orbit_membership[vertex]) 44 | orbit_membership[vertex] = role 45 | 46 | 47 | 48 | orbit_membership_list = [[],[]] 49 | for vertex, om_curr in orbit_membership.items(): 50 | orbit_membership_list[0].append(vertex) 51 | orbit_membership_list[1].append(om_curr) 52 | 53 | 54 | 55 | # make orbit list contiguous (i.e. 0,1,2,...O) 56 | _, contiguous_orbit_membership = np.unique(orbit_membership_list[1], return_inverse = True) 57 | 58 | orbit_membership = {vertex: contiguous_orbit_membership[i] for i,vertex in enumerate(orbit_membership_list[0])} 59 | 60 | 61 | 62 | orbit_partition = {} 63 | for vertex, orbit in orbit_membership.items(): 64 | orbit_partition[orbit] = [vertex] if orbit not in orbit_partition else orbit_partition[orbit]+[vertex] 65 | 66 | aut_count = len(aut_group) 67 | 68 | 69 | 70 | if print_msgs: 71 | print('Orbit partition of given substructure: {}'.format(orbit_partition)) 72 | print('Number of orbits: {}'.format(len(orbit_partition))) 73 | print('Automorphism count: {}'.format(aut_count)) 74 | 75 | 76 | 77 | return graph, orbit_partition, orbit_membership, aut_count 78 | 79 | 80 | def induced_edge_automorphism_orbits(edge_list, **kwargs): 81 | 82 | ##### induced edge automorphism orbits (according to the vertex automorphism group) ##### 83 | 84 | directed=kwargs['directed'] if 'directed' in kwargs else False 85 | directed_orbits=kwargs['directed_orbits'] if 'directed_orbits' in kwargs else False 86 | 87 | graph, orbit_partition, orbit_membership, aut_count = automorphism_orbits(edge_list=edge_list, 88 | directed=directed, 89 | print_msgs=False) 90 | edge_orbit_partition = dict() 91 | edge_orbit_membership = dict() 92 | edge_orbits2inds = dict() 93 | ind = 0 94 | 95 | if not directed: 96 | edge_list = to_undirected(torch.tensor(graph.get_edges()).transpose(1,0)).transpose(1,0).tolist() 97 | 98 | # infer edge automorphisms from the vertex automorphisms 99 | for i,edge in enumerate(edge_list): 100 | if directed_orbits: 101 | edge_orbit = (orbit_membership[edge[0]], orbit_membership[edge[1]]) 102 | else: 103 | edge_orbit = frozenset([orbit_membership[edge[0]], orbit_membership[edge[1]]]) 104 | if edge_orbit not in edge_orbits2inds: 105 | edge_orbits2inds[edge_orbit] = ind 106 | ind_edge_orbit = ind 107 | ind += 1 108 | else: 109 | ind_edge_orbit = edge_orbits2inds[edge_orbit] 110 | 111 | if ind_edge_orbit not in edge_orbit_partition: 112 | edge_orbit_partition[ind_edge_orbit] = [tuple(edge)] 113 | else: 114 | edge_orbit_partition[ind_edge_orbit] += [tuple(edge)] 115 | 116 | edge_orbit_membership[i] = ind_edge_orbit 117 | 118 | print('Edge orbit partition of given substructure: {}'.format(edge_orbit_partition)) 119 | print('Number of edge orbits: {}'.format(len(edge_orbit_partition))) 120 | print('Graph (vertex) automorphism count: {}'.format(aut_count)) 121 | 122 | return graph, edge_orbit_partition, edge_orbit_membership, aut_count 123 | 124 | 125 | def subgraph_isomorphism_vertex_counts(edge_index, **kwargs): 126 | 127 | ##### vertex structural identifiers ##### 128 | 129 | subgraph_dict, induced, num_nodes = kwargs['subgraph_dict'], kwargs['induced'], kwargs['num_nodes'] 130 | directed = kwargs['directed'] if 'directed' in kwargs else False 131 | 132 | G_gt = gt.Graph(directed=directed) 133 | edge_list=list(edge_index.transpose(1,0).cpu().numpy()) 134 | edge_list=[[int(edge_list[i][0]),int(edge_list[i][1])] for i in range(len(edge_list))] 135 | G_gt.add_edge_list(edge_list) 136 | gt.stats.remove_self_loops(G_gt) 137 | gt.stats.remove_parallel_edges(G_gt) 138 | 139 | # compute all subgraph isomorphisms 140 | sub_iso = gt_topology.subgraph_isomorphism(subgraph_dict['subgraph'], G_gt, induced=induced, subgraph=True, generator=True) 141 | 142 | ## num_nodes should be explicitly set for the following edge case: 143 | ## when there is an isolated vertex whose index is larger 144 | ## than the maximum available index in the edge_index 145 | 146 | counts = np.zeros((num_nodes, len(subgraph_dict['orbit_partition']))) 147 | for sub_iso_curr in sub_iso: 148 | for i,node in enumerate(sub_iso_curr): 149 | # increase the count for each orbit 150 | counts[node, subgraph_dict['orbit_membership'][i]] +=1 151 | counts = counts/subgraph_dict['aut_count'] 152 | 153 | counts = torch.tensor(counts) 154 | 155 | return counts 156 | 157 | 158 | def subgraph_isomorphism_vertex_extraction(edge_index,input_graph=None, **kwargs): 159 | ##### vertex structural identifiers ##### 160 | 161 | subgraph_dict, induced, num_nodes = kwargs['subgraph_dict'], kwargs['induced'], kwargs['num_nodes'] 162 | directed = kwargs['directed'] if 'directed' in kwargs else False 163 | 164 | if input_graph is not None: 165 | G_gt=input_graph 166 | else: 167 | G_gt = gt.Graph(directed=directed) 168 | edge_list = list(edge_index.transpose(1, 0).cpu().numpy()) 169 | edge_list = [[int(edge_list[i][0]), int(edge_list[i][1])] for i in range(len(edge_list))] 170 | G_gt.add_edge_list(edge_list) 171 | gt.stats.remove_self_loops(G_gt) 172 | gt.stats.remove_parallel_edges(G_gt) 173 | 174 | # compute all subgraph isomorphisms 175 | sub_iso = gt_topology.subgraph_isomorphism(subgraph_dict['subgraph'], G_gt, induced=induced, subgraph=True, 176 | generator=True) 177 | 178 | ## num_nodes should be explicitly set for the following edge case: 179 | ## when there is an isolated vertex whose index is larger 180 | ## than the maximum available index in the edge_index 181 | subgraphs = set() 182 | for sub_iso_curr in sub_iso: 183 | tem = list(sub_iso_curr.a) 184 | tem.sort() 185 | subgraphs.add(tuple(tem)) 186 | return subgraphs 187 | 188 | 189 | def subgraph_isomorphism_edge_counts(edge_index, **kwargs): 190 | 191 | ##### edge structural identifiers ##### 192 | 193 | subgraph_dict, induced = kwargs['subgraph_dict'], kwargs['induced'] 194 | directed = kwargs['directed'] if 'directed' in kwargs else False 195 | 196 | edge_index = edge_index.transpose(1,0).cpu().numpy() 197 | edge_dict = {} 198 | for i, edge in enumerate(edge_index): 199 | edge_dict[tuple(edge)] = i 200 | 201 | if not directed: 202 | subgraph_edges = to_undirected(torch.tensor(subgraph_dict['subgraph'].get_edges().tolist()).transpose(1,0)).transpose(1,0).tolist() 203 | 204 | 205 | G_gt = gt.Graph(directed=directed) 206 | G_gt.add_edge_list(list(edge_index)) 207 | gt.stats.remove_self_loops(G_gt) 208 | gt.stats.remove_parallel_edges(G_gt) 209 | 210 | # compute all subgraph isomorphisms 211 | sub_iso = gt_topology.subgraph_isomorphism(subgraph_dict['subgraph'], G_gt, induced=induced, subgraph=True, generator=True) 212 | 213 | 214 | counts = np.zeros((edge_index.shape[0], len(subgraph_dict['orbit_partition']))) 215 | 216 | for sub_iso_curr in sub_iso: 217 | mapping = sub_iso_curr.get_array() 218 | # import pdb;pdb.set_trace() 219 | for i,edge in enumerate(subgraph_edges): 220 | 221 | # for every edge in the graph H, find the edge in the subgraph G_S to which it is mapped 222 | # (by finding where its endpoints are matched). 223 | # Then, increase the count of the matched edge w.r.t. the corresponding orbit 224 | # Repeat for the reverse edge (the one with the opposite direction) 225 | 226 | edge_orbit = subgraph_dict['orbit_membership'][i] 227 | mapped_edge = tuple([mapping[edge[0]], mapping[edge[1]]]) 228 | counts[edge_dict[mapped_edge], edge_orbit] += 1 229 | 230 | counts = counts/subgraph_dict['aut_count'] 231 | 232 | counts = torch.tensor(counts) 233 | 234 | return counts 235 | 236 | find_id_filename 237 | 238 | #----------------------- line graph edge automorphism: deprecated 239 | 240 | 241 | 242 | def edge_automorphism_orbits(edge_list, **kwargs): 243 | 244 | ##### edge automorphism orbits according to the line graph ##### 245 | 246 | directed=kwargs['directed'] if 'directed' in kwargs else False 247 | 248 | graph_nx = nx.from_edgelist(edge_list) 249 | graph = gt.Graph(directed=directed) 250 | graph.add_edge_list(edge_list) 251 | gt.stats.remove_self_loops(graph) 252 | gt.stats.remove_parallel_edges(graph) 253 | aut_group = gt_topology.subgraph_isomorphism(graph, graph, induced=False, subgraph=True, generator=False) 254 | aut_count = len(aut_group) 255 | 256 | ##### compute line graph vertex automorphism orbits ##### 257 | 258 | graph_nx_line = nx.line_graph(graph_nx) 259 | mapping = {node: i for i,node in enumerate(graph_nx_line.nodes)} 260 | inverse_mapping = {i: node for i,node in enumerate(graph_nx_line.nodes)} 261 | 262 | graph_nx_line = nx.relabel_nodes(graph_nx_line, mapping) 263 | line_graph = gt.Graph(directed=directed) 264 | line_graph.add_edge_list(list(graph_nx_line.edges)) 265 | 266 | gt.stats.remove_self_loops(line_graph) 267 | gt.stats.remove_parallel_edges(line_graph) 268 | 269 | aut_group_edges = gt_topology.subgraph_isomorphism(line_graph, line_graph, induced=False, subgraph=True, generator=False) 270 | 271 | orbit_membership = {} 272 | for v in line_graph.get_vertices(): 273 | orbit_membership[v] = v 274 | 275 | for aut in aut_group_edges: 276 | for original, vertex in enumerate(aut): 277 | role = min(original, orbit_membership[vertex]) 278 | orbit_membership[vertex] = role 279 | 280 | orbit_membership_list = [[],[]] 281 | for vertex, om_curr in orbit_membership.items(): 282 | orbit_membership_list[0].append(vertex) 283 | orbit_membership_list[1].append(om_curr) 284 | 285 | _, contiguous_orbit_membership = np.unique(orbit_membership_list[1], return_inverse = True) 286 | 287 | orbit_membership = {vertex: contiguous_orbit_membership[i] for i,vertex in enumerate(orbit_membership_list[0])} 288 | 289 | orbit_partition= {} 290 | for vertex, orbit in orbit_membership.items(): 291 | orbit_partition[orbit] = [inverse_mapping[vertex]] if orbit not in orbit_partition else orbit_partition[orbit]+[inverse_mapping[vertex]] 292 | 293 | ##### transfer line graph vertex automorphism orbits to original edges ##### 294 | 295 | orbit_membership_new = {} 296 | for i,edge in enumerate(graph.get_edges()): 297 | mapped_edge = mapping[tuple(edge)] if tuple(edge) in mapping else mapping[tuple([edge[1],edge[0]])] 298 | orbit_membership_new[i] = orbit_membership[mapped_edge] 299 | 300 | print('Edge orbit partition of given substructure: {}'.format(orbit_partition)) 301 | print('Number of edge orbits: {}'.format(len(orbit_partition))) 302 | print('Graph (vertex) automorphism count: {}'.format(aut_count)) 303 | 304 | return graph, orbit_partition, orbit_membership_new, aut_count 305 | 306 | 307 | 308 | 309 | 310 | def subgraph_counts2ids(count_fn, data, subgraph_dicts, subgraph_params): 311 | #### Remove self loops and then assign the structural identifiers by computing subgraph isomorphisms #### 312 | 313 | if hasattr(data, 'edge_features'): 314 | edge_index, edge_features = remove_self_loops(data.edge_index, data.edge_features) 315 | setattr(data, 'edge_features', edge_features) 316 | else: 317 | edge_index = remove_self_loops(data.edge_index)[0] 318 | 319 | num_nodes = data.x.shape[0] 320 | identifiers = None 321 | for subgraph_dict in subgraph_dicts: 322 | kwargs = {'subgraph_dict': subgraph_dict, 323 | 'induced': subgraph_params['induced'], 324 | 'num_nodes': num_nodes, 325 | 'directed': subgraph_params['directed']} 326 | counts = count_fn(edge_index, **kwargs) 327 | identifiers = counts if identifiers is None else torch.cat((identifiers, counts), 1) 328 | setattr(data, 'edge_index', edge_index) 329 | setattr(data, 'identifiers', identifiers.long()) 330 | 331 | return data 332 | 333 | 334 | def subgraph_2token(count_fn, data, subgraph_dicts, subgraph_params): 335 | #### Remove self loops and then assign the structural identifiers by computing subgraph isomorphisms #### 336 | 337 | if hasattr(data, 'edge_features'): 338 | edge_index, edge_features = remove_self_loops(data.edge_index, data.edge_features) 339 | setattr(data, 'edge_features', edge_features) 340 | else: 341 | edge_index = remove_self_loops(data.edge_index)[0] 342 | 343 | G_gt = gt.Graph(directed=subgraph_params['directed']) 344 | edge_list = list(edge_index.transpose(1, 0).cpu().numpy()) 345 | edge_list = [[int(edge_list[i][0]), int(edge_list[i][1])] for i in range(len(edge_list))] 346 | G_gt.add_edge_list(edge_list) 347 | # gt.stats.remove_self_loops(G_gt) 348 | gt.stats.remove_parallel_edges(G_gt) 349 | 350 | num_nodes = data.x.shape[0] 351 | identifiers = [] 352 | for subgraph_dict in subgraph_dicts: 353 | kwargs = {'subgraph_dict': subgraph_dict, 354 | 'induced': subgraph_params['induced'], 355 | 'num_nodes': num_nodes, 356 | 'directed': subgraph_params['directed']} 357 | subgraphs = count_fn(edge_index, input_graph=G_gt, **kwargs) 358 | identifiers.append(subgraphs) 359 | # setattr(data, 'edge_index', edge_index) 360 | # setattr(data, 'identifiers', identifiers) 361 | 362 | return identifiers -------------------------------------------------------------------------------- /deepgraph/data/substructure_dataset_utils/substructure_transform.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | 4 | import torch 5 | 6 | from .utils import process_arguments_substructure, substructure_to_gt,transfer_subgraph_to_batchtensor_complete 7 | 8 | from tqdm import tqdm 9 | 10 | from deepgraph.data.subsampling.sampler import Subgraphs_Sampler 11 | import lmdb 12 | 13 | import numpy as np 14 | from deepgraph.data.substructure_dataset_utils.neighbor_extractors import k_hop_subgraph,random_walk_subgraph 15 | 16 | from collections import deque 17 | 18 | 19 | 20 | def graph_data_modification(dataset,substructures,args): 21 | new_dataset=[] 22 | for i in tqdm(range(len(dataset))): 23 | if dataset[i].x.dim()==1: 24 | dataset[i].x=dataset[i].x.unsqueeze(1) 25 | if dataset[i].y.dim() == 2: 26 | dataset[i].y = dataset[i].y.squeeze(1) 27 | if hasattr(dataset[i], 'edge_features'): 28 | dataset[i].edge_attr=dataset[i].edge_features 29 | del(dataset[i].edge_features) 30 | if hasattr(dataset[i],'degree'): 31 | del(dataset[i].degrees) 32 | if hasattr(dataset[i],'graph_size'): 33 | del (dataset[i].graph_size) 34 | if dataset[i].edge_attr.dim()==1: 35 | dataset[i].edge_attr=dataset[i].edge_attr.unsqueeze(1) 36 | assert hasattr(dataset[i],'edge_attr') 37 | assert hasattr(dataset[i], 'edge_index') 38 | new_data=dataset[i] 39 | setattr(new_data,'identifiers', substructures[i]) 40 | new_dataset.append(new_data) 41 | return new_dataset 42 | 43 | 44 | 45 | 46 | def find_edges_of_substructure(substructures,prototype): 47 | edges_res=[] 48 | for nodes in substructures: 49 | for edge in prototype: 50 | edges_res.append([nodes[edge[0]],nodes[edge[1]]]) 51 | return edges_res 52 | 53 | def transfer_subgraph_to_batchtensor(substructures,must_select_list): 54 | id_lis=[] 55 | note_lis=[] 56 | cur_id=0 57 | must_select_list_data=[] 58 | for id_sub,subtype in enumerate(substructures): 59 | for data in subtype: 60 | id_lis=id_lis+[cur_id]*len(data) 61 | if id_sub in must_select_list: 62 | must_select_list_data.append(cur_id) 63 | note_lis=note_lis+list(data) 64 | 65 | cur_id += 1 66 | 67 | res=torch.tensor([id_lis,note_lis]) 68 | return res,must_select_list_data 69 | 70 | 71 | 72 | 73 | 74 | 75 | def select_result(substructures,selected_subgraphs): 76 | id_lis=[] 77 | note_lis=[] 78 | cur_id=0 79 | substructures_new=[] 80 | for subtype in substructures: 81 | subtype_new=set() 82 | for data in subtype: 83 | if cur_id in selected_subgraphs: 84 | subtype_new.add(data) 85 | cur_id+=1 86 | substructures_new.append(subtype_new) 87 | return substructures_new 88 | 89 | def select_result_as_tensor(subgraph_tensor,selected_subgraphs): 90 | selected_index=(subgraph_tensor[0].unsqueeze(1) ==torch.tensor(selected_subgraphs).unsqueeze(0)).any(1) 91 | return subgraph_tensor[:,selected_index] 92 | 93 | def re_tag(subgraph_ids): 94 | diff=((subgraph_ids[1:]-subgraph_ids[0:-1])>0).long() 95 | sum=diff.cumsum(dim=0) 96 | res=torch.cat([torch.zeros([1]) if subgraph_ids.shape[0]>0 else torch.tensor([]),sum]).long() 97 | return res 98 | 99 | 100 | 101 | def bfs(graph_adj, start_node,sort_func): 102 | visited = set() # set of visited nodes 103 | queue = deque([start_node]) # create a queue and enqueue the starting node 104 | 105 | while queue: 106 | 107 | node = queue.popleft() 108 | 109 | if node not in visited: 110 | # mark the node as visited 111 | visited.add(node) 112 | 113 | neighbors = (graph_adj[node].nonzero())[0] 114 | 115 | neighbors_not_visited=neighbors[~np.in1d(neighbors, visited)] 116 | 117 | sorted_neighbors_not_visited = sort_func(neighbors_not_visited) 118 | 119 | queue.extend(sorted_neighbors_not_visited.tolist()) 120 | 121 | return visited 122 | 123 | def sort_node(node_list, degrees,only_minimum=False): 124 | 125 | if only_minimum: 126 | #compute minimum over all the nodes 127 | min_indices = np.argwhere(np.array(degrees) == np.min(degrees)).flatten() 128 | res= np.random.permutation(min_indices) 129 | else: 130 | degrees=degrees[node_list] 131 | 132 | ind_per = np.random.permutation(len(degrees)) 133 | node_list=node_list[ind_per] 134 | degrees=degrees[ind_per] 135 | 136 | ind_sort=np.argsort(degrees) 137 | res= node_list[ind_sort] 138 | 139 | return res 140 | 141 | 142 | def generate_sorted_adj(edge_tensor,node_sorted,subgraph_max_size): 143 | 144 | 145 | 146 | adj_list = [] 147 | if edge_tensor[0].shape[0]>0: 148 | number_subgraph = int(edge_tensor[0, :].max()) + 1 149 | max_node = int(edge_tensor[1:, :].max()) + 1 150 | # edges = torch.cat([edge_tensor, edge_tensor[[0, 2, 1], :]], 1) 151 | edge_sparse = torch.sparse_coo_tensor(edge_tensor, torch.ones_like(edge_tensor[0]), [number_subgraph, max_node, max_node] 152 | ).to_dense().long() 153 | node_sorted_id = [] 154 | node_sorted_sequence = [] 155 | node_sorted_cat = [] 156 | for i, item in enumerate(node_sorted): 157 | node_sorted_cat += item 158 | node_sorted_sequence += list(range(len(item))) 159 | node_sorted_id += [i] * len(item) 160 | 161 | node_sorted_tensor = torch.sparse_coo_tensor([node_sorted_id, node_sorted_sequence, node_sorted_cat], 162 | torch.ones([len(node_sorted_id)]), 163 | [number_subgraph, subgraph_max_size, max_node] 164 | ).to_dense().long() 165 | 166 | adjs = node_sorted_tensor @ edge_sparse @ (node_sorted_tensor.transpose(1, 2)) 167 | 168 | return adjs 169 | else: 170 | return torch.ones([0,subgraph_max_size,subgraph_max_size]) 171 | 172 | class transform_sub: 173 | def __init__(self,args): 174 | args, extract_ids_fn, count_fn, automorphism_fn = process_arguments_substructure(args) 175 | self.args= args 176 | self.extract_ids_fn=extract_ids_fn 177 | self.count_fn=count_fn 178 | self.automorphism_fn=automorphism_fn 179 | self.subgraph_params = {'induced': args['induced'], 180 | 'edge_list': args['custom_edge_list'], 181 | 'directed': args['directed'], 182 | 'directed_orbits': args['directed_orbits'], 183 | } 184 | self.must_select_list=args.must_select_list 185 | self.substructure_cache={} 186 | if args.subsampling: 187 | self.sampler = Subgraphs_Sampler( 188 | sampling_mode=args.sampling_mode, 189 | minimum_redundancy=args.sampling_redundancy, 190 | shortest_path_mode_stride=args.sampling_stride, 191 | random_mode_sampling_rate=args.sampling_random_rate, 192 | random_init=True, 193 | only_unused_nodes=(not self.args.not_only_unused_nodes)) 194 | def __call__(self, data,substructure_tensor_provided): 195 | 196 | if substructure_tensor_provided is None: 197 | substructures = self.compute_substructures(data) 198 | substructure_tensor = transfer_subgraph_to_batchtensor_complete(substructures) 199 | subgraph_tensor = torch.Tensor([substructure_tensor[1], substructure_tensor[0],substructure_tensor[2]]).long() 200 | else: 201 | subgraph_tensor=substructure_tensor_provided 202 | 203 | if self.args.subsampling: 204 | 205 | must_select_list_data=[] 206 | for i in self.must_select_list: 207 | must_select_list_data+=subgraph_tensor[0][subgraph_tensor[2]==i].unique().long().tolist() 208 | 209 | selected_subgraphs, node_selected_times = self.sampler(data.edge_index, subgraph_tensor, data.num_nodes,must_select_list_data) 210 | 211 | subgraph_tensor = select_result_as_tensor(subgraph_tensor, selected_subgraphs) 212 | subgraph_tensor=torch.cat([re_tag(subgraph_tensor[0,:]).unsqueeze(0),subgraph_tensor[1:,:]]) 213 | 214 | edge_tensor=self.compute_edge_tensor_of_substructures(data,subgraph_tensor) 215 | 216 | node_sorted=self.sort_subgraphs_randomly(data,subgraph_tensor) 217 | 218 | sorted_adj=generate_sorted_adj(edge_tensor,node_sorted,self.args['subgraph_max_size']) 219 | 220 | 221 | 222 | return subgraph_tensor,sorted_adj 223 | 224 | 225 | def compute_edge_tensor_of_substructures(self,data,subgraph_tensor): 226 | 227 | if subgraph_tensor[0].shape[0] > 0 and data.edge_index[0].shape[0]>0: 228 | try: 229 | num_subgraph=int(subgraph_tensor[0].max()+1) 230 | max_node_number=max(int(data.edge_index.max()+1),int(subgraph_tensor[1].max()+1)) 231 | sub_tensor=torch.sparse_coo_tensor(subgraph_tensor[[0,1],:], 232 | torch.ones_like(subgraph_tensor[0]), 233 | [num_subgraph,max_node_number]).to_dense() 234 | edg_tensor=torch.sparse_coo_tensor(data.edge_index,torch.ones_like(data.edge_index[0]),[max_node_number,max_node_number]).to_dense() 235 | 236 | res=(sub_tensor.unsqueeze(1))*(edg_tensor.unsqueeze(0))*(sub_tensor.unsqueeze(2)) 237 | except: 238 | print(subgraph_tensor) 239 | print(data.edge_index) 240 | print(num_subgraph) 241 | print(max_node_number) 242 | print(sub_tensor) 243 | print(edg_tensor) 244 | return res.nonzero().T 245 | 246 | else: 247 | return torch.zeros([3,0]) 248 | 249 | 250 | 251 | def sort_subgraphs_randomly(self,data,subgraph_tensor): 252 | 253 | 254 | graph_adj = np.zeros((data.x.shape[0],data.x.shape[0]),dtype=np.int64) 255 | graph_adj[data.edge_index[0,:],data.edge_index[1,:]]=1 256 | 257 | if subgraph_tensor[0].shape[0]>0: 258 | nodes_record=[] 259 | for i in range(int(subgraph_tensor[0].max())+1): 260 | 261 | subgraph_node=subgraph_tensor[1][subgraph_tensor[0] == i] 262 | 263 | subgraph_adj=graph_adj[subgraph_node][:,subgraph_node] 264 | 265 | degrees = subgraph_adj.sum(0) 266 | 267 | if subgraph_adj.sum()>0: 268 | 269 | root=sort_node(None,degrees,only_minimum=True)[0] 270 | 271 | nodes=[n for n in bfs(subgraph_adj, start_node=root,sort_func=lambda node_array: sort_node(node_array,degrees))] 272 | 273 | nodes_record.append(subgraph_node[nodes]) 274 | 275 | return nodes_record 276 | else: 277 | return [[]] 278 | 279 | def compute_substructures(self,data,subgraph_dicts=None): 280 | if subgraph_dicts is None: 281 | subgraph_dicts = substructure_to_gt(self.subgraph_params, self.automorphism_fn, ) 282 | substructures = self.extract_ids_fn(self.count_fn, data, subgraph_dicts, self.subgraph_params) 283 | return substructures 284 | 285 | 286 | 287 | def compute_neighbor(self,data): 288 | neighbor=None 289 | if 'k_neighborhood' in self.args['neighbor_type']: 290 | neighbor=k_hop_subgraph(data.edge_index,data.x.shape[0],self.args['neighbor_size'][self.args['neighbor_type'].index('k_neighborhood')]) 291 | elif 'random_walk' in self.args['neighbor_type']: 292 | neighbor = random_walk_subgraph(data.edge_index, data.x.shape[0], 293 | self.args['neighbor_size'][self.args['neighbor_type'].index('random_walk')]) 294 | neighbor_nodes = neighbor[0].nonzero().T 295 | return neighbor_nodes.tolist(),int(neighbor[0].T.sum(0).max()) 296 | 297 | def substructure_as_tensor(self,substructure_tensor=None,neighbor_tensor=None): 298 | result=None 299 | if substructure_tensor is not None: 300 | result = torch.Tensor([substructure_tensor[1], substructure_tensor[0],substructure_tensor[2]]).long() 301 | if neighbor_tensor is not None: 302 | if substructure_tensor is not None and result.shape[1]>0: 303 | ids=(torch.Tensor([neighbor_tensor[0]])+result[0].max()+1).long() 304 | neighbor_tensor=torch.cat([ids,torch.Tensor([neighbor_tensor[1]]).long(),-torch.ones_like(ids)],0) 305 | result=torch.cat([result,neighbor_tensor],1) 306 | else: 307 | ids=(torch.Tensor([neighbor_tensor[0]])).long() 308 | type = (-1* torch.ones_like(ids)).long() 309 | result=torch.cat([ids,torch.Tensor([neighbor_tensor[1]]).long(),type],0) 310 | if result is None: 311 | result=torch.ones([3,0]) 312 | return result 313 | 314 | def pre_compute_substructures_direct_to_lmdb(self,dataset): 315 | 316 | subgraph_dicts=substructure_to_gt(self.subgraph_params,self.automorphism_fn,) 317 | 318 | print('********************************transfer to lmdb type**************************') 319 | 320 | path_lmdb_tensor = os.path.join(self.args['data_dir'],self.args['lmdb_root_dir'],self.args['pre_defined_path'] + self.args['lmdb_dir'], 'inner') 321 | print('saving to ', path_lmdb_tensor) 322 | if not os.path.exists(path_lmdb_tensor): 323 | os.makedirs(path_lmdb_tensor) 324 | lmdb_env_tensor = lmdb.open(path_lmdb_tensor, map_size=1e12) 325 | txn_tensor = lmdb_env_tensor.begin(write=True) 326 | 327 | for idx, data in tqdm(enumerate(dataset)): 328 | 329 | substructure_computed=self.compute_substructures(data,subgraph_dicts) 330 | 331 | substructure_tensor=self.transfer_subgraph_to_batchtensor_complete(substructure_computed) 332 | 333 | txn_tensor.put(key=str(idx).encode(), value=str(substructure_tensor).encode()) 334 | 335 | print('********************************commit substructures tensors**************************') 336 | txn_tensor.commit() 337 | lmdb_env_tensor.close() 338 | 339 | def pre_compute_neighbors_direct_to_lmdb(self,dataset): 340 | print('***********************************computing neighbors to lmdb type**************************') 341 | path_lmdb = os.path.join(self.args['data_dir'],self.args['lmdb_root_dir'],self.args['node_neighbor_path']+ self.args['lmdb_dir'], 'neighbor') 342 | print('saving to ', path_lmdb) 343 | if not os.path.exists(path_lmdb): 344 | os.makedirs(path_lmdb) 345 | lmdb_env = lmdb.open(path_lmdb, map_size=1e12) 346 | txn = lmdb_env.begin(write=True) 347 | max_size=0 348 | for idx, data in tqdm(enumerate(dataset)): 349 | substructure_computed,size=self.compute_neighbor(data) 350 | max_size=max(max_size,size) 351 | txn.put(key=str(idx).encode(), value=str(substructure_computed).encode()) 352 | txn.put(key='max_size'.encode(),value=str(max_size).encode()) 353 | print('********************************commit neighbors**************************') 354 | txn.commit() 355 | lmdb_env.close() 356 | 357 | 358 | def transfer_subgraph_to_batchtensor_complete(self,substructures): 359 | return transfer_subgraph_to_batchtensor_complete(substructures) 360 | 361 | 362 | 363 | def clean_cache(self): 364 | self.substructure_cache = {} 365 | -------------------------------------------------------------------------------- /deepgraph/data/substructure_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | from torch_geometric.data import Dataset 5 | from sklearn.model_selection import train_test_split 6 | 7 | import torch 8 | import numpy as np 9 | 10 | from .wrapper import preprocess_item,graph_data_modification_single,encode_token_single_tensor,encode_token_single_tensor_with_adj 11 | 12 | 13 | from deepgraph.data.substructure_dataset_utils.substructure_transform import transform_sub 14 | 15 | import copy 16 | import lmdb 17 | import os 18 | 19 | import torch.distributed as dist 20 | 21 | 22 | class concat_dataset: 23 | def __init__(self,dataset_list): 24 | self.dataset_list=dataset_list 25 | def __getitem__(self,idx): 26 | if idx 0: 153 | if args.recache_lmdb or not os.path.exists(os.path.join(args['data_dir'], 154 | args['lmdb_root_dir'], 155 | args['pre_defined_path']+ args['lmdb_dir'], 156 | 'inner')): 157 | print('************************* Warning recache substructures**************************') 158 | if not dist.is_initialized() or dist.get_rank() == 0: 159 | self.transform.pre_compute_substructures_direct_to_lmdb(self.dataset) 160 | 161 | if dist.is_initialized(): 162 | dist.barrier() 163 | if len(args['neighbor_type']) > 0: 164 | if args.recache_lmdb or not os.path.exists(os.path.join(args['data_dir'],args['lmdb_root_dir'],args['node_neighbor_path']+ args['lmdb_dir'],'neighbor')): 165 | print('************************* Warning recache neighbors**************************') 166 | if not dist.is_initialized() or dist.get_rank() == 0: 167 | self.transform.pre_compute_neighbors_direct_to_lmdb(self.dataset) 168 | 169 | if dist.is_initialized(): 170 | dist.barrier() 171 | 172 | 173 | 174 | # if self.lmdb_tensor_env is None: 175 | 176 | if len(args['custom_edge_list'])>0: 177 | self.lmdb_tensor_env=lmdb.open(os.path.join(args['data_dir'],args['lmdb_root_dir'],args['pre_defined_path']+ args['lmdb_dir'], 'inner'),map_size=1e12, readonly=True, 178 | lock=False, readahead=False, meminit=False) 179 | else: 180 | self.lmdb_tensor_env=None 181 | if len(args['neighbor_type'])>0: 182 | self.lmdb_neighbor_env=lmdb.open(os.path.join(args['data_dir'],args['lmdb_root_dir'],args['node_neighbor_path']+ args['lmdb_dir'],'neighbor'),map_size=1e12, readonly=True, 183 | lock=False, readahead=False, meminit=False) 184 | txn = self.lmdb_neighbor_env.begin(write=False) 185 | max_size=eval(txn.get(('max_size').encode())) 186 | self.transform.args['subgraph_max_size']=max(self.transform.args['subgraph_max_size'],max_size) 187 | else: 188 | self.lmdb_neighbor_env=None 189 | 190 | 191 | 192 | 193 | self.train_data.lmdb_tensor_env=self.lmdb_tensor_env 194 | self.train_data.lmdb_neighbor_env = self.lmdb_neighbor_env 195 | 196 | self.valid_data.lmdb_tensor_env = self.lmdb_tensor_env 197 | self.valid_data.lmdb_neighbor_env = self.lmdb_neighbor_env 198 | 199 | self.test_data.lmdb_tensor_env = self.lmdb_tensor_env 200 | self.test_data.lmdb_neighbor_env = self.lmdb_neighbor_env 201 | 202 | if train_idx is not None: 203 | self.recomputeind=recompute_idx_byindex(train_idx,valid_idx,test_idx,'inner') 204 | self.train_data.recomputeind=recompute_idx_byindex(train_idx,valid_idx,test_idx,'train') 205 | self.test_data.recomputeind = recompute_idx_byindex(train_idx,valid_idx,test_idx,'test') 206 | if args['valid_on_test']: 207 | self.valid_data.recomputeind = recompute_idx_byindex(train_idx, valid_idx, test_idx, 'test') 208 | else: 209 | self.valid_data.recomputeind = recompute_idx_byindex(train_idx, valid_idx, test_idx, 'valid') 210 | else: 211 | self.recomputeind=recompute_idx(len(self.train_data),len(self.valid_data),'train') 212 | self.train_data.recomputeind=recompute_idx(len(self.train_data),len(self.valid_data),'train') 213 | self.test_data.recomputeind = recompute_idx(len(self.train_data), len(self.valid_data), 'test') 214 | if args['valid_on_test']: 215 | self.valid_data.recomputeind = recompute_idx(len(self.train_data), len(self.valid_data), 'test') 216 | else: 217 | self.valid_data.recomputeind = recompute_idx(len(self.train_data), len(self.valid_data), 'valid') 218 | 219 | 220 | 221 | if 'use_transform_cache' in args and args.use_transform_cache: 222 | print('************************* Using transform cache **************************') 223 | 224 | transform_cache_path=os.path.join(args['data_dir'], args['lmdb_root_dir'], args['transform_cache_path'] + args['transform_dir']) 225 | if not os.path.exists(transform_cache_path): 226 | print(transform_cache_path+' do not exits; creating**************') 227 | os.makedirs(transform_cache_path) 228 | try: 229 | self.transform_cache_env=lmdb.open(transform_cache_path,map_size=1e12, readonly=True, 230 | lock=False, readahead=False, meminit=False) 231 | self.train_data.transform_cache_env = self.transform_cache_env 232 | self.valid_data.transform_cache_env = self.transform_cache_env 233 | self.test_data.transform_cache_env = self.transform_cache_env 234 | except: 235 | assert self.only_substructure # The only situation where use_transform_cache but no cache directory is generating cache 236 | 237 | 238 | def index_select(self, idx): 239 | dataset = copy.copy(self) 240 | if not isinstance(self.dataset,list): 241 | dataset.dataset = self.dataset.index_select(idx) 242 | else: 243 | dataset.dataset = [self.dataset[i] for i in idx.tolist()] 244 | if isinstance(idx, torch.Tensor): 245 | dataset.num_data = idx.size(0) 246 | else: 247 | dataset.num_data = idx.shape[0] 248 | dataset.__indices__ = idx 249 | dataset.train_data = None 250 | dataset.valid_data = None 251 | dataset.test_data = None 252 | 253 | return dataset 254 | 255 | def create_subset(self, subset): 256 | dataset = copy.copy(self) 257 | dataset.dataset = subset 258 | dataset.num_data = len(subset) 259 | dataset.__indices__ = None 260 | dataset.train_data = None 261 | dataset.valid_data = None 262 | dataset.test_data = None 263 | 264 | return dataset 265 | 266 | 267 | 268 | # @lru_cache(maxsize=16) 269 | def __getitem__(self, idx,parallel=False): 270 | 271 | if isinstance(idx, int): 272 | 273 | item = self.dataset[idx] 274 | 275 | if self.transform is not None: 276 | 277 | if (not self.use_transform_cache) or self.only_substructure: 278 | 279 | if (self.lmdb_tensor_env is not None): 280 | txn = self.lmdb_tensor_env.begin(write=False) 281 | tem=txn.get(str(self.recomputeind.compute(idx)).encode()) 282 | if tem is not None: 283 | substructure_tensor=eval(tem) 284 | else: 285 | raise ValueError("substructure_tensor not prepared") 286 | else: 287 | substructure_tensor=None 288 | if (self.lmdb_neighbor_env is not None): 289 | txn = self.lmdb_neighbor_env.begin(write=False) 290 | tem=txn.get(str(self.recomputeind.compute(idx)).encode()) 291 | if tem is not None: 292 | neighbor_tensor=eval(tem) 293 | else: 294 | raise ValueError("substructure_tensor not prepared") 295 | else: 296 | neighbor_tensor=None 297 | 298 | substructure_tensor=self.transform.substructure_as_tensor(substructure_tensor,neighbor_tensor) 299 | 300 | if self.only_substructure: 301 | #caching #transform_cache_number sampled substructures 302 | 303 | subgraph_tensor_res=[] 304 | sorted_adj_res=[] 305 | 306 | for i in range(self.transform_cache_number): 307 | 308 | 309 | subgraph_tensor, sorted_adj = self.transform(item, substructure_tensor) 310 | 311 | 312 | subgraph_tensor=subgraph_tensor.tolist() 313 | sorted_adj=sorted_adj.tolist() 314 | subgraph_tensor_res.append(subgraph_tensor) 315 | sorted_adj_res.append(sorted_adj) 316 | 317 | item.idx=idx 318 | item.subgraph_tensor_res = subgraph_tensor_res 319 | item.sorted_adj_res=sorted_adj_res 320 | item.y = item.y.reshape(-1) 321 | return item 322 | 323 | else: 324 | subgraph_tensor,sorted_adj = self.transform(item,substructure_tensor) 325 | 326 | else: 327 | txn = self.transform_cache_env.begin(write=False) 328 | id=np.random.randint(self.transform_cache_number) 329 | tem = txn.get((str(self.recomputeind.compute(idx))+'_'+str(id)).encode()) 330 | if tem is not None: 331 | subgraph_tensor, sorted_adj = eval(tem) 332 | else: 333 | raise ValueError("substructure_tensor not prepared") 334 | subgraph_tensor = torch.tensor(subgraph_tensor) 335 | sorted_adj = torch.tensor(sorted_adj) 336 | 337 | item = graph_data_modification_single(item, subgraph_tensor, sorted_adj) 338 | 339 | if self.extra_method == 'token': 340 | item = encode_token_single_tensor(item, self.args.atom_id_max, self.args.edge_id_max, ) 341 | elif self.extra_method == 'adj': 342 | item = encode_token_single_tensor_with_adj(item, local_attention_on_substructures=self.local_attention_on_substructures) 343 | 344 | if (not self.not_re_define) or (not hasattr(item, 'attn_bias')): 345 | item.idx = idx 346 | item.y = item.y.reshape(-1) 347 | 348 | item= preprocess_item(item,local_attention_on_substructures=self.local_attention_on_substructures,continuous_feature=self.continuous_feature) 349 | return item 350 | 351 | else: 352 | raise TypeError("index to a GraphormerPYGDataset can only be an integer.") 353 | 354 | def __len__(self): 355 | return self.num_data 356 | 357 | --------------------------------------------------------------------------------