├── deepgraph
├── data
│ ├── substructure_dataset_utils
│ │ ├── __init__.py
│ │ ├── LICENSE
│ │ ├── utils_encoding.py
│ │ ├── neighbor_extractors.py
│ │ ├── utils.py
│ │ ├── utils_graph_processing.py
│ │ └── substructure_transform.py
│ ├── ogb_datasets
│ │ ├── __init__.py
│ │ └── ogb_dataset_lookup_table.py
│ ├── pyg_datasets
│ │ ├── __init__.py
│ │ └── pyg_dataset_lookup_table.py
│ ├── __init__.py
│ ├── algos.pyx
│ ├── subsampling
│ │ ├── sampler.py
│ │ └── sampling.py
│ ├── dataset.py
│ ├── collator.py
│ ├── wrapper.py
│ └── substructure_dataset.py
├── tasks
│ ├── __init__.py
│ └── is2re.py
├── models
│ └── __init__.py
├── modules
│ ├── __init__.py
│ ├── deepgraph_graph_encoder_layer.py
│ ├── deepgraph_layers.py
│ └── deepgraph_graph_encoder.py
├── criterions
│ ├── __init__.py
│ ├── contrastive_loss.py
│ ├── l1_loss.py
│ ├── multiclass_cross_entropy.py
│ ├── multilabel_multiclass_cross_entropy.py
│ ├── binary_logloss.py
│ ├── node_multiclass_cross_entropy.py
│ └── mae_deltapos.py
├── __init__.py
├── pretrain
│ └── __init__.py
└── evaluate
│ ├── cache_data.py
│ └── evaluate.py
├── overview.png
├── train.py
├── LICENSE
└── README.md
/deepgraph/data/substructure_dataset_utils/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/overview.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhao-ht/DeepGraph/HEAD/overview.png
--------------------------------------------------------------------------------
/deepgraph/tasks/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 |
--------------------------------------------------------------------------------
/deepgraph/models/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 |
4 | from .deepgraph import DeepGraphModel
5 |
--------------------------------------------------------------------------------
/deepgraph/data/ogb_datasets/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 |
4 | from .ogb_dataset_lookup_table import OGBDatasetLookupTable
5 |
--------------------------------------------------------------------------------
/deepgraph/data/pyg_datasets/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 |
4 | from .pyg_dataset_lookup_table import PYGDatasetLookupTable
5 |
6 |
--------------------------------------------------------------------------------
/deepgraph/data/__init__.py:
--------------------------------------------------------------------------------
1 | DATASET_REGISTRY = {}
2 |
3 | def register_dataset(name: str):
4 | def register_dataset_func(func):
5 | DATASET_REGISTRY[name] = func()
6 | return register_dataset_func
7 |
8 |
--------------------------------------------------------------------------------
/deepgraph/modules/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 |
4 | from .multihead_attention import MultiheadAttention
5 | from .deepgraph_layers import GraphNodeFeature, GraphAttnBias
6 | from .deepgraph_graph_encoder_layer import DeepGraphGraphEncoderLayer
7 | from .deepgraph_graph_encoder import DeepGraphGraphEncoder, init_graphormer_params
8 |
--------------------------------------------------------------------------------
/deepgraph/criterions/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 |
4 | from pathlib import Path
5 | import importlib
6 |
7 | # automatically import any Python files in the criterions/ directory
8 | for file in sorted(Path(__file__).parent.glob("*.py")):
9 | if not file.name.startswith("_"):
10 | importlib.import_module("deepgraph.criterions." + file.name[:-3])
11 |
--------------------------------------------------------------------------------
/deepgraph/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 |
4 | import importlib
5 |
6 |
7 | try:
8 | import torch
9 |
10 | torch.multiprocessing.set_start_method("fork", force=True)
11 | except:
12 | import sys
13 |
14 | print(
15 | "Your OS does not support multiprocessing based on fork, please use num_workers=0",
16 | file=sys.stderr,
17 | flush=True,
18 | )
19 |
20 | import deepgraph.criterions
21 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import sys
2 | from rdkit import Chem
3 | import graph_tool
4 | import torch_geometric
5 | import joblib
6 | from ogb.graphproppred import PygGraphPropPredDataset
7 | import warnings
8 |
9 | warnings.filterwarnings('ignore')
10 | # import fairseq
11 | # from fairseq.distributed import utils as distributed_utils
12 | # from fairseq.logging import meters, metrics, progress_bar # noqa
13 |
14 | # from fairseq.logging import metrics
15 |
16 | # sys.modules["fairseq.distributed_utils"] = distributed_utils
17 | # sys.modules["fairseq.meters"] = meters
18 | # sys.modules["fairseq.metrics"] = metrics
19 | # sys.modules["fairseq.progress_bar"] = progress_bar
20 | from fairseq_cli.train import cli_main
21 | import logging
22 | logging.getLogger().setLevel(logging.INFO)
23 |
24 |
25 |
26 | if __name__ == "__main__":
27 | cli_main()
--------------------------------------------------------------------------------
/deepgraph/data/substructure_dataset_utils/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 Giorgos Bouritsas
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) Microsoft Corporation.
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE
22 |
--------------------------------------------------------------------------------
/deepgraph/pretrain/__init__.py:
--------------------------------------------------------------------------------
1 | from torch.hub import load_state_dict_from_url
2 | import torch.distributed as dist
3 |
4 | PRETRAINED_MODEL_URLS = {
5 | "pcqm4mv1_graphormer_base":"https://szheng.blob.core.windows.net/graphormer/modelzoo/pcqm4mv1/checkpoint_best_pcqm4mv1.pt",
6 | "pcqm4mv2_graphormer_base":"https://szheng.blob.core.windows.net/graphormer/modelzoo/pcqm4mv2/checkpoint_best_pcqm4mv2.pt",
7 | "oc20is2re_graphormer3d_base":"https://szheng.blob.core.windows.net/graphormer/modelzoo/oc20is2re/checkpoint_last_oc20_is2re.pt",
8 | "pcqm4mv1_graphormer_base_for_molhiv":"https://szheng.blob.core.windows.net/graphormer/modelzoo/pcqm4mv1/checkpoint_base_preln_pcqm4mv1_for_hiv.pt",
9 | }
10 |
11 | def load_pretrained_model(pretrained_model_name):
12 | if pretrained_model_name not in PRETRAINED_MODEL_URLS:
13 | raise ValueError("Unknown pretrained model name %s", pretrained_model_name)
14 | if not dist.is_initialized():
15 | return load_state_dict_from_url(PRETRAINED_MODEL_URLS[pretrained_model_name], progress=True)["model"]
16 | else:
17 | pretrained_model = load_state_dict_from_url(PRETRAINED_MODEL_URLS[pretrained_model_name], progress=True, file_name=f"{pretrained_model_name}_{dist.get_rank()}")["model"]
18 | dist.barrier()
19 | return pretrained_model
20 |
--------------------------------------------------------------------------------
/deepgraph/data/substructure_dataset_utils/utils_encoding.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import sys
3 | from sklearn.preprocessing import KBinsDiscretizer, OneHotEncoder, MinMaxScaler, StandardScaler
4 |
5 | import numpy as np
6 | from tqdm import tqdm
7 |
8 |
9 | def find_max_id(dataset_list):
10 | atom_id_max = 0
11 | edge_id_max = 0
12 | for dataset in dataset_list:
13 | for id,graph in tqdm(enumerate(dataset)):
14 | try:
15 | atom_id_max = max(atom_id_max, graph.x.max())
16 | edge_id_max = max(edge_id_max, graph.edge_attr.max())
17 | except:
18 | pass
19 |
20 | return int(atom_id_max),int(edge_id_max)
21 |
22 |
23 | class one_hot_unique:
24 |
25 | def __init__(self, tensor_list, **kwargs):
26 | tensor_list = torch.cat(tensor_list, 0)
27 | self.d = list()
28 | self.corrs = dict()
29 | for col in range(tensor_list.shape[1]):
30 | uniques, corrs = np.unique(tensor_list[:, col], return_inverse=True, axis=0)
31 | self.d.append(len(uniques))
32 | self.corrs[col] = corrs
33 | return
34 |
35 | def fit(self, tensor_list):
36 | pointer = 0
37 | encoded_tensors = list()
38 | for tensor in tensor_list:
39 | n = tensor.shape[0]
40 | for col in range(tensor.shape[1]):
41 | translated = torch.LongTensor(self.corrs[col][pointer:pointer+n]).unsqueeze(1)
42 | encoded = torch.cat((encoded, translated), 1) if col > 0 else translated
43 | encoded_tensors.append(encoded)
44 | pointer += n
45 | return encoded_tensors
46 |
47 |
48 | class one_hot_max:
49 |
50 | def __init__(self, tensor_list, **kwargs):
51 | tensor_list = torch.cat(tensor_list,0)
52 | self.d = [int(tensor_list[:,i].max()+1) for i in range(tensor_list.shape[1])]
53 |
54 | def fit(self, tensor_list):
55 | return tensor_list
56 |
57 |
58 |
--------------------------------------------------------------------------------
/deepgraph/criterions/contrastive_loss.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 |
4 | from fairseq.dataclass.configs import FairseqDataclass
5 |
6 | import torch
7 | import torch.nn as nn
8 | from fairseq.logging import metrics
9 | # from fairseq import metrics
10 | from fairseq.criterions import FairseqCriterion, register_criterion
11 | import numpy as np
12 |
13 | @register_criterion("contrustive_loss", dataclass=FairseqDataclass)
14 | class GraphPretrainContrustiveLoss(FairseqCriterion):
15 | """
16 | Implementation for the L1 loss (MAE loss) used in graphormer model training.
17 | """
18 |
19 | def forward(self, model, sample, reduce=True):
20 | """Compute the loss for the given sample.
21 |
22 | Returns a tuple with three elements:
23 | 1) the loss
24 | 2) the sample size, which is used as the denominator for the gradient
25 | 3) logging outputs to display while training
26 | """
27 | sample_size = sample["nsamples"]
28 |
29 | with torch.no_grad():
30 | natoms = sample['net_input']['batched_data']['data']['x'].shape[1]
31 |
32 | x1 = model(sample['net_input']['batched_data']['data1'])
33 | x2 = model(sample['net_input']['batched_data']['data2'])
34 |
35 | T = 0.1
36 | batch_size, _ = x1.size()
37 | x1_abs = x1.norm(dim=1)
38 | x2_abs = x2.norm(dim=1)
39 |
40 | sim_matrix = torch.einsum('ik,jk->ij', x1, x2) / torch.einsum('i,j->ij', x1_abs, x2_abs).type(torch.float64)
41 | sim_matrix = torch.exp(sim_matrix / T)
42 | pos_sim = sim_matrix[np.arange(batch_size), np.arange(batch_size)]
43 | loss = pos_sim / (sim_matrix.sum(dim=1) - pos_sim)
44 | loss = - torch.log(loss).mean().type(torch.float16)
45 |
46 | logging_output = {
47 | "loss": loss.data,
48 | "sample_size": x1.size(0),
49 | "nsentences": sample_size,
50 | "ntokens": natoms,
51 | }
52 | return loss, sample_size, logging_output
53 |
54 | @staticmethod
55 | def reduce_metrics(logging_outputs) -> None:
56 | """Aggregate logging outputs from data parallel training."""
57 | loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
58 | sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
59 |
60 | metrics.log_scalar("loss", loss_sum / sample_size, sample_size, round=6)
61 |
62 | @staticmethod
63 | def logging_outputs_can_be_summed() -> bool:
64 | """
65 | Whether the logging outputs returned by `forward` can be summed
66 | across workers prior to calling `reduce_metrics`. Setting this
67 | to True will improves distributed training speed.
68 | """
69 | return True
70 |
--------------------------------------------------------------------------------
/deepgraph/data/algos.pyx:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 |
4 | import cython
5 | from cython.parallel cimport prange, parallel
6 | cimport numpy
7 | import numpy
8 |
9 | def floyd_warshall(adjacency_matrix):
10 |
11 | (nrows, ncols) = adjacency_matrix.shape
12 | assert nrows == ncols
13 | cdef unsigned int n = nrows
14 |
15 | adj_mat_copy = adjacency_matrix.astype(long, order='C', casting='safe', copy=True)
16 | assert adj_mat_copy.flags['C_CONTIGUOUS']
17 | cdef numpy.ndarray[long, ndim=2, mode='c'] M = adj_mat_copy
18 | cdef numpy.ndarray[long, ndim=2, mode='c'] path = -1 * numpy.ones([n, n], dtype=numpy.int64)
19 |
20 | cdef unsigned int i, j, k
21 | cdef long M_ij, M_ik, cost_ikkj
22 | cdef long* M_ptr = &M[0,0]
23 | cdef long* M_i_ptr
24 | cdef long* M_k_ptr
25 |
26 | # set unreachable nodes distance to 510
27 | for i in range(n):
28 | for j in range(n):
29 | if i == j:
30 | M[i][j] = 0
31 | elif M[i][j] == 0:
32 | M[i][j] = 510
33 |
34 | # floyed algo
35 | for k in range(n):
36 | M_k_ptr = M_ptr + n*k
37 | for i in range(n):
38 | M_i_ptr = M_ptr + n*i
39 | M_ik = M_i_ptr[k]
40 | for j in range(n):
41 | cost_ikkj = M_ik + M_k_ptr[j]
42 | M_ij = M_i_ptr[j]
43 | if M_ij > cost_ikkj:
44 | M_i_ptr[j] = cost_ikkj
45 | path[i][j] = k
46 |
47 | # set unreachable path to 510
48 | for i in range(n):
49 | for j in range(n):
50 | if M[i][j] >= 510:
51 | path[i][j] = 510
52 | M[i][j] = 510
53 |
54 | return M, path
55 |
56 |
57 | def get_all_edges(path, i, j):
58 | cdef int k = path[i][j]
59 | if k == -1:
60 | return []
61 | else:
62 | return get_all_edges(path, i, k) + [k] + get_all_edges(path, k, j)
63 |
64 |
65 | def gen_edge_input(max_dist, path, edge_feat):
66 |
67 | (nrows, ncols) = path.shape
68 | assert nrows == ncols
69 | cdef unsigned int n = nrows
70 | cdef unsigned int max_dist_copy = max_dist
71 |
72 | path_copy = path.astype(long, order='C', casting='safe', copy=True)
73 | edge_feat_copy = edge_feat.astype(long, order='C', casting='safe', copy=True)
74 | assert path_copy.flags['C_CONTIGUOUS']
75 | assert edge_feat_copy.flags['C_CONTIGUOUS']
76 |
77 | cdef numpy.ndarray[long, ndim=4, mode='c'] edge_fea_all = -1 * numpy.ones([n, n, max_dist_copy, edge_feat.shape[-1]], dtype=numpy.int64)
78 | cdef unsigned int i, j, k, num_path, cur
79 |
80 | for i in range(n):
81 | for j in range(n):
82 | if i == j:
83 | continue
84 | if path_copy[i][j] == 510:
85 | continue
86 | path = [i] + get_all_edges(path_copy, i, j) + [j]
87 | num_path = len(path) - 1
88 | for k in range(num_path):
89 | edge_fea_all[i, j, k, :] = edge_feat_copy[path[k], path[k+1], :]
90 |
91 | return edge_fea_all
92 |
--------------------------------------------------------------------------------
/deepgraph/data/substructure_dataset_utils/neighbor_extractors.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch_sparse import SparseTensor # for propagation
3 |
4 | def k_hop_subgraph(edge_index, num_nodes, num_hops):
5 | # return k-hop subgraphs for all nodes in the graph
6 | row, col = edge_index
7 | sparse_adj = SparseTensor(row=row, col=col, sparse_sizes=(num_nodes, num_nodes))
8 | hop_masks = [torch.eye(num_nodes, dtype=torch.bool, device=edge_index.device)] # each one contains <= i hop masks
9 | hop_indicator = row.new_full((num_nodes, num_nodes), -1)
10 | hop_indicator[hop_masks[0]] = 0
11 | for i in range(num_hops):
12 | next_mask = sparse_adj.matmul(hop_masks[i].float()) > 0
13 | hop_masks.append(next_mask)
14 | hop_indicator[(hop_indicator==-1) & next_mask] = i+1
15 | hop_indicator = hop_indicator.T # N x N
16 | node_mask = (hop_indicator >= 0) # N x N dense mask matrix
17 | return node_mask, hop_indicator
18 |
19 |
20 | from torch_cluster import random_walk
21 | def random_walk_subgraph(edge_index, num_nodes, walk_length, p=1, q=1, repeat=1, cal_hops=True, max_hops=10):
22 | """
23 | p (float, optional): Likelihood of immediately revisiting a node in the
24 | walk. (default: :obj:`1`) Setting it to a high value (> max(q, 1)) ensures
25 | that we are less likely to sample an already visited node in the following two steps.
26 | q (float, optional): Control parameter to interpolate between
27 | breadth-first strategy and depth-first strategy (default: :obj:`1`)
28 | if q > 1, the random walk is biased towards nodes close to node t.
29 | if q < 1, the walk is more inclined to visit nodes which are further away from the node t.
30 | p, q ∈ {0.25, 0.50, 1, 2, 4}.
31 | Typical values:
32 | Fix p and tune q
33 |
34 | repeat: restart the random walk many times and combine together for the result
35 |
36 | """
37 | row, col = edge_index
38 | start = torch.arange(num_nodes, device=edge_index.device)
39 | walks = [random_walk(row, col,
40 | start=start,
41 | walk_length=walk_length,
42 | p=p, q=q,
43 | num_nodes=num_nodes) for _ in range(repeat)]
44 | walk = torch.cat(walks, dim=-1)
45 | node_mask = row.new_empty((num_nodes, num_nodes), dtype=torch.bool)
46 | # print(walk.shape)
47 | node_mask.fill_(False)
48 | node_mask[start.repeat_interleave((walk_length+1)*repeat), walk.reshape(-1)] = True
49 | if cal_hops: # this is fast enough
50 | sparse_adj = SparseTensor(row=row, col=col, sparse_sizes=(num_nodes, num_nodes))
51 | hop_masks = [torch.eye(num_nodes, dtype=torch.bool, device=edge_index.device)]
52 | hop_indicator = row.new_full((num_nodes, num_nodes), -1)
53 | hop_indicator[hop_masks[0]] = 0
54 | for i in range(max_hops):
55 | next_mask = sparse_adj.matmul(hop_masks[i].float())>0
56 | hop_masks.append(next_mask)
57 | hop_indicator[(hop_indicator==-1) & next_mask] = i+1
58 | if hop_indicator[node_mask].min() != -1:
59 | break
60 | return node_mask, hop_indicator
61 | return node_mask, None
62 |
63 |
64 |
--------------------------------------------------------------------------------
/deepgraph/criterions/l1_loss.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 |
4 | from fairseq.dataclass.configs import FairseqDataclass
5 |
6 | import torch
7 | import torch.nn as nn
8 | from fairseq.logging import metrics
9 | # from fairseq import metrics
10 | from fairseq.criterions import FairseqCriterion, register_criterion
11 |
12 |
13 | @register_criterion("l1_loss", dataclass=FairseqDataclass)
14 | class GraphPredictionL1Loss(FairseqCriterion):
15 | """
16 | Implementation for the L1 loss (MAE loss) used in graphormer model training.
17 | """
18 |
19 | def forward(self, model, sample, reduce=True):
20 | """Compute the loss for the given sample.
21 |
22 | Returns a tuple with three elements:
23 | 1) the loss
24 | 2) the sample size, which is used as the denominator for the gradient
25 | 3) logging outputs to display while training
26 | """
27 | sample_size = sample["nsamples"]
28 |
29 | with torch.no_grad():
30 | natoms = sample["net_input"]["batched_data"]["x"].shape[1]
31 |
32 | logits = model(**sample["net_input"])
33 | logits = logits[:, 0, :]
34 | targets = model.get_targets(sample, [logits])
35 |
36 | loss = nn.L1Loss(reduction="sum")(logits, targets[: logits.size(0)])
37 |
38 | logging_output = {
39 | "loss": loss.data,
40 | "sample_size": logits.size(0),
41 | "nsentences": sample_size,
42 | "ntokens": natoms,
43 | }
44 | return loss, sample_size, logging_output
45 |
46 | @staticmethod
47 | def reduce_metrics(logging_outputs) -> None:
48 | """Aggregate logging outputs from data parallel training."""
49 | loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
50 | sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
51 |
52 | metrics.log_scalar("loss", loss_sum / sample_size, sample_size, round=6)
53 |
54 | @staticmethod
55 | def logging_outputs_can_be_summed() -> bool:
56 | """
57 | Whether the logging outputs returned by `forward` can be summed
58 | across workers prior to calling `reduce_metrics`. Setting this
59 | to True will improves distributed training speed.
60 | """
61 | return True
62 |
63 |
64 | @register_criterion("l1_loss_with_flag", dataclass=FairseqDataclass)
65 | class GraphPredictionL1LossWithFlag(GraphPredictionL1Loss):
66 | """
67 | Implementation for the binary log loss used in graphormer model training.
68 | """
69 |
70 | def perturb_forward(self, model, sample, perturb, reduce=True):
71 | """Compute the loss for the given sample.
72 |
73 | Returns a tuple with three elements:
74 | 1) the loss
75 | 2) the sample size, which is used as the denominator for the gradient
76 | 3) logging outputs to display while training
77 | """
78 | sample_size = sample["nsamples"]
79 |
80 | batch_data = sample["net_input"]["batched_data"]["x"]
81 | with torch.no_grad():
82 | natoms = batch_data.shape[1]
83 | logits = model(**sample["net_input"], perturb=perturb)[:, 0, :]
84 | targets = model.get_targets(sample, [logits])
85 | loss = nn.L1Loss(reduction="sum")(logits, targets[: logits.size(0)])
86 |
87 | logging_output = {
88 | "loss": loss.data,
89 | "sample_size": logits.size(0),
90 | "nsentences": sample_size,
91 | "ntokens": natoms,
92 | }
93 | return loss, sample_size, logging_output
94 |
--------------------------------------------------------------------------------
/deepgraph/evaluate/cache_data.py:
--------------------------------------------------------------------------------
1 |
2 | import numpy as np
3 | from fairseq import checkpoint_utils, utils, options, tasks
4 | from fairseq.logging import progress_bar
5 | from fairseq.dataclass.utils import convert_namespace_to_omegaconf
6 |
7 | import os
8 |
9 | import sys
10 | from os import path
11 |
12 | from multiprocessing import Pool
13 |
14 | sys.path.append( path.dirname( path.dirname( path.abspath(__file__) ) ) )
15 |
16 | import logging
17 | import lmdb
18 |
19 | logging.root.setLevel(logging.INFO)
20 | logging.basicConfig(level=logging.INFO)
21 |
22 | def f(x):
23 | return x * x
24 |
25 | if __name__ == '__main__':
26 | with Pool(5) as p:
27 | print(p.map(f, [1, 2, 3]))
28 |
29 |
30 |
31 | parser = options.get_training_parser()
32 |
33 | res_lis = []
34 | args = options.parse_args_and_arch(parser, modify_parser=None)
35 | logger = logging.getLogger(__name__)
36 |
37 | checkpoint_path = None
38 | logger.info(f"evaluating checkpoint file {checkpoint_path}")
39 |
40 |
41 |
42 |
43 | cfg = convert_namespace_to_omegaconf(args)
44 | np.random.seed(cfg.common.seed)
45 | utils.set_torch_seed(cfg.common.seed)
46 |
47 | # initialize task
48 | task = tasks.setup_task(cfg.task)
49 | model = task.build_model(cfg.model)
50 |
51 |
52 | split='inner'
53 | task.load_dataset(split)
54 | batch_iterator = task.get_batch_iterator(
55 | dataset=task.dataset(split),
56 | max_tokens=cfg.dataset.max_tokens_valid,
57 | max_sentences=cfg.dataset.batch_size_valid,
58 | max_positions=utils.resolve_max_positions(
59 | task.max_positions(),
60 | model.max_positions(),
61 | ),
62 | ignore_invalid_inputs=cfg.dataset.skip_invalid_size_inputs_valid_test,
63 | required_batch_size_multiple=cfg.dataset.required_batch_size_multiple,
64 | seed=cfg.common.seed,
65 | num_workers=cfg.dataset.num_workers,
66 | epoch=0,
67 | data_buffer_size=cfg.dataset.data_buffer_size,
68 | disable_iterator_cache=False,
69 | )
70 | itr = batch_iterator.next_epoch_itr(
71 | shuffle=False, set_dataset_epoch=False
72 | )
73 | progress = progress_bar.progress_bar(
74 | itr,
75 | log_format=cfg.common.log_format,
76 | log_interval=cfg.common.log_interval,
77 | default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple")
78 | )
79 |
80 |
81 | path_lmdb = os.path.join(task.dm.dataset.args['data_dir'],task.dm.dataset.args['lmdb_root_dir'],task.dm.dataset.args['transform_cache_path']+ task.dm.dataset.args['transform_dir'])
82 | print('saving to ', path_lmdb)
83 |
84 | if not os.path.exists(path_lmdb):
85 | os.makedirs(path_lmdb)
86 | lmdb_env = lmdb.open(path_lmdb, map_size=1e12)
87 | txn = lmdb_env.begin(write=True)
88 |
89 |
90 |
91 | for data in progress:
92 |
93 | idx=data["net_input"]["batched_data"]['idx']
94 | subgraph_tensor_res=data["net_input"]["batched_data"]['subgraph_tensor_res']
95 | sorted_adj_res=data["net_input"]["batched_data"]['sorted_adj_res']
96 |
97 | for i in range(len(idx)):
98 | id_cur=idx[i]
99 | subgraph_tensor = subgraph_tensor_res[i]
100 | sorted_adj = sorted_adj_res[i]
101 | assert len(subgraph_tensor)==args.transform_cache_number
102 | if len(subgraph_tensor)>1:
103 | assert subgraph_tensor[0] != subgraph_tensor[1]
104 | for sample_id in range(args.transform_cache_number):
105 | result = [subgraph_tensor[sample_id], sorted_adj[sample_id]]
106 | txn.put(key=(str(id_cur) + '_' + str(sample_id)).encode(), value=str(result).encode())
107 |
108 |
109 | print('********************************commit substructure sampling caching **************************')
110 | txn.commit()
111 | lmdb_env.close()
112 |
113 |
114 |
115 |
116 |
--------------------------------------------------------------------------------
/deepgraph/data/subsampling/sampler.py:
--------------------------------------------------------------------------------
1 | from torch_geometric.data import Data
2 | from deepgraph.data.subsampling.sampling import *
3 |
4 | import re
5 |
6 |
7 |
8 |
9 | class Subgraphs_Sampler(object):
10 | def __init__(self,
11 | sampling_mode=None,
12 | minimum_redundancy=2,
13 | shortest_path_mode_stride=2,
14 | random_mode_sampling_rate=0.3,
15 | random_init=False,
16 | only_unused_nodes=True):
17 | super().__init__()
18 |
19 | self.random_init = random_init
20 | self.sampling_mode = sampling_mode
21 | self.minimum_redundancy = minimum_redundancy
22 | self.shortest_path_mode_stride = shortest_path_mode_stride
23 | self.random_mode_sampling_rate = random_mode_sampling_rate
24 | self.only_unused_nodes=only_unused_nodes
25 |
26 |
27 | def __call__(self, edge_index,subgraphs_nodes_tensor,num_nodes,must_select_list):
28 |
29 | if subgraphs_nodes_tensor.shape[1]>0:
30 | selected_subgraphs, node_selected_times = subsampling_subgraphs_general(edge_index,
31 | subgraphs_nodes_tensor,
32 | num_nodes,
33 | sampling_mode=self.sampling_mode,
34 | random_init=self.random_init,
35 | minimum_redundancy=self.minimum_redundancy,
36 | shortest_path_mode_stride=self.shortest_path_mode_stride,
37 | random_mode_sampling_rate=self.random_mode_sampling_rate,
38 | must_select_list=must_select_list,only_unused_nodes=self.only_unused_nodes)
39 | else:
40 | selected_subgraphs=[]
41 | node_selected_times=None
42 |
43 | return selected_subgraphs,node_selected_times
44 |
45 |
46 |
47 |
48 | def subsampling_subgraphs_general(edge_index, subgraphs_nodes, num_nodes=None, sampling_mode='shortest_path', random_init=False, minimum_redundancy=0,
49 | shortest_path_mode_stride=2, random_mode_sampling_rate=0.5,must_select_list=None,only_unused_nodes=None):
50 |
51 | assert sampling_mode in ['shortest_path', 'random', 'min_set_cover','min_set_cover_random']
52 | if sampling_mode == 'random':
53 | selected_subgraphs, node_selected_times = random_sampling_general(subgraphs_nodes, num_nodes=num_nodes, rate=random_mode_sampling_rate, minimum_redundancy=minimum_redundancy,must_select_list=must_select_list)
54 | if sampling_mode == 'shortest_path':
55 | selected_subgraphs, node_selected_times = shortest_path_sampling_general(edge_index, subgraphs_nodes, num_nodes=num_nodes, minimum_redundancy=minimum_redundancy,
56 | stride=max(1, shortest_path_mode_stride), random_init=random_init)
57 | if sampling_mode in ['min_set_cover']:
58 |
59 | selected_subgraphs, node_selected_times = min_set_cover_sampling_general(subgraphs_nodes,
60 | minimum_redundancy=minimum_redundancy, random_init=random_init,num_nodes=num_nodes,must_select_list=must_select_list,only_unused_nodes=only_unused_nodes)
61 |
62 | if sampling_mode in ['min_set_cover_random']:
63 |
64 | selected_subgraphs, node_selected_times = min_set_cover_random_sampling_general(subgraphs_nodes,
65 | minimum_redundancy=minimum_redundancy, random_init=random_init,num_nodes=num_nodes,must_select_list=must_select_list,only_unused_nodes=only_unused_nodes)
66 | else:
67 | raise ValueError('not supported sampling mode')
68 |
69 | return selected_subgraphs, node_selected_times
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
--------------------------------------------------------------------------------
/deepgraph/criterions/multiclass_cross_entropy.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 |
4 | from fairseq.dataclass.configs import FairseqDataclass
5 |
6 | import torch
7 | from torch.nn import functional
8 | from fairseq.logging import metrics
9 | # from fairseq import metrics
10 | from fairseq.criterions import FairseqCriterion, register_criterion
11 |
12 |
13 | @register_criterion("multiclass_cross_entropy", dataclass=FairseqDataclass)
14 | class GraphPredictionMulticlassCrossEntropy(FairseqCriterion):
15 | """
16 | Implementation for the multi-class log loss used in graphormer model training.
17 | """
18 |
19 | def forward(self, model, sample, reduce=True):
20 | """Compute the loss for the given sample.
21 |
22 | Returns a tuple with three elements:
23 | 1) the loss
24 | 2) the sample size, which is used as the denominator for the gradient
25 | 3) logging outputs to display while training
26 | """
27 | sample_size = sample["nsamples"]
28 |
29 | with torch.no_grad():
30 | natoms = sample["net_input"]["batched_data"]["x"].shape[1]
31 |
32 | logits = model(**sample["net_input"])
33 | logits = logits[:, 0, :]
34 | targets = model.get_targets(sample, [logits])[: logits.size(0)]
35 | ncorrect = (torch.argmax(logits, dim=-1).reshape(-1) == targets.reshape(-1)).sum()
36 |
37 | loss = functional.cross_entropy(
38 | logits, targets.reshape(-1), reduction="sum"
39 | )
40 |
41 | logging_output = {
42 | "loss": loss.data,
43 | "sample_size": sample_size,
44 | "nsentences": sample_size,
45 | "ntokens": natoms,
46 | "ncorrect": ncorrect,
47 | }
48 | return loss, sample_size, logging_output
49 |
50 | @staticmethod
51 | def reduce_metrics(logging_outputs) -> None:
52 | """Aggregate logging outputs from data parallel training."""
53 | loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
54 | sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
55 |
56 | metrics.log_scalar("loss", loss_sum / sample_size, sample_size, round=3)
57 | if len(logging_outputs) > 0 and "ncorrect" in logging_outputs[0]:
58 | ncorrect = sum(log.get("ncorrect", 0) for log in logging_outputs)
59 | metrics.log_scalar(
60 | "accuracy", 100.0 * ncorrect / sample_size, sample_size, round=1
61 | )
62 |
63 | @staticmethod
64 | def logging_outputs_can_be_summed() -> bool:
65 | """
66 | Whether the logging outputs returned by `forward` can be summed
67 | across workers prior to calling `reduce_metrics`. Setting this
68 | to True will improves distributed training speed.
69 | """
70 | return True
71 |
72 |
73 | @register_criterion("multiclass_cross_entropy_with_flag", dataclass=FairseqDataclass)
74 | class GraphPredictionMulticlassCrossEntropyWithFlag(GraphPredictionMulticlassCrossEntropy):
75 | """
76 | Implementation for the multi-class log loss used in graphormer model training.
77 | """
78 |
79 | def forward(self, model, sample, reduce=True):
80 | """Compute the loss for the given sample.
81 |
82 | Returns a tuple with three elements:
83 | 1) the loss
84 | 2) the sample size, which is used as the denominator for the gradient
85 | 3) logging outputs to display while training
86 | """
87 | sample_size = sample["nsamples"]
88 | perturb = sample.get("perturb", None)
89 |
90 | with torch.no_grad():
91 | natoms = sample["net_input"]["batched_data"]["x"].shape[1]
92 |
93 | logits = model(**sample["net_input"], perturb=perturb)
94 | logits = logits[:, 0, :]
95 | targets = model.get_targets(sample, [logits])[: logits.size(0)]
96 | ncorrect = (torch.argmax(logits, dim=-1).reshape(-1) == targets.reshape(-1)).sum()
97 |
98 | loss = functional.cross_entropy(
99 | logits, targets.reshape(-1), reduction="sum"
100 | )
101 |
102 | logging_output = {
103 | "loss": loss.data,
104 | "sample_size": sample_size,
105 | "nsentences": sample_size,
106 | "ntokens": natoms,
107 | "ncorrect": ncorrect,
108 | }
109 | return loss, sample_size, logging_output
110 |
--------------------------------------------------------------------------------
/deepgraph/criterions/multilabel_multiclass_cross_entropy.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 |
4 | from fairseq.dataclass.configs import FairseqDataclass
5 |
6 | import torch
7 | from torch.nn import functional
8 | from fairseq.logging import metrics
9 | # from fairseq import metrics
10 | from fairseq.criterions import FairseqCriterion, register_criterion
11 |
12 |
13 | @register_criterion("multilabel_multiclass_cross_entropy", dataclass=FairseqDataclass)
14 | class GraphPredictionMultiLabelMulticlassCrossEntropy(FairseqCriterion):
15 | """
16 | Implementation for the multi-class log loss used in graphormer model training.
17 | """
18 |
19 | def forward(self, model, sample, reduce=True):
20 | """Compute the loss for the given sample.
21 |
22 | Returns a tuple with three elements:
23 | 1) the loss
24 | 2) the sample size, which is used as the denominator for the gradient
25 | 3) logging outputs to display while training
26 | """
27 |
28 | logits = model(**sample["net_input"])
29 |
30 | targets = model.get_targets(sample, [logits])[: logits.size(0)]
31 |
32 |
33 | logits=logits.reshape([logits.shape[0]*logits.shape[1],logits.shape[2]])
34 | targets=targets
35 | ncorrect = (torch.argmax(logits, dim=-1).reshape(-1) == targets.reshape(-1)).sum()
36 |
37 | loss = functional.cross_entropy(
38 | logits, targets.reshape(-1), reduction="sum"
39 | )
40 |
41 | sample_size = sample["nsamples"]
42 |
43 | with torch.no_grad():
44 | natoms = targets.shape[0]*targets.shape[1]
45 |
46 |
47 | logging_output = {
48 | "loss": float(loss.data),
49 | "sample_size": natoms,
50 | "nsentences": sample_size,
51 | "ntokens": natoms,
52 | "ncorrect": ncorrect,
53 | }
54 | return loss, natoms, logging_output
55 |
56 | @staticmethod
57 | def reduce_metrics(logging_outputs) -> None:
58 | """Aggregate logging outputs from data parallel training."""
59 | loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
60 | sample_size = sum(log.get("ntokens", 0) for log in logging_outputs)
61 |
62 | metrics.log_scalar("loss", loss_sum / sample_size, sample_size, round=3)
63 | if len(logging_outputs) > 0 and "ncorrect" in logging_outputs[0]:
64 | ncorrect = sum(log.get("ncorrect", 0) for log in logging_outputs)
65 | metrics.log_scalar(
66 | "accuracy", 100.0 * ncorrect / sample_size, sample_size, round=1
67 | )
68 |
69 | @staticmethod
70 | def logging_outputs_can_be_summed() -> bool:
71 | """
72 | Whether the logging outputs returned by `forward` can be summed
73 | across workers prior to calling `reduce_metrics`. Setting this
74 | to True will improves distributed training speed.
75 | """
76 | return True
77 |
78 |
79 | @register_criterion("multilabel_multiclass_cross_entropy_with_flag", dataclass=FairseqDataclass)
80 | class GraphPredictionMultiLabelMulticlassCrossEntropyWithFlag(GraphPredictionMultiLabelMulticlassCrossEntropy):
81 | """
82 | Implementation for the multi-class log loss used in graphormer model training.
83 | """
84 |
85 | def forward(self, model, sample, reduce=True):
86 | """Compute the loss for the given sample.
87 |
88 | Returns a tuple with three elements:
89 | 1) the loss
90 | 2) the sample size, which is used as the denominator for the gradient
91 | 3) logging outputs to display while training
92 | """
93 | sample_size = sample["nsamples"]
94 | perturb = sample.get("perturb", None)
95 |
96 | with torch.no_grad():
97 | natoms = sample["net_input"]["batched_data"]["x"].shape[1]
98 |
99 | logits = model(**sample["net_input"], perturb=perturb)
100 | logits = logits[:, 0, :,:]
101 | targets = model.get_targets(sample, [logits])[: logits.size(0)]
102 | ncorrect = (torch.argmax(logits, dim=-1).reshape(-1) == targets.reshape(-1)).sum()
103 |
104 | loss = functional.cross_entropy(
105 | logits, targets.reshape(-1), reduction="sum"
106 | )
107 |
108 | logging_output = {
109 | "loss": loss.data,
110 | "sample_size": sample_size,
111 | "nsentences": sample_size,
112 | "ntokens": natoms,
113 | "ncorrect": ncorrect,
114 | }
115 | return loss, sample_size, logging_output
116 |
--------------------------------------------------------------------------------
/deepgraph/criterions/binary_logloss.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 |
4 | from fairseq.dataclass.configs import FairseqDataclass
5 |
6 | import torch
7 | from torch.nn import functional
8 | from fairseq.logging import metrics
9 | from fairseq import utils
10 | # from fairseq import metrics, utils
11 | from fairseq.criterions import FairseqCriterion, register_criterion
12 |
13 |
14 | @register_criterion("binary_logloss", dataclass=FairseqDataclass)
15 | class GraphPredictionBinaryLogLoss(FairseqCriterion):
16 | """
17 | Implementation for the binary log loss used in graphormer model training.
18 | """
19 |
20 | def forward(self, model, sample, reduce=True):
21 | """Compute the loss for the given sample.
22 |
23 | Returns a tuple with three elements:
24 | 1) the loss
25 | 2) the sample size, which is used as the denominator for the gradient
26 | 3) logging outputs to display while training
27 | """
28 | sample_size = sample["nsamples"]
29 |
30 | with torch.no_grad():
31 | natoms = sample["net_input"]["batched_data"]["x"].shape[1]
32 |
33 | logits = model(**sample["net_input"])
34 | logits = logits[:, 0, :]
35 | targets = model.get_targets(sample, [logits])
36 | preds = torch.where(torch.sigmoid(logits) < 0.5, 0, 1)
37 |
38 | logits_flatten = logits.reshape(-1)
39 | targets_flatten = targets[: logits.size(0)].reshape(-1)
40 | mask = ~torch.isnan(targets_flatten)
41 | loss = functional.binary_cross_entropy_with_logits(
42 | logits_flatten[mask].float(), targets_flatten[mask].float(), reduction="sum"
43 | )
44 |
45 | logging_output = {
46 | "loss": loss.data,
47 | "sample_size": torch.sum(mask.type(torch.int64)),
48 | "nsentences": sample_size,
49 | "ntokens": natoms,
50 | "ncorrect": (preds == targets[:preds.size(0)]).sum(),
51 | }
52 | return loss, sample_size, logging_output
53 |
54 | @staticmethod
55 | def reduce_metrics(logging_outputs) -> None:
56 | """Aggregate logging outputs from data parallel training."""
57 | loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
58 | sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
59 |
60 | metrics.log_scalar("loss", loss_sum / sample_size, sample_size, round=3)
61 | if len(logging_outputs) > 0 and "ncorrect" in logging_outputs[0]:
62 | ncorrect = sum(log.get("ncorrect", 0) for log in logging_outputs)
63 | metrics.log_scalar(
64 | "accuracy", 100.0 * ncorrect / sample_size, sample_size, round=1
65 | )
66 |
67 | @staticmethod
68 | def logging_outputs_can_be_summed() -> bool:
69 | """
70 | Whether the logging outputs returned by `forward` can be summed
71 | across workers prior to calling `reduce_metrics`. Setting this
72 | to True will improves distributed training speed.
73 | """
74 | return True
75 |
76 |
77 | @register_criterion("binary_logloss_with_flag", dataclass=FairseqDataclass)
78 | class GraphPredictionBinaryLogLossWithFlag(GraphPredictionBinaryLogLoss):
79 | """
80 | Implementation for the binary log loss used in graphormer model training.
81 | """
82 |
83 | def forward(self, model, sample, reduce=True):
84 | """Compute the loss for the given sample.
85 |
86 | Returns a tuple with three elements:
87 | 1) the loss
88 | 2) the sample size, which is used as the denominator for the gradient
89 | 3) logging outputs to display while training
90 | """
91 | sample_size = sample["nsamples"]
92 | perturb = sample.get("perturb", None)
93 |
94 | batch_data = sample["net_input"]["batched_data"]["x"]
95 | with torch.no_grad():
96 | natoms = batch_data.shape[1]
97 | logits = model(**sample["net_input"], perturb=perturb)[:, 0, :]
98 | targets = model.get_targets(sample, [logits])
99 | preds = torch.where(torch.sigmoid(logits) < 0.5, 0, 1)
100 |
101 | logits_flatten = logits.reshape(-1)
102 | targets_flatten = targets[: logits.size(0)].reshape(-1)
103 | mask = ~torch.isnan(targets_flatten)
104 | loss = functional.binary_cross_entropy_with_logits(
105 | logits_flatten[mask].float(), targets_flatten[mask].float(), reduction="sum"
106 | )
107 |
108 | logging_output = {
109 | "loss": loss.data,
110 | "sample_size": logits.size(0),
111 | "nsentences": sample_size,
112 | "ntokens": natoms,
113 | "ncorrect": (preds == targets[:preds.size(0)]).sum(),
114 | }
115 | return loss, sample_size, logging_output
116 |
--------------------------------------------------------------------------------
/deepgraph/criterions/node_multiclass_cross_entropy.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 |
4 | from fairseq.dataclass.configs import FairseqDataclass
5 |
6 | import torch
7 | from torch.nn import functional
8 | from fairseq.logging import metrics
9 | # from fairseq import metrics
10 | from fairseq.criterions import FairseqCriterion, register_criterion
11 |
12 |
13 | @register_criterion("node_multiclass_cross_entropy", dataclass=FairseqDataclass)
14 | class GraphNodePredictionMulticlassCrossEntropy(FairseqCriterion):
15 | """
16 | Implementation for the multi-class log loss used in graphormer model training.
17 | """
18 |
19 | def forward(self, model, sample, reduce=True):
20 | """Compute the loss for the given sample.
21 |
22 | Returns a tuple with three elements:
23 | 1) the loss
24 | 2) the sample size, which is used as the denominator for the gradient
25 | 3) logging outputs to display while training
26 | """
27 |
28 | logits = model(**sample["net_input"])
29 | # logits = logits[:, 0, :]
30 | logits = logits[:, 1:, :]
31 | targets = model.get_targets(sample, [logits])[: logits.size(0)]
32 |
33 | if 'sub_adj_mask' in sample['net_input']['batched_data']:
34 | node_ind=(sample['net_input']['batched_data']['sub_adj_mask'][:,:,0]==1)
35 | node_ind_y=node_ind[:,0:targets.shape[1]]
36 | else:
37 | node_ind=~(sample['net_input']['batched_data']['attn_bias'][:,1,:].isinf())[:,1:]
38 | node_ind_y=node_ind
39 | logits_node=logits[node_ind]
40 |
41 | targets_node=targets[node_ind_y]-1
42 | ncorrect = (torch.argmax(logits_node, dim=-1).reshape(-1) == targets_node.reshape(-1)).sum()
43 |
44 | loss = functional.cross_entropy(
45 | logits_node, targets_node.reshape(-1), reduction="sum"
46 | )
47 |
48 |
49 | sample_size = sample["nsamples"]
50 |
51 | with torch.no_grad():
52 | natoms = int(node_ind_y.sum())
53 |
54 |
55 | logging_output = {
56 | "loss": float(loss.data),
57 | "sample_size": natoms,
58 | "nsentences": sample_size,
59 | "ntokens": natoms,
60 | "ncorrect": ncorrect,
61 | }
62 | return loss, natoms, logging_output
63 |
64 | @staticmethod
65 | def reduce_metrics(logging_outputs) -> None:
66 | """Aggregate logging outputs from data parallel training."""
67 | loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
68 | sample_size = sum(log.get("ntokens", 0) for log in logging_outputs)
69 |
70 | metrics.log_scalar("loss", loss_sum / sample_size, sample_size, round=3)
71 | if len(logging_outputs) > 0 and "ncorrect" in logging_outputs[0]:
72 | ncorrect = sum(log.get("ncorrect", 0) for log in logging_outputs)
73 | metrics.log_scalar(
74 | "accuracy", 100.0 * ncorrect / sample_size, sample_size, round=1
75 | )
76 |
77 | @staticmethod
78 | def logging_outputs_can_be_summed() -> bool:
79 | """
80 | Whether the logging outputs returned by `forward` can be summed
81 | across workers prior to calling `reduce_metrics`. Setting this
82 | to True will improves distributed training speed.
83 | """
84 | return True
85 |
86 |
87 | @register_criterion("node_multiclass_cross_entropy_with_flag", dataclass=FairseqDataclass)
88 | class GraphNodePredictionMulticlassCrossEntropyWithFlag(GraphNodePredictionMulticlassCrossEntropy):
89 | """
90 | Implementation for the multi-class log loss used in graphormer model training.
91 | """
92 |
93 | def forward(self, model, sample, reduce=True):
94 | """Compute the loss for the given sample.
95 |
96 | Returns a tuple with three elements:
97 | 1) the loss
98 | 2) the sample size, which is used as the denominator for the gradient
99 | 3) logging outputs to display while training
100 | """
101 | sample_size = sample["nsamples"]
102 | perturb = sample.get("perturb", None)
103 |
104 | with torch.no_grad():
105 | natoms = sample["net_input"]["batched_data"]["x"].shape[1]
106 |
107 | logits = model(**sample["net_input"], perturb=perturb)
108 | logits = logits[:, 0, :]
109 | targets = model.get_targets(sample, [logits])[: logits.size(0)]
110 | ncorrect = (torch.argmax(logits, dim=-1).reshape(-1) == targets.reshape(-1)).sum()
111 |
112 | loss = functional.cross_entropy(
113 | logits, targets.reshape(-1), reduction="sum"
114 | )
115 |
116 | logging_output = {
117 | "loss": loss.data,
118 | "sample_size": sample_size,
119 | "nsentences": sample_size,
120 | "ntokens": natoms,
121 | "ncorrect": ncorrect,
122 | }
123 | return loss, sample_size, logging_output
124 |
--------------------------------------------------------------------------------
/deepgraph/criterions/mae_deltapos.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 |
4 | from typing import Callable, Mapping, Sequence, Tuple
5 | from numpy import mod
6 | import torch
7 | from torch import Tensor
8 | import torch.nn.functional as F
9 |
10 |
11 | from fairseq.logging import metrics
12 | # from fairseq import metrics
13 | from fairseq.criterions import FairseqCriterion, register_criterion
14 |
15 |
16 | @register_criterion("mae_deltapos")
17 | class IS2RECriterion(FairseqCriterion):
18 | e_thresh = 0.02
19 | e_mean = -1.4729953244844094
20 | e_std = 2.2707848125378405
21 | d_mean = [0.1353900283575058, 0.06877671927213669, 0.08111362904310226]
22 | d_std = [1.7862379550933838, 1.78688645362854, 0.8023099899291992]
23 |
24 | def __init__(self, task, cfg):
25 | super().__init__(task)
26 | self.node_loss_weight = cfg.node_loss_weight
27 | self.min_node_loss_weight = cfg.min_node_loss_weight
28 | self.max_update = cfg.max_update
29 | self.node_loss_weight_range = max(
30 | 0, self.node_loss_weight - self.min_node_loss_weight
31 | )
32 |
33 | def forward(
34 | self,
35 | model: Callable[..., Tuple[Tensor, Tensor, Tensor]],
36 | sample: Mapping[str, Mapping[str, Tensor]],
37 | reduce=True,
38 | ):
39 | update_num = model.num_updates
40 | assert update_num >= 0
41 | node_loss_weight = (
42 | self.node_loss_weight
43 | - self.node_loss_weight_range * update_num / self.max_update
44 | )
45 |
46 | valid_nodes = sample["net_input"]["atoms"].ne(0).sum()
47 | output, node_output, node_target_mask = model(
48 | **sample["net_input"],
49 | )
50 |
51 | relaxed_energy = sample["targets"]["relaxed_energy"]
52 | relaxed_energy = relaxed_energy.float()
53 | relaxed_energy = (relaxed_energy - self.e_mean) / self.e_std
54 | sample_size = relaxed_energy.numel()
55 | loss = F.l1_loss(output.float().view(-1), relaxed_energy, reduction="none")
56 | with torch.no_grad():
57 | energy_within_threshold = (loss.detach() * self.e_std < self.e_thresh).sum()
58 | loss = loss.sum()
59 |
60 | deltapos = sample["targets"]["deltapos"].float()
61 | deltapos = (deltapos - deltapos.new_tensor(self.d_mean)) / deltapos.new_tensor(
62 | self.d_std
63 | )
64 | deltapos *= node_target_mask
65 | node_output *= node_target_mask
66 | target_cnt = node_target_mask.sum(dim=[1, 2])
67 | node_loss = (
68 | F.l1_loss(node_output.float(), deltapos, reduction="none")
69 | .mean(dim=-1)
70 | .sum(dim=-1)
71 | / target_cnt
72 | ).sum()
73 |
74 | logging_output = {
75 | "loss": loss.detach(),
76 | "energy_within_threshold": energy_within_threshold,
77 | "node_loss": node_loss.detach(),
78 | "sample_size": sample_size,
79 | "nsentences": sample_size,
80 | "num_nodes": valid_nodes.detach(),
81 | "node_loss_weight": node_loss_weight * sample_size,
82 | }
83 | return loss + node_loss_weight * node_loss, sample_size, logging_output
84 |
85 | @staticmethod
86 | def reduce_metrics(logging_outputs: Sequence[Mapping]) -> None:
87 | loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
88 | energy_within_threshold_sum = sum(
89 | log.get("energy_within_threshold", 0) for log in logging_outputs
90 | )
91 | node_loss_sum = sum(log.get("node_loss", 0) for log in logging_outputs)
92 | sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
93 |
94 | mean_loss = (loss_sum / sample_size) * IS2RECriterion.e_std
95 | energy_within_threshold = energy_within_threshold_sum / sample_size
96 | mean_node_loss = (node_loss_sum / sample_size) * sum(IS2RECriterion.d_std) / 3.0
97 | mean_n_nodes = (
98 | sum([log.get("num_nodes", 0) for log in logging_outputs]) / sample_size
99 | )
100 | node_loss_weight = (
101 | sum([log.get("node_loss_weight", 0) for log in logging_outputs])
102 | / sample_size
103 | )
104 |
105 | metrics.log_scalar("loss", mean_loss, sample_size, round=6)
106 | metrics.log_scalar("ewth", energy_within_threshold, sample_size, round=6)
107 | metrics.log_scalar("node_loss", mean_node_loss, sample_size, round=6)
108 | metrics.log_scalar("nodes_per_graph", mean_n_nodes, sample_size, round=6)
109 | metrics.log_scalar("node_loss_weight", node_loss_weight, sample_size, round=6)
110 |
111 | @staticmethod
112 | def logging_outputs_can_be_summed() -> bool:
113 | """
114 | Whether the logging outputs returned by `forward` can be summed
115 | across workers prior to calling `reduce_metrics`. Setting this
116 | to True will improves distributed training speed.
117 | """
118 | return True
119 |
--------------------------------------------------------------------------------
/deepgraph/evaluate/evaluate.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | import numpy as np
4 | from fairseq import checkpoint_utils, utils, options, tasks
5 | from fairseq.logging import progress_bar
6 | from fairseq.dataclass.utils import convert_namespace_to_omegaconf
7 | import ogb
8 | import sys
9 | import os
10 | from pathlib import Path
11 | from sklearn.metrics import roc_auc_score
12 |
13 | import sys
14 | from os import path
15 |
16 | sys.path.append( path.dirname( path.dirname( path.abspath(__file__) ) ) )
17 | from pretrain import load_pretrained_model
18 |
19 | import logging
20 |
21 | def eval(args, use_pretrained, checkpoint_path=None, logger=None):
22 | cfg = convert_namespace_to_omegaconf(args)
23 | np.random.seed(cfg.common.seed)
24 | utils.set_torch_seed(cfg.common.seed)
25 |
26 | # initialize task
27 | task = tasks.setup_task(cfg.task)
28 | model = task.build_model(cfg.model)
29 |
30 | # load checkpoint
31 | if use_pretrained:
32 | model_state = load_pretrained_model(cfg.task.pretrained_model_name)
33 | else:
34 | model_state = torch.load(checkpoint_path)["model"]
35 | model.load_state_dict(
36 | model_state, strict=True, model_cfg=cfg.model
37 | )
38 | del model_state
39 |
40 | model.to(torch.cuda.current_device())
41 | # load dataset
42 | split = args.split
43 | task.load_dataset(split)
44 | batch_iterator = task.get_batch_iterator(
45 | dataset=task.dataset(split),
46 | max_tokens=cfg.dataset.max_tokens_valid,
47 | max_sentences=cfg.dataset.batch_size_valid,
48 | max_positions=utils.resolve_max_positions(
49 | task.max_positions(),
50 | model.max_positions(),
51 | ),
52 | ignore_invalid_inputs=cfg.dataset.skip_invalid_size_inputs_valid_test,
53 | required_batch_size_multiple=cfg.dataset.required_batch_size_multiple,
54 | seed=cfg.common.seed,
55 | num_workers=cfg.dataset.num_workers,
56 | epoch=0,
57 | data_buffer_size=cfg.dataset.data_buffer_size,
58 | disable_iterator_cache=False,
59 | )
60 | itr = batch_iterator.next_epoch_itr(
61 | shuffle=False, set_dataset_epoch=False
62 | )
63 | progress = progress_bar.progress_bar(
64 | itr,
65 | log_format=cfg.common.log_format,
66 | log_interval=cfg.common.log_interval,
67 | default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple")
68 | )
69 |
70 | # infer
71 | y_pred = []
72 | y_true = []
73 | with torch.no_grad():
74 | model.eval()
75 | for i, sample in enumerate(progress):
76 | sample = utils.move_to_cuda(sample)
77 | if args.criterion=='multilabel_multiclass_cross_entropy':
78 | y = model(**sample["net_input"]).squeeze(0).argmax(1)
79 | y_pred.extend(y.detach().cpu())
80 | y_true.extend(sample["target"].detach().cpu())
81 | else:
82 | y = model(**sample["net_input"])[:, 0, :].reshape(-1)
83 | y_pred.extend(y.detach().cpu())
84 | y_true.extend(sample["target"].detach().cpu().reshape(-1)[:y.shape[0]])
85 | torch.cuda.empty_cache()
86 |
87 | # save predictions
88 | y_pred = torch.Tensor(y_pred)
89 | y_true = torch.Tensor(y_true)
90 |
91 | # evaluate pretrained models
92 | if use_pretrained:
93 | if cfg.task.pretrained_model_name == "pcqm4mv1_graphormer_base":
94 | evaluator = ogb.lsc.PCQM4MEvaluator()
95 | input_dict = {'y_pred': y_pred, 'y_true': y_true}
96 | result_dict = evaluator.eval(input_dict)
97 | logger.info(f'PCQM4Mv1Evaluator: {result_dict}')
98 | elif cfg.task.pretrained_model_name == "pcqm4mv2_graphormer_base":
99 | evaluator = ogb.lsc.PCQM4Mv2Evaluator()
100 | input_dict = {'y_pred': y_pred, 'y_true': y_true}
101 | result_dict = evaluator.eval(input_dict)
102 | logger.info(f'PCQM4Mv2Evaluator: {result_dict}')
103 | else:
104 | if args.metric == "auc":
105 | auc = roc_auc_score(y_true, y_pred)
106 | logger.info(f"auc: {auc}")
107 | return auc
108 | elif args.metric == "mae":
109 | mae = np.mean(np.abs(y_true - y_pred).numpy())
110 | logger.info(f"mae: {mae}")
111 | return mae
112 | else:
113 | raise ValueError(f"Unsupported metric {args.metric}")
114 |
115 | def main():
116 | parser = options.get_training_parser()
117 | parser.add_argument(
118 | "--split",
119 | type=str,
120 | )
121 | parser.add_argument(
122 | "--metric",
123 | type=str,
124 | )
125 | res_lis = []
126 | args = options.parse_args_and_arch(parser, modify_parser=None)
127 | logger = logging.getLogger(__name__)
128 | if args.pretrained_model_name != "none":
129 | eval(args, True, logger=logger)
130 | elif hasattr(args, "save_dir"):
131 |
132 | # for checkpoint_fname in ['checkpoint_best.pt','checkpoint_last.pt']:
133 | for checkpoint_fname in os.listdir(args.save_dir):
134 | checkpoint_path = Path(args.save_dir) / checkpoint_fname
135 | logger.info(f"evaluating checkpoint file {checkpoint_path}")
136 | res=eval(args, False, checkpoint_path, logger)
137 | print('-----------------------------------------------------------------result: ',res)
138 | res_lis.append(res)
139 | print(np.max(res_lis),np.argmax(res_lis),np.min(res_lis),np.argmin(res_lis))
140 |
141 |
142 | if __name__ == '__main__':
143 | main()
144 |
--------------------------------------------------------------------------------
/deepgraph/data/dataset.py:
--------------------------------------------------------------------------------
1 | from functools import lru_cache
2 |
3 | import ogb
4 | import numpy as np
5 | import torch
6 | from torch.nn import functional as F
7 | from fairseq.data import data_utils, FairseqDataset, BaseWrapperDataset
8 | import pkg_resources
9 | from .wrapper import MyPygGraphPropPredDataset
10 | from .collator import collator,collatorcontrust,collator_adj,collator_node_label,collator_adj_cache
11 |
12 | from typing import Optional, Union
13 | from torch_geometric.data import Data as PYGDataset
14 | from dgl.data import DGLDataset
15 | from .pyg_datasets import PYGDatasetLookupTable
16 | from .substructure_dataset import SubstructureDataset
17 | from .ogb_datasets import OGBDatasetLookupTable
18 |
19 |
20 |
21 | class BatchedDataDataset(FairseqDataset):
22 | def __init__(
23 | self, dataset, max_node=128, multi_hop_max_dist=5, spatial_pos_max=1024
24 | ):
25 | super().__init__()
26 | self.dataset = dataset
27 | self.max_node = max_node
28 | self.multi_hop_max_dist = multi_hop_max_dist
29 | self.spatial_pos_max = spatial_pos_max
30 |
31 | def __getitem__(self, index):
32 | item = self.dataset[int(index)]
33 | return item
34 |
35 | def __len__(self):
36 | return len(self.dataset)
37 |
38 | def collater(self, samples):
39 | return collator(
40 | samples,
41 | max_node=self.max_node,
42 | multi_hop_max_dist=self.multi_hop_max_dist,
43 | spatial_pos_max=self.spatial_pos_max,
44 | )
45 |
46 | class BatchedDataDataset_Substructure(FairseqDataset):
47 | def __init__(
48 | self, dataset, max_node=128, multi_hop_max_dist=5, spatial_pos_max=1024
49 | ):
50 | super().__init__()
51 | self.dataset = dataset
52 | self.max_node = max_node
53 | self.multi_hop_max_dist = multi_hop_max_dist
54 | self.spatial_pos_max = spatial_pos_max
55 |
56 | def __getitem__(self, index):
57 | item = self.dataset[int(index)]
58 | return item
59 |
60 | def __len__(self):
61 | return len(self.dataset)
62 |
63 | def collater(self, samples):
64 | return collator_adj(
65 | samples,
66 | max_node=self.max_node,
67 | multi_hop_max_dist=self.multi_hop_max_dist,
68 | spatial_pos_max=self.spatial_pos_max,
69 | )
70 |
71 |
72 | class BatchedDataDataset_Substructure_Cache(BatchedDataDataset_Substructure):
73 | def collater(self, samples):
74 | return collator_adj_cache(
75 | samples,
76 | max_node=self.max_node,
77 | multi_hop_max_dist=self.multi_hop_max_dist,
78 | spatial_pos_max=self.spatial_pos_max,
79 | )
80 |
81 |
82 | class BatchedDataContrustDataset(FairseqDataset):
83 | def __init__(
84 | self, dataset, max_node=128, multi_hop_max_dist=5, spatial_pos_max=1024
85 | ):
86 | super().__init__()
87 | self.dataset = dataset
88 | self.max_node = max_node
89 | self.multi_hop_max_dist = multi_hop_max_dist
90 | self.spatial_pos_max = spatial_pos_max
91 |
92 | def __getitem__(self, index):
93 | item = self.dataset[int(index)]
94 | return item
95 |
96 | def __len__(self):
97 | return len(self.dataset)
98 |
99 | def collater(self, samples):
100 | return collatorcontrust(
101 | samples,
102 | max_node=self.max_node,
103 | multi_hop_max_dist=self.multi_hop_max_dist,
104 | spatial_pos_max=self.spatial_pos_max,
105 | )
106 |
107 |
108 | class TargetDataset(FairseqDataset):
109 | def __init__(self, dataset,node_level_task=False):
110 | super().__init__()
111 | self.dataset = dataset
112 | self.node_level_task = node_level_task
113 |
114 | @lru_cache(maxsize=16)
115 | def __getitem__(self, index):
116 | return self.dataset[index].y
117 |
118 | def __len__(self):
119 | return len(self.dataset)
120 |
121 | def collater(self, samples):
122 | if self.node_level_task:
123 | return collator_node_label(samples)
124 | else:
125 | return torch.stack(samples, dim=0)
126 |
127 |
128 | class GraphormerDataset:
129 | def __init__(
130 | self,
131 | dataset: Optional[Union[PYGDataset, DGLDataset]] = None,
132 | dataset_spec: Optional[str] = None,
133 | dataset_source: Optional[str] = None,
134 | seed: int = 0,
135 | train_idx = None,
136 | valid_idx = None,
137 | test_idx = None,
138 | args=None,
139 | **kwargs
140 | ):
141 | super().__init__()
142 | if dataset is not None:
143 | if dataset_source == "pyg":
144 | self.dataset = SubstructureDataset(dataset, train_idx=train_idx, valid_idx=valid_idx, test_idx=test_idx)
145 | else:
146 | raise ValueError("customized dataset can only have source pyg")
147 |
148 | elif dataset_source == "pyg":
149 | self.dataset = PYGDatasetLookupTable.GetPYGDataset(dataset_spec, seed=seed,args=args,**kwargs)
150 | elif dataset_source == "ogb":
151 | self.dataset = OGBDatasetLookupTable.GetOGBDataset(dataset_spec, seed=seed,args=args,**kwargs)
152 | else:
153 | raise ValueError("dataset_source not implied")
154 | self.setup()
155 |
156 | def setup(self):
157 | self.train_idx = self.dataset.train_idx
158 | self.valid_idx = self.dataset.valid_idx
159 | self.test_idx = self.dataset.test_idx
160 |
161 | self.dataset_train = self.dataset.train_data
162 | self.dataset_val = self.dataset.valid_data
163 | self.dataset_test = self.dataset.test_data
164 |
165 |
166 |
167 | class EpochShuffleDataset(BaseWrapperDataset):
168 | def __init__(self, dataset, num_samples, seed):
169 | super().__init__(dataset)
170 | self.num_samples = num_samples
171 | self.seed = seed
172 | self.set_epoch(1)
173 |
174 | def set_epoch(self, epoch):
175 | with data_utils.numpy_seed(self.seed + epoch - 1):
176 | self.sort_order = np.random.permutation(self.num_samples)
177 |
178 | def ordered_indices(self):
179 | return self.sort_order
180 |
181 | @property
182 | def can_reuse_epoch_itr_across_epochs(self):
183 | return False
184 |
185 |
186 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Are More Layers Beneficial to Graph Transformers?
2 |
3 |
4 | ## Introduction
5 |
6 | This is the code of our work [Are More Layers Beneficial to Graph Transformers?](https://openreview.net/pdf?id=uagC-X9XMi8) published on ICLR 2023.
7 |
8 |

9 |
10 |
11 |
12 |
13 |
14 | ## Installation
15 | To run DeepGraph, please clone the repository to your local machine and install the required dependencies using the script provided.
16 | #### Note
17 | Please note that we use CUDA 10.2 and python 3.7. If you are using a different version of CUDA or python, please adjust the package version as necessary.
18 | #### Environment
19 | ```
20 | conda create -n DeepGraph python=3.7
21 |
22 | source activate DeepGraph
23 |
24 | pip3 install torch==1.9.0+cu102 torchaudio torchvision -f https://download.pytorch.org/whl/cu102/torch_stable.html --user
25 |
26 | wget https://data.pyg.org/whl/torch-1.9.0%2Bcu102/torch_sparse-0.6.12-cp37-cp37m-linux_x86_64.whl
27 | wget https://data.pyg.org/whl/torch-1.9.0%2Bcu102/torch_scatter-2.0.9-cp37-cp37m-linux_x86_64.whl
28 | wget https://data.pyg.org/whl/torch-1.9.0%2Bcu102/torch_spline_conv-1.2.1-cp37-cp37m-linux_x86_64.whl
29 | wget https://data.pyg.org/whl/torch-1.9.0%2Bcu102/torch_cluster-1.5.9-cp37-cp37m-linux_x86_64.whl
30 | pip install torch_sparse-0.6.12-cp37-cp37m-linux_x86_64.whl
31 | pip install torch_scatter-2.0.9-cp37-cp37m-linux_x86_64.whl
32 | pip install torch_spline_conv-1.2.1-cp37-cp37m-linux_x86_64.whl
33 | pip install torch_cluster-1.5.9-cp37-cp37m-linux_x86_64.whl
34 |
35 | pip install torch-geometric==1.7.2
36 |
37 | conda install -c conda-forge/label/cf202003 graph-tool
38 |
39 | pip install lmdb
40 | pip install tensorboardX==2.4.1
41 | pip install ogb==1.3.2
42 | pip install rdkit-pypi==2021.9.3
43 | pip install dgl==0.7.2 -f https://data.dgl.ai/wheels/repo.html
44 | pip install tqdm
45 | pip install wandb
46 | pip install networkx
47 | pip install setuptools==59.5.0
48 | pip install multiprocess
49 |
50 | git clone -b 0.12.2-release https://github.com/facebookresearch/fairseq
51 | cd fairseq
52 | pip install ./
53 | python setup.py build_ext --inplace
54 | cd ..
55 |
56 | ```
57 |
58 |
59 | ## Run the Application
60 |
61 | We provide the training scripts for ZINC, CLUSTER, PATTERN and PCQM4M-LSC.
62 |
63 |
64 | #### Script for ZINC
65 | ```
66 | CUDA_VISIBLE_DEVICES=0 python train.py --user-dir ./deepgraph --save-dir ckpts/zinc --ddp-backend=legacy_ddp --dataset-name zinc --dataset-source pyg --data-dir dataset/ --task graph_prediction_substructure --id-type cycle_graph+path_graph+star_graph+k_neighborhood --ks [8,4,6,2] --sampling-redundancy 6 --valid-on-test --criterion l1_loss --arch graphormer_slim --deepnorm --num-classes 1 --num-workers 16 --attention-dropout 0.1 --act-dropout 0.1 --dropout 0.0 --optimizer adam --adam-betas '(0.9,0.999)' --adam-eps 1e-8 --clip-norm 5.0 --weight-decay 0.01 --lr-scheduler polynomial_decay --power 1 --warmup-updates 640000 --total-num-update 2560000 --lr 2e-4 --end-learning-rate 1e-6 --batch-size 16 --fp16 --data-buffer-size 20 --encoder-layers 48 --encoder-embed-dim 80 --encoder-ffn-embed-dim 80 --encoder-attention-heads 8 --max-epoch 10000 --keep-best-checkpoints 2 --keep-last-epochs 3
67 |
68 | ```
69 |
70 | #### Script for CLUSTER
71 |
72 | ```
73 | CUDA_VISIBLE_DEVICES=0 python train.py --user-dir ./deepgraph --save-dir ckpts/CLUSTER --ddp-backend=legacy_ddp --dataset-name CLUSTER --dataset-source pyg --node-level-task --data-dir dataset/ --task graph_prediction_substructure --id-type random_walk --ks [10] --sampling-redundancy 2 --valid-on-test --criterion node_multiclass_cross_entropy --arch graphormer_slim --deepnorm --num-classes 6 --num-workers 16 --attention-dropout 0.4 --act-dropout 0.4 --dropout 0.4 --optimizer adam --adam-betas '(0.9,0.999)' --adam-eps 1e-8 --clip-norm 5.0 --weight-decay 0.01 --lr-scheduler polynomial_decay --power 1 --warmup-updates 300 --total-num-update 40000 --lr 5e-4 --end-learning-rate 5e-4 --batch-size 64 --fp16 --data-buffer-size 20 --encoder-layers 48 --encoder-embed-dim 48 --encoder-ffn-embed-dim 96 --encoder-attention-heads 8 --max-nodes 1024 --max-epoch 10000 --keep-best-checkpoints 2 --keep-last-epochs 3 --best-checkpoint-metric accuracy --maximize-best-checkpoint-metric
74 |
75 | ```
76 |
77 | #### Script for PATTERN
78 |
79 | ```
80 | CUDA_VISIBLE_DEVICES=0 python train.py --user-dir ./deepgraph --save-dir ckpts/PATTERN --ddp-backend=legacy_ddp --dataset-name PATTERN --dataset-source pyg --node-level-task --data-dir dataset/ --task graph_prediction_substructure --id-type random_walk --ks [10] --sampling-redundancy 2 --valid-on-test --criterion node_multiclass_cross_entropy --arch graphormer_slim --deepnorm --num-classes 2 --num-workers 16 --attention-dropout 0.1 --act-dropout 0.1 --dropout 0.0 --optimizer adam --adam-betas '(0.9,0.999)' --adam-eps 1e-8 --clip-norm 5.0 --weight-decay 0.01 --lr-scheduler polynomial_decay --power 1 --warmup-updates 6000 --total-num-update 40000 --lr 2e-4 --end-learning-rate 1e-6 --batch-size 64 --fp16 --data-buffer-size 20 --encoder-layers 48 --encoder-embed-dim 80 --encoder-ffn-embed-dim 80 --encoder-attention-heads 8 --max-nodes 512 --max-epoch 10000 --keep-best-checkpoints 2 --keep-last-epochs 3 --best-checkpoint-metric accuracy --maximize-best-checkpoint-metric
81 |
82 | ```
83 |
84 | #### Script for PCQM4M
85 | ```
86 | CUDA_VISIBLE_DEVICES=0 python train.py --user-dir ./deepgraph --save-dir ckpts/pcqm4m --ddp-backend=legacy_ddp --dataset-name pcqm4m --dataset-source ogb --data-dir dataset/ --task graph_prediction_substructure --id-type cycle_graph+path_graph+star_graph+k_neighborhood --ks [8,4,6,2] --sampling-redundancy 6 --criterion l1_loss --arch graphormer_base --deepnorm --num-classes 1 --num-workers 16 --attention-dropout 0.1 --act-dropout 0.1 --dropout 0.0 --optimizer adam --adam-betas '(0.9, 0.999)' --adam-eps 1e-8 --clip-norm 5.0 --weight-decay 0.01 --lr-scheduler polynomial_decay --power 1 --warmup-updates 50000 --total-num-update 500000 --lr 1e-4 --end-learning-rate 1e-5 --batch-size 100 --fp16 --data-buffer-size 20 --encoder-layers 48 --encoder-embed-dim 768 --encoder-ffn-embed-dim 768 --encoder-attention-heads 32 --max-epoch 10000 --keep-best-checkpoints 2 --keep-last-epochs 3
87 |
88 | ```
89 |
90 | ## Guidence for Further Development
91 |
92 | Our code consists of several folds, and it is compatible with other Fairseq-based frameworks. You can easily integrate DeepGraph with other Fairseq frameworks by merging the folders. Additionally, you can apply further development to our code by adding corresponding modules.
93 | ```
94 | deepgraph
95 | ├─criterions
96 | ├─data
97 | │ ├─ogb_datasets
98 | │ ├─pyg_datasets
99 | │ ├─subsampling
100 | │ └─substructure_dataset_utils
101 | ├─evaluate
102 | ├─models
103 | ├─modules
104 | ├─pretrain
105 | └─tasks
106 | ```
107 |
108 |
109 | ## Citation
110 |
111 | Please kindly cite this paper if our work is useful:
112 | ```
113 | @inproceedings{zhaomore,
114 | title={Are More Layers Beneficial to Graph Transformers?},
115 | author={Zhao, Haiteng and Ma, Shuming and Zhang, Dongdong and Deng, Zhi-Hong and Wei, Furu},
116 | booktitle={International Conference on Learning Representations},
117 | year={2022}
118 | }
119 | ```
120 |
121 |
--------------------------------------------------------------------------------
/deepgraph/modules/deepgraph_graph_encoder_layer.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 |
4 | # Copyright (c) Facebook, Inc. and its affiliates.
5 | #
6 | # This source code is licensed under the MIT license found in the
7 | # LICENSE file in the root directory of this source tree.
8 |
9 | from typing import Callable, Optional
10 |
11 | import torch
12 | import torch.nn as nn
13 | from fairseq import utils
14 | from fairseq.modules import LayerNorm
15 | from fairseq.modules.fairseq_dropout import FairseqDropout
16 | from fairseq.modules.quant_noise import quant_noise
17 |
18 | from .multihead_attention import MultiheadAttention,MultiheadAttentionRe
19 | import math
20 |
21 | class DeepGraphGraphEncoderLayer(nn.Module):
22 | def __init__(
23 | self,
24 | embedding_dim: int = 768,
25 | ffn_embedding_dim: int = 3072,
26 | num_attention_heads: int = 8,
27 | dropout: float = 0.1,
28 | attention_dropout: float = 0.1,
29 | activation_dropout: float = 0.1,
30 | activation_fn: str = "relu",
31 | export: bool = False,
32 | q_noise: float = 0.0,
33 | qn_block_size: int = 8,
34 | init_fn: Callable = None,
35 | pre_layernorm: bool = False,
36 | deepnorm: bool = False,
37 | encoder_layers: int = 12,
38 | reattention:bool=False
39 | ) -> None:
40 | super().__init__()
41 |
42 | if init_fn is not None:
43 | init_fn()
44 |
45 | # Initialize parameters
46 | self.embedding_dim = embedding_dim
47 | self.num_attention_heads = num_attention_heads
48 | self.attention_dropout = attention_dropout
49 | self.q_noise = q_noise
50 | self.qn_block_size = qn_block_size
51 | self.pre_layernorm = pre_layernorm
52 |
53 | self.dropout_module = FairseqDropout(
54 | dropout, module_name=self.__class__.__name__
55 | )
56 | self.activation_dropout_module = FairseqDropout(
57 | activation_dropout, module_name=self.__class__.__name__
58 | )
59 |
60 | # Initialize blocks
61 | self.activation_fn = utils.get_activation_fn(activation_fn)
62 | self.self_attn = self.build_self_attention(
63 | self.embedding_dim,
64 | num_attention_heads,
65 | dropout=attention_dropout,
66 | self_attention=True,
67 | q_noise=q_noise,
68 | qn_block_size=qn_block_size,
69 | reattention=reattention
70 | )
71 |
72 | # layer norm associated with the self attention layer
73 | self.self_attn_layer_norm = LayerNorm(self.embedding_dim, export=export)
74 |
75 | self.fc1 = self.build_fc1(
76 | self.embedding_dim,
77 | ffn_embedding_dim,
78 | q_noise=q_noise,
79 | qn_block_size=qn_block_size,
80 | )
81 | self.fc2 = self.build_fc2(
82 | ffn_embedding_dim,
83 | self.embedding_dim,
84 | q_noise=q_noise,
85 | qn_block_size=qn_block_size,
86 | )
87 |
88 | # layer norm associated with the position wise feed-forward NN
89 | self.final_layer_norm = LayerNorm(self.embedding_dim, export=export)
90 |
91 |
92 | if deepnorm:
93 | # self.shortcut_scale = math.pow(math.pow(cfg.encoder_layers, 4) * cfg.decoder_layers, 0.0625) * 0.81
94 | # if utils.safe_getattr(cfg, "deepnorm_encoder_only", False):
95 | self.shortcut_scale = math.pow(2.0 * encoder_layers, 0.25)
96 | # print('deepnorm shortcut_scale ',self.shortcut_scale)
97 | else:
98 | self.shortcut_scale = 1.0
99 |
100 | if deepnorm:
101 | # self.fixup_scale = math.pow(math.pow(cfg.encoder_layers, 4) * cfg.decoder_layers, 0.0625) / 1.15
102 | # if utils.safe_getattr(cfg, "deepnorm_encoder_only", False):
103 | self.fixup_scale = math.pow(8.0 * encoder_layers, 0.25)
104 | # print('deepnorm fixup_scale ', self.fixup_scale)
105 | self.deepnorm_init()
106 |
107 | def deepnorm_init(self):
108 | def rescale(param):
109 | param.div_(self.fixup_scale)
110 |
111 | rescale(self.self_attn.v_proj.weight.data)
112 | rescale(self.self_attn.v_proj.bias.data)
113 | rescale(self.self_attn.out_proj.weight.data)
114 | rescale(self.self_attn.out_proj.bias.data)
115 |
116 | rescale(self.fc1.weight.data)
117 | rescale(self.fc2.weight.data)
118 | rescale(self.fc1.bias.data)
119 | rescale(self.fc2.bias.data)
120 | # print('deepnorm init')
121 |
122 | def build_fc1(self, input_dim, output_dim, q_noise, qn_block_size):
123 | return quant_noise(nn.Linear(input_dim, output_dim), q_noise, qn_block_size)
124 |
125 | def build_fc2(self, input_dim, output_dim, q_noise, qn_block_size):
126 | return quant_noise(nn.Linear(input_dim, output_dim), q_noise, qn_block_size)
127 |
128 | def build_self_attention(
129 | self,
130 | embed_dim,
131 | num_attention_heads,
132 | dropout,
133 | self_attention,
134 | q_noise,
135 | qn_block_size,
136 | reattention=False
137 | ):
138 | if not reattention:
139 | return MultiheadAttention(
140 | embed_dim,
141 | num_attention_heads,
142 | dropout=dropout,
143 | self_attention=True,
144 | q_noise=q_noise,
145 | qn_block_size=qn_block_size,
146 | )
147 | else:
148 | return MultiheadAttentionRe(
149 | embed_dim,
150 | num_attention_heads,
151 | dropout=dropout,
152 | self_attention=True,
153 | q_noise=q_noise,
154 | qn_block_size=qn_block_size,
155 | )
156 |
157 |
158 | def forward(
159 | self,
160 | x: torch.Tensor,
161 | self_attn_bias: Optional[torch.Tensor] = None,
162 | self_attn_mask: Optional[torch.Tensor] = None,
163 | self_attn_padding_mask: Optional[torch.Tensor] = None,
164 | ):
165 | """
166 | LayerNorm is applied either before or after the self-attention/ffn
167 | modules similar to the original Transformer implementation.
168 | """
169 | # x: T x B x C
170 | residual = x
171 | if self.pre_layernorm:
172 | x = self.self_attn_layer_norm(x)
173 | x, attn = self.self_attn(
174 | query=x,
175 | key=x,
176 | value=x,
177 | attn_bias=self_attn_bias,
178 | key_padding_mask=self_attn_padding_mask,
179 | need_weights=True,
180 | attn_mask=self_attn_mask,
181 | )
182 | x = self.dropout_module(x)
183 | x = residual* self.shortcut_scale + x
184 | if not self.pre_layernorm:
185 | x = self.self_attn_layer_norm(x)
186 |
187 | residual = x
188 | if self.pre_layernorm:
189 | x = self.final_layer_norm(x)
190 | x = self.activation_fn(self.fc1(x))
191 | x = self.activation_dropout_module(x)
192 | x = self.fc2(x)
193 | x = self.dropout_module(x)
194 | x = residual * self.shortcut_scale + x
195 | if not self.pre_layernorm:
196 | x = self.final_layer_norm(x)
197 | return x, attn
198 |
--------------------------------------------------------------------------------
/deepgraph/modules/deepgraph_layers.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 |
4 | # Copyright (c) Facebook, Inc. and its affiliates.
5 | #
6 | # This source code is licensed under the MIT license found in the
7 | # LICENSE file in the root directory of this source tree.
8 |
9 | import math
10 |
11 | import torch
12 | import torch.nn as nn
13 | from torchvision.models import resnet18
14 |
15 | def init_params(module, n_layers):
16 | if isinstance(module, nn.Linear):
17 | module.weight.data.normal_(mean=0.0, std=0.02 / math.sqrt(n_layers))
18 | if module.bias is not None:
19 | module.bias.data.zero_()
20 | if isinstance(module, nn.Embedding):
21 | module.weight.data.normal_(mean=0.0, std=0.02)
22 |
23 |
24 | class GraphNodeFeature(nn.Module):
25 | """
26 | Compute node features for each node in the graph.
27 | """
28 |
29 | def __init__(
30 | self, num_heads, num_atoms, num_in_degree, num_out_degree, hidden_dim, n_layers
31 | ):
32 | super(GraphNodeFeature, self).__init__()
33 | self.num_heads = num_heads
34 | self.num_atoms = num_atoms
35 |
36 | # 1 for graph token
37 | self.atom_encoder = nn.Embedding(num_atoms + 1, hidden_dim, padding_idx=0)
38 | self.in_degree_encoder = nn.Embedding(num_in_degree, hidden_dim, padding_idx=0)
39 | self.out_degree_encoder = nn.Embedding(
40 | num_out_degree, hidden_dim, padding_idx=0
41 | )
42 |
43 | self.graph_token = nn.Embedding(1, hidden_dim)
44 |
45 | self.apply(lambda module: init_params(module, n_layers=n_layers))
46 |
47 | def forward(self, batched_data):
48 | x, in_degree, out_degree = (
49 | batched_data["x"],
50 | batched_data["in_degree"],
51 | batched_data["out_degree"],
52 | )
53 | n_graph, n_node = x.size()[:2]
54 |
55 | # node feauture + graph token
56 | node_feature = self.atom_encoder(x).sum(dim=-2) # [n_graph, n_node, n_hidden]
57 |
58 | # if self.flag and perturb is not None:
59 | # node_feature += perturb
60 |
61 | node_feature = (
62 | node_feature
63 | + self.in_degree_encoder(in_degree)
64 | + self.out_degree_encoder(out_degree)
65 | )
66 |
67 | graph_token_feature = self.graph_token.weight.unsqueeze(0).repeat(n_graph, 1, 1)
68 |
69 | graph_node_feature = torch.cat([graph_token_feature, node_feature], dim=1)
70 |
71 | return graph_node_feature
72 |
73 |
74 | class GraphAttnBias(nn.Module):
75 | """
76 | Compute attention bias for each head.
77 | """
78 |
79 | def __init__(
80 | self,
81 | num_heads,
82 | num_atoms,
83 | num_edges,
84 | num_spatial,
85 | num_edge_dis,
86 | hidden_dim,
87 | edge_type,
88 | multi_hop_max_dist,
89 | n_layers,
90 | ):
91 | super(GraphAttnBias, self).__init__()
92 | self.num_heads = num_heads
93 | self.multi_hop_max_dist = multi_hop_max_dist
94 |
95 | self.edge_encoder = nn.Embedding(num_edges + 1, num_heads, padding_idx=0)
96 | self.edge_type = edge_type
97 | if self.edge_type == "multi_hop":
98 | self.edge_dis_encoder = nn.Embedding(
99 | num_edge_dis * num_heads * num_heads, 1
100 | )
101 | self.spatial_pos_encoder = nn.Embedding(num_spatial, num_heads, padding_idx=0)
102 |
103 | self.graph_token_virtual_distance = nn.Embedding(1, num_heads)
104 |
105 | self.apply(lambda module: init_params(module, n_layers=n_layers))
106 |
107 | def forward(self, batched_data):
108 | attn_bias, spatial_pos, x = (
109 | batched_data["attn_bias"],
110 | batched_data["spatial_pos"],
111 | batched_data["x"],
112 | )
113 | # in_degree, out_degree = batched_data.in_degree, batched_data.in_degree
114 | edge_input, attn_edge_type = (
115 | batched_data["edge_input"],
116 | batched_data["attn_edge_type"],
117 | )
118 |
119 | n_graph, n_node = x.size()[:2]
120 | graph_attn_bias = attn_bias.clone()
121 | graph_attn_bias = graph_attn_bias.unsqueeze(1).repeat(
122 | 1, self.num_heads, 1, 1
123 | ) # [n_graph, n_head, n_node+1, n_node+1]
124 |
125 | # spatial pos
126 | # [n_graph, n_node, n_node, n_head] -> [n_graph, n_head, n_node, n_node]
127 | spatial_pos_bias = self.spatial_pos_encoder(spatial_pos).permute(0, 3, 1, 2)
128 | graph_attn_bias[:, :, 1:, 1:] = graph_attn_bias[:, :, 1:, 1:] + spatial_pos_bias
129 |
130 | # reset spatial pos here
131 | t = self.graph_token_virtual_distance.weight.view(1, self.num_heads, 1)
132 | graph_attn_bias[:, :, 1:, 0] = graph_attn_bias[:, :, 1:, 0] + t
133 | graph_attn_bias[:, :, 0, :] = graph_attn_bias[:, :, 0, :] + t
134 |
135 | # edge feature
136 | if self.edge_type == "multi_hop":
137 | spatial_pos_ = spatial_pos.clone()
138 | spatial_pos_[spatial_pos_ == 0] = 1 # set pad to 1
139 | # set 1 to 1, x > 1 to x - 1
140 | spatial_pos_ = torch.where(spatial_pos_ > 1, spatial_pos_ - 1, spatial_pos_)
141 | if self.multi_hop_max_dist > 0:
142 | spatial_pos_ = spatial_pos_.clamp(0, self.multi_hop_max_dist)
143 | edge_input = edge_input[:, :, :, : self.multi_hop_max_dist, :]
144 | # [n_graph, n_node, n_node, max_dist, n_head]
145 | edge_input = self.edge_encoder(edge_input).mean(-2)
146 | max_dist = edge_input.size(-2)
147 | edge_input_flat = edge_input.permute(3, 0, 1, 2, 4).reshape(
148 | max_dist, -1, self.num_heads
149 | )
150 | edge_input_flat = torch.bmm(
151 | edge_input_flat,
152 | self.edge_dis_encoder.weight.reshape(
153 | -1, self.num_heads, self.num_heads
154 | )[:max_dist, :, :],
155 | )
156 | edge_input = edge_input_flat.reshape(
157 | max_dist, n_graph, n_node, n_node, self.num_heads
158 | ).permute(1, 2, 3, 0, 4)
159 | edge_input = (
160 | edge_input.sum(-2) / (spatial_pos_.float().unsqueeze(-1))
161 | ).permute(0, 3, 1, 2)
162 | else:
163 | # [n_graph, n_node, n_node, n_head] -> [n_graph, n_head, n_node, n_node]
164 | edge_input = self.edge_encoder(attn_edge_type).mean(-2).permute(0, 3, 1, 2)
165 |
166 | graph_attn_bias[:, :, 1:, 1:] = graph_attn_bias[:, :, 1:, 1:] + edge_input
167 | graph_attn_bias = graph_attn_bias + attn_bias.unsqueeze(1) # reset
168 |
169 | return graph_attn_bias
170 |
171 |
172 | class SubStructure_Adj_Encoder(nn.Module):
173 | def __init__(self,subgraph_max_size,output_size):
174 | super(SubStructure_Adj_Encoder,self).__init__()
175 | self.subgraph_max_size=subgraph_max_size
176 | self.output_size=output_size
177 |
178 | self.encoder=nn.Sequential(nn.Linear(subgraph_max_size*subgraph_max_size,output_size),
179 | nn.ReLU(inplace=True),
180 | nn.Linear(output_size, output_size))
181 |
182 |
183 | def forward(self,sorted_adj):
184 | input=sorted_adj.reshape([-1,self.subgraph_max_size*self.subgraph_max_size])
185 | output=self.encoder(input)
186 |
187 | result=output.reshape([sorted_adj.shape[0],sorted_adj.shape[1],-1])
188 | return result
--------------------------------------------------------------------------------
/deepgraph/data/pyg_datasets/pyg_dataset_lookup_table.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 |
4 | from typing import Optional
5 | import graph_tool
6 | from torch_geometric.datasets import *
7 | from torch_geometric.data import Dataset
8 | from ..substructure_dataset import SubstructureDataset
9 | import torch.distributed as dist
10 | import torch
11 |
12 | class MyQM7b(QM7b):
13 | def download(self):
14 | if not dist.is_initialized() or dist.get_rank() == 0:
15 | super(MyQM7b, self).download()
16 | if dist.is_initialized():
17 | dist.barrier()
18 |
19 | def process(self):
20 | if not dist.is_initialized() or dist.get_rank() == 0:
21 | super(MyQM7b, self).process()
22 | if dist.is_initialized():
23 | dist.barrier()
24 |
25 |
26 | class MyQM9(QM9):
27 | def download(self):
28 | if not dist.is_initialized() or dist.get_rank() == 0:
29 | super(MyQM9, self).download()
30 | if dist.is_initialized():
31 | dist.barrier()
32 |
33 | def process(self):
34 | if not dist.is_initialized() or dist.get_rank() == 0:
35 | super(MyQM9, self).process()
36 | if dist.is_initialized():
37 | dist.barrier()
38 |
39 | class MyZINC(ZINC):
40 | def download(self):
41 | if not dist.is_initialized() or dist.get_rank() == 0:
42 | super(MyZINC, self).download()
43 | if dist.is_initialized():
44 | dist.barrier()
45 |
46 | def process(self):
47 | if not dist.is_initialized() or dist.get_rank() == 0:
48 | super(MyZINC, self).process()
49 | if dist.is_initialized():
50 | dist.barrier()
51 |
52 | class MYZINC_UNI(torch.utils.data.Dataset):
53 | def __init__(self,data_list):
54 | self.data_list=data_list
55 |
56 | def __len__(self):
57 | return len(self.data_list)
58 |
59 | def __getitem__(self, idx):
60 | data=self.data_list[idx]
61 | return data
62 |
63 |
64 | class MyMoleculeNet(MoleculeNet):
65 | def download(self):
66 | if not dist.is_initialized() or dist.get_rank() == 0:
67 | super(MyMoleculeNet, self).download()
68 | if dist.is_initialized():
69 | dist.barrier()
70 |
71 | def process(self):
72 | if not dist.is_initialized() or dist.get_rank() == 0:
73 | super(MyMoleculeNet, self).process()
74 | if dist.is_initialized():
75 | dist.barrier()
76 |
77 |
78 |
79 |
80 | class PYGDatasetLookupTable:
81 | @staticmethod
82 | def GetPYGDataset(dataset_spec: str, seed: int,args=None,**kwargs) -> Optional[Dataset]:
83 | split_result = dataset_spec.split(":")
84 | if len(split_result) == 2:
85 | name, params = split_result[0], split_result[1]
86 | params = params.split(",")
87 | elif len(split_result) == 1:
88 | name = dataset_spec
89 | params = []
90 | inner_dataset = None
91 | num_class = 1
92 |
93 | train_set = None
94 | valid_set = None
95 | test_set = None
96 |
97 | root = "dataset"
98 | if name == "qm7b":
99 | inner_dataset = MyQM7b(root=kwargs['data_dir'])
100 | elif name == "qm9":
101 | inner_dataset = MyQM9(root=kwargs['data_dir'])
102 | elif name == "zinc_full":
103 | inner_dataset = MyZINC(root=kwargs['data_dir'],subset=False)
104 | train_set = MyZINC(root=kwargs['data_dir'],subset=False, split="train")
105 | valid_set = MyZINC(root=kwargs['data_dir'],subset=False, split="val")
106 | test_set = MyZINC(root=kwargs['data_dir'],subset=False, split="test")
107 | elif name == "zinc":
108 | inner_dataset = MyZINC(root=kwargs['data_dir'],subset=True)
109 | train_set = MyZINC(root=kwargs['data_dir'],subset=True, split="train")
110 | valid_set = MyZINC(root=kwargs['data_dir'],subset=True, split="val")
111 | test_set = MyZINC(root=kwargs['data_dir'],subset=True, split="test")
112 | elif name == "zinc_val_on_test":
113 | inner_dataset = MyZINC(root=kwargs['data_dir'],subset=True)
114 | train_set = MyZINC(root=kwargs['data_dir'],subset=True, split="train")
115 | valid_set = MyZINC(root=kwargs['data_dir'],subset=True, split="test")
116 | test_set = MyZINC(root=kwargs['data_dir'],subset=True, split="test")
117 | elif name == 'zinc_uniform':
118 | inner_dataset = MyZINC(root=kwargs['data_dir'],subset=True)
119 | train_set = MyZINC(root=kwargs['data_dir'],subset=True, split="train")
120 | valid_set = MyZINC(root=kwargs['data_dir'],subset=True, split="val")
121 | test_set = MyZINC(root=kwargs['data_dir'],subset=True, split="test")
122 | total_data=[]
123 | for i in range(len(train_set)):
124 | total_data.append(train_set[i])
125 | for i in range(len(valid_set)):
126 | total_data.append(valid_set[i])
127 | for i in range(len(test_set)):
128 | total_data.append(test_set[i])
129 | train_data,valid_data,test_data=torch.utils.data.random_split(total_data, [len(train_set), len(valid_set),len(test_set)])
130 | train_set=MYZINC_UNI(train_data)
131 | valid_set = MYZINC_UNI(valid_data)
132 | test_set = MYZINC_UNI(test_data)
133 | inner_dataset=train_set
134 | elif name == 'zinc_uniform_val_on_test':
135 | inner_dataset = MyZINC(root=kwargs['data_dir'],subset=True)
136 | train_set = MyZINC(root=kwargs['data_dir'],subset=True, split="train")
137 | valid_set = MyZINC(root=kwargs['data_dir'],subset=True, split="val")
138 | test_set = MyZINC(root=kwargs['data_dir'],subset=True, split="test")
139 | total_data=[]
140 | for i in range(len(train_set)):
141 | total_data.append(train_set[i])
142 | for i in range(len(valid_set)):
143 | total_data.append(valid_set[i])
144 | for i in range(len(test_set)):
145 | total_data.append(test_set[i])
146 | train_data,valid_data,test_data=torch.utils.data.random_split(total_data, [len(train_set), len(valid_set),len(test_set)])
147 | train_set=MYZINC_UNI(train_data)
148 | valid_set = MYZINC_UNI(test_data)
149 | test_set = MYZINC_UNI(test_data)
150 | inner_dataset=train_set
151 | elif name == "moleculenet":
152 | nm = None
153 | for param in params:
154 | name, value = param.split("=")
155 | if name == "name":
156 | nm = value
157 | inner_dataset = MyMoleculeNet(root=kwargs['data_dir'], name=nm)
158 |
159 | elif name in ["CLUSTER","PATTERN"]:
160 | train_set = GNNBenchmarkDataset(root=kwargs['data_dir'], name=name, split='train')
161 | valid_set = GNNBenchmarkDataset(root=kwargs['data_dir'], name=name, split='val')
162 | test_set = GNNBenchmarkDataset(root=kwargs['data_dir'], name=name, split='test')
163 |
164 | else:
165 | raise ValueError(f"Unknown dataset name {name} for pyg source.")
166 |
167 | if args['valid_on_test']:
168 | valid_set = test_set
169 |
170 | if train_set is not None:
171 | result= SubstructureDataset(
172 | None,
173 | seed,
174 | None,
175 | None,
176 | None,
177 | train_set,
178 | valid_set,
179 | test_set,
180 | args['not_re_define'],
181 | args=args
182 | )
183 | else:
184 | result= (
185 | None
186 | if inner_dataset is None
187 | else SubstructureDataset(inner_dataset, seed)
188 | )
189 |
190 | return result
191 |
--------------------------------------------------------------------------------
/deepgraph/data/collator.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 |
4 | import torch
5 |
6 |
7 | def pad_1d_unsqueeze(x, padlen):
8 | x = x + 1 # pad id = 0
9 | xlen = x.size(0)
10 | if xlen < padlen:
11 | new_x = x.new_zeros([padlen], dtype=x.dtype)
12 | new_x[:xlen] = x
13 | x = new_x
14 | return x.unsqueeze(0)
15 |
16 |
17 | def pad_2d_unsqueeze(x, padlen):
18 | x = x + 1 # pad id = 0
19 | xlen, xdim = x.size()
20 | if xlen < padlen:
21 | new_x = x.new_zeros([padlen, xdim], dtype=x.dtype)
22 | new_x[:xlen, :] = x
23 | x = new_x
24 | return x.unsqueeze(0)
25 |
26 | def pad_sub_adjs_unsqueeze(x, padlen):
27 | xlen, xdim1,xdim2 = x.size()
28 | if xlen < padlen:
29 | new_x = x.new_zeros([padlen, xdim1,xdim2], dtype=x.dtype)
30 | new_x[:xlen, :,:] = x
31 | x = new_x
32 | return x.unsqueeze(0)
33 |
34 | def pad_attn_bias_unsqueeze(x, padlen):
35 | xlen = x.size(0)
36 | if xlen < padlen:
37 | new_x = x.new_zeros([padlen, padlen], dtype=x.dtype).fill_(float("-inf"))
38 | new_x[:xlen, :xlen] = x
39 | new_x[xlen:, :xlen] = 0
40 | x = new_x
41 | return x.unsqueeze(0)
42 |
43 |
44 | def pad_edge_type_unsqueeze(x, padlen):
45 | xlen = x.size(0)
46 | if xlen < padlen:
47 | new_x = x.new_zeros([padlen, padlen, x.size(-1)], dtype=x.dtype)
48 | new_x[:xlen, :xlen, :] = x
49 | x = new_x
50 | return x.unsqueeze(0)
51 |
52 |
53 | def pad_spatial_pos_unsqueeze(x, padlen):
54 | x = x + 1
55 | xlen = x.size(0)
56 | if xlen < padlen:
57 | new_x = x.new_zeros([padlen, padlen], dtype=x.dtype)
58 | new_x[:xlen, :xlen] = x
59 | x = new_x
60 | return x.unsqueeze(0)
61 |
62 |
63 | def pad_3d_unsqueeze(x, padlen1, padlen2, padlen3):
64 | x = x + 1
65 | xlen1, xlen2, xlen3, xlen4 = x.size()
66 | if xlen1 < padlen1 or xlen2 < padlen2 or xlen3 < padlen3:
67 | new_x = x.new_zeros([padlen1, padlen2, padlen3, xlen4], dtype=x.dtype)
68 | new_x[:xlen1, :xlen2, :xlen3, :] = x
69 | x = new_x
70 | return x.unsqueeze(0)
71 |
72 | def pad_3d_sequences(xs, padlen1, padlen2, padlen3):
73 | result=torch.zeros([len(xs),padlen1,padlen2,padlen3,xs[0].shape[-1]], dtype=xs[0].dtype)
74 | for i,x in enumerate(xs):
75 | result[i,:x.shape[0], :x.shape[1], :x.shape[2], :]=x+1
76 | return result
77 |
78 |
79 |
80 | def collator(items, max_node=512, multi_hop_max_dist=20, spatial_pos_max=20):
81 | items = [item for item in items if item is not None and item.x.size(0) <= max_node]
82 | items = [
83 | (
84 | item.idx,
85 | item.attn_bias,
86 | item.attn_edge_type,
87 | item.spatial_pos,
88 | item.in_degree,
89 | item.out_degree,
90 | item.x,
91 | item.edge_input[:, :, :multi_hop_max_dist, :],
92 | item.y,
93 | )
94 | for item in items
95 | ]
96 | (
97 | idxs,
98 | attn_biases,
99 | attn_edge_types,
100 | spatial_poses,
101 | in_degrees,
102 | out_degrees,
103 | xs,
104 | edge_inputs,
105 | ys,
106 | ) = zip(*items)
107 |
108 | for idx, _ in enumerate(attn_biases):
109 | attn_biases[idx][1:, 1:][spatial_poses[idx] >= spatial_pos_max] = float("-inf")
110 | max_node_num = max(i.size(0) for i in xs)
111 | max_dist = max(i.size(-2) for i in edge_inputs)
112 | y = torch.cat(ys)
113 | x = torch.cat([pad_2d_unsqueeze(i, max_node_num) for i in xs])
114 | edge_input = torch.cat(
115 | [pad_3d_unsqueeze(i, max_node_num, max_node_num, max_dist) for i in edge_inputs]
116 | )
117 | attn_bias = torch.cat(
118 | [pad_attn_bias_unsqueeze(i, max_node_num + 1) for i in attn_biases]
119 | )
120 | attn_edge_type = torch.cat(
121 | [pad_edge_type_unsqueeze(i, max_node_num) for i in attn_edge_types]
122 | )
123 | spatial_pos = torch.cat(
124 | [pad_spatial_pos_unsqueeze(i, max_node_num) for i in spatial_poses]
125 | )
126 | in_degree = torch.cat([pad_1d_unsqueeze(i, max_node_num) for i in in_degrees])
127 |
128 | return dict(
129 | idx=torch.LongTensor(idxs),
130 | attn_bias=attn_bias,
131 | attn_edge_type=attn_edge_type,
132 | spatial_pos=spatial_pos,
133 | in_degree=in_degree,
134 | out_degree=in_degree, # for undirected graph
135 | x=x,
136 | edge_input=edge_input,
137 | y=y,
138 | )
139 |
140 | def collator_adj(items, max_node=512, multi_hop_max_dist=20, spatial_pos_max=20):
141 | items = [item for item in items if item is not None and item.x.size(0) <= max_node]
142 | items = [
143 | (
144 | item.idx,
145 | item.attn_bias,
146 | item.attn_edge_type,
147 | item.spatial_pos,
148 | item.in_degree,
149 | item.out_degree,
150 | item.x,
151 | item.edge_input[:, :, :multi_hop_max_dist, :],
152 | item.y,
153 | item.sorted_adj,
154 | item.sub_adj_mask
155 | )
156 | for item in items
157 | ]
158 | (
159 | idxs,
160 | attn_biases,
161 | attn_edge_types,
162 | spatial_poses,
163 | in_degrees,
164 | out_degrees,
165 | xs,
166 | edge_inputs,
167 | ys,
168 | sorted_adjs,
169 | sub_adj_masks
170 | ) = zip(*items)
171 |
172 | for idx, _ in enumerate(attn_biases):
173 | attn_biases[idx][1:, 1:][spatial_poses[idx] >= spatial_pos_max] = float("-inf")
174 | max_node_num = max(i.size(0) for i in xs)
175 | max_dist = max(i.size(-2) for i in edge_inputs)
176 | if sorted_adjs[0] is not None:
177 | max_subs = max(i.size(0) for i in sorted_adjs)
178 | y = torch.cat(ys)
179 | x = torch.cat([pad_2d_unsqueeze(i, max_node_num) for i in xs])
180 | # edge_input = torch.cat(
181 | # [pad_3d_unsqueeze(i, max_node_num, max_node_num, max_dist) for i in edge_inputs]
182 | # )
183 | edge_input = pad_3d_sequences(edge_inputs, max_node_num, max_node_num, max_dist)
184 | attn_bias = torch.cat(
185 | [pad_attn_bias_unsqueeze(i, max_node_num + 1) for i in attn_biases]
186 | )
187 | attn_edge_type = torch.cat(
188 | [pad_edge_type_unsqueeze(i, max_node_num) for i in attn_edge_types]
189 | )
190 | spatial_pos = torch.cat(
191 | [pad_spatial_pos_unsqueeze(i, max_node_num) for i in spatial_poses]
192 | )
193 | in_degree = torch.cat([pad_1d_unsqueeze(i, max_node_num) for i in in_degrees])
194 |
195 |
196 | if sorted_adjs[0] is None:
197 | sorted_adj=None
198 | else:
199 | sorted_adj=torch.cat([pad_sub_adjs_unsqueeze(i, max_subs) for i in sorted_adjs])
200 | if sub_adj_masks[0] is None:
201 | sub_adj_mask=None
202 | else:
203 | sub_adj_mask = torch.cat([pad_2d_unsqueeze(i, max_node_num) for i in sub_adj_masks])
204 |
205 | return dict(
206 | idx=torch.LongTensor(idxs),
207 | attn_bias=attn_bias,
208 | attn_edge_type=attn_edge_type,
209 | spatial_pos=spatial_pos,
210 | in_degree=in_degree,
211 | out_degree=in_degree, # for undirected graph
212 | x=x,
213 | edge_input=edge_input,
214 | y=y,
215 | sorted_adj=sorted_adj,
216 | sub_adj_mask=sub_adj_mask
217 | )
218 |
219 | def collator_adj_cache(items, max_node=512, multi_hop_max_dist=20, spatial_pos_max=20):
220 | items = [
221 | (
222 | item.idx,
223 | item.subgraph_tensor_res,
224 | item.sorted_adj_res
225 | )
226 | for item in items
227 | ]
228 |
229 | (
230 | idx,
231 | subgraph_tensor_res,
232 | sorted_adj_res,
233 | ) = zip(*items)
234 |
235 | return dict(
236 | idx=idx,
237 | subgraph_tensor_res=subgraph_tensor_res,
238 | sorted_adj_res=sorted_adj_res,
239 | )
240 |
241 |
242 |
243 | def collator_node_label(labels):
244 | max_node_num = max(i.size(0) for i in labels)
245 | label_padded=torch.cat([pad_1d_unsqueeze(i,max_node_num) for i in labels])
246 |
247 | return label_padded
248 |
249 |
250 |
251 | def collatorcontrust(datas, max_node=512, multi_hop_max_dist=20, spatial_pos_max=20):
252 | datas_data=[a_data.data for a_data in datas]
253 | datas_data1 = [a_data.data1 for a_data in datas]
254 | datas_data2 = [a_data.data2 for a_data in datas]
255 | result=[]
256 | for items in [datas_data,datas_data1,datas_data2]:
257 | result.append(collator(items, max_node=max_node, multi_hop_max_dist=multi_hop_max_dist, spatial_pos_max=spatial_pos_max))
258 | res_dic={}
259 | res_dic['data']=result[0]
260 | res_dic['data1'] = result[1]
261 | res_dic['data2'] = result[2]
262 | return res_dic
263 |
264 |
265 |
--------------------------------------------------------------------------------
/deepgraph/tasks/is2re.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 |
4 | from pathlib import Path
5 | from typing import Sequence, Union
6 |
7 | import pickle
8 | from functools import lru_cache
9 |
10 | import lmdb
11 |
12 | import numpy as np
13 | import torch
14 | from torch import Tensor
15 | from fairseq.data import (
16 | FairseqDataset,
17 | BaseWrapperDataset,
18 | NestedDictionaryDataset,
19 | data_utils,
20 | )
21 | from fairseq.tasks import FairseqTask, register_task
22 |
23 | from ..data.dataset import EpochShuffleDataset
24 |
25 | class LMDBDataset:
26 | def __init__(self, db_path):
27 | super().__init__()
28 | assert Path(db_path).exists(), f"{db_path}: No such file or directory"
29 | self.env = lmdb.Environment(
30 | db_path,
31 | map_size=(1024 ** 3) * 256,
32 | subdir=False,
33 | readonly=True,
34 | readahead=True,
35 | meminit=False,
36 | )
37 | self.len: int = self.env.stat()["entries"]
38 |
39 | def __len__(self):
40 | return self.len
41 |
42 | @lru_cache(maxsize=16)
43 | # def __getitem__(self, idx: int) -> dict[str, Union[Tensor, float]]:
44 | def __getitem__(self, idx: int) :
45 | if idx < 0 or idx >= self.len:
46 | raise IndexError
47 | data = pickle.loads(self.env.begin().get(f"{idx}".encode()))
48 | return dict(
49 | pos=torch.as_tensor(data["pos"]).float(),
50 | pos_relaxed=torch.as_tensor(data["pos_relaxed"]).float(),
51 | cell=torch.as_tensor(data["cell"]).float().view(3, 3),
52 | atoms=torch.as_tensor(data["atomic_numbers"]).long(),
53 | tags=torch.as_tensor(data["tags"]).long(),
54 | relaxed_energy=data["y_relaxed"], # python float
55 | )
56 |
57 |
58 | class PBCDataset:
59 | def __init__(self, dataset: LMDBDataset):
60 | self.dataset = dataset
61 | self.cell_offsets = torch.tensor(
62 | [
63 | [-1, -1, 0],
64 | [-1, 0, 0],
65 | [-1, 1, 0],
66 | [0, -1, 0],
67 | [0, 1, 0],
68 | [1, -1, 0],
69 | [1, 0, 0],
70 | [1, 1, 0],
71 | ],
72 | ).float()
73 | self.n_cells = self.cell_offsets.size(0)
74 | self.cutoff = 8
75 | self.filter_by_tag = True
76 |
77 | def __len__(self):
78 | return len(self.dataset)
79 |
80 | @lru_cache(maxsize=16)
81 | def __getitem__(self, idx):
82 | data = self.dataset[idx]
83 |
84 | pos = data["pos"]
85 | pos_relaxed = data["pos_relaxed"]
86 | cell = data["cell"]
87 | atoms = data["atoms"]
88 | tags = data["tags"]
89 |
90 | offsets = torch.matmul(self.cell_offsets, cell).view(self.n_cells, 1, 3)
91 | expand_pos = (pos.unsqueeze(0).expand(self.n_cells, -1, -1) + offsets).view(
92 | -1, 3
93 | )
94 | expand_pos_relaxed = (
95 | pos.unsqueeze(0).expand(self.n_cells, -1, -1) + offsets
96 | ).view(-1, 3)
97 | src_pos = pos[tags > 1] if self.filter_by_tag else pos
98 |
99 | dist: Tensor = (src_pos.unsqueeze(1) - expand_pos.unsqueeze(0)).norm(dim=-1)
100 | used_mask = (dist < self.cutoff).any(dim=0) & tags.ne(2).repeat(
101 | self.n_cells
102 | ) # not copy ads
103 | used_expand_pos = expand_pos[used_mask]
104 | used_expand_pos_relaxed = expand_pos_relaxed[used_mask]
105 |
106 | used_expand_tags = tags.repeat(self.n_cells)[
107 | used_mask
108 | ] # original implementation use zeros, need to test
109 | return dict(
110 | pos=torch.cat([pos, used_expand_pos], dim=0),
111 | atoms=torch.cat([atoms, atoms.repeat(self.n_cells)[used_mask]]),
112 | tags=torch.cat([tags, used_expand_tags]),
113 | real_mask=torch.cat(
114 | [
115 | torch.ones_like(tags, dtype=torch.bool),
116 | torch.zeros_like(used_expand_tags, dtype=torch.bool),
117 | ]
118 | ),
119 | deltapos=torch.cat(
120 | [pos_relaxed - pos, used_expand_pos_relaxed - used_expand_pos], dim=0
121 | ),
122 | relaxed_energy=data["relaxed_energy"],
123 | )
124 |
125 |
126 | def pad_1d(samples: Sequence[Tensor], fill=0, multiplier=8):
127 | max_len = max(x.size(0) for x in samples)
128 | max_len = (max_len + multiplier - 1) // multiplier * multiplier
129 | n_samples = len(samples)
130 | out = torch.full(
131 | (n_samples, max_len, *samples[0].shape[1:]), fill, dtype=samples[0].dtype
132 | )
133 | for i in range(n_samples):
134 | x_len = samples[i].size(0)
135 | out[i][:x_len] = samples[i]
136 | return out
137 |
138 |
139 | class AtomDataset(FairseqDataset):
140 | def __init__(self, dataset, keyword):
141 | super().__init__()
142 | self.dataset = dataset
143 | self.keyword = keyword
144 | self.atom_list = [
145 | 1,
146 | 5,
147 | 6,
148 | 7,
149 | 8,
150 | 11,
151 | 13,
152 | 14,
153 | 15,
154 | 16,
155 | 17,
156 | 19,
157 | 20,
158 | 21,
159 | 22,
160 | 23,
161 | 24,
162 | 25,
163 | 26,
164 | 27,
165 | 28,
166 | 29,
167 | 30,
168 | 31,
169 | 32,
170 | 33,
171 | 34,
172 | 37,
173 | 38,
174 | 39,
175 | 40,
176 | 41,
177 | 42,
178 | 43,
179 | 44,
180 | 45,
181 | 46,
182 | 47,
183 | 48,
184 | 49,
185 | 50,
186 | 51,
187 | 52,
188 | 55,
189 | 72,
190 | 73,
191 | 74,
192 | 75,
193 | 76,
194 | 77,
195 | 78,
196 | 79,
197 | 80,
198 | 81,
199 | 82,
200 | 83,
201 | ]
202 | # fill others as unk
203 | unk_idx = len(self.atom_list) + 1
204 | self.atom_mapper = torch.full((128,), unk_idx)
205 | for idx, atom in enumerate(self.atom_list):
206 | self.atom_mapper[atom] = idx + 1 # reserve 0 for paddin
207 |
208 | @lru_cache(maxsize=16)
209 | def __getitem__(self, index):
210 | atoms: Tensor = self.dataset[index][self.keyword]
211 | return self.atom_mapper[atoms]
212 |
213 | def __len__(self):
214 | return len(self.dataset)
215 |
216 | def collater(self, samples):
217 | return pad_1d(samples)
218 |
219 |
220 | class KeywordDataset(FairseqDataset):
221 | def __init__(self, dataset, keyword, is_scalar=False, pad_fill=0):
222 | super().__init__()
223 | self.dataset = dataset
224 | self.keyword = keyword
225 | self.is_scalar = is_scalar
226 | self.pad_fill = pad_fill
227 |
228 | @lru_cache(maxsize=16)
229 | def __getitem__(self, index):
230 | return self.dataset[index][self.keyword]
231 |
232 | def __len__(self):
233 | return len(self.dataset)
234 |
235 | def collater(self, samples):
236 | if self.is_scalar:
237 | return torch.tensor(samples)
238 | return pad_1d(samples, fill=self.pad_fill)
239 |
240 |
241 | @register_task("is2re")
242 | class IS2RETask(FairseqTask):
243 | @classmethod
244 | def add_args(cls, parser):
245 | parser.add_argument("data", metavar="FILE", help="directory for data")
246 |
247 | @property
248 | def target_dictionary(self):
249 | return None
250 |
251 | def load_dataset(self, split, combine=False, **kwargs):
252 | assert split in [
253 | "train",
254 | "val_id",
255 | "val_ood_ads",
256 | "val_ood_cat",
257 | "val_ood_both",
258 | "test_id",
259 | "test_ood_ads",
260 | "test_ood_cat",
261 | "test_ood_both",
262 | ], "invalid split: {}!".format(split)
263 | print(" > Loading {} ...".format(split))
264 |
265 | db_path = str(Path(self.cfg.data) / split / "data.lmdb")
266 | lmdb_dataset = LMDBDataset(db_path)
267 | pbc_dataset = PBCDataset(lmdb_dataset)
268 |
269 | atoms = AtomDataset(pbc_dataset, "atoms")
270 | tags = KeywordDataset(pbc_dataset, "tags")
271 | real_mask = KeywordDataset(pbc_dataset, "real_mask")
272 |
273 | pos = KeywordDataset(pbc_dataset, "pos")
274 |
275 | relaxed_energy = KeywordDataset(pbc_dataset, "relaxed_energy", is_scalar=True)
276 | deltapos = KeywordDataset(pbc_dataset, "deltapos")
277 |
278 | dataset = NestedDictionaryDataset(
279 | {
280 | "net_input": {
281 | "pos": pos,
282 | "atoms": atoms,
283 | "tags": tags,
284 | "real_mask": real_mask,
285 | },
286 | "targets": {
287 | "relaxed_energy": relaxed_energy,
288 | "deltapos": deltapos,
289 | },
290 | },
291 | sizes=[np.zeros(len(atoms))],
292 | )
293 |
294 | if split == "train":
295 | dataset = EpochShuffleDataset(
296 | dataset,
297 | num_samples=len(atoms),
298 | seed=self.cfg.seed,
299 | )
300 |
301 | print("| Loaded {} with {} samples".format(split, len(dataset)))
302 | self.datasets[split] = dataset
303 |
--------------------------------------------------------------------------------
/deepgraph/data/wrapper.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 |
4 | import torch
5 | import numpy as np
6 | import joblib
7 | from ogb.graphproppred import PygGraphPropPredDataset
8 | from ogb.lsc.pcqm4mv2_pyg import PygPCQM4Mv2Dataset
9 | from functools import lru_cache
10 | import pyximport
11 | import torch.distributed as dist
12 |
13 | pyximport.install(setup_args={"include_dirs": np.get_include()})
14 | from . import algos
15 | import sys
16 | from tqdm import tqdm
17 |
18 | @torch.jit.script
19 | def convert_to_single_emb(x, offset: int = 10):
20 | feature_num = x.size(1) if len(x.size()) > 1 else 1
21 | feature_offset = 1 + torch.arange(0, feature_num * offset, offset, dtype=torch.long)
22 | x = x + feature_offset
23 | return x
24 |
25 |
26 | def encode_token_single_tensor(graph, atom_id_max, edge_id_max):
27 | cur_id = graph.x.shape[0]
28 | graph.cur_id = cur_id
29 | _, index = np.unique(graph.identifiers[0, :], return_index=True)
30 | index = np.sort(index)
31 | token_ids = graph.identifiers[2, index].unsqueeze(1) + atom_id_max + 1
32 | graph.x = torch.cat([graph.x, token_ids.repeat([1, graph.x.shape[1]])], 0).type(torch.int64)
33 |
34 | edges = torch.cat([graph.identifiers[0, :].unsqueeze(0) + cur_id, graph.identifiers[1, :].unsqueeze(0)], 0).long()
35 | graph.edge_index = torch.cat([graph.edge_index, edges, edges[[1, 0], :]], 1)
36 |
37 | if graph.edge_attr is not None:
38 | edge_attr = graph.identifiers[2, :].unsqueeze(1).repeat([1, graph.edge_attr.shape[1]]) + edge_id_max + 2
39 | graph.edge_attr = torch.cat([graph.edge_attr, edge_attr, edge_attr], 0).long()
40 |
41 | graph.sorted_adj = None
42 |
43 | graph.sub_adj_mask = torch.ones([graph.x.shape[0],1]).long()
44 | graph.sub_adj_mask[0:graph.cur_id]=0
45 | graph.sub_adj_mask=graph.sub_adj_mask.long()
46 |
47 | return graph
48 |
49 |
50 | def encode_token_single_tensor_with_adj(graph, local_attention_on_substructures):
51 | cur_id = graph.x.shape[0]
52 | graph.cur_id = cur_id
53 |
54 | _, index = np.unique(graph.identifiers[0, :], return_index=True)
55 | index = np.sort(index)
56 | token_ids = -1 * torch.ones_like(graph.identifiers[2, index].unsqueeze(1))
57 |
58 | graph.sorted_adj = torch.cat([torch.zeros([graph.num_nodes,
59 | graph.sorted_adj.shape[1],
60 | graph.sorted_adj.shape[2]]), graph.sorted_adj], 0)
61 |
62 | graph.x = torch.cat([graph.x, token_ids.repeat([1, graph.x.shape[1]])], 0).type(torch.int64)
63 | graph.sub_adj_mask = (graph.x < 0).long()
64 | if graph.sub_adj_mask.shape[1] > 1:
65 | graph.sub_adj_mask = graph.sub_adj_mask[:, 0].unsqueeze(1)
66 |
67 |
68 | # expanded edges is used for attention mask, not for position embedding
69 | edges = torch.cat([graph.identifiers[0, :].unsqueeze(0) + cur_id, graph.identifiers[1, :].unsqueeze(0)], 0).long()
70 | graph.edge_index = torch.cat([graph.edge_index, edges, edges[[1, 0], :]], 1)
71 |
72 | if not local_attention_on_substructures:
73 | # expanded edge_attr is actually not used in local_attention_on_substructures
74 | edge_id = graph.edge_attr.shape[0]
75 | graph.edge_id = edge_id
76 | if graph.edge_attr is not None:
77 | edge_attr = (-1 * torch.ones_like(graph.identifiers[2].unsqueeze(1))) \
78 | .repeat([1, graph.edge_attr.shape[1]]) \
79 | .type(torch.int64) + edge_id_max + 2
80 | graph.edge_attr = torch.cat([graph.edge_attr, edge_attr, edge_attr], 0).long()
81 |
82 | return graph
83 |
84 | def graph_data_modification_single(data,substructures,sorted_adj=None):
85 |
86 | if data.x.dim()==1:
87 | data.x=data.x.unsqueeze(1)
88 | if data.y.dim() == 2:
89 | data.y = data.y.squeeze(1)
90 | if hasattr(data, 'edge_features'):
91 | data.edge_attr=data.edge_features
92 | del(data.edge_features)
93 | if hasattr(data,'degree'):
94 | del(data.degrees)
95 | if hasattr(data,'graph_size'):
96 | del (data.graph_size)
97 | if data.edge_attr is not None and data.edge_attr.dim()==1:
98 | data.edge_attr=data.edge_attr.unsqueeze(1)
99 | assert hasattr(data,'edge_attr')
100 | assert hasattr(data, 'edge_index')
101 | setattr(data,'identifiers', substructures)
102 | setattr(data,'sorted_adj',sorted_adj)
103 |
104 | return data
105 |
106 |
107 |
108 | def preprocess_item(item,local_attention_on_substructures=False,continuous_feature=False):
109 | max_dist_const=510
110 | substructure_dist_const=max_dist_const-1
111 |
112 | edge_attr, edge_index, x = item.edge_attr, item.edge_index, item.x
113 |
114 | N = x.size(0)
115 | x = convert_to_single_emb(x)
116 |
117 | # node adj matrix [N, N] bool
118 | adj_matrix = torch.zeros([N, N], dtype=torch.bool)
119 | adj_matrix[edge_index[0, :], edge_index[1, :]] = True
120 | if local_attention_on_substructures:
121 |
122 | adj = adj_matrix[0:item.cur_id, 0:item.cur_id]
123 | else:
124 | adj = adj_matrix[0:,0:]
125 |
126 |
127 |
128 | if edge_attr is None:
129 | attn_edge_type = -1*torch.ones([N, N, 1], dtype=torch.long)
130 | else:
131 | if len(edge_attr.size()) == 1:
132 | edge_attr = edge_attr[:, None]
133 | attn_edge_type = torch.zeros([N, N, edge_attr.size(-1)], dtype=torch.long)
134 | edge_attr=edge_attr.long()
135 | attn_edge_type[edge_index[0, :edge_attr.size(0)], edge_index[1, :edge_attr.size(0)]] = (
136 | convert_to_single_emb(edge_attr) + 1
137 | )
138 |
139 |
140 | #compute attention mask (attn_bias)
141 | if local_attention_on_substructures:
142 | adj_matrix = adj_matrix.float()
143 | adj_matrix[0:item.cur_id, 0:item.cur_id] = 1
144 | attn_bias=1-adj_matrix
145 | attn_bias[attn_bias>0]=float('-inf')
146 | attn_bias_res = torch.zeros([N + 1, N + 1], dtype=torch.float)
147 | attn_bias_res[1:,1:]=attn_bias
148 | attn_bias=attn_bias_res
149 | else:
150 | attn_bias = torch.zeros([N + 1, N + 1], dtype=torch.float) # with graph token
151 |
152 |
153 |
154 | shortest_path_result, path = algos.floyd_warshall(adj.numpy())
155 | max_dist = np.amax(shortest_path_result)
156 | edge_input = algos.gen_edge_input(max_dist, path, attn_edge_type.numpy())
157 | spatial_pos = torch.from_numpy((shortest_path_result)).long()
158 |
159 | #Expand the size of edge_input and spatial_pos
160 | if local_attention_on_substructures:
161 | edge_input_new=(-1*np.ones([N,N,edge_input.shape[2],edge_input.shape[3]])).astype(np.int64)
162 | edge_input_new[0:item.cur_id,0:item.cur_id,:,:]=edge_input
163 | edge_input=edge_input_new
164 | spatial_pos_new=(max_dist_const*torch.ones(N,N)).long()
165 | spatial_pos_new[0:item.cur_id,0:item.cur_id]=spatial_pos
166 | spatial_pos=spatial_pos_new
167 |
168 |
169 | # combine
170 | item.x = x.long()
171 | item.attn_bias = attn_bias
172 | item.attn_edge_type = attn_edge_type
173 | item.spatial_pos = spatial_pos
174 | item.in_degree = adj.long().sum(dim=1).view(-1)
175 | item.out_degree = item.in_degree # for undirected graph
176 | item.edge_input = torch.from_numpy(edge_input).long()
177 |
178 |
179 | return item
180 |
181 |
182 |
183 | def preprocess_item_local_attention(item,local_attention_on_substructures=False,continuous_feature=False):
184 | max_dist_const=510
185 | substructure_dist_const=max_dist_const-1
186 |
187 | edge_attr, edge_index, x = item.edge_attr, item.edge_index, item.x
188 | N = x.size(0)
189 | x = convert_to_single_emb(x)
190 |
191 | # node adj matrix [N, N] bool
192 | adj = torch.zeros([N, N], dtype=torch.bool)
193 | adj[edge_index[0, :], edge_index[1, :]] = True
194 |
195 | # edge feature here
196 | if edge_attr is None:
197 | attn_edge_type = -1*torch.ones([N, N, 1], dtype=torch.long)
198 | else:
199 | if len(edge_attr.size()) == 1:
200 | edge_attr = edge_attr[:, None]
201 | attn_edge_type = torch.zeros([N, N, edge_attr.size(-1)], dtype=torch.long)
202 | edge_attr=edge_attr.long()
203 | attn_edge_type[edge_index[0, :], edge_index[1, :]] = (
204 | convert_to_single_emb(edge_attr) + 1
205 | )
206 |
207 |
208 | if local_attention_on_substructures:
209 | adj_matrix = torch.sparse_coo_tensor(item.edge_index, torch.ones_like(item.edge_index[0]),
210 | [item.x.shape[0], item.x.shape[0]]
211 | ).to_dense().float()
212 | adj_matrix[0:item.cur_id, 0:item.cur_id] = 1
213 | attn_bias=1-adj_matrix
214 | attn_bias[attn_bias>0]=float('-inf')
215 | attn_bias_res = torch.zeros([N + 1, N + 1], dtype=torch.float)
216 | attn_bias_res[1:,1:]=attn_bias
217 | attn_bias=attn_bias_res
218 | else:
219 | attn_bias = torch.zeros([N + 1, N + 1], dtype=torch.float) # with graph token
220 |
221 |
222 | if local_attention_on_substructures:
223 | adj_new=torch.zeros_like(adj)
224 | adj_new[0:item.cur_id, 0:item.cur_id]=adj[0:item.cur_id, 0:item.cur_id]
225 | adj=adj_new
226 | shortest_path_result, path = algos.floyd_warshall(adj.numpy())
227 | max_dist = np.amax(shortest_path_result)
228 | edge_input = algos.gen_edge_input(max_dist, path, attn_edge_type.numpy())
229 | spatial_pos = torch.from_numpy((shortest_path_result)).long()
230 |
231 | if local_attention_on_substructures:
232 | mask=torch.tensor(shortest_path_result == max_dist_const) & adj_matrix.bool()
233 | shortest_path_result[mask]=substructure_dist_const
234 |
235 | # combine
236 | item.x = x.long()
237 | item.attn_bias = attn_bias
238 | item.attn_edge_type = attn_edge_type
239 | item.spatial_pos = spatial_pos
240 | item.in_degree = adj.long().sum(dim=1).view(-1)
241 | item.out_degree = item.in_degree # for undirected graph
242 | item.edge_input = torch.from_numpy(edge_input).long()
243 |
244 |
245 | return item
246 |
247 |
248 | class MyPygPCQM4MDataset(PygPCQM4Mv2Dataset):
249 | def download(self):
250 | super(MyPygPCQM4MDataset, self).download()
251 |
252 | def process(self):
253 | super(MyPygPCQM4MDataset, self).process()
254 |
255 | @lru_cache(maxsize=16)
256 | def __getitem__(self, idx):
257 | item = self.get(self.indices()[idx])
258 | item.idx = idx
259 | return preprocess_item(item)
260 |
261 |
262 | class MyPygGraphPropPredDataset(PygGraphPropPredDataset):
263 | def download(self):
264 | if dist.get_rank() == 0:
265 | super(MyPygGraphPropPredDataset, self).download()
266 | dist.barrier()
267 |
268 | def process(self):
269 | if dist.get_rank() == 0:
270 | super(MyPygGraphPropPredDataset, self).process()
271 | dist.barrier()
272 |
273 | @lru_cache(maxsize=16)
274 | def __getitem__(self, idx):
275 | item = self.get(self.indices()[idx])
276 | item.idx = idx
277 | item.y = item.y.reshape(-1)
278 | return preprocess_item(item)
279 |
--------------------------------------------------------------------------------
/deepgraph/modules/deepgraph_graph_encoder.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 |
4 | # Copyright (c) Facebook, Inc. and its affiliates.
5 | #
6 | # This source code is licensed under the MIT license found in the
7 | # LICENSE file in the root directory of this source tree.
8 |
9 | from typing import Optional, Tuple
10 |
11 | import torch
12 | import torch.nn as nn
13 | from fairseq.modules import FairseqDropout, LayerDropModuleList, LayerNorm
14 | from fairseq.modules.quant_noise import quant_noise as apply_quant_noise_
15 |
16 | from .multihead_attention import MultiheadAttention
17 | from .deepgraph_layers import GraphNodeFeature, GraphAttnBias,SubStructure_Adj_Encoder
18 | from .deepgraph_graph_encoder_layer import DeepGraphGraphEncoderLayer
19 |
20 |
21 | def init_graphormer_params(module):
22 | """
23 | Initialize the weights specific to the Graphormer Model.
24 | """
25 |
26 | def normal_(data):
27 | # with FSDP, module params will be on CUDA, so we cast them back to CPU
28 | # so that the RNG is consistent with and without FSDP
29 | data.copy_(data.cpu().normal_(mean=0.0, std=0.02).to(data.device))
30 |
31 | if isinstance(module, nn.Linear):
32 | normal_(module.weight.data)
33 | if module.bias is not None:
34 | module.bias.data.zero_()
35 | if isinstance(module, nn.Embedding):
36 | normal_(module.weight.data)
37 | if module.padding_idx is not None:
38 | module.weight.data[module.padding_idx].zero_()
39 | if isinstance(module, MultiheadAttention):
40 | normal_(module.q_proj.weight.data)
41 | normal_(module.k_proj.weight.data)
42 | normal_(module.v_proj.weight.data)
43 |
44 |
45 | class DeepGraphGraphEncoder(nn.Module):
46 | def __init__(
47 | self,
48 | num_atoms: int,
49 | num_in_degree: int,
50 | num_out_degree: int,
51 | num_edges: int,
52 | num_spatial: int,
53 | num_edge_dis: int,
54 | edge_type: str,
55 | multi_hop_max_dist: int,
56 | num_encoder_layers: int = 12,
57 | embedding_dim: int = 768,
58 | ffn_embedding_dim: int = 768,
59 | num_attention_heads: int = 32,
60 | dropout: float = 0.1,
61 | attention_dropout: float = 0.1,
62 | activation_dropout: float = 0.1,
63 | layerdrop: float = 0.0,
64 | encoder_normalize_before: bool = False,
65 | pre_layernorm: bool = False,
66 | apply_graphormer_init: bool = False,
67 | activation_fn: str = "gelu",
68 | embed_scale: float = None,
69 | freeze_embeddings: bool = False,
70 | n_trans_layers_to_freeze: int = 0,
71 | export: bool = False,
72 | traceable: bool = False,
73 | q_noise: float = 0.0,
74 | qn_block_size: int = 8,
75 | deepnorm:bool = False,
76 | encoder_layers: int = 12,
77 | encode_adj: bool = False,
78 | subgraph_max_size=None,
79 | reattention=False
80 | ) -> None:
81 |
82 | super().__init__()
83 | self.dropout_module = FairseqDropout(
84 | dropout, module_name=self.__class__.__name__
85 | )
86 | self.layerdrop = layerdrop
87 | self.embedding_dim = embedding_dim
88 | self.apply_graphormer_init = apply_graphormer_init
89 | self.traceable = traceable
90 |
91 | self.graph_node_feature = GraphNodeFeature(
92 | num_heads=num_attention_heads,
93 | num_atoms=num_atoms,
94 | num_in_degree=num_in_degree,
95 | num_out_degree=num_out_degree,
96 | hidden_dim=embedding_dim,
97 | n_layers=num_encoder_layers,
98 | )
99 |
100 | self.graph_attn_bias = GraphAttnBias(
101 | num_heads=num_attention_heads,
102 | num_atoms=num_atoms,
103 | num_edges=num_edges,
104 | num_spatial=num_spatial,
105 | num_edge_dis=num_edge_dis,
106 | edge_type=edge_type,
107 | multi_hop_max_dist=multi_hop_max_dist,
108 | hidden_dim=embedding_dim,
109 | n_layers=num_encoder_layers,
110 | )
111 |
112 | self.encode_adj=encode_adj
113 | if self.encode_adj:
114 | self.adj_encoder =SubStructure_Adj_Encoder(subgraph_max_size=subgraph_max_size,
115 | output_size=embedding_dim)
116 |
117 |
118 | self.embed_scale = embed_scale
119 |
120 | if q_noise > 0:
121 | self.quant_noise = apply_quant_noise_(
122 | nn.Linear(self.embedding_dim, self.embedding_dim, bias=False),
123 | q_noise,
124 | qn_block_size,
125 | )
126 | else:
127 | self.quant_noise = None
128 |
129 | if encoder_normalize_before:
130 | self.emb_layer_norm = LayerNorm(self.embedding_dim, export=export)
131 | else:
132 | self.emb_layer_norm = None
133 |
134 | if pre_layernorm:
135 | self.final_layer_norm = LayerNorm(self.embedding_dim, export=export)
136 |
137 | if self.layerdrop > 0.0:
138 | self.layers = LayerDropModuleList(p=self.layerdrop)
139 | else:
140 | self.layers = nn.ModuleList([])
141 | self.layers.extend(
142 | [
143 | self.build_graphormer_graph_encoder_layer(
144 | embedding_dim=self.embedding_dim,
145 | ffn_embedding_dim=ffn_embedding_dim,
146 | num_attention_heads=num_attention_heads,
147 | dropout=self.dropout_module.p,
148 | attention_dropout=attention_dropout,
149 | activation_dropout=activation_dropout,
150 | activation_fn=activation_fn,
151 | export=export,
152 | q_noise=q_noise,
153 | qn_block_size=qn_block_size,
154 | pre_layernorm=pre_layernorm,
155 | deepnorm=deepnorm,
156 | encoder_layers=encoder_layers,
157 | reattention=reattention
158 | )
159 | for _ in range(num_encoder_layers)
160 | ]
161 | )
162 |
163 | # Apply initialization of model params after building the model
164 | if self.apply_graphormer_init:
165 | print('reiniting by graphormer_init')
166 | self.apply(init_graphormer_params)
167 |
168 | def freeze_module_params(m):
169 | if m is not None:
170 | for p in m.parameters():
171 | p.requires_grad = False
172 |
173 | if freeze_embeddings:
174 | raise NotImplementedError("Freezing embeddings is not implemented yet.")
175 |
176 | for layer in range(n_trans_layers_to_freeze):
177 | freeze_module_params(self.layers[layer])
178 |
179 | def build_graphormer_graph_encoder_layer(
180 | self,
181 | embedding_dim,
182 | ffn_embedding_dim,
183 | num_attention_heads,
184 | dropout,
185 | attention_dropout,
186 | activation_dropout,
187 | activation_fn,
188 | export,
189 | q_noise,
190 | qn_block_size,
191 | pre_layernorm,
192 | deepnorm,
193 | encoder_layers,
194 | reattention=False
195 | ):
196 | return DeepGraphGraphEncoderLayer(
197 | embedding_dim=embedding_dim,
198 | ffn_embedding_dim=ffn_embedding_dim,
199 | num_attention_heads=num_attention_heads,
200 | dropout=dropout,
201 | attention_dropout=attention_dropout,
202 | activation_dropout=activation_dropout,
203 | activation_fn=activation_fn,
204 | export=export,
205 | q_noise=q_noise,
206 | qn_block_size=qn_block_size,
207 | pre_layernorm=pre_layernorm,
208 | deepnorm=deepnorm,
209 | encoder_layers=encoder_layers,
210 | reattention=reattention
211 | )
212 |
213 | def forward(
214 | self,
215 | batched_data,
216 | perturb=None,
217 | last_state_only: bool = False,
218 | token_embeddings: Optional[torch.Tensor] = None,
219 | attn_mask: Optional[torch.Tensor] = None,
220 | ) -> Tuple[torch.Tensor, torch.Tensor]:
221 | is_tpu = False
222 | # compute padding mask. This is needed for multi-head attention
223 | data_x = batched_data["x"]
224 | n_graph, n_node = data_x.size()[:2]
225 | padding_mask = (data_x[:, :, 0]).eq(0) # B x T x 1
226 | padding_mask_cls = torch.zeros(
227 | n_graph, 1, device=padding_mask.device, dtype=padding_mask.dtype
228 | )
229 | padding_mask = torch.cat((padding_mask_cls, padding_mask), dim=1)
230 | # B x (T+1) x 1
231 |
232 | if token_embeddings is not None:
233 | x = token_embeddings
234 | else:
235 | x = self.graph_node_feature(batched_data)
236 |
237 | if self.encode_adj:
238 | sub_adj_mask=torch.cat([torch.zeros_like(batched_data['sub_adj_mask'][:,0,:].unsqueeze(1)),
239 | batched_data['sub_adj_mask']],1)
240 | index=(sub_adj_mask==2).type(x.dtype)
241 | start_index=index.sum(0).squeeze(1).nonzero().min()
242 |
243 | adj_emb=self.adj_encoder(batched_data['sorted_adj'][:,start_index:,:,:])
244 | adj_emb=torch.cat([torch.zeros([adj_emb.shape[0],1+start_index, adj_emb.shape[2]]).to(adj_emb.device),adj_emb],1).type(x.dtype)
245 |
246 | x=x*(1-index)+adj_emb*index
247 |
248 | if perturb is not None:
249 | #ic(torch.mean(torch.abs(x[:, 1, :])))
250 | #ic(torch.mean(torch.abs(perturb)))
251 | x[:, 1:, :] += perturb
252 |
253 | # x: B x T x C
254 |
255 | attn_bias = self.graph_attn_bias(batched_data)
256 |
257 | if self.embed_scale is not None:
258 | x = x * self.embed_scale
259 |
260 | if self.quant_noise is not None:
261 | x = self.quant_noise(x)
262 |
263 | if self.emb_layer_norm is not None:
264 | x = self.emb_layer_norm(x)
265 |
266 | x = self.dropout_module(x)
267 |
268 | # account for padding while computing the representation
269 |
270 | # B x T x C -> T x B x C
271 | x = x.transpose(0, 1)
272 |
273 | inner_states = []
274 | if not last_state_only:
275 | inner_states.append(x)
276 |
277 | atts=[]
278 |
279 | for layer in self.layers:
280 | x, att = layer(
281 | x,
282 | self_attn_padding_mask=padding_mask,
283 | self_attn_mask=attn_mask,
284 | self_attn_bias=attn_bias,
285 | )
286 | if not last_state_only:
287 | inner_states.append(x)
288 | atts.append(att)
289 | graph_rep = x[0, :, :]
290 |
291 | if last_state_only:
292 | inner_states = [x]
293 |
294 | if self.traceable:
295 | return torch.stack(inner_states), graph_rep,atts
296 | else:
297 | return inner_states, graph_rep,atts
298 |
--------------------------------------------------------------------------------
/deepgraph/data/substructure_dataset_utils/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | from .utils_graph_processing import subgraph_isomorphism_edge_counts, subgraph_isomorphism_vertex_counts, induced_edge_automorphism_orbits, edge_automorphism_orbits, automorphism_orbits,subgraph_isomorphism_vertex_extraction,subgraph_counts2ids,subgraph_2token
3 |
4 | import torch
5 | import torch.nn as nn
6 | import numpy as np
7 | from torch_geometric.data import Data
8 | import networkx as nx
9 | from torch_geometric.data import Data
10 | import glob
11 | import re
12 | import types
13 | from ast import literal_eval
14 | from omegaconf import open_dict
15 | import lmdb
16 | from tqdm import tqdm
17 | import time
18 |
19 | def get_custom_edge_list(ks, substructure_type=None, filename=None):
20 | '''
21 | Instantiates a list of `edge_list`s representing substructures
22 | of type `substructure_type` with sizes specified by `ks`.
23 | '''
24 | if substructure_type is None and filename is None:
25 | raise ValueError('You must specify either a type or a filename where to read substructures from.')
26 | edge_lists = []
27 | for k in ks:
28 | if substructure_type is not None:
29 | graphs_nx = getattr(nx, substructure_type)(k)
30 | else:
31 | graphs_nx = nx.read_graph6(os.path.join(filename, 'graph{}c.g6'.format(k)))
32 | if isinstance(graphs_nx, list) or isinstance(graphs_nx, types.GeneratorType):
33 | edge_lists += [list(graph_nx.edges) for graph_nx in graphs_nx]
34 | else:
35 | edge_lists.append(list(graphs_nx.edges))
36 | return edge_lists
37 |
38 |
39 | def process_arguments_substructure(args):
40 | with open_dict(args):
41 | args['k']=literal_eval(args['ks'])
42 | args['subgraph_max_size']=0
43 | tem= (args['id_type'])
44 | del args['id_type']
45 | args['id_type']=tem.split('+')
46 | tem= (args['must_select_sub'])
47 | del args['must_select_sub']
48 | args['must_select_sub']=tem.split('+')
49 |
50 |
51 | with open_dict(args):
52 | del args['custom_edge_list']
53 | args['custom_edge_list']=[]
54 |
55 | if args['extra_method']=='feature':
56 | extract_id_fn = subgraph_counts2ids
57 | else:
58 | extract_id_fn = subgraph_2token
59 | ###### choose the function that computes the automorphism group and the orbits #######
60 | if args['edge_automorphism'] == 'induced':
61 | automorphism_fn = induced_edge_automorphism_orbits if args['id_scope'] == 'local' else automorphism_orbits
62 | elif args['edge_automorphism'] == 'line_graph':
63 | automorphism_fn = edge_automorphism_orbits if args['id_scope'] == 'local' else automorphism_orbits
64 | else:
65 | raise NotImplementedError
66 |
67 | ###### choose the function that computes the subgraph isomorphisms #######
68 | if args['extra_method']=='feature':
69 | count_fn = subgraph_isomorphism_edge_counts if args['id_scope'] == 'local' else subgraph_isomorphism_vertex_counts
70 | else:
71 | count_fn = subgraph_isomorphism_vertex_extraction
72 |
73 | ###### choose the substructures: usually loaded from networkx,
74 | ###### except for 'all_simple_graphs' where they need to be precomputed,
75 | ###### or when a custom edge list is provided in the input by the user
76 | with open_dict(args):
77 | del (args['custom_edge_list'])
78 | args['custom_edge_list']=[]
79 | args['must_select_list']=[]
80 | args['neighbor_type']=[]
81 | args['neighbor_size'] = []
82 | args['pre_defined_path']=args['dataset_name']+'_'+'substructure'
83 | args['node_neighbor_path']=args['dataset_name']+'_'+'neighbor'
84 | args['transform_cache_path']=args['dataset_name']+'_'+args['sampling_mode']+'_'+str(args['sampling_redundancy'])+'_'+str(args['transform_cache_number'])
85 |
86 | for number,id_type in enumerate(args['id_type']):
87 |
88 | if id_type in ['cycle_graph','complete_graph',
89 | 'binomial_tree',
90 | 'star_graph',
91 | 'nonisomorphic_trees']:
92 | # args['k'] = args['k'][0]
93 | k_max = args['k'][number]
94 | k_min = 2 if id_type == 'star_graph' else 3
95 | args['subgraph_max_size']=max(args['subgraph_max_size'],k_max+1 if id_type=='star_graph' else k_max)
96 | with open_dict(args):
97 | add_sub=get_custom_edge_list(list(range(k_min, k_max + 1)), id_type)
98 | if id_type in args['must_select_sub']:
99 | args['must_select_list']+=list(range(len(args['custom_edge_list']),
100 | len(args['custom_edge_list'])+len(add_sub)
101 | ))
102 | args['custom_edge_list'] +=add_sub
103 | args['pre_defined_path'] +='_'+id_type+'_'+str(args['k'][number])
104 | args['transform_cache_path'] += '_'+id_type+'_'+str(args['k'][number])
105 |
106 |
107 | elif id_type in ['path_graph']:
108 | k_min = args['k'][number]
109 | k_max = 8
110 | args['subgraph_max_size'] = max(args['subgraph_max_size'], k_max+1)
111 | with open_dict(args):
112 | add_sub=get_custom_edge_list(list(range(k_min, k_max + 1)), id_type)
113 | if id_type in args['must_select_sub']:
114 | args['must_select_list']+=list(range(len(args['custom_edge_list']),
115 | len(args['custom_edge_list'])+len(add_sub)
116 | ))
117 | args['custom_edge_list'] +=add_sub
118 | args['pre_defined_path'] += '_' + id_type + '_' + str(args['k'][number])
119 | args['transform_cache_path'] += '_' + id_type + '_' + str(args['k'][number])
120 |
121 | elif id_type in ['k_neighborhood', 'random_walk']:
122 | with open_dict(args):
123 | args['neighbor_type'].append(id_type)
124 | args['neighbor_size'].append(args['k'][number])
125 | args['node_neighbor_path']+='_'+id_type+'_'+str(args['k'][number])
126 | args['transform_cache_path'] += '_' + id_type + '_' + str(args['k'][number])
127 |
128 |
129 |
130 | elif id_type in ['cycle_graph_chosen_k',
131 | 'path_graph_chosen_k',
132 | 'complete_graph_chosen_k',
133 | 'binomial_tree_chosen_k',
134 | 'star_graph_chosen_k',
135 | 'nonisomorphic_trees_chosen_k']:
136 | print('warning: not complete subgraph_max_size for ',id_type)
137 | with open_dict(args):
138 | args['custom_edge_list'] += get_custom_edge_list(args['k'], id_type.replace('_chosen_k', ''))
139 |
140 | elif id_type in ['all_simple_graphs']:
141 | # args['k'] = args['k'][0]
142 | k_max = args['k'][number]
143 | k_min = 3
144 | args['subgraph_max_size'] = max(args['subgraph_max_size'], k_max)
145 | filename = os.path.join(args['root_folder'], 'all_simple_graphs')
146 | with open_dict(args):
147 | args['custom_edge_list'] += get_custom_edge_list(list(range(k_min, k_max + 1)), filename=filename)
148 |
149 | elif id_type in ['all_simple_graphs_chosen_k']:
150 | print('warning: not complete subgraph_max_size for ', id_type)
151 | filename = os.path.join(args['root_folder'], 'all_simple_graphs')
152 | with open_dict(args):
153 | args['custom_edge_list'] += get_custom_edge_list(args['k'], filename=filename)
154 |
155 | elif id_type in ['diamond_graph']:
156 | print('warning: not complete subgraph_max_size for ', id_type)
157 | # args['k'] = None
158 | graph_nx = nx.diamond_graph()
159 | with open_dict(args):
160 | args['custom_edge_list'] += [list(graph_nx.edges)]
161 |
162 | elif id_type == 'custom':
163 | print('warning: not complete subgraph_max_size for ', id_type)
164 | assert args['custom_edge_list'] is not None, "Custom edge list must be provided."
165 |
166 | else:
167 | raise NotImplementedError("Identifiers {} are not currently supported.".format(id_type))
168 |
169 | return args, extract_id_fn, count_fn, automorphism_fn
170 |
171 |
172 |
173 |
174 |
175 |
176 | def transfer_subgraph_to_batchtensor_complete(substructures):
177 | id_lis=[]
178 | note_lis=[]
179 | type_list = []
180 | cur_id=0
181 | for id_type,subtype in enumerate(substructures):
182 | for data in subtype:
183 | id_lis=id_lis+[cur_id]*len(data)
184 | type_list=type_list+[int(id_type)]*len(data)
185 | note_lis=note_lis+list(data)
186 | cur_id += 1
187 |
188 | return note_lis,id_lis,type_list
189 |
190 |
191 | def _prepare_process(data, subgraph_dicts, subgraph_params,ex_fn, cnt_fn):
192 | if data.edge_index.shape[1] == 0 and cnt_fn.__name__ == 'subgraph_isomorphism_edge_counts':
193 | setattr(data, 'identifiers', torch.zeros((0, sum(orbit_partition_sizes))).long())
194 | else:
195 | new_data = ex_fn(cnt_fn, data, subgraph_dicts, subgraph_params)
196 |
197 | return new_data
198 |
199 |
200 | def substructure_to_gt(subgraph_params,automorphism_fn):
201 | subgraph_dicts = []
202 | if 'edge_list' not in subgraph_params:
203 | raise ValueError('Edge list not provided.')
204 | for edge_list in subgraph_params['edge_list']:
205 | subgraph, orbit_partition, orbit_membership, aut_count = \
206 | automorphism_fn(edge_list=edge_list,
207 | only_graph=True,
208 | directed=subgraph_params['directed'],
209 | directed_orbits=subgraph_params['directed_orbits'])
210 | subgraph_dicts.append({'subgraph': subgraph})
211 | return subgraph_dicts
212 |
213 |
214 |
215 |
216 |
217 |
218 | def load_dataset(data_file):
219 | '''
220 | Loads dataset from `data_file`.
221 | '''
222 | print("Loading dataset from {}".format(data_file))
223 | dataset_obj = torch.load(data_file)
224 | graphs_ptg = dataset_obj[0]
225 | num_classes = dataset_obj[1]
226 | orbit_partition_sizes = dataset_obj[2]
227 |
228 | return graphs_ptg, num_classes, orbit_partition_sizes
229 |
230 |
231 | def try_downgrading(data_folder, id_type, induced, directed_orbits, k, k_min):
232 | '''
233 | Extracts the substructures of size up to the `k`, if a collection of substructures
234 | with size larger than k has already been computed.
235 | '''
236 | found_data_filename, k_found = find_id_filename(data_folder, id_type, induced, directed_orbits, k)
237 | if found_data_filename is not None:
238 | graphs_ptg, num_classes, orbit_partition_sizes = load_dataset(found_data_filename)
239 | print("Downgrading k from dataset {}...".format(found_data_filename))
240 | graphs_ptg, orbit_partition_sizes = downgrade_k(graphs_ptg, k, orbit_partition_sizes, k_min)
241 | return True, graphs_ptg, num_classes, orbit_partition_sizes
242 | else:
243 | return False, None, None, None
244 |
245 |
246 | def find_id_filename(data_folder, id_type, induced, directed_orbits, k):
247 | '''
248 | Looks for existing precomputed datasets in `data_folder` with counts for substructure
249 | `id_type` larger `k`.
250 | '''
251 | if induced:
252 | if directed_orbits:
253 | pattern = os.path.join(data_folder, '{}_induced_directed_orbits_[0-9]*.pt'.format(id_type))
254 | else:
255 | pattern = os.path.join(data_folder, '{}_induced_[0-9]*.pt'.format(id_type))
256 | else:
257 | if directed_orbits:
258 | pattern = os.path.join(data_folder, '{}_directed_orbits_[0-9]*.pt'.format(id_type))
259 | else:
260 | pattern = os.path.join(data_folder, '{}_[0-9]*.pt'.format(id_type))
261 | filenames = glob.glob(pattern)
262 | for name in filenames:
263 | k_found = int(re.findall(r'\d+', name)[-1])
264 | if k_found >= k:
265 | return name, k_found
266 | return None, None
267 |
268 | def downgrade_k(dataset, k, orbit_partition_sizes, k_min):
269 | '''
270 | Donwgrades `dataset` by keeping only the orbits of the requested substructures.
271 | '''
272 | feature_vector_size = sum(orbit_partition_sizes[0:k-k_min+1])
273 | graphs_ptg = list()
274 | for data in dataset:
275 | new_data = Data()
276 | for attr in data.__iter__():
277 | name, value = attr
278 | setattr(new_data, name, value)
279 | setattr(new_data, 'identifiers', data.identifiers[:, 0:feature_vector_size])
280 | graphs_ptg.append(new_data)
281 | return graphs_ptg, orbit_partition_sizes[0:k-k_min+1]
282 |
283 |
--------------------------------------------------------------------------------
/deepgraph/data/ogb_datasets/ogb_dataset_lookup_table.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 |
4 | from typing import Optional
5 | from ogb.lsc.pcqm4mv2_pyg import PygPCQM4Mv2Dataset
6 | from ogb.lsc.pcqm4m_pyg import PygPCQM4MDataset
7 | from ogb.graphproppred import PygGraphPropPredDataset
8 | from torch_geometric.data import Dataset
9 | from ..substructure_dataset import SubstructureDataset
10 | import torch.distributed as dist
11 | import os
12 | import torch
13 | from torch_scatter import scatter
14 | import numpy as np
15 | from torchvision import transforms
16 |
17 |
18 |
19 | class MyPygPCQM4Mv2Dataset(PygPCQM4Mv2Dataset):
20 | def download(self):
21 | if not dist.is_initialized() or dist.get_rank() == 0:
22 | super(MyPygPCQM4Mv2Dataset, self).download()
23 | if dist.is_initialized():
24 | dist.barrier()
25 |
26 | def process(self):
27 | if not dist.is_initialized() or dist.get_rank() == 0:
28 | super(MyPygPCQM4Mv2Dataset, self).process()
29 | if dist.is_initialized():
30 | dist.barrier()
31 |
32 |
33 | class MyPygPCQM4MDataset(PygPCQM4MDataset):
34 | def download(self):
35 | if not dist.is_initialized() or dist.get_rank() == 0:
36 | super(MyPygPCQM4MDataset, self).download()
37 | if dist.is_initialized():
38 | dist.barrier()
39 |
40 | def process(self):
41 | if not dist.is_initialized() or dist.get_rank() == 0:
42 | super(MyPygPCQM4MDataset, self).process()
43 | if dist.is_initialized():
44 | dist.barrier()
45 |
46 |
47 | class MyPygGraphPropPredDataset(PygGraphPropPredDataset):
48 | def download(self):
49 | if not dist.is_initialized() or dist.get_rank() == 0:
50 | super(MyPygGraphPropPredDataset, self).download()
51 | if dist.is_initialized():
52 | dist.barrier()
53 |
54 | def process(self):
55 | if not dist.is_initialized() or dist.get_rank() == 0:
56 | super(MyPygGraphPropPredDataset, self).process()
57 | if dist.is_initialized():
58 | dist.barrier()
59 |
60 |
61 |
62 | def extract_node_feature(data, reduce='add'):
63 | if reduce in ['mean', 'max', 'add']:
64 | data.x = scatter(data.edge_attr,
65 | data.edge_index[0],
66 | dim=0,
67 | dim_size=data.num_nodes,
68 | reduce=reduce)
69 | else:
70 | raise Exception('Unknown Aggregation Type')
71 | return data
72 |
73 |
74 | def get_vocab_mapping(seq_list, num_vocab):
75 | '''
76 | Input:
77 | seq_list: a list of sequences
78 | num_vocab: vocabulary size
79 | Output:
80 | vocab2idx:
81 | A dictionary that maps vocabulary into integer index.
82 | Additioanlly, we also index '__UNK__' and '__EOS__'
83 | '__UNK__' : out-of-vocabulary term
84 | '__EOS__' : end-of-sentence
85 | idx2vocab:
86 | A list that maps idx to actual vocabulary.
87 | '''
88 |
89 | vocab_cnt = {}
90 | vocab_list = []
91 | for seq in seq_list:
92 | for w in seq:
93 | if w in vocab_cnt:
94 | vocab_cnt[w] += 1
95 | else:
96 | vocab_cnt[w] = 1
97 | vocab_list.append(w)
98 |
99 | cnt_list = np.array([vocab_cnt[w] for w in vocab_list])
100 | topvocab = np.argsort(-cnt_list, kind = 'stable')[:num_vocab]
101 |
102 | print('Coverage of top {} vocabulary:'.format(num_vocab))
103 | print(float(np.sum(cnt_list[topvocab]))/np.sum(cnt_list))
104 |
105 | vocab2idx = {vocab_list[vocab_idx]: idx for idx, vocab_idx in enumerate(topvocab)}
106 | idx2vocab = [vocab_list[vocab_idx] for vocab_idx in topvocab]
107 |
108 | # print(topvocab)
109 | # print([vocab_list[v] for v in topvocab[:10]])
110 | # print([vocab_list[v] for v in topvocab[-10:]])
111 |
112 | vocab2idx['__UNK__'] = num_vocab
113 | idx2vocab.append('__UNK__')
114 |
115 | vocab2idx['__EOS__'] = num_vocab + 1
116 | idx2vocab.append('__EOS__')
117 |
118 | # test the correspondence between vocab2idx and idx2vocab
119 | for idx, vocab in enumerate(idx2vocab):
120 | assert(idx == vocab2idx[vocab])
121 |
122 | # test that the idx of '__EOS__' is len(idx2vocab) - 1.
123 | # This fact will be used in decode_arr_to_seq, when finding __EOS__
124 | assert(vocab2idx['__EOS__'] == len(idx2vocab) - 1)
125 |
126 | return vocab2idx, idx2vocab
127 |
128 |
129 | def augment_edge(data):
130 | '''
131 | Input:
132 | data: PyG data object
133 | Output:
134 | data (edges are augmented in the following ways):
135 | data.edge_index: Added next-token edge. The inverse edges were also added.
136 | data.edge_attr (torch.Long):
137 | data.edge_attr[:,0]: whether it is AST edge (0) for next-token edge (1)
138 | data.edge_attr[:,1]: whether it is original direction (0) or inverse direction (1)
139 | '''
140 |
141 | ##### AST edge
142 | edge_index_ast = data.edge_index
143 | edge_attr_ast = torch.zeros((edge_index_ast.size(1), 2))
144 |
145 | ##### Inverse AST edge
146 | edge_index_ast_inverse = torch.stack([edge_index_ast[1], edge_index_ast[0]], dim=0)
147 | edge_attr_ast_inverse = torch.cat(
148 | [torch.zeros(edge_index_ast_inverse.size(1), 1), torch.ones(edge_index_ast_inverse.size(1), 1)], dim=1)
149 |
150 | ##### Next-token edge
151 |
152 | ## Obtain attributed nodes and get their indices in dfs order
153 | # attributed_node_idx = torch.where(data.node_is_attributed.view(-1,) == 1)[0]
154 | # attributed_node_idx_in_dfs_order = attributed_node_idx[torch.argsort(data.node_dfs_order[attributed_node_idx].view(-1,))]
155 |
156 | ## Since the nodes are already sorted in dfs ordering in our case, we can just do the following.
157 | attributed_node_idx_in_dfs_order = torch.where(data.node_is_attributed.view(-1, ) == 1)[0]
158 |
159 | ## build next token edge
160 | # Given: attributed_node_idx_in_dfs_order
161 | # [1, 3, 4, 5, 8, 9, 12]
162 | # Output:
163 | # [[1, 3, 4, 5, 8, 9]
164 | # [3, 4, 5, 8, 9, 12]
165 | edge_index_nextoken = torch.stack([attributed_node_idx_in_dfs_order[:-1], attributed_node_idx_in_dfs_order[1:]],
166 | dim=0)
167 | edge_attr_nextoken = torch.cat(
168 | [torch.ones(edge_index_nextoken.size(1), 1), torch.zeros(edge_index_nextoken.size(1), 1)], dim=1)
169 |
170 | ##### Inverse next-token edge
171 | edge_index_nextoken_inverse = torch.stack([edge_index_nextoken[1], edge_index_nextoken[0]], dim=0)
172 | edge_attr_nextoken_inverse = torch.ones((edge_index_nextoken.size(1), 2))
173 |
174 | data.edge_index = torch.cat(
175 | [edge_index_ast, edge_index_ast_inverse, edge_index_nextoken, edge_index_nextoken_inverse], dim=1)
176 | data.edge_attr = torch.cat([edge_attr_ast, edge_attr_ast_inverse, edge_attr_nextoken, edge_attr_nextoken_inverse],
177 | dim=0)
178 |
179 | return data
180 |
181 | def encode_seq_to_arr(seq, vocab2idx, max_seq_len):
182 | '''
183 | Input:
184 | seq: A list of words
185 | output: add y_arr (torch.Tensor)
186 | '''
187 |
188 | augmented_seq = seq[:max_seq_len] + ['__EOS__'] * max(0, max_seq_len - len(seq))
189 | return torch.tensor([[vocab2idx[w] if w in vocab2idx else vocab2idx['__UNK__'] for w in augmented_seq]], dtype = torch.long)
190 |
191 |
192 | def encode_code2(data, vocab2idx, max_seq_len):
193 | '''
194 | Input:
195 | data: PyG graph object
196 | output: add y_arr to data
197 | '''
198 |
199 | # PyG >= 1.5.0
200 | data.x=torch.cat([data.x[:,0].unsqueeze(1),data.node_depth,data.x[:,1].unsqueeze(1)],1)
201 | seq = data.y
202 |
203 | # PyG = 1.4.3
204 | # seq = data.y[0]
205 | data.y_ori=seq
206 | data.y = encode_seq_to_arr(seq, vocab2idx, max_seq_len)
207 |
208 | return data
209 |
210 |
211 |
212 | class OGBDatasetLookupTable:
213 | @staticmethod
214 | def GetOGBDataset(dataset_name: str, seed: int,args=None,**kwargs) -> Optional[Dataset]:
215 |
216 | inner_dataset = None
217 | train_idx = None
218 | valid_idx = None
219 | test_idx = None
220 | if dataset_name == "ogbg-molhiv":
221 | folder_name = dataset_name.replace("-", "_")
222 | os.system(f"mkdir -p {os.path.join(kwargs['data_dir'],folder_name)}")
223 | os.system(f"touch {os.path.join(kwargs['data_dir'],folder_name,'RELEASE_v1.txt')}")
224 | inner_dataset = MyPygGraphPropPredDataset(dataset_name,root=kwargs['data_dir'])
225 | idx_split = inner_dataset.get_idx_split()
226 | train_idx = idx_split["train"]
227 | valid_idx = idx_split["valid"]
228 | test_idx = idx_split["test"]
229 | elif dataset_name == "ogbg-molpcba":
230 | folder_name = dataset_name.replace("-", "_")
231 | os.system(f"mkdir -p {os.path.join(kwargs['data_dir'],folder_name)}")
232 | os.system(f"touch {os.path.join(kwargs['data_dir'],folder_name,'RELEASE_v1.txt')}")
233 | inner_dataset = MyPygGraphPropPredDataset(dataset_name,root=kwargs['data_dir'])
234 | idx_split = inner_dataset.get_idx_split()
235 | train_idx = idx_split["train"]
236 | valid_idx = idx_split["valid"]
237 | test_idx = idx_split["test"]
238 | elif dataset_name == "pcqm4mv2":
239 | os.system(f"mkdir -p {os.path.join(kwargs['data_dir'],'pcqm4m-v2')}")
240 | os.system(f"touch {os.path.join(kwargs['data_dir'],'pcqm4m-v2','RELEASE_v1.txt')}")
241 | inner_dataset = MyPygPCQM4Mv2Dataset()
242 | idx_split = inner_dataset.get_idx_split()
243 | train_idx = idx_split["train"]
244 | valid_idx = idx_split["valid"]
245 | test_idx = idx_split["test-dev"]
246 | elif dataset_name == "pcqm4m":
247 | os.system(f"mkdir -p {os.path.join(kwargs['data_dir'],'pcqm4m_kddcup2021')}")
248 | os.system(f"touch {os.path.join(kwargs['data_dir'],'pcqm4m_kddcup2021','RELEASE_v1.txt')}")
249 | inner_dataset = MyPygPCQM4MDataset(root=kwargs['data_dir'])
250 | idx_split = inner_dataset.get_idx_split()
251 | train_idx = idx_split["train"]
252 | valid_idx = idx_split["valid"]
253 | test_idx = idx_split["test"]
254 | elif dataset_name == 'pcqm4m_contrust_pretraining':
255 | os.system(f"mkdir -p {os.path.join(kwargs['data_dir'],'pcqm4m_kddcup2021')}")
256 | os.system(f"touch {os.path.join(kwargs['data_dir'],'pcqm4m_kddcup2021','RELEASE_v1.txt')}")
257 | inner_dataset = MyPygPCQM4MDataset(root=kwargs['data_dir'])
258 | idx_split = inner_dataset.get_idx_split()
259 | train_idx = idx_split["train"]
260 | valid_idx = idx_split["valid"]
261 | test_idx = idx_split["test"]
262 |
263 | elif dataset_name in ['ogbg-ppa']:
264 | from functools import partial
265 | transform = partial(extract_node_feature, reduce='add')
266 | folder_name = dataset_name.replace("-", "_")
267 | os.system(f"mkdir -p {os.path.join(kwargs['data_dir'],folder_name)}")
268 | os.system(f"touch {os.path.join(kwargs['data_dir'],folder_name,'RELEASE_v1.txt')}")
269 | inner_dataset = MyPygGraphPropPredDataset(name=dataset_name, root=kwargs['data_dir'], transform=transform)
270 | idx_split = inner_dataset.get_idx_split()
271 | train_idx = idx_split["train"]
272 | valid_idx = idx_split["valid"]
273 | test_idx = idx_split["test"]
274 |
275 | elif dataset_name in ['ogbg-code2']:
276 | inner_dataset = PygGraphPropPredDataset(name=dataset_name, root=kwargs['data_dir'])
277 | idx_split = inner_dataset.get_idx_split()
278 | train_idx = idx_split["train"]
279 | valid_idx = idx_split["valid"]
280 | test_idx = idx_split["test"]
281 | split_idx = inner_dataset.get_idx_split()
282 | #
283 | # ### building vocabulary for sequence predition. Only use training data.
284 | vocab2idx, idx2vocab = get_vocab_mapping([inner_dataset.data.y[i] for i in split_idx['train']], 5000)
285 | #
286 | # ### set the transform function
287 | # # augment_edge: add next-token edge as well as inverse edges. add edge attributes.
288 | # # encode_y_to_arr: add y_arr to PyG data object, indicating the array representation of a sequence.
289 | inner_dataset.transform = transforms.Compose([
290 | augment_edge, lambda data: encode_code2(data, vocab2idx, 5)
291 | ])
292 |
293 | else:
294 | raise ValueError(f"Unknown dataset name {dataset_name} for ogb source.")
295 |
296 |
297 | result=None if inner_dataset is None else SubstructureDataset(
298 | inner_dataset, seed, train_idx, valid_idx, test_idx,args=args
299 | )
300 |
301 | return result
302 |
--------------------------------------------------------------------------------
/deepgraph/data/subsampling/sampling.py:
--------------------------------------------------------------------------------
1 | import networkx as nx
2 | import torch
3 | import numpy as np
4 | import time
5 |
6 |
7 |
8 | def random_sampling(subgraphs_nodes, rate=0.5, minimum_redundancy=0, num_nodes=None):
9 | if num_nodes is None:
10 | num_nodes = subgraphs_nodes[1].max() + 1
11 | while True:
12 | selected = np.random.choice(num_nodes, int(num_nodes*rate), replace=False)
13 | node_selected_times = torch.bincount(subgraphs_nodes[1][check_values_in_set(subgraphs_nodes[0], selected)], minlength=num_nodes)
14 | if node_selected_times.min() >= minimum_redundancy:
15 | # rate += 0.1 # enlarge the sampling rate
16 | break
17 | return selected, node_selected_times
18 |
19 |
20 | def random_sampling_general(subgraphs_nodes, rate=0.5, minimum_redundancy=2,num_nodes=None,must_select_list=None):
21 |
22 | subgraphs_nodes_mask=transfer_edge_index_to_mask(subgraphs_nodes,num_nodes) # number_subgraph* number_nodes True when node j in subgraph i
23 |
24 |
25 |
26 | node_selected_times = torch.zeros(num_nodes)
27 | selected_all = []
28 |
29 | if must_select_list is not None:
30 | for selected in must_select_list:
31 | try:
32 | selected_all.append(selected)
33 | node_selected_times[subgraphs_nodes_mask[selected]] += 1
34 | except:
35 | print('selected ',selected)
36 | print('num_nodes ',num_nodes)
37 | print('subgraphs_nodes_mask',subgraphs_nodes_mask)
38 | raise ValueError("error ")
39 |
40 | num_subgraphs = int(subgraphs_nodes[0].max())+1
41 | init_rate = min(subgraphs_nodes.shape[0]/num_subgraphs,1)
42 | rate=min(init_rate*minimum_redundancy,1)
43 |
44 | for i in range(int(subgraphs_nodes[0].max())+1):
45 | selected_all = np.random.choice(num_subgraphs, int(num_subgraphs*rate), replace=False)
46 | node_selected_times = subgraphs_nodes_mask[selected_all].sum(0)
47 |
48 | if node_selected_times.min() >= minimum_redundancy or rate >=1:
49 | break
50 |
51 | rate=min(rate+init_rate,1)
52 |
53 |
54 |
55 |
56 | return list(selected_all), node_selected_times
57 |
58 |
59 | # Approach 1: based on shortets path distance
60 | def shortest_path_sampling(edge_index, subgraphs_nodes, stride=2, minimum_redundancy=0, random_init=False, num_nodes=None):
61 | if num_nodes is None:
62 | num_nodes = subgraphs_nodes[1].max() + 1
63 | G = nx.from_edgelist(edge_index.t().numpy())
64 | G.add_nodes_from(range(num_nodes))
65 | if random_init:
66 | farthest = np.random.choice(num_nodes) # here can also choose the one with highest degree
67 | else:
68 | subgraph_size = torch.bincount(subgraphs_nodes[0], minlength=num_nodes)
69 | farthest = subgraph_size.argmax().item()
70 |
71 | distance = np.ones(num_nodes)*1e10
72 | selected = []
73 | node_selected_times = torch.zeros(num_nodes)
74 |
75 | for i in range(num_nodes):
76 | selected.append(farthest)
77 | node_selected_times[subgraphs_nodes[1][subgraphs_nodes[0] == farthest]] += 1
78 | length_shortest_dict = nx.single_source_shortest_path_length(G, farthest)
79 | length_shortest = np.ones(num_nodes)*1e10
80 | length_shortest[list(length_shortest_dict.keys())] = list(length_shortest_dict.values())
81 | mask = length_shortest < distance
82 | distance[mask] = length_shortest[mask]
83 |
84 | if (distance.max() < stride) and (node_selected_times.min() >= minimum_redundancy): # stop criterion
85 | break
86 | farthest = np.argmax(distance)
87 | return selected, node_selected_times
88 |
89 |
90 | # Approach 1: based on shortets path distance
91 | def shortest_path_sampling_general(edge_index, subgraphs_nodes, stride=2, minimum_redundancy=0, random_init=False,
92 | num_nodes=None):
93 | if num_nodes is None:
94 | num_nodes = subgraphs_nodes[1].max() + 1
95 | G = nx.from_edgelist(edge_index.t().numpy())
96 | G.add_nodes_from(range(num_nodes))
97 | if random_init:
98 | farthest = np.random.choice(num_nodes) # here can also choose the one with highest degree
99 | else:
100 | subgraph_size = torch.bincount(subgraphs_nodes[0], minlength=num_nodes)
101 | farthest = subgraph_size.argmax().item()
102 |
103 | distance = np.ones(num_nodes) * 1e10
104 | selected = []
105 | node_selected_times = torch.zeros(num_nodes)
106 |
107 | for i in range(num_nodes):
108 | selected.append(farthest)
109 | node_selected_times[subgraphs_nodes[1][subgraphs_nodes[0] == farthest]] += 1
110 | length_shortest_dict = nx.single_source_shortest_path_length(G, farthest)
111 | length_shortest = np.ones(num_nodes) * 1e10
112 | length_shortest[list(length_shortest_dict.keys())] = list(length_shortest_dict.values())
113 | mask = length_shortest < distance
114 | distance[mask] = length_shortest[mask]
115 |
116 | if (distance.max() < stride) and (node_selected_times.min() >= minimum_redundancy): # stop criterion
117 | break
118 | farthest = np.argmax(distance)
119 | return selected, node_selected_times
120 |
121 |
122 |
123 | def check_values_in_set(x, set, approach=1):
124 | assert min(x.shape) > 0
125 | assert min(set.shape) > 0
126 | if approach == 0:
127 | mask = sum(x==i for i in set)
128 | else:
129 | mapper = torch.zeros(max(x.max()+1, set.max()+1), dtype=torch.bool)
130 | mapper[set] = True
131 | mask = mapper[x]
132 | return mask
133 |
134 | ############################################### use dense input ##################################################
135 | ### this part is hard to change
136 |
137 | def min_set_cover_sampling(edge_index, subgraphs_nodes_mask, random_init=False, minimum_redundancy=2):
138 |
139 | num_nodes = subgraphs_nodes_mask.size(0)
140 | if random_init:
141 | selected = np.random.choice(num_nodes)
142 | else:
143 | selected = subgraphs_nodes_mask.sum(-1).argmax().item() # choose the maximum subgraph size one to remove randomness
144 |
145 | node_selected_times = torch.zeros(num_nodes)
146 | selected_all = []
147 |
148 | for i in range(num_nodes):
149 | # selected_subgraphs[selected] = True
150 | selected_all.append(selected)
151 | node_selected_times[subgraphs_nodes_mask[selected]] += 1
152 | if node_selected_times.min() >= minimum_redundancy: # stop criterion
153 | break
154 | # calculate how many unused nodes in each subgraph (greedy set cover)
155 | unused_nodes = ~ ((node_selected_times - node_selected_times.min()).bool())
156 | num_unused_nodes = (subgraphs_nodes_mask & unused_nodes).sum(-1)
157 | scores = num_unused_nodes
158 | scores[selected_all] = 0
159 | selected = np.argmax(scores).item()
160 |
161 | return selected_all, node_selected_times
162 |
163 |
164 | def transfer_edge_index_to_mask(subgraphs_nodes,num_nodes):
165 | res=torch.zeros([int(subgraphs_nodes[0].max()+1),num_nodes])
166 | res[subgraphs_nodes[0,:],subgraphs_nodes[1,:]]=1
167 | out=res>0
168 | return out
169 |
170 | def min_set_cover_sampling_general(subgraphs_nodes, random_init=False, minimum_redundancy=2,num_nodes=None,must_select_list=None,only_unused_nodes=True):
171 |
172 |
173 | subgraphs_nodes_mask=transfer_edge_index_to_mask(subgraphs_nodes,num_nodes)
174 |
175 | node_selected_times = torch.zeros(num_nodes)
176 | selected_all = []
177 |
178 | if must_select_list is not None:
179 | for selected in must_select_list:
180 | try:
181 | selected_all.append(selected)
182 | node_selected_times[subgraphs_nodes_mask[selected]] += 1
183 | except:
184 | print('selected ',selected)
185 | print('num_nodes ',num_nodes)
186 | print('subgraphs_nodes_mask',subgraphs_nodes_mask)
187 | raise ValueError("error ")
188 | # num_nodes = subgraphs_nodes_mask.size(0)
189 | if random_init:
190 | selected = np.random.choice(int(subgraphs_nodes[0].max())+1)
191 | else:
192 | selected = subgraphs_nodes_mask.sum(-1).argmax().item() # choose the maximum subgraph size one to remove randomness
193 |
194 | for i in range(int(subgraphs_nodes[0].max())+1):
195 | # selected_subgraphs[selected] = True
196 | selected_all.append(selected)
197 | node_selected_times[subgraphs_nodes_mask[selected]] += 1
198 | if node_selected_times.min() >= minimum_redundancy: # stop criterion
199 | break
200 | # calculate how many unused nodes in each subgraph (greedy set cover)
201 | if only_unused_nodes:
202 | unused_nodes = ~ ((node_selected_times - node_selected_times.min()).bool())
203 | else:
204 | unused_nodes = ((node_selected_times - minimum_redundancy)<0)
205 | num_unused_nodes = (subgraphs_nodes_mask & unused_nodes).sum(-1)
206 | scores = num_unused_nodes
207 | scores[selected_all] = 0
208 | if scores.sum()==0:
209 | break
210 | selected = np.argmax(scores).item()
211 |
212 | return selected_all, node_selected_times
213 |
214 | def min_set_cover_random_sampling_general(subgraphs_nodes, random_init=False, minimum_redundancy=2,num_nodes=None,must_select_list=None,must_select_list_ratio=0.2,only_unused_nodes=True,ratio_init_drop=0.5,ratio_each_iter=0.5,balance_different_substructure=True):
215 |
216 |
217 | # The algorithm implies:
218 | # At the beginning, random drop some substructures for randomness (controled by ratio_init_drop)
219 | # Randomly sample top K substructures (controled by ratio_each_iter)
220 | # The final sampled substructure number adapted to the total substructure number. Expectation is cover number plus one
221 | # Balanced substructure types (controled by balance_different_substructure)
222 | # times = []
223 | # times.append(time.time())
224 |
225 | num_subgraphs = int(subgraphs_nodes[0].max())+1
226 | number_each_iter = max(int(num_nodes*subgraphs_nodes.shape[0] / num_subgraphs), 1)
227 |
228 | subgraphs_nodes_mask=transfer_edge_index_to_mask(subgraphs_nodes,num_nodes)
229 |
230 | init_drop=np.random.choice(num_subgraphs, int(num_subgraphs*ratio_init_drop), replace=False)
231 |
232 |
233 | if balance_different_substructure:
234 |
235 | # It requires the geometric substructures before the neighbor substructures
236 |
237 | ids=subgraphs_nodes[0,subgraphs_nodes[2]==-1]
238 | if ids.shape[0]>0:
239 | number_predefined=int(ids[0])
240 | number_neighbor = num_subgraphs-number_predefined
241 | if number_predefined>number_neighbor:
242 | predefined_drop=np.random.choice(number_predefined,(number_predefined-number_neighbor), replace=False)
243 | init_drop=np.concatenate([init_drop,predefined_drop],0)
244 |
245 |
246 |
247 | node_selected_times = torch.zeros(num_nodes)
248 | selected_all = []
249 |
250 | if must_select_list is not None:
251 | try:
252 | number_must_select=max(int(num_nodes*must_select_list_ratio),1)
253 | if len(must_select_list)>=number_must_select:
254 | must_select_list=[must_select_list[i] for i in np.random.choice(len(must_select_list), number_must_select, replace=False)]
255 | selected_all += must_select_list
256 | node_selected_times += subgraphs_nodes_mask[must_select_list].sum(0)
257 | except:
258 | print('selected ',must_select_list)
259 | print('num_nodes ',num_nodes)
260 | print('subgraphs_nodes_mask',subgraphs_nodes_mask)
261 | raise ValueError("error ")
262 | # num_nodes = subgraphs_nodes_mask.size(0)
263 | if random_init:
264 | selected = [np.random.choice(int(subgraphs_nodes[0].max())+1)]
265 | else:
266 | selected = [subgraphs_nodes_mask.sum(-1).argmax().item()] # choose the maximum subgraph size one to remove randomness
267 |
268 |
269 | for i in range(int(subgraphs_nodes[0].max())+1):
270 | # selected_subgraphs[selected] = True
271 | selected_all+=selected
272 | node_selected_times += subgraphs_nodes_mask[selected].sum(0)
273 | if node_selected_times.min() >= minimum_redundancy: # stop criterion
274 | break
275 | # calculate how many unused nodes in each subgraph (greedy set cover)
276 | if only_unused_nodes:
277 | unused_nodes = ~ ((node_selected_times - node_selected_times.min()).bool())
278 | else:
279 | unused_nodes = ((node_selected_times - minimum_redundancy)<0)
280 | num_unused_nodes = (subgraphs_nodes_mask & unused_nodes).sum(-1)
281 | scores = num_unused_nodes
282 | scores[selected_all] = 0
283 |
284 | scores[init_drop] = 0
285 |
286 | if scores.sum()==0 :
287 | break
288 | number_lefted=int((scores>0).sum())
289 | _,index=torch.topk(scores, min(int(number_each_iter/ratio_each_iter),number_lefted))
290 | selected=index[np.random.choice(len(index), min(number_each_iter,number_lefted), replace=False)].tolist()
291 |
292 |
293 | return selected_all, node_selected_times
294 |
295 |
--------------------------------------------------------------------------------
/deepgraph/data/substructure_dataset_utils/utils_graph_processing.py:
--------------------------------------------------------------------------------
1 |
2 | from torch_geometric.utils import to_undirected,remove_self_loops
3 | import networkx as nx
4 |
5 | import graph_tool as gt
6 | import graph_tool.topology as gt_topology
7 |
8 | import torch
9 | import numpy as np
10 | import time
11 |
12 | def automorphism_orbits(edge_list, print_msgs=False, only_graph=False,**kwargs):
13 |
14 | ##### vertex automorphism orbits #####
15 |
16 | directed=kwargs['directed'] if 'directed' in kwargs else False
17 |
18 | graph = gt.Graph(directed=directed)
19 | graph.add_edge_list(edge_list)
20 |
21 | # gt.stats.remove_self_loops(graph)
22 | # gt.stats.remove_parallel_edges(graph)
23 |
24 | if only_graph:
25 |
26 | return graph, None, None, None
27 |
28 |
29 | # compute the vertex automorphism group
30 | aut_group = gt_topology.subgraph_isomorphism(graph, graph, induced=False, subgraph=True, generator=False)
31 |
32 |
33 |
34 | orbit_membership = {}
35 | for v in graph.get_vertices():
36 | orbit_membership[v] = v
37 |
38 |
39 |
40 | # whenever two nodes can be mapped via some automorphism, they are assigned the same orbit
41 | for aut in aut_group:
42 | for original, vertex in enumerate(aut):
43 | role = min(original, orbit_membership[vertex])
44 | orbit_membership[vertex] = role
45 |
46 |
47 |
48 | orbit_membership_list = [[],[]]
49 | for vertex, om_curr in orbit_membership.items():
50 | orbit_membership_list[0].append(vertex)
51 | orbit_membership_list[1].append(om_curr)
52 |
53 |
54 |
55 | # make orbit list contiguous (i.e. 0,1,2,...O)
56 | _, contiguous_orbit_membership = np.unique(orbit_membership_list[1], return_inverse = True)
57 |
58 | orbit_membership = {vertex: contiguous_orbit_membership[i] for i,vertex in enumerate(orbit_membership_list[0])}
59 |
60 |
61 |
62 | orbit_partition = {}
63 | for vertex, orbit in orbit_membership.items():
64 | orbit_partition[orbit] = [vertex] if orbit not in orbit_partition else orbit_partition[orbit]+[vertex]
65 |
66 | aut_count = len(aut_group)
67 |
68 |
69 |
70 | if print_msgs:
71 | print('Orbit partition of given substructure: {}'.format(orbit_partition))
72 | print('Number of orbits: {}'.format(len(orbit_partition)))
73 | print('Automorphism count: {}'.format(aut_count))
74 |
75 |
76 |
77 | return graph, orbit_partition, orbit_membership, aut_count
78 |
79 |
80 | def induced_edge_automorphism_orbits(edge_list, **kwargs):
81 |
82 | ##### induced edge automorphism orbits (according to the vertex automorphism group) #####
83 |
84 | directed=kwargs['directed'] if 'directed' in kwargs else False
85 | directed_orbits=kwargs['directed_orbits'] if 'directed_orbits' in kwargs else False
86 |
87 | graph, orbit_partition, orbit_membership, aut_count = automorphism_orbits(edge_list=edge_list,
88 | directed=directed,
89 | print_msgs=False)
90 | edge_orbit_partition = dict()
91 | edge_orbit_membership = dict()
92 | edge_orbits2inds = dict()
93 | ind = 0
94 |
95 | if not directed:
96 | edge_list = to_undirected(torch.tensor(graph.get_edges()).transpose(1,0)).transpose(1,0).tolist()
97 |
98 | # infer edge automorphisms from the vertex automorphisms
99 | for i,edge in enumerate(edge_list):
100 | if directed_orbits:
101 | edge_orbit = (orbit_membership[edge[0]], orbit_membership[edge[1]])
102 | else:
103 | edge_orbit = frozenset([orbit_membership[edge[0]], orbit_membership[edge[1]]])
104 | if edge_orbit not in edge_orbits2inds:
105 | edge_orbits2inds[edge_orbit] = ind
106 | ind_edge_orbit = ind
107 | ind += 1
108 | else:
109 | ind_edge_orbit = edge_orbits2inds[edge_orbit]
110 |
111 | if ind_edge_orbit not in edge_orbit_partition:
112 | edge_orbit_partition[ind_edge_orbit] = [tuple(edge)]
113 | else:
114 | edge_orbit_partition[ind_edge_orbit] += [tuple(edge)]
115 |
116 | edge_orbit_membership[i] = ind_edge_orbit
117 |
118 | print('Edge orbit partition of given substructure: {}'.format(edge_orbit_partition))
119 | print('Number of edge orbits: {}'.format(len(edge_orbit_partition)))
120 | print('Graph (vertex) automorphism count: {}'.format(aut_count))
121 |
122 | return graph, edge_orbit_partition, edge_orbit_membership, aut_count
123 |
124 |
125 | def subgraph_isomorphism_vertex_counts(edge_index, **kwargs):
126 |
127 | ##### vertex structural identifiers #####
128 |
129 | subgraph_dict, induced, num_nodes = kwargs['subgraph_dict'], kwargs['induced'], kwargs['num_nodes']
130 | directed = kwargs['directed'] if 'directed' in kwargs else False
131 |
132 | G_gt = gt.Graph(directed=directed)
133 | edge_list=list(edge_index.transpose(1,0).cpu().numpy())
134 | edge_list=[[int(edge_list[i][0]),int(edge_list[i][1])] for i in range(len(edge_list))]
135 | G_gt.add_edge_list(edge_list)
136 | gt.stats.remove_self_loops(G_gt)
137 | gt.stats.remove_parallel_edges(G_gt)
138 |
139 | # compute all subgraph isomorphisms
140 | sub_iso = gt_topology.subgraph_isomorphism(subgraph_dict['subgraph'], G_gt, induced=induced, subgraph=True, generator=True)
141 |
142 | ## num_nodes should be explicitly set for the following edge case:
143 | ## when there is an isolated vertex whose index is larger
144 | ## than the maximum available index in the edge_index
145 |
146 | counts = np.zeros((num_nodes, len(subgraph_dict['orbit_partition'])))
147 | for sub_iso_curr in sub_iso:
148 | for i,node in enumerate(sub_iso_curr):
149 | # increase the count for each orbit
150 | counts[node, subgraph_dict['orbit_membership'][i]] +=1
151 | counts = counts/subgraph_dict['aut_count']
152 |
153 | counts = torch.tensor(counts)
154 |
155 | return counts
156 |
157 |
158 | def subgraph_isomorphism_vertex_extraction(edge_index,input_graph=None, **kwargs):
159 | ##### vertex structural identifiers #####
160 |
161 | subgraph_dict, induced, num_nodes = kwargs['subgraph_dict'], kwargs['induced'], kwargs['num_nodes']
162 | directed = kwargs['directed'] if 'directed' in kwargs else False
163 |
164 | if input_graph is not None:
165 | G_gt=input_graph
166 | else:
167 | G_gt = gt.Graph(directed=directed)
168 | edge_list = list(edge_index.transpose(1, 0).cpu().numpy())
169 | edge_list = [[int(edge_list[i][0]), int(edge_list[i][1])] for i in range(len(edge_list))]
170 | G_gt.add_edge_list(edge_list)
171 | gt.stats.remove_self_loops(G_gt)
172 | gt.stats.remove_parallel_edges(G_gt)
173 |
174 | # compute all subgraph isomorphisms
175 | sub_iso = gt_topology.subgraph_isomorphism(subgraph_dict['subgraph'], G_gt, induced=induced, subgraph=True,
176 | generator=True)
177 |
178 | ## num_nodes should be explicitly set for the following edge case:
179 | ## when there is an isolated vertex whose index is larger
180 | ## than the maximum available index in the edge_index
181 | subgraphs = set()
182 | for sub_iso_curr in sub_iso:
183 | tem = list(sub_iso_curr.a)
184 | tem.sort()
185 | subgraphs.add(tuple(tem))
186 | return subgraphs
187 |
188 |
189 | def subgraph_isomorphism_edge_counts(edge_index, **kwargs):
190 |
191 | ##### edge structural identifiers #####
192 |
193 | subgraph_dict, induced = kwargs['subgraph_dict'], kwargs['induced']
194 | directed = kwargs['directed'] if 'directed' in kwargs else False
195 |
196 | edge_index = edge_index.transpose(1,0).cpu().numpy()
197 | edge_dict = {}
198 | for i, edge in enumerate(edge_index):
199 | edge_dict[tuple(edge)] = i
200 |
201 | if not directed:
202 | subgraph_edges = to_undirected(torch.tensor(subgraph_dict['subgraph'].get_edges().tolist()).transpose(1,0)).transpose(1,0).tolist()
203 |
204 |
205 | G_gt = gt.Graph(directed=directed)
206 | G_gt.add_edge_list(list(edge_index))
207 | gt.stats.remove_self_loops(G_gt)
208 | gt.stats.remove_parallel_edges(G_gt)
209 |
210 | # compute all subgraph isomorphisms
211 | sub_iso = gt_topology.subgraph_isomorphism(subgraph_dict['subgraph'], G_gt, induced=induced, subgraph=True, generator=True)
212 |
213 |
214 | counts = np.zeros((edge_index.shape[0], len(subgraph_dict['orbit_partition'])))
215 |
216 | for sub_iso_curr in sub_iso:
217 | mapping = sub_iso_curr.get_array()
218 | # import pdb;pdb.set_trace()
219 | for i,edge in enumerate(subgraph_edges):
220 |
221 | # for every edge in the graph H, find the edge in the subgraph G_S to which it is mapped
222 | # (by finding where its endpoints are matched).
223 | # Then, increase the count of the matched edge w.r.t. the corresponding orbit
224 | # Repeat for the reverse edge (the one with the opposite direction)
225 |
226 | edge_orbit = subgraph_dict['orbit_membership'][i]
227 | mapped_edge = tuple([mapping[edge[0]], mapping[edge[1]]])
228 | counts[edge_dict[mapped_edge], edge_orbit] += 1
229 |
230 | counts = counts/subgraph_dict['aut_count']
231 |
232 | counts = torch.tensor(counts)
233 |
234 | return counts
235 |
236 | find_id_filename
237 |
238 | #----------------------- line graph edge automorphism: deprecated
239 |
240 |
241 |
242 | def edge_automorphism_orbits(edge_list, **kwargs):
243 |
244 | ##### edge automorphism orbits according to the line graph #####
245 |
246 | directed=kwargs['directed'] if 'directed' in kwargs else False
247 |
248 | graph_nx = nx.from_edgelist(edge_list)
249 | graph = gt.Graph(directed=directed)
250 | graph.add_edge_list(edge_list)
251 | gt.stats.remove_self_loops(graph)
252 | gt.stats.remove_parallel_edges(graph)
253 | aut_group = gt_topology.subgraph_isomorphism(graph, graph, induced=False, subgraph=True, generator=False)
254 | aut_count = len(aut_group)
255 |
256 | ##### compute line graph vertex automorphism orbits #####
257 |
258 | graph_nx_line = nx.line_graph(graph_nx)
259 | mapping = {node: i for i,node in enumerate(graph_nx_line.nodes)}
260 | inverse_mapping = {i: node for i,node in enumerate(graph_nx_line.nodes)}
261 |
262 | graph_nx_line = nx.relabel_nodes(graph_nx_line, mapping)
263 | line_graph = gt.Graph(directed=directed)
264 | line_graph.add_edge_list(list(graph_nx_line.edges))
265 |
266 | gt.stats.remove_self_loops(line_graph)
267 | gt.stats.remove_parallel_edges(line_graph)
268 |
269 | aut_group_edges = gt_topology.subgraph_isomorphism(line_graph, line_graph, induced=False, subgraph=True, generator=False)
270 |
271 | orbit_membership = {}
272 | for v in line_graph.get_vertices():
273 | orbit_membership[v] = v
274 |
275 | for aut in aut_group_edges:
276 | for original, vertex in enumerate(aut):
277 | role = min(original, orbit_membership[vertex])
278 | orbit_membership[vertex] = role
279 |
280 | orbit_membership_list = [[],[]]
281 | for vertex, om_curr in orbit_membership.items():
282 | orbit_membership_list[0].append(vertex)
283 | orbit_membership_list[1].append(om_curr)
284 |
285 | _, contiguous_orbit_membership = np.unique(orbit_membership_list[1], return_inverse = True)
286 |
287 | orbit_membership = {vertex: contiguous_orbit_membership[i] for i,vertex in enumerate(orbit_membership_list[0])}
288 |
289 | orbit_partition= {}
290 | for vertex, orbit in orbit_membership.items():
291 | orbit_partition[orbit] = [inverse_mapping[vertex]] if orbit not in orbit_partition else orbit_partition[orbit]+[inverse_mapping[vertex]]
292 |
293 | ##### transfer line graph vertex automorphism orbits to original edges #####
294 |
295 | orbit_membership_new = {}
296 | for i,edge in enumerate(graph.get_edges()):
297 | mapped_edge = mapping[tuple(edge)] if tuple(edge) in mapping else mapping[tuple([edge[1],edge[0]])]
298 | orbit_membership_new[i] = orbit_membership[mapped_edge]
299 |
300 | print('Edge orbit partition of given substructure: {}'.format(orbit_partition))
301 | print('Number of edge orbits: {}'.format(len(orbit_partition)))
302 | print('Graph (vertex) automorphism count: {}'.format(aut_count))
303 |
304 | return graph, orbit_partition, orbit_membership_new, aut_count
305 |
306 |
307 |
308 |
309 |
310 | def subgraph_counts2ids(count_fn, data, subgraph_dicts, subgraph_params):
311 | #### Remove self loops and then assign the structural identifiers by computing subgraph isomorphisms ####
312 |
313 | if hasattr(data, 'edge_features'):
314 | edge_index, edge_features = remove_self_loops(data.edge_index, data.edge_features)
315 | setattr(data, 'edge_features', edge_features)
316 | else:
317 | edge_index = remove_self_loops(data.edge_index)[0]
318 |
319 | num_nodes = data.x.shape[0]
320 | identifiers = None
321 | for subgraph_dict in subgraph_dicts:
322 | kwargs = {'subgraph_dict': subgraph_dict,
323 | 'induced': subgraph_params['induced'],
324 | 'num_nodes': num_nodes,
325 | 'directed': subgraph_params['directed']}
326 | counts = count_fn(edge_index, **kwargs)
327 | identifiers = counts if identifiers is None else torch.cat((identifiers, counts), 1)
328 | setattr(data, 'edge_index', edge_index)
329 | setattr(data, 'identifiers', identifiers.long())
330 |
331 | return data
332 |
333 |
334 | def subgraph_2token(count_fn, data, subgraph_dicts, subgraph_params):
335 | #### Remove self loops and then assign the structural identifiers by computing subgraph isomorphisms ####
336 |
337 | if hasattr(data, 'edge_features'):
338 | edge_index, edge_features = remove_self_loops(data.edge_index, data.edge_features)
339 | setattr(data, 'edge_features', edge_features)
340 | else:
341 | edge_index = remove_self_loops(data.edge_index)[0]
342 |
343 | G_gt = gt.Graph(directed=subgraph_params['directed'])
344 | edge_list = list(edge_index.transpose(1, 0).cpu().numpy())
345 | edge_list = [[int(edge_list[i][0]), int(edge_list[i][1])] for i in range(len(edge_list))]
346 | G_gt.add_edge_list(edge_list)
347 | # gt.stats.remove_self_loops(G_gt)
348 | gt.stats.remove_parallel_edges(G_gt)
349 |
350 | num_nodes = data.x.shape[0]
351 | identifiers = []
352 | for subgraph_dict in subgraph_dicts:
353 | kwargs = {'subgraph_dict': subgraph_dict,
354 | 'induced': subgraph_params['induced'],
355 | 'num_nodes': num_nodes,
356 | 'directed': subgraph_params['directed']}
357 | subgraphs = count_fn(edge_index, input_graph=G_gt, **kwargs)
358 | identifiers.append(subgraphs)
359 | # setattr(data, 'edge_index', edge_index)
360 | # setattr(data, 'identifiers', identifiers)
361 |
362 | return identifiers
--------------------------------------------------------------------------------
/deepgraph/data/substructure_dataset_utils/substructure_transform.py:
--------------------------------------------------------------------------------
1 |
2 | import os
3 |
4 | import torch
5 |
6 | from .utils import process_arguments_substructure, substructure_to_gt,transfer_subgraph_to_batchtensor_complete
7 |
8 | from tqdm import tqdm
9 |
10 | from deepgraph.data.subsampling.sampler import Subgraphs_Sampler
11 | import lmdb
12 |
13 | import numpy as np
14 | from deepgraph.data.substructure_dataset_utils.neighbor_extractors import k_hop_subgraph,random_walk_subgraph
15 |
16 | from collections import deque
17 |
18 |
19 |
20 | def graph_data_modification(dataset,substructures,args):
21 | new_dataset=[]
22 | for i in tqdm(range(len(dataset))):
23 | if dataset[i].x.dim()==1:
24 | dataset[i].x=dataset[i].x.unsqueeze(1)
25 | if dataset[i].y.dim() == 2:
26 | dataset[i].y = dataset[i].y.squeeze(1)
27 | if hasattr(dataset[i], 'edge_features'):
28 | dataset[i].edge_attr=dataset[i].edge_features
29 | del(dataset[i].edge_features)
30 | if hasattr(dataset[i],'degree'):
31 | del(dataset[i].degrees)
32 | if hasattr(dataset[i],'graph_size'):
33 | del (dataset[i].graph_size)
34 | if dataset[i].edge_attr.dim()==1:
35 | dataset[i].edge_attr=dataset[i].edge_attr.unsqueeze(1)
36 | assert hasattr(dataset[i],'edge_attr')
37 | assert hasattr(dataset[i], 'edge_index')
38 | new_data=dataset[i]
39 | setattr(new_data,'identifiers', substructures[i])
40 | new_dataset.append(new_data)
41 | return new_dataset
42 |
43 |
44 |
45 |
46 | def find_edges_of_substructure(substructures,prototype):
47 | edges_res=[]
48 | for nodes in substructures:
49 | for edge in prototype:
50 | edges_res.append([nodes[edge[0]],nodes[edge[1]]])
51 | return edges_res
52 |
53 | def transfer_subgraph_to_batchtensor(substructures,must_select_list):
54 | id_lis=[]
55 | note_lis=[]
56 | cur_id=0
57 | must_select_list_data=[]
58 | for id_sub,subtype in enumerate(substructures):
59 | for data in subtype:
60 | id_lis=id_lis+[cur_id]*len(data)
61 | if id_sub in must_select_list:
62 | must_select_list_data.append(cur_id)
63 | note_lis=note_lis+list(data)
64 |
65 | cur_id += 1
66 |
67 | res=torch.tensor([id_lis,note_lis])
68 | return res,must_select_list_data
69 |
70 |
71 |
72 |
73 |
74 |
75 | def select_result(substructures,selected_subgraphs):
76 | id_lis=[]
77 | note_lis=[]
78 | cur_id=0
79 | substructures_new=[]
80 | for subtype in substructures:
81 | subtype_new=set()
82 | for data in subtype:
83 | if cur_id in selected_subgraphs:
84 | subtype_new.add(data)
85 | cur_id+=1
86 | substructures_new.append(subtype_new)
87 | return substructures_new
88 |
89 | def select_result_as_tensor(subgraph_tensor,selected_subgraphs):
90 | selected_index=(subgraph_tensor[0].unsqueeze(1) ==torch.tensor(selected_subgraphs).unsqueeze(0)).any(1)
91 | return subgraph_tensor[:,selected_index]
92 |
93 | def re_tag(subgraph_ids):
94 | diff=((subgraph_ids[1:]-subgraph_ids[0:-1])>0).long()
95 | sum=diff.cumsum(dim=0)
96 | res=torch.cat([torch.zeros([1]) if subgraph_ids.shape[0]>0 else torch.tensor([]),sum]).long()
97 | return res
98 |
99 |
100 |
101 | def bfs(graph_adj, start_node,sort_func):
102 | visited = set() # set of visited nodes
103 | queue = deque([start_node]) # create a queue and enqueue the starting node
104 |
105 | while queue:
106 |
107 | node = queue.popleft()
108 |
109 | if node not in visited:
110 | # mark the node as visited
111 | visited.add(node)
112 |
113 | neighbors = (graph_adj[node].nonzero())[0]
114 |
115 | neighbors_not_visited=neighbors[~np.in1d(neighbors, visited)]
116 |
117 | sorted_neighbors_not_visited = sort_func(neighbors_not_visited)
118 |
119 | queue.extend(sorted_neighbors_not_visited.tolist())
120 |
121 | return visited
122 |
123 | def sort_node(node_list, degrees,only_minimum=False):
124 |
125 | if only_minimum:
126 | #compute minimum over all the nodes
127 | min_indices = np.argwhere(np.array(degrees) == np.min(degrees)).flatten()
128 | res= np.random.permutation(min_indices)
129 | else:
130 | degrees=degrees[node_list]
131 |
132 | ind_per = np.random.permutation(len(degrees))
133 | node_list=node_list[ind_per]
134 | degrees=degrees[ind_per]
135 |
136 | ind_sort=np.argsort(degrees)
137 | res= node_list[ind_sort]
138 |
139 | return res
140 |
141 |
142 | def generate_sorted_adj(edge_tensor,node_sorted,subgraph_max_size):
143 |
144 |
145 |
146 | adj_list = []
147 | if edge_tensor[0].shape[0]>0:
148 | number_subgraph = int(edge_tensor[0, :].max()) + 1
149 | max_node = int(edge_tensor[1:, :].max()) + 1
150 | # edges = torch.cat([edge_tensor, edge_tensor[[0, 2, 1], :]], 1)
151 | edge_sparse = torch.sparse_coo_tensor(edge_tensor, torch.ones_like(edge_tensor[0]), [number_subgraph, max_node, max_node]
152 | ).to_dense().long()
153 | node_sorted_id = []
154 | node_sorted_sequence = []
155 | node_sorted_cat = []
156 | for i, item in enumerate(node_sorted):
157 | node_sorted_cat += item
158 | node_sorted_sequence += list(range(len(item)))
159 | node_sorted_id += [i] * len(item)
160 |
161 | node_sorted_tensor = torch.sparse_coo_tensor([node_sorted_id, node_sorted_sequence, node_sorted_cat],
162 | torch.ones([len(node_sorted_id)]),
163 | [number_subgraph, subgraph_max_size, max_node]
164 | ).to_dense().long()
165 |
166 | adjs = node_sorted_tensor @ edge_sparse @ (node_sorted_tensor.transpose(1, 2))
167 |
168 | return adjs
169 | else:
170 | return torch.ones([0,subgraph_max_size,subgraph_max_size])
171 |
172 | class transform_sub:
173 | def __init__(self,args):
174 | args, extract_ids_fn, count_fn, automorphism_fn = process_arguments_substructure(args)
175 | self.args= args
176 | self.extract_ids_fn=extract_ids_fn
177 | self.count_fn=count_fn
178 | self.automorphism_fn=automorphism_fn
179 | self.subgraph_params = {'induced': args['induced'],
180 | 'edge_list': args['custom_edge_list'],
181 | 'directed': args['directed'],
182 | 'directed_orbits': args['directed_orbits'],
183 | }
184 | self.must_select_list=args.must_select_list
185 | self.substructure_cache={}
186 | if args.subsampling:
187 | self.sampler = Subgraphs_Sampler(
188 | sampling_mode=args.sampling_mode,
189 | minimum_redundancy=args.sampling_redundancy,
190 | shortest_path_mode_stride=args.sampling_stride,
191 | random_mode_sampling_rate=args.sampling_random_rate,
192 | random_init=True,
193 | only_unused_nodes=(not self.args.not_only_unused_nodes))
194 | def __call__(self, data,substructure_tensor_provided):
195 |
196 | if substructure_tensor_provided is None:
197 | substructures = self.compute_substructures(data)
198 | substructure_tensor = transfer_subgraph_to_batchtensor_complete(substructures)
199 | subgraph_tensor = torch.Tensor([substructure_tensor[1], substructure_tensor[0],substructure_tensor[2]]).long()
200 | else:
201 | subgraph_tensor=substructure_tensor_provided
202 |
203 | if self.args.subsampling:
204 |
205 | must_select_list_data=[]
206 | for i in self.must_select_list:
207 | must_select_list_data+=subgraph_tensor[0][subgraph_tensor[2]==i].unique().long().tolist()
208 |
209 | selected_subgraphs, node_selected_times = self.sampler(data.edge_index, subgraph_tensor, data.num_nodes,must_select_list_data)
210 |
211 | subgraph_tensor = select_result_as_tensor(subgraph_tensor, selected_subgraphs)
212 | subgraph_tensor=torch.cat([re_tag(subgraph_tensor[0,:]).unsqueeze(0),subgraph_tensor[1:,:]])
213 |
214 | edge_tensor=self.compute_edge_tensor_of_substructures(data,subgraph_tensor)
215 |
216 | node_sorted=self.sort_subgraphs_randomly(data,subgraph_tensor)
217 |
218 | sorted_adj=generate_sorted_adj(edge_tensor,node_sorted,self.args['subgraph_max_size'])
219 |
220 |
221 |
222 | return subgraph_tensor,sorted_adj
223 |
224 |
225 | def compute_edge_tensor_of_substructures(self,data,subgraph_tensor):
226 |
227 | if subgraph_tensor[0].shape[0] > 0 and data.edge_index[0].shape[0]>0:
228 | try:
229 | num_subgraph=int(subgraph_tensor[0].max()+1)
230 | max_node_number=max(int(data.edge_index.max()+1),int(subgraph_tensor[1].max()+1))
231 | sub_tensor=torch.sparse_coo_tensor(subgraph_tensor[[0,1],:],
232 | torch.ones_like(subgraph_tensor[0]),
233 | [num_subgraph,max_node_number]).to_dense()
234 | edg_tensor=torch.sparse_coo_tensor(data.edge_index,torch.ones_like(data.edge_index[0]),[max_node_number,max_node_number]).to_dense()
235 |
236 | res=(sub_tensor.unsqueeze(1))*(edg_tensor.unsqueeze(0))*(sub_tensor.unsqueeze(2))
237 | except:
238 | print(subgraph_tensor)
239 | print(data.edge_index)
240 | print(num_subgraph)
241 | print(max_node_number)
242 | print(sub_tensor)
243 | print(edg_tensor)
244 | return res.nonzero().T
245 |
246 | else:
247 | return torch.zeros([3,0])
248 |
249 |
250 |
251 | def sort_subgraphs_randomly(self,data,subgraph_tensor):
252 |
253 |
254 | graph_adj = np.zeros((data.x.shape[0],data.x.shape[0]),dtype=np.int64)
255 | graph_adj[data.edge_index[0,:],data.edge_index[1,:]]=1
256 |
257 | if subgraph_tensor[0].shape[0]>0:
258 | nodes_record=[]
259 | for i in range(int(subgraph_tensor[0].max())+1):
260 |
261 | subgraph_node=subgraph_tensor[1][subgraph_tensor[0] == i]
262 |
263 | subgraph_adj=graph_adj[subgraph_node][:,subgraph_node]
264 |
265 | degrees = subgraph_adj.sum(0)
266 |
267 | if subgraph_adj.sum()>0:
268 |
269 | root=sort_node(None,degrees,only_minimum=True)[0]
270 |
271 | nodes=[n for n in bfs(subgraph_adj, start_node=root,sort_func=lambda node_array: sort_node(node_array,degrees))]
272 |
273 | nodes_record.append(subgraph_node[nodes])
274 |
275 | return nodes_record
276 | else:
277 | return [[]]
278 |
279 | def compute_substructures(self,data,subgraph_dicts=None):
280 | if subgraph_dicts is None:
281 | subgraph_dicts = substructure_to_gt(self.subgraph_params, self.automorphism_fn, )
282 | substructures = self.extract_ids_fn(self.count_fn, data, subgraph_dicts, self.subgraph_params)
283 | return substructures
284 |
285 |
286 |
287 | def compute_neighbor(self,data):
288 | neighbor=None
289 | if 'k_neighborhood' in self.args['neighbor_type']:
290 | neighbor=k_hop_subgraph(data.edge_index,data.x.shape[0],self.args['neighbor_size'][self.args['neighbor_type'].index('k_neighborhood')])
291 | elif 'random_walk' in self.args['neighbor_type']:
292 | neighbor = random_walk_subgraph(data.edge_index, data.x.shape[0],
293 | self.args['neighbor_size'][self.args['neighbor_type'].index('random_walk')])
294 | neighbor_nodes = neighbor[0].nonzero().T
295 | return neighbor_nodes.tolist(),int(neighbor[0].T.sum(0).max())
296 |
297 | def substructure_as_tensor(self,substructure_tensor=None,neighbor_tensor=None):
298 | result=None
299 | if substructure_tensor is not None:
300 | result = torch.Tensor([substructure_tensor[1], substructure_tensor[0],substructure_tensor[2]]).long()
301 | if neighbor_tensor is not None:
302 | if substructure_tensor is not None and result.shape[1]>0:
303 | ids=(torch.Tensor([neighbor_tensor[0]])+result[0].max()+1).long()
304 | neighbor_tensor=torch.cat([ids,torch.Tensor([neighbor_tensor[1]]).long(),-torch.ones_like(ids)],0)
305 | result=torch.cat([result,neighbor_tensor],1)
306 | else:
307 | ids=(torch.Tensor([neighbor_tensor[0]])).long()
308 | type = (-1* torch.ones_like(ids)).long()
309 | result=torch.cat([ids,torch.Tensor([neighbor_tensor[1]]).long(),type],0)
310 | if result is None:
311 | result=torch.ones([3,0])
312 | return result
313 |
314 | def pre_compute_substructures_direct_to_lmdb(self,dataset):
315 |
316 | subgraph_dicts=substructure_to_gt(self.subgraph_params,self.automorphism_fn,)
317 |
318 | print('********************************transfer to lmdb type**************************')
319 |
320 | path_lmdb_tensor = os.path.join(self.args['data_dir'],self.args['lmdb_root_dir'],self.args['pre_defined_path'] + self.args['lmdb_dir'], 'inner')
321 | print('saving to ', path_lmdb_tensor)
322 | if not os.path.exists(path_lmdb_tensor):
323 | os.makedirs(path_lmdb_tensor)
324 | lmdb_env_tensor = lmdb.open(path_lmdb_tensor, map_size=1e12)
325 | txn_tensor = lmdb_env_tensor.begin(write=True)
326 |
327 | for idx, data in tqdm(enumerate(dataset)):
328 |
329 | substructure_computed=self.compute_substructures(data,subgraph_dicts)
330 |
331 | substructure_tensor=self.transfer_subgraph_to_batchtensor_complete(substructure_computed)
332 |
333 | txn_tensor.put(key=str(idx).encode(), value=str(substructure_tensor).encode())
334 |
335 | print('********************************commit substructures tensors**************************')
336 | txn_tensor.commit()
337 | lmdb_env_tensor.close()
338 |
339 | def pre_compute_neighbors_direct_to_lmdb(self,dataset):
340 | print('***********************************computing neighbors to lmdb type**************************')
341 | path_lmdb = os.path.join(self.args['data_dir'],self.args['lmdb_root_dir'],self.args['node_neighbor_path']+ self.args['lmdb_dir'], 'neighbor')
342 | print('saving to ', path_lmdb)
343 | if not os.path.exists(path_lmdb):
344 | os.makedirs(path_lmdb)
345 | lmdb_env = lmdb.open(path_lmdb, map_size=1e12)
346 | txn = lmdb_env.begin(write=True)
347 | max_size=0
348 | for idx, data in tqdm(enumerate(dataset)):
349 | substructure_computed,size=self.compute_neighbor(data)
350 | max_size=max(max_size,size)
351 | txn.put(key=str(idx).encode(), value=str(substructure_computed).encode())
352 | txn.put(key='max_size'.encode(),value=str(max_size).encode())
353 | print('********************************commit neighbors**************************')
354 | txn.commit()
355 | lmdb_env.close()
356 |
357 |
358 | def transfer_subgraph_to_batchtensor_complete(self,substructures):
359 | return transfer_subgraph_to_batchtensor_complete(substructures)
360 |
361 |
362 |
363 | def clean_cache(self):
364 | self.substructure_cache = {}
365 |
--------------------------------------------------------------------------------
/deepgraph/data/substructure_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 |
4 | from torch_geometric.data import Dataset
5 | from sklearn.model_selection import train_test_split
6 |
7 | import torch
8 | import numpy as np
9 |
10 | from .wrapper import preprocess_item,graph_data_modification_single,encode_token_single_tensor,encode_token_single_tensor_with_adj
11 |
12 |
13 | from deepgraph.data.substructure_dataset_utils.substructure_transform import transform_sub
14 |
15 | import copy
16 | import lmdb
17 | import os
18 |
19 | import torch.distributed as dist
20 |
21 |
22 | class concat_dataset:
23 | def __init__(self,dataset_list):
24 | self.dataset_list=dataset_list
25 | def __getitem__(self,idx):
26 | if idx 0:
153 | if args.recache_lmdb or not os.path.exists(os.path.join(args['data_dir'],
154 | args['lmdb_root_dir'],
155 | args['pre_defined_path']+ args['lmdb_dir'],
156 | 'inner')):
157 | print('************************* Warning recache substructures**************************')
158 | if not dist.is_initialized() or dist.get_rank() == 0:
159 | self.transform.pre_compute_substructures_direct_to_lmdb(self.dataset)
160 |
161 | if dist.is_initialized():
162 | dist.barrier()
163 | if len(args['neighbor_type']) > 0:
164 | if args.recache_lmdb or not os.path.exists(os.path.join(args['data_dir'],args['lmdb_root_dir'],args['node_neighbor_path']+ args['lmdb_dir'],'neighbor')):
165 | print('************************* Warning recache neighbors**************************')
166 | if not dist.is_initialized() or dist.get_rank() == 0:
167 | self.transform.pre_compute_neighbors_direct_to_lmdb(self.dataset)
168 |
169 | if dist.is_initialized():
170 | dist.barrier()
171 |
172 |
173 |
174 | # if self.lmdb_tensor_env is None:
175 |
176 | if len(args['custom_edge_list'])>0:
177 | self.lmdb_tensor_env=lmdb.open(os.path.join(args['data_dir'],args['lmdb_root_dir'],args['pre_defined_path']+ args['lmdb_dir'], 'inner'),map_size=1e12, readonly=True,
178 | lock=False, readahead=False, meminit=False)
179 | else:
180 | self.lmdb_tensor_env=None
181 | if len(args['neighbor_type'])>0:
182 | self.lmdb_neighbor_env=lmdb.open(os.path.join(args['data_dir'],args['lmdb_root_dir'],args['node_neighbor_path']+ args['lmdb_dir'],'neighbor'),map_size=1e12, readonly=True,
183 | lock=False, readahead=False, meminit=False)
184 | txn = self.lmdb_neighbor_env.begin(write=False)
185 | max_size=eval(txn.get(('max_size').encode()))
186 | self.transform.args['subgraph_max_size']=max(self.transform.args['subgraph_max_size'],max_size)
187 | else:
188 | self.lmdb_neighbor_env=None
189 |
190 |
191 |
192 |
193 | self.train_data.lmdb_tensor_env=self.lmdb_tensor_env
194 | self.train_data.lmdb_neighbor_env = self.lmdb_neighbor_env
195 |
196 | self.valid_data.lmdb_tensor_env = self.lmdb_tensor_env
197 | self.valid_data.lmdb_neighbor_env = self.lmdb_neighbor_env
198 |
199 | self.test_data.lmdb_tensor_env = self.lmdb_tensor_env
200 | self.test_data.lmdb_neighbor_env = self.lmdb_neighbor_env
201 |
202 | if train_idx is not None:
203 | self.recomputeind=recompute_idx_byindex(train_idx,valid_idx,test_idx,'inner')
204 | self.train_data.recomputeind=recompute_idx_byindex(train_idx,valid_idx,test_idx,'train')
205 | self.test_data.recomputeind = recompute_idx_byindex(train_idx,valid_idx,test_idx,'test')
206 | if args['valid_on_test']:
207 | self.valid_data.recomputeind = recompute_idx_byindex(train_idx, valid_idx, test_idx, 'test')
208 | else:
209 | self.valid_data.recomputeind = recompute_idx_byindex(train_idx, valid_idx, test_idx, 'valid')
210 | else:
211 | self.recomputeind=recompute_idx(len(self.train_data),len(self.valid_data),'train')
212 | self.train_data.recomputeind=recompute_idx(len(self.train_data),len(self.valid_data),'train')
213 | self.test_data.recomputeind = recompute_idx(len(self.train_data), len(self.valid_data), 'test')
214 | if args['valid_on_test']:
215 | self.valid_data.recomputeind = recompute_idx(len(self.train_data), len(self.valid_data), 'test')
216 | else:
217 | self.valid_data.recomputeind = recompute_idx(len(self.train_data), len(self.valid_data), 'valid')
218 |
219 |
220 |
221 | if 'use_transform_cache' in args and args.use_transform_cache:
222 | print('************************* Using transform cache **************************')
223 |
224 | transform_cache_path=os.path.join(args['data_dir'], args['lmdb_root_dir'], args['transform_cache_path'] + args['transform_dir'])
225 | if not os.path.exists(transform_cache_path):
226 | print(transform_cache_path+' do not exits; creating**************')
227 | os.makedirs(transform_cache_path)
228 | try:
229 | self.transform_cache_env=lmdb.open(transform_cache_path,map_size=1e12, readonly=True,
230 | lock=False, readahead=False, meminit=False)
231 | self.train_data.transform_cache_env = self.transform_cache_env
232 | self.valid_data.transform_cache_env = self.transform_cache_env
233 | self.test_data.transform_cache_env = self.transform_cache_env
234 | except:
235 | assert self.only_substructure # The only situation where use_transform_cache but no cache directory is generating cache
236 |
237 |
238 | def index_select(self, idx):
239 | dataset = copy.copy(self)
240 | if not isinstance(self.dataset,list):
241 | dataset.dataset = self.dataset.index_select(idx)
242 | else:
243 | dataset.dataset = [self.dataset[i] for i in idx.tolist()]
244 | if isinstance(idx, torch.Tensor):
245 | dataset.num_data = idx.size(0)
246 | else:
247 | dataset.num_data = idx.shape[0]
248 | dataset.__indices__ = idx
249 | dataset.train_data = None
250 | dataset.valid_data = None
251 | dataset.test_data = None
252 |
253 | return dataset
254 |
255 | def create_subset(self, subset):
256 | dataset = copy.copy(self)
257 | dataset.dataset = subset
258 | dataset.num_data = len(subset)
259 | dataset.__indices__ = None
260 | dataset.train_data = None
261 | dataset.valid_data = None
262 | dataset.test_data = None
263 |
264 | return dataset
265 |
266 |
267 |
268 | # @lru_cache(maxsize=16)
269 | def __getitem__(self, idx,parallel=False):
270 |
271 | if isinstance(idx, int):
272 |
273 | item = self.dataset[idx]
274 |
275 | if self.transform is not None:
276 |
277 | if (not self.use_transform_cache) or self.only_substructure:
278 |
279 | if (self.lmdb_tensor_env is not None):
280 | txn = self.lmdb_tensor_env.begin(write=False)
281 | tem=txn.get(str(self.recomputeind.compute(idx)).encode())
282 | if tem is not None:
283 | substructure_tensor=eval(tem)
284 | else:
285 | raise ValueError("substructure_tensor not prepared")
286 | else:
287 | substructure_tensor=None
288 | if (self.lmdb_neighbor_env is not None):
289 | txn = self.lmdb_neighbor_env.begin(write=False)
290 | tem=txn.get(str(self.recomputeind.compute(idx)).encode())
291 | if tem is not None:
292 | neighbor_tensor=eval(tem)
293 | else:
294 | raise ValueError("substructure_tensor not prepared")
295 | else:
296 | neighbor_tensor=None
297 |
298 | substructure_tensor=self.transform.substructure_as_tensor(substructure_tensor,neighbor_tensor)
299 |
300 | if self.only_substructure:
301 | #caching #transform_cache_number sampled substructures
302 |
303 | subgraph_tensor_res=[]
304 | sorted_adj_res=[]
305 |
306 | for i in range(self.transform_cache_number):
307 |
308 |
309 | subgraph_tensor, sorted_adj = self.transform(item, substructure_tensor)
310 |
311 |
312 | subgraph_tensor=subgraph_tensor.tolist()
313 | sorted_adj=sorted_adj.tolist()
314 | subgraph_tensor_res.append(subgraph_tensor)
315 | sorted_adj_res.append(sorted_adj)
316 |
317 | item.idx=idx
318 | item.subgraph_tensor_res = subgraph_tensor_res
319 | item.sorted_adj_res=sorted_adj_res
320 | item.y = item.y.reshape(-1)
321 | return item
322 |
323 | else:
324 | subgraph_tensor,sorted_adj = self.transform(item,substructure_tensor)
325 |
326 | else:
327 | txn = self.transform_cache_env.begin(write=False)
328 | id=np.random.randint(self.transform_cache_number)
329 | tem = txn.get((str(self.recomputeind.compute(idx))+'_'+str(id)).encode())
330 | if tem is not None:
331 | subgraph_tensor, sorted_adj = eval(tem)
332 | else:
333 | raise ValueError("substructure_tensor not prepared")
334 | subgraph_tensor = torch.tensor(subgraph_tensor)
335 | sorted_adj = torch.tensor(sorted_adj)
336 |
337 | item = graph_data_modification_single(item, subgraph_tensor, sorted_adj)
338 |
339 | if self.extra_method == 'token':
340 | item = encode_token_single_tensor(item, self.args.atom_id_max, self.args.edge_id_max, )
341 | elif self.extra_method == 'adj':
342 | item = encode_token_single_tensor_with_adj(item, local_attention_on_substructures=self.local_attention_on_substructures)
343 |
344 | if (not self.not_re_define) or (not hasattr(item, 'attn_bias')):
345 | item.idx = idx
346 | item.y = item.y.reshape(-1)
347 |
348 | item= preprocess_item(item,local_attention_on_substructures=self.local_attention_on_substructures,continuous_feature=self.continuous_feature)
349 | return item
350 |
351 | else:
352 | raise TypeError("index to a GraphormerPYGDataset can only be an integer.")
353 |
354 | def __len__(self):
355 | return self.num_data
356 |
357 |
--------------------------------------------------------------------------------