├── .idea
├── .gitignore
├── Graph-Trans.iml
├── dictionaries
├── inspectionProfiles
│ ├── Project_Default.xml
│ └── profiles_settings.xml
├── misc.xml
├── modules.xml
└── vcs.xml
├── README.md
├── data
├── __init__.py
├── algos.pyx
├── collator.py
├── dataset.py
├── ogb_datasets
│ ├── __init__.py
│ └── ogb_dataset_lookup_table.py
├── pyg_datasets
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-36.pyc
│ │ ├── pyg_dataset.cpython-36.pyc
│ │ └── pyg_dataset_lookup_table.cpython-36.pyc
│ ├── pyg_dataset.py
│ └── pyg_dataset_lookup_table.py
├── smiles
│ └── smiles_dataset.py
└── wrapper.py
├── graphtrasformer
├── __pycache__
│ ├── architectures.cpython-36.pyc
│ ├── gt_layers.cpython-36.pyc
│ ├── gt_models.cpython-36.pyc
│ └── layers.cpython-36.pyc
├── architectures.py
├── gnn_layers.py
├── gt_layers.py
├── gt_models.py
├── layer_tests.py
└── layers.py
├── gt_dataset.py
├── run.py
└── utils
└── utils.py
/.idea/.gitignore:
--------------------------------------------------------------------------------
1 | # Default ignored files
2 | /shelf/
3 | /workspace.xml
4 | # Editor-based HTTP Client requests
5 | /httpRequests/
6 | # Datasource local storage ignored files
7 | /dataSources/
8 | /dataSources.local.xml
9 |
--------------------------------------------------------------------------------
/.idea/Graph-Trans.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/dictionaries:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/Project_Default.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | # Graph-Transformer Framework
4 |
5 | Source code for the paper "**[Transformer for Graphs: An Overview from Architecture Perspective](https://arxiv.org/pdf/2202.08455.pdf)**"
6 |
7 |
8 | We provide a comprehensive review of various Graph Transformer models from the architectural design perspective.
9 | We first disassemble the existing models and conclude three typical ways to incorporate the graph
10 | information into the vanilla Transformer:
11 | - GNNs as Auxiliary Modules,
12 | - Improved Positional Embedding from Graphs
13 | - Improved Attention Matrix from Graphs.
14 |
15 | We implement the representative components in three groups and conduct a comprehensive comparison on various kinds of famous graph data benchmarks to investigate the real performance gain of each component.
16 |
17 |
18 |
19 |
20 |
21 | ## Running
22 |
23 | - Train and Evaluation. Please see details in our code annotations.
24 | ```
25 | $python run.py --seed ${CUSTOMIZED_SEED} \
26 | --model_scale ${CUSTOMIZED_SCALE} \
27 | --data_name ${CUSTOMIZED_DATASET} \
28 | --use_super_node ${True/False} \
29 | --node_level_modules ${CUSTOMIZED_NODE_MODULES} \
30 | --attn_level_modules ${CUSTOMIZED_ATTENTION_MODULES} \
31 | --attn_mask_modules ${CUSTOMIZED_MASK_MODULES} \
32 | --use_gnn_layers ${True/False} \
33 | --gnn_insert_pos ${CUSTOMIZED_GNN_POSTION} \
34 | --gnn_type ${CUSTOMIZED_GNN} \
35 | --sampling_algo ${CUSTOMIZED_SAMPLING_ALGORITHMS}
36 | ```
37 | - Example 1: Transformer with degree postional embedding, spatial encoding, shortest path edge encoding
38 |
39 | ```
40 | $python run.py --seed 1024 \
41 | --model_scale small \
42 | --data_name ZINC \
43 | --use_super_node True \
44 | --node_level_modules degree \
45 | --attn_level_modules spatial,spe \
46 | ```
47 | - Example 2: Transformer with 1hop attention mask
48 | ```
49 | $python run.py --seed 1024 \
50 | --model_scale middle \
51 | --data_name flickr \
52 | --use_super_node True \
53 | --node_level_modules eig,svd \
54 | --attn_mask_modules 1hop \
55 | --sampling_algo shadowkhop \
56 | --depth 2 \
57 | --num_neighbors 10
58 | ```
59 | - Example 3: Transformer with GIN layers before Transformer layers
60 | ```
61 | $python run.py --seed 1024 \
62 | --model_scale large \
63 | --data_name ZINC \
64 | --use_super_node True \
65 | --use_gnn_layers True \
66 | --gnn_insert_pos before \
67 | --gnn_type GIN
68 | ```
69 |
70 |
71 |
72 |
73 | ## Requirements
74 | - Python 3.x
75 | - pytorch >=1.5.0
76 | - torch-geometric >=2.0.3
77 | - transformers >= 4.8.2
78 | - tensorflow >= 2.3.1
79 | - scikit-learn >= 0.23.2
80 | - ogb >= 1.3.2
81 | - datasets >=1.8.0
82 |
83 | ## Results
84 | Please refer to our [paper](https://arxiv.org/pdf/2202.08455.pdf)
85 |
86 | ## Reference
87 | Please cite the paper whenever our graph transformer is used to produce published results or incorporated into other software:
88 | ```
89 | @article{min2022transformer,
90 | title={Transformer for Graphs: An Overview from Architecture Perspective},
91 | author={Min, Erxue and Chen, Runfa and Bian, Yatao and Xu, Tingyang and Zhao, Kangfei and Huang, Wenbing and Zhao, Peilin and Huang, Junzhou and Ananiadou, Sophia and Rong, Yu},
92 | journal={arXiv preprint arXiv:2202.08455},
93 | year={2022}
94 | }
95 | ```
96 |
97 |
98 |
--------------------------------------------------------------------------------
/data/__init__.py:
--------------------------------------------------------------------------------
1 | DATASET_REGISTRY = {}
2 |
3 | def register_dataset(name: str):
4 | def register_dataset_func(func):
5 | DATASET_REGISTRY[name] = func()
6 | return register_dataset_func
7 |
--------------------------------------------------------------------------------
/data/algos.pyx:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 |
4 | import cython
5 | from cython.parallel cimport prange, parallel
6 | cimport numpy
7 | import numpy
8 |
9 | def floyd_warshall(adjacency_matrix):
10 |
11 | (nrows, ncols) = adjacency_matrix.shape
12 | assert nrows == ncols
13 | cdef unsigned int n = nrows
14 |
15 | adj_mat_copy = adjacency_matrix.astype(long, order='C', casting='safe', copy=True)
16 | assert adj_mat_copy.flags['C_CONTIGUOUS']
17 | cdef numpy.ndarray[long, ndim=2, mode='c'] M = adj_mat_copy
18 | cdef numpy.ndarray[long, ndim=2, mode='c'] path = numpy.zeros([n, n], dtype=numpy.int64)
19 |
20 | cdef unsigned int i, j, k
21 | cdef long M_ij, M_ik, cost_ikkj
22 | cdef long* M_ptr = &M[0,0]
23 | cdef long* M_i_ptr
24 | cdef long* M_k_ptr
25 |
26 | # set unreachable nodes distance to 510
27 | for i in range(n):
28 | for j in range(n):
29 | if i == j:
30 | M[i][j] = 0
31 | elif M[i][j] == 0:
32 | M[i][j] = 510
33 |
34 | # floyed algo
35 | for k in range(n):
36 | M_k_ptr = M_ptr + n*k
37 | for i in range(n):
38 | M_i_ptr = M_ptr + n*i
39 | M_ik = M_i_ptr[k]
40 | for j in range(n):
41 | cost_ikkj = M_ik + M_k_ptr[j]
42 | M_ij = M_i_ptr[j]
43 | if M_ij > cost_ikkj:
44 | M_i_ptr[j] = cost_ikkj
45 | path[i][j] = k
46 |
47 | # set unreachable path to 510
48 | for i in range(n):
49 | for j in range(n):
50 | if M[i][j] >= 510:
51 | path[i][j] = 510
52 | M[i][j] = 510
53 |
54 | return M, path
55 |
56 |
57 | def get_all_edges(path, i, j):
58 | cdef unsigned int k = path[i][j]
59 | if k == 0:
60 | return []
61 | else:
62 | return get_all_edges(path, i, k) + [k] + get_all_edges(path, k, j)
63 |
64 |
65 | def gen_edge_input(max_dist, path, edge_feat):
66 |
67 | (nrows, ncols) = path.shape
68 | assert nrows == ncols
69 | cdef unsigned int n = nrows
70 | cdef unsigned int max_dist_copy = max_dist
71 |
72 | path_copy = path.astype(long, order='C', casting='safe', copy=True)
73 | edge_feat_copy = edge_feat.astype(long, order='C', casting='safe', copy=True)
74 | assert path_copy.flags['C_CONTIGUOUS']
75 | assert edge_feat_copy.flags['C_CONTIGUOUS']
76 |
77 | cdef numpy.ndarray[long, ndim=4, mode='c'] edge_fea_all = -1 * numpy.ones([n, n, max_dist_copy, edge_feat.shape[-1]], dtype=numpy.int64)
78 | cdef unsigned int i, j, k, num_path, cur
79 |
80 | for i in range(n):
81 | for j in range(n):
82 | if i == j:
83 | continue
84 | if path_copy[i][j] == 510:
85 | continue
86 | path = [i] + get_all_edges(path_copy, i, j) + [j]
87 | num_path = len(path) - 1
88 | for k in range(num_path):
89 | edge_fea_all[i, j, k, :] = edge_feat_copy[path[k], path[k+1], :]
90 |
91 | return edge_fea_all
92 |
--------------------------------------------------------------------------------
/data/collator.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 |
4 | import torch
5 | from torch_geometric.data import Batch,Data
6 |
7 | def pad_1d_unsqueeze(x, padlen):
8 | x = x + 1 # pad id = 0
9 | xlen = x.size(0)
10 | if xlen < padlen:
11 | new_x = x.new_zeros([padlen], dtype=x.dtype)
12 | new_x[:xlen] = x
13 | x = new_x
14 | return x.unsqueeze(0)
15 |
16 |
17 | def pad_2d_unsqueeze(x, padlen):
18 | x = x + 1 # pad id = 0
19 | xlen, xdim = x.size()
20 | if xlen < padlen:
21 | new_x = x.new_zeros([padlen, xdim], dtype=x.dtype)
22 | new_x[:xlen, :] = x
23 | x = new_x
24 | return x.unsqueeze(0)
25 |
26 |
27 | def pad_attn_bias_unsqueeze(x, padlen):
28 | xlen = x.size(0)
29 | if xlen < padlen:
30 | new_x = x.new_zeros([padlen, padlen], dtype=x.dtype).fill_(float("-inf"))
31 | new_x[:xlen, :xlen] = x
32 | new_x[xlen:, :xlen] = 0
33 | x = new_x
34 | return x.unsqueeze(0)
35 |
36 |
37 | def pad_edge_type_unsqueeze(x, padlen):
38 | xlen = x.size(0)
39 | if xlen < padlen:
40 | new_x = x.new_zeros([padlen, padlen, x.size(-1)], dtype=x.dtype)
41 | new_x[:xlen, :xlen, :] = x
42 | x = new_x
43 | return x.unsqueeze(0)
44 |
45 | def pad_pos_emb_unsqueeze(x, padlen):
46 | xlen, xdim = x.size()
47 | if xlen < padlen:
48 | new_x = x.new_zeros([padlen, xdim], dtype=x.dtype)
49 | new_x[:xlen, :] = x
50 | x = new_x
51 | return x.unsqueeze(0)
52 |
53 |
54 | def pad_spatial_pos_unsqueeze(x, padlen):
55 | x = x + 1
56 | xlen = x.size(0)
57 | if xlen < padlen:
58 | new_x = x.new_zeros([padlen, padlen], dtype=x.dtype)
59 | new_x[:xlen, :xlen] = x
60 | x = new_x
61 | return x.unsqueeze(0)
62 |
63 | def pad_adj_unsqueeze(x, padlen):
64 | xlen = x.size(0)
65 | if xlen < padlen:
66 | new_x = x.new_zeros([padlen, padlen], dtype=x.dtype)
67 | new_x[:xlen, :xlen] = x
68 | x = new_x
69 | return x.unsqueeze(0)
70 |
71 |
72 | def pad_3d_unsqueeze(x, padlen1, padlen2, padlen3):
73 | x = x + 1
74 | xlen1, xlen2, xlen3, xlen4 = x.size()
75 | if xlen1 < padlen1 or xlen2 < padlen2 or xlen3 < padlen3:
76 | new_x = x.new_zeros([padlen1, padlen2, padlen3, xlen4], dtype=x.dtype)
77 | new_x[:xlen1, :xlen2, :xlen3, :] = x
78 | x = new_x
79 | return x.unsqueeze(0)
80 |
81 |
82 | def collator(items, args):
83 |
84 | max_node = args.max_node
85 | multi_hop_max_dist = args.multi_hop_max_dist
86 | spatial_pos_max = args.spatial_pos_max
87 |
88 | items = [item for item in items if item is not None and item.x.size(0) <= max_node]
89 | items = [
90 | (
91 | item.idx,
92 | item.attn_bias,
93 | item.attn_edge_type,
94 | item.spatial_pos,
95 | item.in_degree,
96 | item.out_degree,
97 | item.x,
98 | item.edge_input,
99 | item.y,
100 | item.adj,
101 | item.adj_norm,
102 | item.edge_index,
103 | item.eig_pos_emb,
104 | item.svd_pos_emb,
105 | item.root_n_id
106 | )
107 | for item in items
108 | ]
109 | (
110 | idxs,
111 | attn_biases,
112 | attn_edge_types,
113 | spatial_poses,
114 | in_degrees,
115 | out_degrees,
116 | xs,
117 | edge_inputs,
118 | ys,
119 | adjs,
120 | adj_norms,
121 | edge_indexs,
122 | eig_pos_embs,
123 | svd_pos_embs,
124 | root_n_ids
125 | ) = zip(*items)
126 |
127 | for i, _ in enumerate(attn_biases):
128 | attn_biases[i][int(args.use_super_node):, int(args.use_super_node):][spatial_poses[i] >= spatial_pos_max] = float("-inf")
129 | max_node_num = max(i.size(0) for i in xs)
130 | ns = [x.size(0) for x in xs]
131 | x_mask = torch.zeros(len(xs),max_node_num)
132 | for i,n in enumerate(ns):
133 | x_mask[i,:n]=1
134 |
135 |
136 |
137 | y = torch.cat(ys)
138 | root_n_id = torch.tensor(root_n_ids)
139 |
140 | if args.node_feature_type=='cate':
141 | x = torch.cat([pad_2d_unsqueeze(i, max_node_num) for i in xs])
142 | else:
143 | x = torch.cat([pad_pos_emb_unsqueeze(i, max_node_num) for i in xs])
144 |
145 |
146 | if isinstance(edge_inputs[0],int):
147 | edge_input=None
148 | attn_edge_type=None
149 | else:
150 | max_dist = max(i.size(-2) for i in edge_inputs)
151 | edge_input = torch.cat(
152 | [pad_3d_unsqueeze(i[:, :, :multi_hop_max_dist, :], max_node_num, max_node_num, max_dist) for i in edge_inputs]
153 | )
154 | attn_edge_type = torch.cat(
155 | [pad_edge_type_unsqueeze(i, max_node_num) for i in attn_edge_types]
156 | )
157 |
158 | attn_bias = torch.cat(
159 | [pad_attn_bias_unsqueeze(i, max_node_num + int(args.use_super_node)) for i in attn_biases]
160 | )
161 |
162 |
163 | in_degree = torch.cat([pad_1d_unsqueeze(i, max_node_num) for i in in_degrees]) if not isinstance(in_degrees[0],int) else None
164 | adj = torch.cat([pad_adj_unsqueeze(a, max_node_num) for a in adjs])
165 |
166 | adj_norm = torch.cat([pad_adj_unsqueeze(a, max_node_num) for a in adj_norms]) if not isinstance(adj_norms[0], int) else None
167 |
168 |
169 | spatial_pos = torch.cat(
170 | [pad_spatial_pos_unsqueeze(i, max_node_num) for i in spatial_poses]
171 | ) if not isinstance(spatial_poses[i],int) else None
172 |
173 |
174 | batch_edge_index = Batch.from_data_list([Data(edge_index=ei, num_nodes=ns[i]) for i, ei in enumerate(edge_indexs)]).edge_index if args.use_gnn_layers else None
175 |
176 |
177 | eig_pos_embs = torch.cat([pad_pos_emb_unsqueeze(i, max_node_num) for i in eig_pos_embs]) if not isinstance(eig_pos_embs[0],int) else None
178 |
179 | svd_pos_embs = torch.cat([pad_pos_emb_unsqueeze(i, max_node_num) for i in svd_pos_embs]) if not isinstance(svd_pos_embs[0],int) else None
180 |
181 |
182 |
183 | return dict(
184 | idx=torch.LongTensor(idxs),
185 | attn_bias=attn_bias,
186 | attn_edge_type=attn_edge_type,
187 | spatial_pos=spatial_pos,
188 | in_degree=in_degree,
189 | out_degree=in_degree, # for undirected graph
190 | x=x,
191 | edge_input=edge_input,
192 | x_mask = x_mask,
193 | ns = torch.LongTensor(ns), #node number in each graph
194 | labels=y.squeeze(),
195 | adj = adj,
196 | adj_norm=adj_norm,
197 | edge_index = batch_edge_index,
198 | eig_pos_emb=eig_pos_embs,
199 | svd_pos_emb=svd_pos_embs,
200 | root_n_id=root_n_id
201 | )
202 |
--------------------------------------------------------------------------------
/data/dataset.py:
--------------------------------------------------------------------------------
1 |
2 | from typing import Optional, Union
3 | from torch_geometric.data import Data as PYGDataset
4 | from dgl.data import DGLDataset
5 | from .pyg_datasets import PYGDatasetLookupTable, GraphormerPYGDataset
6 | from .ogb_datasets import OGBDatasetLookupTable
7 |
8 |
9 |
10 |
11 |
12 | class GraphormerDataset:
13 | def __init__(
14 | self,
15 | dataset: Optional[Union[PYGDataset, DGLDataset]] = None,
16 | dataset_spec: Optional[str] = None,
17 | dataset_source: Optional[str] = None,
18 | seed: int = 0,
19 | train_idx = None,
20 | valid_idx = None,
21 | test_idx = None,
22 | ):
23 | super().__init__()
24 | if dataset is not None:
25 | self.dataset = GraphormerPYGDataset(dataset, train_idx, valid_idx, test_idx)
26 |
27 | elif dataset_source == "pyg":
28 | self.dataset = PYGDatasetLookupTable.GetPYGDataset(dataset_spec, seed)
29 | elif dataset_source == "ogb":
30 | self.dataset = OGBDatasetLookupTable.GetOGBDataset(dataset_spec, seed)
31 | self.setup()
32 |
33 | def setup(self):
34 | self.train_idx = self.dataset.train_idx
35 | self.valid_idx = self.dataset.valid_idx
36 | self.test_idx = self.dataset.test_idx
37 |
38 | self.dataset_train = self.dataset.train_data
39 | self.dataset_val = self.dataset.valid_data
40 | self.dataset_test = self.dataset.test_data
41 |
--------------------------------------------------------------------------------
/data/ogb_datasets/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 |
4 | from .ogb_dataset_lookup_table import OGBDatasetLookupTable
5 |
--------------------------------------------------------------------------------
/data/ogb_datasets/ogb_dataset_lookup_table.py:
--------------------------------------------------------------------------------
1 |
2 | from typing import Optional
3 | from ogb.lsc.pcqm4mv2_pyg import PygPCQM4Mv2Dataset
4 | from ogb.lsc.pcqm4m_pyg import PygPCQM4MDataset
5 | from ogb.graphproppred import PygGraphPropPredDataset
6 | from torch_geometric.data import Dataset
7 | from ..pyg_datasets import GraphormerPYGDataset
8 | import torch.distributed as dist
9 | import os
10 |
11 |
12 |
13 |
14 |
15 | class MyPygGraphPropPredDataset(PygGraphPropPredDataset):
16 | def download(self):
17 | if not dist.is_initialized() or dist.get_rank() == 0:
18 | super(MyPygGraphPropPredDataset, self).download()
19 | if dist.is_initialized():
20 | dist.barrier()
21 |
22 | def process(self):
23 | if not dist.is_initialized() or dist.get_rank() == 0:
24 | super(MyPygGraphPropPredDataset, self).process()
25 | if dist.is_initialized():
26 | dist.barrier()
27 |
28 |
29 | class OGBDatasetLookupTable:
30 | @staticmethod
31 | def GetOGBDataset(dataset_name: str, seed: int) -> Optional[Dataset]:
32 | inner_dataset = None
33 | train_idx = None
34 | valid_idx = None
35 | test_idx = None
36 | if dataset_name == "ogbg-molhiv":
37 | folder_name = dataset_name.replace("-", "_")
38 | os.system(f"mkdir -p dataset/{folder_name}/")
39 | os.system(f"touch dataset/{folder_name}/RELEASE_v1.txt")
40 | inner_dataset = MyPygGraphPropPredDataset(dataset_name)
41 | idx_split = inner_dataset.get_idx_split()
42 | train_idx = idx_split["train"]
43 | valid_idx = idx_split["valid"]
44 | test_idx = idx_split["test"]
45 | elif dataset_name == "ogbg-molpcba":
46 | folder_name = dataset_name.replace("-", "_")
47 | os.system(f"mkdir -p dataset/{folder_name}/")
48 | os.system(f"touch dataset/{folder_name}/RELEASE_v1.txt")
49 | inner_dataset = MyPygGraphPropPredDataset(dataset_name)
50 | idx_split = inner_dataset.get_idx_split()
51 | train_idx = idx_split["train"]
52 | valid_idx = idx_split["valid"]
53 | test_idx = idx_split["test"]
54 | elif dataset_name == "pcqm4mv2":
55 | os.system("mkdir -p dataset/pcqm4m-v2/")
56 | os.system("touch dataset/pcqm4m-v2/RELEASE_v1.txt")
57 | inner_dataset = MyPygPCQM4Mv2Dataset()
58 | idx_split = inner_dataset.get_idx_split()
59 | train_idx = idx_split["train"]
60 | valid_idx = idx_split["valid"]
61 | test_idx = idx_split["test-dev"]
62 | elif dataset_name == "pcqm4m":
63 | os.system("mkdir -p dataset/pcqm4m_kddcup2021/")
64 | os.system("touch dataset/pcqm4m_kddcup2021/RELEASE_v1.txt")
65 | inner_dataset = MyPygPCQM4MDataset()
66 | idx_split = inner_dataset.get_idx_split()
67 | train_idx = idx_split["train"]
68 | valid_idx = idx_split["valid"]
69 | test_idx = idx_split["test"]
70 | else:
71 | raise ValueError(f"Unknown dataset name {dataset_name} for ogb source.")
72 | return (
73 | None
74 | if inner_dataset is None
75 | else GraphormerPYGDataset(
76 | inner_dataset, seed, train_idx, valid_idx, test_idx
77 | )
78 | )
79 |
--------------------------------------------------------------------------------
/data/pyg_datasets/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 |
4 | from .pyg_dataset_lookup_table import PYGDatasetLookupTable
5 | from .pyg_dataset import GraphormerPYGDataset
6 |
--------------------------------------------------------------------------------
/data/pyg_datasets/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/qwerfdsaplking/Graph-Trans/e4f52d0bed92b6aea3812e86fe7de9f997550318/data/pyg_datasets/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/data/pyg_datasets/__pycache__/pyg_dataset.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/qwerfdsaplking/Graph-Trans/e4f52d0bed92b6aea3812e86fe7de9f997550318/data/pyg_datasets/__pycache__/pyg_dataset.cpython-36.pyc
--------------------------------------------------------------------------------
/data/pyg_datasets/__pycache__/pyg_dataset_lookup_table.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/qwerfdsaplking/Graph-Trans/e4f52d0bed92b6aea3812e86fe7de9f997550318/data/pyg_datasets/__pycache__/pyg_dataset_lookup_table.cpython-36.pyc
--------------------------------------------------------------------------------
/data/pyg_datasets/pyg_dataset.py:
--------------------------------------------------------------------------------
1 | import copy
2 | from typing import Optional
3 |
4 | import torch
5 | from torch import Tensor
6 | from torch_sparse import SparseTensor
7 | from torch_geometric.data import Data, Batch
8 | from torch_geometric.data import Dataset
9 | from sklearn.model_selection import train_test_split
10 | from typing import List
11 | import torch
12 | import numpy as np
13 |
14 | from ..wrapper import preprocess_item
15 | from .. import algos
16 |
17 | import copy
18 | from functools import lru_cache
19 |
20 | from typing import Callable, List, NamedTuple, Optional, Tuple, Union
21 |
22 | import torch
23 | from torch import Tensor
24 | from torch_sparse import SparseTensor
25 |
26 |
27 | class EdgeIndex(NamedTuple):
28 | edge_index: Tensor
29 | e_id: Optional[Tensor]
30 | size: Tuple[int, int]
31 |
32 | def to(self, *args, **kwargs):
33 | edge_index = self.edge_index.to(*args, **kwargs)
34 | e_id = self.e_id.to(*args, **kwargs) if self.e_id is not None else None
35 | return EdgeIndex(edge_index, e_id, self.size)
36 |
37 |
38 | class Adj(NamedTuple):
39 | adj_t: SparseTensor
40 | e_id: Optional[Tensor]
41 | size: Tuple[int, int]
42 |
43 | def to(self, *args, **kwargs):
44 | adj_t = self.adj_t.to(*args, **kwargs)
45 | e_id = self.e_id.to(*args, **kwargs) if self.e_id is not None else None
46 | return Adj(adj_t, e_id, self.size)
47 |
48 |
49 |
50 |
51 |
52 | class GraphormerPYGDataset(Dataset):
53 | def __init__(
54 | self,
55 | dataset: Dataset,
56 | args,
57 | seed: int = 0,
58 | train_idx=None,
59 | valid_idx=None,
60 | test_idx=None,
61 | train_set=None,
62 | valid_set=None,
63 | test_set=None,
64 | x_norm_func=lambda x:x,
65 | ):
66 | self.args=args
67 | self.dataset = dataset
68 | if self.dataset is not None:
69 | self.num_data = len(self.dataset)
70 | self.seed = seed
71 | self.x_norm_func=x_norm_func
72 | if train_idx is None and train_set is None:
73 | train_valid_idx, test_idx = train_test_split(
74 | np.arange(self.num_data),
75 | test_size=self.num_data // 10,
76 | random_state=seed,
77 | )
78 | train_idx, valid_idx = train_test_split(
79 | train_valid_idx, test_size=self.num_data // 5, random_state=seed
80 | )
81 | self.train_idx = torch.from_numpy(train_idx)
82 | self.valid_idx = torch.from_numpy(valid_idx)
83 | self.test_idx = torch.from_numpy(test_idx)
84 | self.train_data = self.index_select(self.train_idx)
85 | self.valid_data = self.index_select(self.valid_idx)
86 | self.test_data = self.index_select(self.test_idx)
87 | elif train_set is not None:
88 | self.num_data = len(train_set) + len(valid_set) + len(test_set)
89 | self.train_data = self.create_subset(train_set)
90 | self.valid_data = self.create_subset(valid_set)
91 | self.test_data = self.create_subset(test_set)
92 | self.train_idx = None
93 | self.valid_idx = None
94 | self.test_idx = None
95 | else:
96 | self.num_data = len(train_idx) + len(valid_idx) + len(test_idx)
97 | self.train_idx = train_idx
98 | self.valid_idx = valid_idx
99 | self.test_idx = test_idx
100 | self.train_data = self.index_select(self.train_idx)
101 | self.valid_data = self.index_select(self.valid_idx)
102 | self.test_data = self.index_select(self.test_idx)
103 | self.__indices__ = None
104 |
105 | def index_select(self, idx):
106 | dataset = copy.copy(self)
107 | dataset.dataset = self.dataset.index_select(idx)
108 | if isinstance(idx, torch.Tensor):
109 | dataset.num_data = idx.size(0)
110 | else:
111 | dataset.num_data = idx.shape[0]
112 | dataset.__indices__ = idx
113 | dataset.train_data = None
114 | dataset.valid_data = None
115 | dataset.test_data = None
116 | dataset.train_idx = None
117 | dataset.valid_idx = None
118 | dataset.test_idx = None
119 | return dataset
120 |
121 | def create_subset(self, subset):
122 | dataset = GraphormerPYGDataset(subset,seed=self.seed,args=self.args)
123 | dataset.train_data = None
124 | dataset.valid_data = None
125 | dataset.test_data = None
126 | dataset.train_idx = None
127 | dataset.valid_idx = None
128 | dataset.test_idx = None
129 | return dataset
130 |
131 |
132 | @lru_cache(maxsize=16)
133 | def __getitem__(self, idx):
134 | if isinstance(idx, int):
135 | item = self.dataset[idx]
136 | item.idx = idx
137 | item.y=item.y.reshape(1, -1) if item.y.shape[-1] > 1 else item.y.reshape(-1)
138 |
139 | return preprocess_item(item, self.x_norm_func, args=self.args)
140 | else:
141 | raise TypeError("index to a GraphormerPYGDataset can only be an integer.")
142 |
143 | def __len__(self):
144 | return self.num_data
145 |
146 |
147 |
148 | class Graphtrans_Sampling_Dataset(Dataset):#shadowhop sampling
149 | def __init__(self,
150 | data,
151 | node_idx,
152 | depth: int, num_neighbors: int,
153 | replace: bool = False,
154 | x_norm_func = lambda x:x,
155 | args=None
156 | ):
157 |
158 | self.data = data#copy.copy(data)
159 | self.depth = depth
160 | self.num_neighbors = num_neighbors
161 | self.replace = replace
162 | self.x_norm_func = x_norm_func
163 | self.args=args
164 |
165 | if data.edge_index is not None:
166 | self.is_sparse_tensor = False
167 | row, col = data.edge_index.cpu()
168 | self.adj_t = SparseTensor(
169 | row=row, col=col, value=torch.arange(col.size(0)),
170 | sparse_sizes=(data.num_nodes, data.num_nodes)).t()
171 | else:
172 | self.is_sparse_tensor = True
173 | self.adj_t = data.adj_t.cpu()
174 |
175 | if node_idx is None:
176 | node_idx = torch.arange(self.adj_t.sparse_size(0))
177 | elif node_idx.dtype == torch.bool:
178 | node_idx = node_idx.nonzero(as_tuple=False).view(-1)
179 | self.node_idx = node_idx
180 | self.num_data = len(self.node_idx)
181 |
182 |
183 | @lru_cache(maxsize=16)
184 | def __getitem__(self, idx):
185 | n_id = self.node_idx[idx]
186 |
187 | rowptr, col, value = self.adj_t.csr()
188 | out = torch.ops.torch_sparse.ego_k_hop_sample_adj(
189 | rowptr, col, n_id, self.depth, self.num_neighbors, self.replace)
190 | rowptr, col, n_id, e_id, ptr, root_n_id = out
191 |
192 | adj_t = SparseTensor(rowptr=rowptr, col=col,
193 | value=value[e_id] if value is not None else None,
194 | sparse_sizes=(n_id.numel(), n_id.numel()),
195 | is_sorted=True)
196 |
197 | batch = Batch(batch=torch.ops.torch_sparse.ptr2ind(ptr, n_id.numel()),
198 | ptr=ptr)
199 | batch.root_n_id = root_n_id
200 |
201 | if self.is_sparse_tensor:
202 | batch.adj_t = adj_t
203 | else:
204 | row, col, e_id = adj_t.t().coo()
205 | batch.edge_index = torch.stack([row, col], dim=0)
206 |
207 | for k, v in self.data:
208 | if k in ['edge_index', 'adj_t', 'num_nodes']:
209 | continue
210 | if k == 'y' and v.size(0) == self.data.num_nodes:
211 | batch[k] = v[n_id][root_n_id]
212 | elif isinstance(v, Tensor) and v.size(0) == self.data.num_nodes:
213 | batch[k] = v[n_id]
214 | elif isinstance(v, Tensor) and v.size(0) == self.data.num_edges:
215 | batch[k] = v[e_id]
216 | else:
217 | batch[k] = v
218 |
219 | item = batch
220 | item.idx = self.node_idx[idx]
221 | return preprocess_item(item,x_norm_func=self.x_norm_func,args=self.args)
222 |
223 | def __len__(self):
224 | return self.num_data
225 |
226 |
227 |
228 |
229 | class Graphtrans_Sampling_Dataset_v2(Dataset):#sage sampling +induced subgraph
230 | def __init__(self,
231 | data,
232 | node_idx,
233 | depth: int,
234 | num_neighbors,
235 | replace: bool = False,
236 | x_norm_func = lambda x:x,
237 | args=None
238 | ):
239 |
240 | self.data = copy.copy(data)
241 | self.depth = depth
242 | if isinstance(num_neighbors,int):
243 | self.num_neighbors = [num_neighbors]+(depth-1)*[1]
244 | self.replace = replace
245 | self.x_norm_func = x_norm_func
246 | self.args=args
247 |
248 |
249 | if data.edge_index is not None:
250 | self.is_sparse_tensor = False
251 | row, col = data.edge_index.cpu()
252 | self.adj_t = SparseTensor(
253 | row=row, col=col, value=torch.arange(col.size(0)),
254 | sparse_sizes=(data.num_nodes, data.num_nodes)).t()
255 | else:
256 | self.is_sparse_tensor = True
257 | self.adj_t = data.adj_t.cpu()
258 |
259 |
260 | if node_idx.dtype == torch.bool:
261 | node_idx = node_idx.nonzero(as_tuple=False).view(-1)
262 | self.node_idx = node_idx
263 | self.num_data = len(self.node_idx)
264 |
265 |
266 | def __getitem__(self, idx):
267 |
268 | n_id = self.node_idx[idx].reshape(1)
269 | root_n_id=0
270 | hop_node_nums = [n_id.shape[0]]
271 |
272 | for size in self.num_neighbors:
273 | adj_t, n_id = self.adj_t.sample_adj(n_id, size, replace=False)
274 | hop_node_nums.append(n_id.shape[0])
275 | n_hops = torch.ones(n_id.shape)+len(self.num_neighbors)
276 | for i,hop_offset in enumerate(hop_node_nums):
277 | n_hops[:hop_offset]-=1
278 |
279 | adj_t,_ = self.adj_t.saint_subgraph(n_id)
280 | row, col, e_id = adj_t.t().coo()
281 | edge_index = torch.stack([row, col])
282 |
283 |
284 | item = Data(x = self.data.x[n_id],edge_index=edge_index)
285 | for k, v in self.data:
286 | if k in ['edge_index', 'adj_t', 'num_nodes']:
287 | continue
288 | if k == 'y' and v.size(0) == self.data.num_nodes:
289 | item[k] = v[n_id][root_n_id].reshape(1)
290 | elif isinstance(v, Tensor) and v.size(0) == self.data.num_nodes:
291 | item[k] = v[n_id]
292 | elif isinstance(v, Tensor) and v.size(0) == self.data.num_edges:
293 | item[k] = v[e_id]
294 | else:
295 | item[k] = v
296 | item.root_n_id = root_n_id
297 |
298 | item.idx = self.node_idx[idx]
299 | item.n_hops = n_hops
300 | return preprocess_item(item,x_norm_func=self.x_norm_func,args=self.args)
301 |
302 | def __len__(self):
303 | return self.num_data
304 |
305 |
306 |
--------------------------------------------------------------------------------
/data/pyg_datasets/pyg_dataset_lookup_table.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 |
4 | from typing import Optional
5 | from torch_geometric.datasets import *
6 | from torch_geometric.data import Dataset
7 | from .pyg_dataset import GraphormerPYGDataset
8 | import torch.distributed as dist
9 | import os
10 |
11 | class MyQM7b(QM7b):
12 | def download(self):
13 | if not dist.is_initialized() or dist.get_rank() == 0:
14 | super(MyQM7b, self).download()
15 | if dist.is_initialized():
16 | dist.barrier()
17 |
18 | def process(self):
19 | if not dist.is_initialized() or dist.get_rank() == 0:
20 | super(MyQM7b, self).process()
21 | if dist.is_initialized():
22 | dist.barrier()
23 |
24 |
25 | class MyQM9(QM9):
26 | def download(self):
27 | if not dist.is_initialized() or dist.get_rank() == 0:
28 | super(MyQM9, self).download()
29 | if dist.is_initialized():
30 | dist.barrier()
31 |
32 | def process(self):
33 | if not dist.is_initialized() or dist.get_rank() == 0:
34 | super(MyQM9, self).process()
35 | if dist.is_initialized():
36 | dist.barrier()
37 |
38 | class MyZINC(ZINC):
39 | def download(self):
40 | if not dist.is_initialized() or dist.get_rank() == 0:
41 | super(MyZINC, self).download()
42 | if dist.is_initialized():
43 | dist.barrier()
44 |
45 | def process(self):
46 | if not dist.is_initialized() or dist.get_rank() == 0:
47 | super(MyZINC, self).process()
48 | if dist.is_initialized():
49 | dist.barrier()
50 |
51 |
52 | class MyMoleculeNet(MoleculeNet):
53 | def download(self):
54 | if not dist.is_initialized() or dist.get_rank() == 0:
55 | super(MyMoleculeNet, self).download()
56 | if dist.is_initialized():
57 | dist.barrier()
58 |
59 | def process(self):
60 | if not dist.is_initialized() or dist.get_rank() == 0:
61 | super(MyMoleculeNet, self).process()
62 | if dist.is_initialized():
63 | dist.barrier()
64 |
65 |
66 |
67 | class PYGDatasetLookupTable:
68 | @staticmethod
69 | def GetPYGDataset(dataset_spec: str, seed: int) -> Optional[Dataset]:
70 | split_result = dataset_spec.split(":")
71 | if len(split_result) == 2:
72 | name, params = split_result[0], split_result[1]
73 | params = params.split(",")
74 | elif len(split_result) == 1:
75 | name = dataset_spec
76 | params = []
77 | inner_dataset = None
78 | num_class = 1
79 |
80 | train_set = None
81 | valid_set = None
82 | test_set = None
83 |
84 |
85 | folder_name = name.replace("-", "_")
86 | os.system(f"mkdir -p dataset/{folder_name}/")
87 | root = "dataset/"+folder_name
88 |
89 |
90 | if name == "qm7b":
91 | inner_dataset = MyQM7b(root=root)
92 | elif name == "qm9":
93 | inner_dataset = MyQM9(root=root)
94 | elif name == "zinc":
95 | inner_dataset = MyZINC(root=root)
96 | train_set = MyZINC(root=root, split="train")
97 | valid_set = MyZINC(root=root, split="val")
98 | test_set = MyZINC(root=root, split="test")
99 | elif name == 'zinc-subset':
100 | inner_dataset = MyZINC(root=root,subset=True)
101 | train_set = MyZINC(root=root,subset=True, split="train")
102 | valid_set = MyZINC(root=root,subset=True, split="val")
103 | test_set = MyZINC(root=root,subset=True, split="test")
104 | elif name == "moleculenet":
105 | nm = None
106 | for param in params:
107 | name, value = param.split("=")
108 | if name == "name":
109 | nm = value
110 | inner_dataset = MyMoleculeNet(root=root, name=nm)
111 | else:
112 | raise ValueError(f"Unknown dataset name {name} for pyg source.")
113 | if train_set is not None:
114 | return GraphormerPYGDataset(
115 | None,
116 | seed,
117 | None,
118 | None,
119 | None,
120 | train_set,
121 | valid_set,
122 | test_set,
123 | )
124 | else:
125 | return (
126 | None
127 | if inner_dataset is None
128 | else GraphormerPYGDataset(inner_dataset, seed)
129 | )
130 |
--------------------------------------------------------------------------------
/data/smiles/smiles_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 |
4 | from sklearn.model_selection import train_test_split
5 | import torch
6 | import numpy as np
7 |
8 | from ..wrapper import preprocess_item
9 | from .. import algos
10 | from ..pyg_datasets import GraphormerPYGDataset
11 |
12 | from ogb.utils.mol import smiles2graph
13 |
14 |
15 | class GraphormerSMILESDataset(GraphormerPYGDataset):
16 | def __init__(
17 | self,
18 | dataset: str,
19 | num_class: int,
20 | max_node: int,
21 | multi_hop_max_dist: int,
22 | spatial_pos_max: int,
23 | ):
24 | self.dataset = np.genfromtxt(dataset, delimiter=",", dtype=str)
25 | num_data = len(self.dataset)
26 | self.num_class = num_class
27 | self.__get_graph_metainfo(max_node, multi_hop_max_dist, spatial_pos_max)
28 | train_valid_idx, test_idx = train_test_split(num_data // 10)
29 | train_idx, valid_idx = train_test_split(train_valid_idx, num_data // 5)
30 | self.train_idx = train_idx
31 | self.valid_idx = valid_idx
32 | self.test_idx = test_idx
33 | self.__indices__ = None
34 | self.train_data = self.index_select(train_idx)
35 | self.valid_data = self.index_select(valid_idx)
36 | self.test_data = self.index_select(test_idx)
37 |
38 | def __get_graph_metainfo(
39 | self, max_node: int, multi_hop_max_dist: int, spatial_pos_max: int
40 | ):
41 | self.max_node = min(
42 | max_node,
43 | torch.max(self.dataset[i][0].num_nodes() for i in range(len(self.dataset))),
44 | )
45 | max_dist = 0
46 | for i in range(len(self.dataset)):
47 | pyg_graph = smiles2graph(self.dataset[i])
48 | dense_adj = pyg_graph.adj().to_dense().type(torch.int)
49 | shortest_path_result, _ = algos.floyd_warshall(dense_adj.numpy())
50 | max_dist = max(max_dist, np.amax(shortest_path_result))
51 | self.multi_hop_max_dist = min(multi_hop_max_dist, max_dist)
52 | self.spatial_pos_max = min(spatial_pos_max, max_dist)
53 |
54 | def __getitem__(self, idx):
55 | if isinstance(idx, int):
56 | item = smiles2graph(self.dataset[idx])
57 | item.idx = idx
58 | return preprocess_item(item)
59 | else:
60 | raise TypeError("index to a GraphormerPYGDataset can only be an integer.")
61 |
--------------------------------------------------------------------------------
/data/wrapper.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 | from utils import utils
4 | import torch
5 | import numpy as np
6 | from ogb.graphproppred import PygGraphPropPredDataset
7 | #from ogb.lsc.pcqm4mv2_pyg import PygPCQM4Mv2Dataset
8 | from functools import lru_cache
9 | import pyximport
10 | import torch.distributed as dist
11 | from torch_geometric.utils import to_undirected,add_self_loops
12 | from torch_geometric.data import Data
13 |
14 | pyximport.install(setup_args={"include_dirs": np.get_include()})
15 | from . import algos
16 | from utils.utils import *
17 | from copy import deepcopy
18 |
19 | @torch.jit.script
20 | def convert_to_single_emb(x, offset: int = 512):
21 | feature_num = x.size(1) if len(x.size()) > 1 else 1
22 | feature_offset = 1 + torch.arange(0, feature_num * offset, offset, dtype=torch.long)
23 | x = x + feature_offset
24 | return x
25 |
26 |
27 | def preprocess_item(raw_item,x_norm_func,args):
28 | edge_attr, edge_index, x,y,idx = raw_item.edge_attr, raw_item.edge_index, raw_item.x,raw_item.y,raw_item.idx
29 | root_n_id=raw_item.root_n_id if 'root_n_id' in raw_item.to_dict().keys() else -1
30 |
31 |
32 |
33 | N = x.size(0)
34 | if args.node_feature_type=='cate':
35 | x = convert_to_single_emb(x)
36 | elif args.node_feature_type=='dense':
37 | x = x_norm_func(x)
38 | else:
39 | raise ValueError('node feature type error')
40 |
41 | # node adj matrix [N, N] bool
42 | try:
43 | edge_index = to_undirected(edge_index)
44 | except:
45 | print(edge_index)
46 | assert 1==2
47 |
48 | adj = torch.zeros([N, N], dtype=torch.bool)
49 | adj[edge_index[0, :], edge_index[1, :]] = True
50 |
51 | adj_w_sl = adj.clone()#adj with self loop
52 | adj_w_sl[torch.arange(N),torch.arange(N)]=1
53 |
54 | #positional bias
55 | if 'degree' in args.node_level_modules:
56 | in_degree = adj.long().sum(dim=1).view(-1)
57 | else:
58 | in_degree = 0
59 |
60 | if 'eig' in args.node_level_modules:
61 | if N0
23 | sign = 1 if sign else -1
24 | return self.embeddings(pos * sign)
25 |
26 | class SVD_Embedding(nn.Module):
27 | def __init__(self, svd_dim, hidden_dim):
28 | super(SVD_Embedding,self).__init__()
29 | self.svd_dim=svd_dim
30 | self.embeddings = nn.Linear(svd_dim*2,hidden_dim)
31 | def forward(self, batched_data):
32 | pos = batched_data['svd_pos_emb']
33 | sign = torch.randn(1)[0]>0
34 | sign = 1 if sign else -1
35 | pos_u = pos[:,:,:self.svd_dim]*sign
36 | pos_v = pos[:,:,self.svd_dim:]*(-sign)
37 | pos = torch.cat([pos_u,pos_v],dim=-1)
38 | return self.embeddings(pos)
39 |
40 |
41 | class WL_Role_Embedding(nn.Module):
42 | def __init__(self, max_index, hidden_dim):
43 | super(WL_Role_Embedding,self).__init__()
44 | self.embeddings = nn.Linear(max_index,hidden_dim)
45 | def forward(self, batched_data):
46 | pos = batched_data['wl_role_ids']
47 | return self.embeddings(pos)
48 |
49 | class Inti_Pos_Embedding(nn.Module):
50 | def __init__(self, max_index, hidden_dim):
51 | super(Inti_Pos_Embedding,self).__init__()
52 | self.embeddings = nn.Linear(max_index,hidden_dim)
53 | def forward(self, batched_data):
54 | pos = batched_data['init_pos_ids']
55 | return self.embeddings(pos)
56 |
57 | class Hop_Dis_Embedding(nn.Module):
58 | def __init__(self, max_index, hidden_dim):
59 | super(Hop_Dis_Embedding,self).__init__()
60 | self.embeddings = nn.Linear(max_index,hidden_dim)
61 | def forward(self, batched_data):
62 | pos = batched_data['hop_dis_ids']
63 | return self.embeddings(pos)
64 |
65 | class DegreeEncoder(nn.Module):
66 | def __init__(self,
67 | num_in_degree,
68 | num_out_degree,
69 | hidden_dim,
70 | n_layers #for parameter initialization
71 | ):
72 | super(DegreeEncoder, self).__init__()
73 | self.in_degree_encoder = nn.Embedding(num_in_degree, hidden_dim, padding_idx=0)
74 | self.out_degree_encoder = nn.Embedding(num_out_degree, hidden_dim, padding_idx=0)
75 | self.apply(lambda module: init_params(module, n_layers=n_layers))
76 |
77 | def forward(self, batched_data):
78 | in_degree, out_degree = (
79 | batched_data["in_degree"],
80 | batched_data["out_degree"],
81 | )
82 | return self.in_degree_encoder(in_degree)+self.out_degree_encoder(out_degree)
83 |
84 |
85 | class AddSuperNode(nn.Module):
86 | def __init__(self, hidden_dim):
87 | super(AddSuperNode, self).__init__()
88 | self.graph_token = nn.Embedding(1, hidden_dim)
89 |
90 | def forward(self, node_feature):
91 | n_graph = node_feature.size()[0]
92 | graph_token_feature = self.graph_token.weight.unsqueeze(0).repeat(n_graph, 1, 1)
93 | graph_node_feature = torch.cat([graph_token_feature, node_feature], dim=1)
94 |
95 | return graph_node_feature
96 |
97 |
98 |
99 |
100 |
101 | class NodeFeatureEncoder(nn.Module):
102 | def __init__(
103 | self,
104 | feat_type,
105 | hidden_dim,
106 | n_layers,
107 | num_atoms=None,
108 | feat_dim=None
109 | ):
110 | super(NodeFeatureEncoder, self).__init__()
111 |
112 | self.feat_type = feat_type
113 |
114 | if feat_type=='dense' and feat_dim is not None:#dense feature
115 | self.feature_encoder = nn.Linear(feat_dim, hidden_dim)
116 | elif feat_type=='cate' and num_atoms is not None:#cate feature
117 | # 1 for graph token
118 | self.feature_encoder = nn.Embedding(num_atoms + 1, hidden_dim, padding_idx=0)
119 | else:
120 | raise ValueError('conflict feature type')
121 |
122 | self.apply(lambda module: init_params(module, n_layers=n_layers))
123 |
124 | def forward(self, batched_data):
125 | x=batched_data["x"]
126 | if self.feat_type=='cate':#
127 | node_feature = self.feature_encoder(x).sum(dim=-2) # [n_graph, n_node, n_hidden]
128 | else:
129 | node_feature = self.feature_encoder(x)
130 |
131 | return node_feature
132 |
133 |
134 | def getAttnMasks(batched_data,attn_mask_modules,use_super_node,num_heads):
135 | adj = batched_data['adj'].bool().float()
136 |
137 | attn_mask = torch.ones(adj.shape[0], num_heads,adj.shape[1] + int(use_super_node),
138 | adj.shape[2] + int(use_super_node)).to(adj.device)
139 | if attn_mask_modules == '1hop':
140 | adjs = adj.unsqueeze(1).expand(-1,num_heads,-1,-1).bool().float()
141 | attn_mask[:,:,int(use_super_node):,int(use_super_node):] = adjs
142 |
143 |
144 | if attn_mask_modules == 'nhop':
145 | multi_hop_adjs = torch.cat([torch.matrix_power(adj, i + 1).unsqueeze(1) for i in range(num_heads)],
146 | dim=1).bool().float()
147 | attn_mask[:,:, int(use_super_node):, int(use_super_node):] = multi_hop_adjs
148 |
149 | return attn_mask
150 |
151 |
152 | class GraphAttnHopBias(nn.Module):
153 | def __init__(
154 | self,
155 | num_heads,
156 | n_hops,
157 | use_super_node
158 | ):
159 | super(GraphAttnHopBias, self).__init__()
160 | self.num_heads = num_heads
161 | self.use_super_node=use_super_node
162 | self.hop_bias = nn.Parameter(torch.randn(n_hops,num_heads))
163 | self.n_hops = n_hops
164 |
165 | def forward(self, batched_data):
166 | x, adj, attn_bias = (
167 | batched_data["x"],
168 | batched_data['adj_norm'],
169 | batched_data['attn_bias']
170 | )
171 |
172 |
173 | adj_n_hops_bias = torch.ones(adj.shape[0],adj.shape[1]+int(self.use_super_node),
174 | adj.shape[2]+int(self.use_super_node),self.n_hops).to(x.device)
175 | adj_list = [torch.matrix_power(adj,i+1).unsqueeze(-1) for i in range(self.n_hops)]
176 | adj_n_hops = torch.cat(adj_list,dim=-1)# n_graph, n_node, n_node, n_hops
177 | adj_n_hops_bias[:,int(self.use_super_node):,int(self.use_super_node):,:] = adj_n_hops
178 | adj_n_hops_bias = torch.matmul(adj_n_hops_bias,self.hop_bias).permute(0, 3, 1, 2)
179 |
180 | return adj_n_hops_bias# [n_graph, n_head, n_node+1, n_node+1]
181 |
182 |
183 |
184 |
185 | class GraphAttnSpatialBias(nn.Module):#refer to Graphormer
186 | def __init__(
187 | self,
188 | num_heads,
189 | num_spatial,
190 | n_layers,
191 | use_super_node
192 | ):
193 | super(GraphAttnSpatialBias, self).__init__()
194 | self.num_heads = num_heads
195 | self.use_super_node = use_super_node
196 |
197 | self.spatial_pos_encoder = nn.Embedding(num_spatial, num_heads, padding_idx=0)
198 |
199 | if use_super_node:
200 | self.graph_token_virtual_distance = nn.Embedding(1, num_heads)
201 |
202 | self.apply(lambda module: init_params(module, n_layers=n_layers))
203 |
204 | def forward(self, batched_data):
205 | attn_bias, spatial_pos, x = (
206 | batched_data["attn_bias"],#[n_graph, n_node+1, n_node+1]
207 | batched_data["spatial_pos"],#[n_graph, n_node, n_node]
208 | batched_data["x"],
209 | )
210 |
211 | graph_attn_bias = attn_bias.clone()
212 | graph_attn_bias = graph_attn_bias.unsqueeze(1).repeat(
213 | 1, self.num_heads, 1, 1
214 | ) # [n_graph, n_head, n_node+1, n_node+1]
215 |
216 | # spatial pos
217 | # [n_graph, n_node, n_node, n_head] -> [n_graph, n_head, n_node, n_node]
218 | spatial_pos_bias = self.spatial_pos_encoder(spatial_pos).permute(0, 3, 1, 2)
219 | graph_attn_bias[:, :, int(self.use_super_node):, int(self.use_super_node):] = graph_attn_bias[:, :, int(self.use_super_node):, int(self.use_super_node):] + spatial_pos_bias
220 |
221 | # reset spatial pos here
222 | if self.use_super_node:
223 | t = self.graph_token_virtual_distance.weight.view(1, self.num_heads, 1)
224 | graph_attn_bias[:, :, 1:, 0] = graph_attn_bias[:, :, 1:, 0] + t
225 | graph_attn_bias[:, :, 0, :] = graph_attn_bias[:, :, 0, :] + t
226 |
227 | graph_attn_bias = graph_attn_bias + attn_bias.unsqueeze(1) # reset pad -inf
228 |
229 | return graph_attn_bias# [n_graph, n_head, n_node+1, n_node+1]
230 |
231 |
232 |
233 | class GraphAttnEdgeBias(nn.Module): #refer to Graphormer
234 | """
235 | Compute attention bias for each head. We do not need to consider super node in this module.
236 | """
237 | def __init__(
238 | self,
239 | num_heads,
240 | num_edges,
241 | num_edge_dis,
242 | edge_type,
243 | multi_hop_max_dist,
244 | n_layers,
245 | ):
246 | super(GraphAttnEdgeBias, self).__init__()
247 | self.num_heads = num_heads
248 | self.multi_hop_max_dist = multi_hop_max_dist
249 | #probably some issues here
250 | self.edge_encoder = nn.Embedding(num_edges + 1, num_heads, padding_idx=0)
251 | self.edge_type = edge_type
252 | if self.edge_type == "multi_hop":
253 | self.edge_dis_encoder = nn.Embedding(
254 | num_edge_dis * num_heads * num_heads, 1
255 | )
256 |
257 | self.apply(lambda module: init_params(module, n_layers=n_layers))
258 |
259 | def forward(self, batched_data):
260 | attn_bias, spatial_pos, x = (
261 | batched_data["attn_bias"],
262 | batched_data["spatial_pos"],
263 | batched_data["x"],
264 | )
265 | edge_input, attn_edge_type = (
266 | batched_data["edge_input"],
267 | batched_data["attn_edge_type"],
268 | )
269 |
270 | n_graph, n_node = x.size()[:2]
271 |
272 |
273 | if attn_edge_type is None:
274 | edge_input = torch.zeros(n_graph, self.num_heads, n_node, n_node).to(x.device)
275 | return edge_input
276 |
277 | # edge feature
278 | if self.edge_type == "multi_hop":
279 | spatial_pos_ = spatial_pos.clone()
280 | spatial_pos_[spatial_pos_ == 0] = 1 # set pad to 1
281 | # set 1 to 1, x > 1 to x - 1
282 | spatial_pos_ = torch.where(spatial_pos_ > 1, spatial_pos_ - 1, spatial_pos_)
283 | if self.multi_hop_max_dist > 0:
284 | spatial_pos_ = spatial_pos_.clamp(0, self.multi_hop_max_dist)
285 | edge_input = edge_input[:, :, :, : self.multi_hop_max_dist, :]
286 | # [n_graph, n_node, n_node, max_dist, n_head]
287 | edge_input = self.edge_encoder(edge_input).mean(-2)
288 | max_dist = edge_input.size(-2)
289 | edge_input_flat = edge_input.permute(3, 0, 1, 2, 4).reshape(
290 | max_dist, -1, self.num_heads
291 | )
292 | edge_input_flat = torch.bmm(
293 | edge_input_flat,
294 | self.edge_dis_encoder.weight.reshape(
295 | -1, self.num_heads, self.num_heads
296 | )[:max_dist, :, :],
297 | )
298 | edge_input = edge_input_flat.reshape(
299 | max_dist, n_graph, n_node, n_node, self.num_heads
300 | ).permute(1, 2, 3, 0, 4)
301 | edge_input = (
302 | edge_input.sum(-2) / (spatial_pos_.float().unsqueeze(-1))
303 | ).permute(0, 3, 1, 2)
304 | else:
305 | # [n_graph, n_node, n_node, n_head] -> [n_graph, n_head, n_node, n_node]
306 | edge_input = self.edge_encoder(attn_edge_type).mean(-2).permute(0, 3, 1, 2)
307 |
308 |
309 | return edge_input#[n_graph, n_head, n_node, n_node]
310 |
311 |
312 |
313 |
--------------------------------------------------------------------------------
/graphtrasformer/gt_models.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from typing import Optional, Tuple
3 | from graphtrasformer.gnn_layers import *
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | from torch_scatter import scatter
8 | from graphtrasformer.gt_layers import *
9 | from graphtrasformer.layers import *
10 | logger = logging.getLogger(__name__)
11 |
12 |
13 | def init_graphormer_params(module):
14 | """
15 | Initialize the weights specific to the Graphormer Model.
16 | """
17 | def normal_(data):
18 | data.copy_(data.cpu().normal_(mean=0.0, std=0.02).to(data.device))
19 |
20 | if isinstance(module, nn.Linear):
21 | normal_(module.weight.data)
22 | if module.bias is not None:
23 | module.bias.data.zero_()
24 | if isinstance(module, nn.Embedding):
25 | normal_(module.weight.data)
26 | if module.padding_idx is not None:
27 | module.weight.data[module.padding_idx].zero_()
28 | if isinstance(module, MultiheadAttention):
29 | normal_(module.q_proj.weight.data)
30 | normal_(module.k_proj.weight.data)
31 | normal_(module.v_proj.weight.data)
32 |
33 |
34 |
35 |
36 |
37 | class GraphTransformer(nn.Module):
38 | def __init__(
39 | self,
40 | num_encoder_layers: int = 12,
41 | hidden_dim: int = 768,
42 | ffn_hidden_dim: int = 768*3,
43 | num_attn_heads: int = 32,
44 | emb_dropout: float = 0,
45 | dropout: float = 0.1,
46 | attn_dropout: float = 0.1,
47 | num_class: int =2 ,
48 | encoder_normalize_before: bool = False,
49 | apply_graphormer_init: bool = False,
50 | activation_fn: str = "gelu",
51 | n_trans_layers_to_freeze: int = 0,
52 | traceable = False,
53 |
54 | use_super_node: bool = True,
55 |
56 | node_feature_type: str = 'cate',
57 | node_feature_dim: int = None,
58 | num_atoms: int = None,
59 |
60 | node_level_modules: tuple = ('degree'),
61 | attn_level_modules: tuple = ('spe','spatial'),
62 | attn_mask_modules: str = None,
63 |
64 | num_in_degree: int = None,
65 | num_out_degree: int = None,
66 | eig_pos_dim: int = None,
67 | svd_pos_dim: int = None,
68 |
69 | num_spatial: int = None,
70 | num_edges: int = None,
71 | num_edge_dis: int = None,
72 | edge_type: str = None,
73 | multi_hop_max_dist: int = None,
74 | num_hop_bias: int=None,
75 |
76 | use_gnn_layers: bool=False,
77 | gnn_insert_pos: str='before',
78 | num_gnn_layers: int=1,
79 | gnn_type: str='GAT',
80 | gnn_dropout: float=0.5
81 | ) -> None:
82 |
83 | super().__init__()
84 | self.emb_dropout = nn.Dropout(p=emb_dropout)
85 | self.hidden_dim= hidden_dim
86 | self.apply_graphormer_init = apply_graphormer_init
87 | self.traceable=traceable
88 | self.use_super_node = use_super_node
89 | self.use_gnn_layers = use_gnn_layers
90 | self.gnn_insert_pos = gnn_insert_pos
91 | self.num_attn_heads=num_attn_heads
92 | self.attn_mask_modules=attn_mask_modules
93 |
94 | if encoder_normalize_before:
95 | self.emb_layer_norm = nn.LayerNorm(self.hidden_dim)
96 | else:
97 | self.emb_layer_norm = None
98 |
99 | #node feature encoder
100 | self.node_feature_encoder = NodeFeatureEncoder(feat_type=node_feature_type,
101 | hidden_dim=hidden_dim,
102 | n_layers=num_encoder_layers,
103 | num_atoms=num_atoms,
104 | feat_dim=node_feature_dim
105 | )
106 |
107 | if use_super_node:
108 | self.add_super_node = AddSuperNode(hidden_dim=hidden_dim)
109 |
110 | #node-level graph-structural feature encoder
111 | self.node_level_layers = nn.ModuleList([])
112 | for module_name in node_level_modules:
113 | if module_name=='degree':
114 | layer = DegreeEncoder(num_in_degree=num_in_degree,
115 | num_out_degree=num_out_degree,
116 | hidden_dim=hidden_dim,
117 | n_layers=num_encoder_layers)
118 | elif module_name=='eig':
119 | layer = Eig_Embedding(eig_dim=eig_pos_dim,hidden_dim=hidden_dim)
120 | elif module_name=='svd':
121 | layer = SVD_Embedding(svd_dim=svd_pos_dim,hidden_dim=hidden_dim)
122 | else:
123 | raise ValueError('node level module error!')
124 | self.node_level_layers.append(layer)
125 | #attention-level graph-structural feature encoder
126 | self.attn_level_layers = nn.ModuleList([])
127 | for module_name in attn_level_modules:
128 | if module_name=='spatial':
129 | layer = GraphAttnSpatialBias(num_heads=num_attn_heads,
130 | num_spatial=num_spatial,
131 | n_layers=num_encoder_layers,
132 | use_super_node=use_super_node)
133 | elif module_name=='spe':
134 | layer = GraphAttnEdgeBias(num_heads = num_attn_heads,
135 | num_edges = num_edges,
136 | num_edge_dis = num_edge_dis,
137 | edge_type=edge_type,
138 | multi_hop_max_dist=multi_hop_max_dist,
139 | n_layers=num_encoder_layers)
140 | elif module_name=='nhop':
141 | layer = GraphAttnHopBias(num_heads = num_attn_heads,
142 | n_hops = num_hop_bias,
143 | use_super_node=use_super_node)
144 | else:
145 | raise ValueError('attn level module error!')
146 | self.attn_level_layers.append(layer)
147 |
148 |
149 |
150 | #gnn layers
151 | if use_gnn_layers:
152 | if gnn_insert_pos=='before':
153 | self.gnn_layers = Geometric_GNN(gnn_type=gnn_type,
154 | hidden_dim=hidden_dim,
155 | gnn_dropout=gnn_dropout,
156 | n_layers=num_gnn_layers,
157 | use_super_node=use_super_node)
158 | elif gnn_insert_pos in ('alter','parallel'):
159 | self.gnn_layers = nn.ModuleList([Geometric_GNN(gnn_type=gnn_type,
160 | hidden_dim=hidden_dim,
161 | gnn_dropout=gnn_dropout,
162 | n_layers=num_gnn_layers,
163 | use_super_node=use_super_node) for _ in range(num_encoder_layers)])
164 |
165 |
166 | #transformer layers
167 | self.transformer_layers =nn.ModuleList([
168 | Transformer_Layer(
169 | num_heads=num_attn_heads,
170 | hidden_dim=hidden_dim,
171 | ffn_hidden_dim=ffn_hidden_dim,
172 | dropout=dropout,
173 | attn_dropout=attn_dropout,
174 | temperature=1,
175 | activation_fn=activation_fn
176 | ) for _ in range(num_encoder_layers)
177 | ])
178 |
179 |
180 | self.output_layer_norm = nn.LayerNorm(hidden_dim)
181 | self.output_fc1 = nn.Linear(hidden_dim,hidden_dim)
182 | self.output_fc2 = nn.Linear(hidden_dim,num_class)
183 | self.out_act_fn = get_activation_function(activation_fn)
184 |
185 |
186 | # Apply initialization of model params after building the model
187 | if self.apply_graphormer_init:
188 | self.apply(init_graphormer_params)
189 |
190 | def freeze_module_params(m):
191 | if m is not None:
192 | for p in m.parameters():
193 | p.requires_grad = False
194 |
195 | for layer in range(n_trans_layers_to_freeze):
196 | freeze_module_params(self.layers[layer])
197 |
198 |
199 |
200 | def forward(
201 | self,
202 | batched_data,
203 | perturb=None,
204 | last_state_only: bool = False,
205 | ):
206 |
207 | #==============preparation==========================
208 | # compute padding mask. This is needed for multi-head attention
209 | data_x = batched_data["x"]
210 | n_graph, n_node = data_x.size()[:2]
211 |
212 | #calculate attention padding mask # B x T x T / Bx T+1 x T+1
213 | padding_mask = batched_data['x_mask']
214 | if self.use_super_node:
215 | padding_mask_cls = torch.ones(
216 | n_graph, 1, device=padding_mask.device, dtype=padding_mask.dtype
217 | )
218 | padding_mask = torch.cat((padding_mask_cls, padding_mask), dim=1).float()
219 | attn_mask = torch.matmul(padding_mask.unsqueeze(-1), padding_mask.unsqueeze(1)).long()
220 | self.attn_mask=attn_mask
221 |
222 | #x feature encode
223 | x = self.node_feature_encoder(batched_data)# B x T x C
224 | for nl_layer in self.node_level_layers:
225 | node_bias = nl_layer(batched_data)
226 | x += node_bias
227 | #add the super node
228 | if self.use_super_node:
229 | x = self.add_super_node(x)# B x T+1 x C
230 |
231 |
232 |
233 | # attention bias computation, B x H x (T+1) x (T+1) or B x H x T x T
234 | attn_bias = torch.zeros(n_graph,self.num_attn_heads,n_node+int(self.use_super_node),n_node+int(self.use_super_node)).to(data_x.device)
235 | for al_layer in self.attn_level_layers:
236 | bias = al_layer(batched_data)
237 | if bias.shape[-1]==attn_bias.shape[-1]:
238 | attn_bias+=bias
239 | elif bias.shape[-1]==attn_bias.shape[-1]-1:
240 | attn_bias[:, :, int(self.use_super_node):, int(self.use_super_node):] = attn_bias[:, :, int(self.use_super_node):, int(self.use_super_node):] + bias
241 | else:
242 | raise ValueError('attention calculation error')
243 |
244 | #attention mask
245 | if self.attn_mask_modules in ('1hop','nhop'):
246 | adj_mask = getAttnMasks(batched_data,self.attn_mask_modules,self.use_super_node,self.num_attn_heads)
247 | attn_mask = attn_mask.unsqueeze(1).expand(-1,self.num_attn_heads,-1,-1)*adj_mask
248 |
249 |
250 | #===================data flow===============
251 | #input feature normalization and dropout
252 | if self.emb_layer_norm is not None:
253 | x = self.emb_layer_norm(x)
254 | x = self.emb_dropout(x) # B x T+1 x C
255 |
256 | #gnn layers before transformer
257 | if self.use_gnn_layers and self.gnn_insert_pos=='before':
258 | x = self.gnn_layers(batched_data,x)
259 |
260 |
261 | # graph transformer layers
262 | inner_states = []
263 | if not last_state_only:
264 | inner_states.append(x)
265 | for i,layer in enumerate(self.transformer_layers):
266 |
267 | if self.use_gnn_layers and self.gnn_insert_pos=='parallel':
268 | x_graph = self.gnn_layers[i](batched_data, x)
269 | else:
270 | x_graph = 0
271 |
272 | #self-attention layer
273 | x, _ = layer.attention(
274 | x=x,
275 | mask=attn_mask,
276 | attn_bias=attn_bias,
277 | )
278 |
279 | if self.use_gnn_layers and self.gnn_insert_pos=='alter':#by default, gnn after mhsa
280 | x = self.gnn_layers[i](batched_data, x)
281 |
282 | x = x + x_graph
283 |
284 |
285 | #FFN layer
286 | x = layer.ffn_layer(x)
287 | if not last_state_only:
288 | inner_states.append(x)
289 |
290 |
291 |
292 | #output layers
293 | if self.use_super_node:
294 | graph_rep = x[:, 0, :].squeeze()#B x 1 x C
295 | else:
296 | #center node
297 | root_n_id = batched_data['root_n_id']
298 | root_idx = (torch.arange(n_graph,device=x.device)*n_node+root_n_id).long()
299 | graph_rep = x.reshape(-1,x.shape[-1])[root_idx].squeeze()
300 | #mean pooling, other readout methods to be implemented, e.g, center node
301 | #x = x.reshape(-1, self.hidden_dim)
302 | #padding_mask = padding_mask.reshape(-1).bool()
303 | #x[~padding_mask]=0
304 | #ns = batched_data['ns']#node number in each graph
305 | #graph_rep = x.reshape(-1,n_node,self.hidden_dim).sum(1) / ns.unsqueeze(1)
306 |
307 | #output transformation
308 | out = self.output_layer_norm(self.out_act_fn(self.output_fc1(graph_rep)))
309 | out = self.output_fc2(out).squeeze()
310 |
311 | return {'logits':out}
312 |
313 |
314 |
--------------------------------------------------------------------------------
/graphtrasformer/layer_tests.py:
--------------------------------------------------------------------------------
1 | from tqdm import tqdm
2 |
3 | from graphtrasformer.layers import *
4 | if __name__=='__main__':
5 |
6 |
7 | layer = Transformer_Layer(
8 | num_heads=4,
9 | hidden_dim=64,
10 | ffn_hidden_dim=128,
11 | dropout=0.1,
12 | attn_dropout=0.1,
13 | temperature=1,
14 | activation_fn='GELU')
15 |
16 |
17 |
18 |
19 | x = torch.randn(8, 100, 64)
20 | x[:,80:,:]=0
21 | mask = torch.zeros(8, 100,100)
22 | mask[:,:80,:80]=1
23 |
24 |
25 |
26 | out,attn = layer.attention(x,mask)
27 |
28 |
29 |
--------------------------------------------------------------------------------
/graphtrasformer/layers.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import math
3 | from torch.nn import init
4 | import json
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 |
8 |
9 |
10 |
11 | def get_activation_function(activation: str='PReLU') -> nn.Module:
12 | if activation == 'ReLU':
13 | return nn.ReLU()
14 | elif activation == 'LeakyReLU':
15 | return nn.LeakyReLU(0.1)
16 | elif activation == 'PReLU':
17 | return nn.PReLU()
18 | elif activation == 'tanh':
19 | return nn.Tanh()
20 | elif activation == 'SELU':
21 | return nn.SELU()
22 | elif activation == 'ELU':
23 | return nn.ELU()
24 | elif activation == "Linear":
25 | return lambda x: x
26 | elif activation == 'GELU':
27 | return nn.GELU()
28 | else:
29 | raise ValueError(f'Activation "{activation}" not supported.')
30 |
31 |
32 |
33 |
34 | class PositionwiseFeedForward(nn.Module):
35 | """Implements FFN equation."""
36 |
37 | def __init__(self,hidden_dim , ffn_hidden_dim, activation_fn="GELU", dropout=0.1):
38 | super(PositionwiseFeedForward, self).__init__()
39 |
40 | self.fc1 = nn.Linear(hidden_dim, ffn_hidden_dim)
41 | self.fc2 = nn.Linear(ffn_hidden_dim, hidden_dim)
42 | self.act_dropout = nn.Dropout(dropout)
43 | self.dropout = nn.Dropout(dropout)
44 | self.ffn_layer_norm = nn.LayerNorm(hidden_dim, eps=1e-6)
45 | self.ffn_act_func = get_activation_function(activation_fn)
46 |
47 | def forward(self, x):
48 | residual=x
49 | x = self.dropout(self.fc2(self.act_dropout(self.ffn_act_func(self.fc1(x)))))
50 | x+=residual
51 | x = self.ffn_layer_norm(x)
52 | return x
53 |
54 |
55 |
56 | class MultiheadAttention(nn.Module):
57 | """
58 | Compute 'Scaled Dot Product SelfAttention
59 | """
60 | def __init__(self,
61 | num_heads,
62 | hidden_dim,
63 | dropout=0.1,
64 | attn_dropout=0.1,
65 | temperature = 1):
66 | super().__init__()
67 | self.d_k = hidden_dim // num_heads
68 | self.num_heads = num_heads # number of heads
69 | self.temperature =temperature
70 | self.q_proj = nn.Linear(hidden_dim, hidden_dim)
71 | self.k_proj = nn.Linear(hidden_dim, hidden_dim)
72 | self.v_proj = nn.Linear(hidden_dim, hidden_dim)
73 | self.a_proj = nn.Linear(hidden_dim, hidden_dim)
74 | self.attn_dropout = nn.Dropout(attn_dropout)
75 | self.dropout=nn.Dropout(dropout)
76 | self.layer_norm = nn.LayerNorm(hidden_dim,eps=1e-6)
77 | self.reset_parameters()
78 |
79 | def reset_parameters(self):
80 | nn.init.xavier_uniform_(self.k_proj.weight)
81 | nn.init.xavier_uniform_(self.v_proj.weight)
82 | nn.init.xavier_uniform_(self.q_proj.weight)
83 | nn.init.xavier_uniform_(self.a_proj.weight)
84 |
85 | def forward(self, x, mask=None, attn_bias=None):
86 | residual = x
87 | batch_size = x.size(0)
88 |
89 | query = self.q_proj(x)
90 | key = self.k_proj(x)
91 | value = self.v_proj(x)
92 |
93 | query = query.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
94 | key = key.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
95 | value = value.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
96 |
97 | #ScaledDotProductAttention
98 | if mask is not None and len(mask.shape) == 3:
99 | mask = mask.unsqueeze(1)
100 |
101 | scores = torch.matmul(query/self.temperature, key.transpose(-2, -1)) \
102 | / math.sqrt(query.size(-1))
103 |
104 | if attn_bias is not None:
105 | scores = scores+attn_bias
106 |
107 | if mask is not None:
108 | if scores.shape==mask.shape:#different heads have different mask
109 | scores = scores * mask
110 | scores = scores.masked_fill(scores == 0, -1e12)
111 | else:
112 | scores = scores.masked_fill(mask == 0, -1e12)
113 |
114 | attn = self.attn_dropout(F.softmax(scores, dim=-1))
115 | #ScaledDotProductAttention
116 |
117 | out = torch.matmul(attn, value)
118 | out = out.transpose(1, 2).contiguous().view(batch_size, -1, self.num_heads * self.d_k)
119 | out = self.dropout(self.a_proj(out))
120 | out += residual
121 | out = self.layer_norm(out)
122 |
123 | return out, attn
124 |
125 |
126 | class Transformer_Layer(nn.Module):
127 | def __init__(self,
128 | num_heads,
129 | hidden_dim,
130 | ffn_hidden_dim,
131 | dropout=0.1,
132 | attn_dropout=0.1,
133 | temperature = 1,
134 | activation_fn='GELU'):
135 | super().__init__()
136 | assert hidden_dim % num_heads == 0
137 |
138 | self.attention = MultiheadAttention(num_heads,
139 | hidden_dim,
140 | dropout,
141 | attn_dropout,
142 | temperature)
143 | self.ffn_layer = PositionwiseFeedForward(hidden_dim,ffn_hidden_dim,activation_fn=activation_fn)
144 |
145 |
146 | def forward(self, x, attn_mask, attn_bias=None):
147 | x, attn = self.attention(x, mask=attn_mask, attn_bias=attn_bias)
148 | x = self.ffn_layer(x)
149 |
150 | return x, attn
151 |
152 |
153 |
--------------------------------------------------------------------------------
/gt_dataset.py:
--------------------------------------------------------------------------------
1 | from torch_geometric.datasets import *
2 | import torch.nn as nn
3 | from tqdm import tqdm
4 |
5 |
6 | from ogb.graphproppred import PygGraphPropPredDataset
7 | from ogb.nodeproppred import PygNodePropPredDataset
8 | from ogb.graphproppred import Evaluator
9 | from datasets import load_dataset, load_metric
10 | from data.pyg_datasets.pyg_dataset import GraphormerPYGDataset, Graphtrans_Sampling_Dataset,Graphtrans_Sampling_Dataset_v2
11 |
12 |
13 | def get_loss_and_metric(data_name):
14 |
15 | if data_name in ['ZINC','pcqm4mv2','QM7','QM9','ZINC-full']:
16 | loss = nn.L1Loss(reduction='mean')
17 | metric = nn.L1Loss(reduction='mean')
18 | task_type='regression'
19 | metric_name = 'MAE'
20 |
21 | elif data_name in ['UPFD']:
22 | loss = nn.BCEWithLogitsLoss(reduction='mean')
23 | metric = load_metric("accuracy")
24 | task_type='binary_classification'
25 | metric_name='accuracy'
26 |
27 | elif data_name in ["ogbg-molhiv"]:
28 | loss = nn.BCEWithLogitsLoss(reduction='mean')
29 | metric = Evaluator(name=data_name)
30 | task_type='binary_classification'
31 | metric_name='ROC-AUC'
32 |
33 | elif data_name in ['flickr','ogbn-products','ogbn-arxiv']:
34 | loss = nn.CrossEntropyLoss(reduction='mean')
35 | metric = load_metric('accuracy')
36 | task_type='multi_classification'
37 | metric_name='accuracy'
38 | elif data_name in ["ogbg-molpcba"]:
39 | loss = nn.BCEWithLogitsLoss(reduction='mean')
40 |
41 | metric = Evaluator(name=data_name)
42 | task_type='multi_binary_classification'
43 | metric_name='AP'
44 |
45 | else:
46 | raise ValueError('no such dataset')
47 |
48 | return loss, metric, task_type,metric_name
49 |
50 |
51 |
52 | def normalization(data_list,mean,std):
53 | for i in tqdm(range(len(data_list))):
54 | data_list[i] = (data_list[i].x-mean)/std
55 | return data_list
56 |
57 |
58 | def get_graph_level_dataset(name,param=None,seed=1024,set_default_params=False,args=None):
59 |
60 | path = 'dataset/'+name
61 | print(path)
62 | train_set = None
63 | val_set = None
64 | test_set = None
65 | inner_dataset = None
66 | train_idx=None
67 | val_idx=None
68 | test_idx=None
69 |
70 | #graph regression
71 | if name=='ZINC':#250,000 molecular graphs with up to 38 heavy atoms
72 | train_set = ZINC(path,subset=True,split='train')
73 | val_set = ZINC(path,subset=True,split='val')
74 | test_set = ZINC(path,subset=True,split='test')
75 | args.node_feature_type='cate'
76 | args.num_class =1
77 | args.eval_steps=1000
78 | args.save_steps=1000
79 | args.greater_is_better = False
80 | args.warmup_steps=40000
81 | args.max_steps=400000
82 |
83 | elif name == 'ZINC-full': # 250,000 molecular graphs with up to 38 heavy atoms
84 | train_set = ZINC(path, subset=False, split='train')
85 | val_set = ZINC(path, subset=False, split='val')
86 | test_set = ZINC(path, subset=False, split='test')
87 | args.node_feature_type = 'cate'
88 | args.num_class = 1
89 | args.eval_steps = 1000
90 | args.save_steps = 1000
91 | args.greater_is_better = False
92 | args.warmup_steps = 40000
93 | args.max_steps = 400000
94 |
95 | elif name == "ogbg-molpcba":
96 | inner_dataset = PygGraphPropPredDataset(name)
97 | idx_split = inner_dataset.get_idx_split()
98 | train_idx = idx_split["train"]
99 | val_idx = idx_split["valid"]
100 | test_idx = idx_split["test"]
101 | args.node_feature_type = 'cate'
102 | args.num_class = 128
103 | args.eval_steps = 2000
104 | args.save_steps = 2000
105 | args.greater_is_better = True
106 | args.warmup_steps = 40000
107 | args.max_steps = 1000000
108 |
109 |
110 |
111 | elif name == "ogbg-molhiv":
112 | inner_dataset = PygGraphPropPredDataset(name)
113 | idx_split = inner_dataset.get_idx_split()
114 | train_idx = idx_split["train"]
115 | val_idx = idx_split["valid"]
116 | test_idx = idx_split["test"]
117 | args.node_feature_type = 'cate'
118 | args.num_class = 1
119 | args.eval_steps = 1000
120 | args.save_steps = 1000
121 | args.greater_is_better = True
122 | args.warmup_steps = 40000
123 | args.max_steps = 1200000
124 |
125 |
126 |
127 | elif name=='UPFD' and param in ('politifact', 'gossipcop'):
128 | train_set = UPFD(path,param,'bert',split='train')
129 | val_set = UPFD(path,param,'bert',split='val')
130 | test_set = UPFD(path,param,'bert',split='test')
131 | args.learning_rate=1e-5
132 | args.node_feature_type='dense'
133 | args.node_feature_dim=768
134 | args.greater_is_better = True
135 |
136 |
137 |
138 | else:
139 | raise ValueError('no such dataset')
140 |
141 |
142 | dataset = GraphormerPYGDataset(
143 | dataset=inner_dataset,
144 | train_idx=train_idx,
145 | valid_idx=val_idx,
146 | test_idx=test_idx,
147 | train_set=train_set,
148 | valid_set=val_set,
149 | test_set=test_set,
150 | seed=seed,
151 | args=args
152 | )
153 | return dataset.train_data,dataset.valid_data,dataset.test_data, inner_dataset
154 |
155 |
156 | def get_node_level_dataset(name,param=None,args=None):
157 | path = 'dataset/' + name
158 | print(path)
159 |
160 | if args.sampling_algo=='shadowkhop':
161 | args.num_neighbors=10
162 | elif args.sampling_algo=='sage':
163 | args.num_neighbors=50
164 |
165 | if name in ['cora','citeseer','dblp','pubmed']:
166 | dataset = CitationFull(f'dataset/{name}',name)
167 |
168 |
169 | elif name =='flickr':
170 | dataset = Flickr(path)
171 | x_norm_func = lambda x:x #
172 |
173 | args.node_feature_dim=500
174 | args.node_feature_type='dense'
175 | args.num_class =7
176 |
177 | args.encoder_normalize_before =True
178 | args.apply_graphormer_init =True
179 | args.greater_is_better = True
180 |
181 | args.warmup_steps=2000
182 | args.max_steps=100000
183 |
184 | train_idx = dataset.data.train_mask.nonzero().squeeze()
185 | valid_idx = dataset.data.val_mask.nonzero().squeeze()
186 | test_idx = dataset.data.test_mask.nonzero().squeeze()
187 |
188 |
189 | elif name=='ogbn-products':
190 | dataset = PygNodePropPredDataset(name='ogbn-products')
191 | split_idx = dataset.get_idx_split()
192 | train_idx, valid_idx, test_idx = split_idx["train"], split_idx["valid"], split_idx["test"]
193 |
194 | x_norm_func = lambda x:x
195 |
196 | args.node_feature_dim=100
197 | args.node_feature_type='dense'
198 | args.num_class =47
199 |
200 | args.encoder_normalize_before =True
201 | args.apply_graphormer_init =True
202 | args.greater_is_better = True
203 |
204 | args.warmup_steps=10000
205 | args.max_steps=400000
206 |
207 |
208 | elif name =='ogbn-arxiv':
209 | dataset = PygNodePropPredDataset(name='ogbn-arxiv')
210 | split_idx = dataset.get_idx_split()
211 | train_idx, valid_idx, test_idx = split_idx["train"], split_idx["valid"], split_idx["test"]
212 |
213 | x_norm_func = lambda x:x
214 |
215 | args.node_feature_dim=128
216 | args.node_feature_type='dense'
217 | args.num_class =40
218 |
219 | args.encoder_normalize_before =True
220 | args.apply_graphormer_init =True
221 | args.greater_is_better = True
222 |
223 | args.warmup_steps=10000
224 | args.max_steps=800000
225 |
226 |
227 | else:
228 | raise ValueError('no such dataset')
229 |
230 |
231 | if args.sampling_algo=='shadowkhop':
232 | Sampling_Dataset = Graphtrans_Sampling_Dataset
233 | elif args.sampling_algo=='sage':
234 | Sampling_Dataset = Graphtrans_Sampling_Dataset_v2
235 | args.num_neighbors=50
236 |
237 |
238 | train_set = Sampling_Dataset(dataset.data,
239 | node_idx=train_idx,
240 | depth=args.depth,
241 | num_neighbors=args.num_neighbors,
242 | replace=False,
243 | x_norm_func=x_norm_func,
244 | args=args)
245 | valid_set = Sampling_Dataset(dataset.data,
246 | node_idx=valid_idx,
247 | depth=args.depth,
248 | num_neighbors=args.num_neighbors,
249 | replace=False,
250 | x_norm_func=x_norm_func,
251 | args=args)
252 | test_set = Sampling_Dataset(dataset.data,
253 | node_idx=test_idx,
254 | depth=args.depth,
255 | num_neighbors=args.num_neighbors,
256 | replace=False,
257 | x_norm_func=x_norm_func,
258 | args=args)
259 |
260 | return train_set,valid_set,test_set, dataset, args
261 |
262 |
263 |
264 |
265 |
266 |
267 |
268 | #just test
269 | if __name__=='__main__':
270 | pass
271 |
272 |
273 |
274 |
--------------------------------------------------------------------------------
/run.py:
--------------------------------------------------------------------------------
1 | from tensorflow.python.util import deprecation
2 | deprecation._PRINT_DEPRECATION_WARNINGS = False
3 | from sklearn.metrics import roc_auc_score, average_precision_score
4 | import tensorflow.compat.v1 as tf
5 | tf.disable_v2_behavior()
6 | from argparse import ArgumentParser, Namespace
7 | from data.collator import *
8 | from gt_dataset import *
9 |
10 | import gc
11 | from graphtrasformer.architectures import *
12 | import transformers
13 | from transformers import (
14 | AutoConfig,
15 | AutoModelForSequenceClassification,
16 | AutoTokenizer,
17 | DataCollatorWithPadding,
18 | EvalPrediction,
19 | HfArgumentParser,
20 | Trainer,
21 | TrainingArguments,
22 | default_data_collator,
23 | set_seed,
24 | )
25 |
26 |
27 | from sklearn import metrics
28 | import h5py
29 | import numpy as np
30 | import pandas as pd
31 | from tqdm import tqdm
32 | import logging
33 | import time
34 | import torch.onnx
35 | import os
36 |
37 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
38 | tf.logging.set_verbosity(tf.logging.ERROR)
39 | import warnings
40 | warnings.filterwarnings('ignore')
41 |
42 |
43 |
44 |
45 | def str_tuple(string):
46 | return tuple(string.split(','))
47 |
48 | def boolean_string(s):
49 | if s not in {'False', 'True'}:
50 | raise ValueError('Not a valid boolean string')
51 | return s == 'True'
52 |
53 | def set_model_scale(model_scale,args):
54 | if model_scale=='mini':
55 | args.num_encoder_layers = 3
56 | args.hidden_dim = 64
57 | args.ffn_hidden_dim = 64
58 | args.num_attn_heads = 4
59 | elif model_scale=='small':
60 | args.num_encoder_layers = 6
61 | args.hidden_dim = 80
62 | args.ffn_hidden_dim = 80
63 | args.num_attn_heads = 8
64 |
65 | elif model_scale=='middle':
66 | args.num_encoder_layers = 12
67 | args.hidden_dim = 80
68 | args.ffn_hidden_dim = 80
69 | args.num_attn_heads = 8
70 | elif model_scale=='large':
71 | args.num_encoder_layers = 12
72 | args.hidden_dim = 512
73 | args.ffn_hidden_dim = 512
74 | args.num_attn_heads = 32
75 | return args
76 |
77 | def parse_args():
78 | parser = ArgumentParser()
79 |
80 | parser.add_argument('--disable_tqdm',type=boolean_string,default=True)#just for debug
81 |
82 | parser.add_argument('--model_scale', type=str, default='small')#('small','middle','large')
83 | parser.add_argument('--data_name',type=str,default='ZINC')
84 | parser.add_argument('--data_param',type=str,default=None)
85 | #basic Transformer parameters
86 | parser.add_argument('--max_node', type=int, default=512)
87 | parser.add_argument('--num_encoder_layers', type=int, default=12)
88 | parser.add_argument('--hidden_dim', type=int, default=768)
89 | parser.add_argument('--ffn_hidden_dim', type=int, default=768*3)
90 | parser.add_argument('--num_attn_heads', type=int, default=32)
91 | parser.add_argument('--emb_dropout',type=float,default=0.0)
92 | parser.add_argument('--dropout', type=float, default=0.1)
93 | parser.add_argument('--attn_dropout', type=float, default=0.1)
94 | parser.add_argument('--num_class', type=int, default=1)
95 | parser.add_argument('--encoder_normalize_before', type=boolean_string, default=True)
96 | parser.add_argument('--apply_graphormer_init', type=boolean_string, default=True)
97 | parser.add_argument('--activation_fn', type=str, default='GELU')
98 | parser.add_argument('--n_trans_layers_to_freeze', type=int, default=0)
99 | parser.add_argument('--traceable', type=boolean_string, default=False)
100 |
101 |
102 |
103 | #various positional embedding parameters
104 | parser.add_argument('--use_super_node', type=boolean_string, default=True)
105 | parser.add_argument('--node_feature_type', type=str, default=None)#or dense
106 | parser.add_argument('--node_feature_dim', type=int, default=None)# valid only for dense feature type
107 | parser.add_argument('--num_atoms', type=int, default=512*9)# valid only for cate feature type
108 | parser.add_argument('--node_level_modules', type=str_tuple, default=())#,'eig','svd'))'degree'
109 | parser.add_argument('--eig_pos_dim',type=int, default=3)#2
110 | parser.add_argument('--svd_pos_dim',type=int, default=3)
111 | parser.add_argument('--num_in_degree', type=int, default=512)
112 | parser.add_argument('--num_out_degree', type=int, default=512)
113 |
114 | #various attention bias/mask parameters
115 | parser.add_argument('--attn_level_modules', type=str_tuple, default=())#,'nhop'))'spatial',spe
116 | parser.add_argument('--attn_mask_modules',type=str, default=None)#'nhop'
117 | parser.add_argument('--num_edges', type=int, default=512*3)
118 | parser.add_argument('--num_spatial', type=int, default=512)
119 | parser.add_argument('--num_edge_dis', type=int, default=128)
120 | parser.add_argument('--spatial_pos_max', type=int, default=20)
121 | parser.add_argument('--edge_type', type=str, default=None)
122 | parser.add_argument('--multi_hop_max_dist', type=int, default=5)
123 | parser.add_argument('--num_hop_bias', type=int, default=3)#2/3/4
124 |
125 | #gnn layers parameters. Insert gnn layers before/alternate/parallel self-attention layers
126 | #gnn layers are implemented by pytorch geometric for simplicity, so we always require data transformation across gnn layer and self-attention layers
127 | parser.add_argument('--use_gnn_layers', type=boolean_string, default=False)
128 | parser.add_argument('--gnn_insert_pos', type=str, default='before')#'before'/'alter'/'parallel' gnn insert position
129 | parser.add_argument('--num_gnn_layers', type=int, default=1) #
130 | parser.add_argument('--gnn_type',type=str,default='GAT') #GCN,SAGE,GAT,RGCN ... any types of GNN supported by Geometric
131 | parser.add_argument('--gnn_dropout',type=float,default=0.5)
132 |
133 |
134 | #sampling parameters
135 | parser.add_argument('--depth',type=int,default=2)
136 | parser.add_argument('--num_neighbors', type=int,default=10)
137 | parser.add_argument('--sampling_algo',type=str,default='shadowkhop')# or sage
138 |
139 |
140 | # training parameters, we use Trainer class from Huggingface Transformer, which is highly optimized specifically for Transformer architecture
141 | parser.add_argument('--seed', type=int, default=1)
142 | parser.add_argument('--output_dir',type=str)#G#v2'./output'
143 | parser.add_argument('--per_device_train_batch_size', type=int, default=256)
144 | parser.add_argument('--per_device_eval_batch_size', type=int, default=256)
145 | parser.add_argument('--gradient_accumulation_steps',type=int,default=1)
146 | parser.add_argument('--learning_rate',type=float, default=2e-4)
147 | parser.add_argument('--weight_decay',type=float,default=0.01)
148 | parser.add_argument('--adam_beta1',type=float,default='0.9')
149 | parser.add_argument('--adam_beta2',type=float,default='0.999')
150 | parser.add_argument('--adam_epsilon',type=float,default=1e-8)
151 | parser.add_argument('--max_grad_norm',type=float,default=5.0)
152 | parser.add_argument('--num_train_epochs',type=int,default=300)
153 | parser.add_argument('--max_steps',type=int,default=400000)#1000000
154 | parser.add_argument('--lr_scheduler_type', type=str, default='linear')
155 | parser.add_argument('--warmup_steps',type=int,default=40000)
156 | parser.add_argument('--dataloader_num_workers',type=int,default=32)
157 | parser.add_argument('--evaluation_strategy',type=str,default='steps')
158 | parser.add_argument('--eval_steps',type=int,default=1000)
159 | parser.add_argument('--save_steps',type=int,default=1000)
160 | parser.add_argument('--greater_is_better',type=boolean_string,default=True)
161 |
162 | parser.add_argument('--rerun',type=boolean_string,default=False)
163 |
164 | args = parser.parse_args()
165 | set_model_scale(args.model_scale,args)
166 |
167 |
168 | return args
169 |
170 |
171 |
172 |
173 |
174 |
175 |
176 | def expand_graph_level_dataset(dataset,N):
177 | train_data = list(dataset.dataset)
178 | dataset.dataset = N*train_data
179 | dataset.num_data*=N
180 | return dataset
181 | def expand_node_level_dataset(dataset,N):
182 | dataset.node_idx = dataset.node_idx.expand(N,-1).reshape(-1)
183 | dataset.num_data*=N
184 | return dataset
185 |
186 |
187 |
188 |
189 |
190 | if __name__=='__main__':
191 | args = parse_args()
192 |
193 |
194 | #for efficiency in Trainer.
195 | # I don't know why the Trainer Class will suspend several seconds after each epoch of dataloader,
196 | # So I expand the training set manually for efficiency
197 | expand_num_dict={'flickr':20,
198 | 'ZINC':20,
199 | 'ogbn-products':5,
200 | 'ogbg-molpcba':1,
201 | 'ZINC-full':1,
202 | 'ogbn-arxiv':1,
203 | "ogbg-molhiv":5}
204 |
205 |
206 | if args.data_name in ('flickr','ogbn-products','ogbn-arxiv'):
207 | train_set, valid_set, test_set, odata, args = get_node_level_dataset(args.data_name, args=args)
208 | train_set = expand_node_level_dataset(train_set,expand_num_dict[args.data_name])
209 |
210 | elif args.data_name in ('ZINC','UPFD','ogbg-molpcba','ZINC-full',"ogbg-molhiv"):
211 | train_set,valid_set,test_set,odata = get_graph_level_dataset(args.data_name,param=args.data_param,set_default_params=True,args=args)
212 | train_set = expand_graph_level_dataset(train_set, expand_num_dict[args.data_name])
213 |
214 | else:
215 | raise ValueError('no dataset')
216 |
217 |
218 | criterion, metric, task_type,metric_name=get_loss_and_metric(args.data_name)
219 |
220 | #print parameters
221 | for k,v in vars(args).items():
222 | print(k,v)
223 | #========================model===============================
224 | model=get_model(args)
225 |
226 |
227 |
228 | log_file_param_list = (args.data_name,
229 | args.model_scale,
230 | args.use_super_node,
231 | args.node_level_modules,
232 | args.eig_pos_dim,
233 | args.svd_pos_dim,
234 | args.attn_level_modules,
235 | args.attn_mask_modules,
236 | args.num_hop_bias,
237 | args.use_gnn_layers,
238 | args.gnn_insert_pos,
239 | args.num_gnn_layers,
240 | args.gnn_type,
241 | args.gnn_dropout,
242 | args.sampling_algo,
243 | args.depth,
244 | args.num_neighbors,
245 | args.seed)
246 |
247 | log_file_param_list_p=[]
248 | for x in log_file_param_list:
249 | if isinstance(x, tuple):
250 | if len(x)==0:
251 | x='None'
252 | else:
253 | x='+'.join(x)
254 | else:
255 | x=str(x)
256 |
257 | log_file_param_list_p.append(x)
258 |
259 | output_dir ='./outputs/'+'_'.join(log_file_param_list_p)
260 | log_file_path = output_dir+'/logs.json'
261 |
262 | setattr(args,'log_file_path',log_file_path)
263 | setattr(args,'output_dir', output_dir)
264 |
265 |
266 |
267 | ##huggingface trainer============================
268 | def compute_metrics(p: EvalPrediction):
269 | preds,labels = p
270 |
271 | gc.collect()
272 | if task_type=='multi_classification':
273 | preds = np.argmax(preds, axis=1)
274 | labels = labels.astype(np.long)
275 | return metric.compute(predictions=preds, references=labels)
276 |
277 | elif task_type=='multi_binary_classification' and metric_name=='AP':
278 | preds = torch.sigmoid(torch.tensor(preds)).numpy()
279 | return {metric_name:metric.eval({'y_true':labels,'y_pred':preds})['ap']}#输入的格式,输出的格式都要确认 #确认node edge特征是否正确处理
280 |
281 | elif task_type=='regression':
282 | return {metric_name:metric(torch.tensor(preds),torch.tensor(labels)).item()}#mae
283 |
284 | elif task_type=='binary_classification' and metric_name=='ROC-AUC':
285 | return {metric_name:roc_auc_score(y_true=labels,y_score=torch.sigmoid(torch.tensor(preds)).numpy())}
286 | elif task_type=='binary_classification' and metric_name=='accuracy':
287 | return metric.compute(predictions=torch.sigmoid(torch.tensor(preds)), references=labels)
288 |
289 |
290 |
291 | from transformers import TrainerCallback,TrainerState,TrainerControl,EarlyStoppingCallback
292 | class MyCallback(TrainerCallback):
293 | def on_log(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
294 | print("save logs...")
295 | state.save_to_json(log_file_path)
296 |
297 |
298 |
299 | class MyTrainer(Trainer):
300 | def compute_loss(self, model, inputs, return_outputs=False):
301 | self.inputs=inputs
302 | labels = inputs['labels']
303 | outputs = model(inputs)
304 |
305 | labels = labels.long() if task_type=='multi_classification' else labels.float()
306 | if task_type=='multi_binary_classification':
307 | labels = labels.reshape(-1)
308 | mask = ~torch.isnan(labels)
309 | loss = criterion(outputs['logits'].reshape(-1)[mask],labels[mask])
310 |
311 | else:
312 | loss = criterion(outputs['logits'],labels)
313 | return (loss,outputs) if return_outputs else loss
314 |
315 |
316 | training_args = TrainingArguments(
317 | output_dir=args.output_dir,
318 | evaluation_strategy=args.evaluation_strategy,
319 | eval_steps=args.eval_steps,
320 | per_device_train_batch_size=args.per_device_train_batch_size,
321 | per_device_eval_batch_size=args.per_device_eval_batch_size,
322 | gradient_accumulation_steps=args.gradient_accumulation_steps,
323 | learning_rate=args.learning_rate,
324 | weight_decay=args.weight_decay,
325 | adam_beta1=args.adam_beta1,
326 | adam_beta2=args.adam_beta2,
327 | adam_epsilon=args.adam_epsilon,
328 | max_grad_norm=args.max_grad_norm,
329 | num_train_epochs=args.num_train_epochs,
330 | max_steps=args.max_steps,
331 | lr_scheduler_type=args.lr_scheduler_type,
332 | warmup_steps=args.warmup_steps,
333 | dataloader_num_workers=args.dataloader_num_workers,# sensitive
334 | load_best_model_at_end=True,
335 | metric_for_best_model=metric_name,
336 | greater_is_better=args.greater_is_better,
337 | save_steps=args.save_steps,
338 | save_total_limit=10,
339 | logging_steps=args.eval_steps,
340 | seed=args.seed
341 |
342 | )
343 |
344 |
345 | training_args.disable_tqdm=args.disable_tqdm
346 | training_args.ignore_data_skip=True
347 |
348 |
349 |
350 | resume_from_checkpoint = True if (check_checkpoints(args.output_dir) and not args.rerun) else None
351 |
352 | trainer = MyTrainer(
353 | model=model,
354 | args=training_args,
355 | train_dataset=train_set,
356 | eval_dataset=valid_set,
357 | compute_metrics=compute_metrics,
358 | data_collator=lambda x:collator(x,args),
359 | callbacks=[MyCallback,EarlyStoppingCallback(early_stopping_patience=20)]
360 | )
361 | trainer.args._n_gpu=1
362 |
363 | print(trainer.evaluate())
364 | trainer.train(resume_from_checkpoint=resume_from_checkpoint)
365 |
366 |
367 | predictions, labels, test_metrics = trainer.predict(test_set, metric_key_prefix="predict")
368 | test_metrics['best_val_metric']=trainer.state.best_metric
369 | test_metrics['best_model_checkpoint']=trainer.state.best_model_checkpoint
370 | f=open(args.output_dir+'/test.txt','w')
371 | for k,v in test_metrics.items():
372 | f.write(str(k)+':'+str(v)+'\n')
373 | f.write('\n')
374 | f.close()
375 |
376 |
377 |
378 |
379 |
380 |
381 |
--------------------------------------------------------------------------------
/utils/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import scipy.sparse as sp
4 | from numpy.linalg import inv
5 | import pickle
6 |
7 | from torch_geometric.datasets import *
8 |
9 | import torch
10 | import numpy as np
11 | from torch_sparse.matmul import matmul
12 | from torch_sparse import SparseTensor
13 |
14 |
15 | c = 0.15
16 | k = 5
17 |
18 |
19 | def adj_normalize(mx):
20 | rowsum = np.array(mx.sum(1))
21 | r_inv = np.power(rowsum, -0.5).flatten()
22 | r_inv[np.isinf(r_inv)] = 0.
23 | r_mat_inv = sp.diags(r_inv)
24 | mx = r_mat_inv.dot(mx).dot(r_mat_inv)
25 | return mx
26 |
27 |
28 | def get_intimacy_matrix(edges,n):
29 | edges= np.array(edges)
30 | adj = sp.coo_matrix((np.ones(edges.shape[0]), (edges[:, 0], edges[:, 1])),
31 | shape=(n,n),
32 | dtype=np.float32)
33 | print('normalize')
34 | adj_norm = adj_normalize(adj)
35 | print('inverse')
36 | eigen_adj = c * inv((sp.eye(adj.shape[0]) - (1 - c) * adj_norm).toarray())
37 |
38 | return eigen_adj
39 |
40 |
41 | def adj_normalize_sparse(mx):
42 | mx=mx.to(device)
43 | rowsum = mx.sum(1)
44 | r_inv =rowsum.pow(-0.5).flatten()
45 | r_inv[torch.isinf(r_inv)] = 0.
46 | r_mat_inv = SparseTensor(row = torch.arange(n).to(device),col=torch.arange(n).to(device),value=r_inv, sparse_sizes=(n,n))
47 | nr_mx = matmul(matmul(r_mat_inv,mx),r_mat_inv)
48 | return nr_mx
49 |
50 | def get_intimacy_matrix_sparse(edges,n):
51 | adj = SparseTensor(row=edges[0], col=edges[1], value=torch.ones(edges.shape[1]), sparse_sizes=(n, n))
52 | adj_norm = adj_normalize_sparse(adj)
53 | return adj_norm
54 |
55 | def get_svd_dense(mx,q=3):
56 | mx = mx.float()
57 | u,s,v = torch.svd_lowrank(mx,q=q)
58 | s=torch.diag(s)
59 | pu = u@s.pow(0.5)
60 | pv = v@s.pow(0.5)
61 | return pu,pv
62 |
63 |
64 | def unweighted_adj_normalize_dense_batch(adj):
65 | adj = (adj+adj.transpose(-1,-2)).bool().float()
66 | adj = adj.float()
67 | rowsum = adj.sum(-1)
68 | r_inv = rowsum.pow(-0.5)
69 | r_mat_inv = torch.diag_embed(r_inv)
70 | nr_adj = torch.matmul(torch.matmul(r_mat_inv,adj),r_mat_inv)
71 | return nr_adj
72 |
73 |
74 | def get_eig_dense(adj):
75 | adj = adj.float()
76 | rowsum = adj.sum(1)
77 | r_inv =rowsum.pow(-0.5)
78 | r_mat_inv = torch.diag(r_inv)
79 | nr_adj = torch.matmul(torch.matmul(r_mat_inv,adj),r_mat_inv)
80 | graph_laplacian = torch.eye(adj.shape[0])-nr_adj
81 | L,V = torch.eig(graph_laplacian,eigenvectors=True)
82 | return L.T[0],V
83 |
84 |
85 |
86 | def check_checkpoints(output_dir):
87 | import os
88 | import shutil
89 | if os.path.exists(output_dir):
90 | files = os.listdir(output_dir)
91 | for file in files:
92 | if 'checkpoint' in file:
93 |
94 | return True
95 | print('remove ',output_dir)
96 | shutil.rmtree(output_dir)
97 | return False
98 |
99 |
100 | if __name__=='__main__':
101 | #just test
102 |
103 | device = torch.device('cuda',0)
104 |
105 | data = Flickr('dataset/flickr')
106 |
107 | edges= data.data.edge_index
108 | n=data.data.x.shape[0]
109 |
110 |
111 | adj = SparseTensor(row=edges[0], col=edges[1], value=torch.ones(edges.shape[1]), sparse_sizes=(n, n))
112 | nr_adj = adj_normalize_sparse(adj)
113 |
114 | pu,pv= get_svd_dense(nr_adj.to_torch_sparse_coo_tensor(),q=10)
115 |
116 |
117 | adj= (torch.randn(10,10)>0).float()
118 | L,V = get_eig_dense(adj)
119 |
--------------------------------------------------------------------------------