├── .idea ├── .gitignore ├── Graph-Trans.iml ├── dictionaries ├── inspectionProfiles │ ├── Project_Default.xml │ └── profiles_settings.xml ├── misc.xml ├── modules.xml └── vcs.xml ├── README.md ├── data ├── __init__.py ├── algos.pyx ├── collator.py ├── dataset.py ├── ogb_datasets │ ├── __init__.py │ └── ogb_dataset_lookup_table.py ├── pyg_datasets │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ ├── pyg_dataset.cpython-36.pyc │ │ └── pyg_dataset_lookup_table.cpython-36.pyc │ ├── pyg_dataset.py │ └── pyg_dataset_lookup_table.py ├── smiles │ └── smiles_dataset.py └── wrapper.py ├── graphtrasformer ├── __pycache__ │ ├── architectures.cpython-36.pyc │ ├── gt_layers.cpython-36.pyc │ ├── gt_models.cpython-36.pyc │ └── layers.cpython-36.pyc ├── architectures.py ├── gnn_layers.py ├── gt_layers.py ├── gt_models.py ├── layer_tests.py └── layers.py ├── gt_dataset.py ├── run.py └── utils └── utils.py /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # Default ignored files 2 | /shelf/ 3 | /workspace.xml 4 | # Editor-based HTTP Client requests 5 | /httpRequests/ 6 | # Datasource local storage ignored files 7 | /dataSources/ 8 | /dataSources.local.xml 9 | -------------------------------------------------------------------------------- /.idea/Graph-Trans.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/dictionaries: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/Project_Default.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 71 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 36 | 37 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | # Graph-Transformer Framework 4 | 5 | Source code for the paper "**[Transformer for Graphs: An Overview from Architecture Perspective](https://arxiv.org/pdf/2202.08455.pdf)**" 6 | 7 | 8 | We provide a comprehensive review of various Graph Transformer models from the architectural design perspective. 9 | We first disassemble the existing models and conclude three typical ways to incorporate the graph 10 | information into the vanilla Transformer: 11 | - GNNs as Auxiliary Modules, 12 | - Improved Positional Embedding from Graphs 13 | - Improved Attention Matrix from Graphs. 14 | 15 | We implement the representative components in three groups and conduct a comprehensive comparison on various kinds of famous graph data benchmarks to investigate the real performance gain of each component. 16 | 17 | 18 | 19 | 20 | 21 | ## Running 22 | 23 | - Train and Evaluation. Please see details in our code annotations. 24 | ``` 25 | $python run.py --seed ${CUSTOMIZED_SEED} \ 26 | --model_scale ${CUSTOMIZED_SCALE} \ 27 | --data_name ${CUSTOMIZED_DATASET} \ 28 | --use_super_node ${True/False} \ 29 | --node_level_modules ${CUSTOMIZED_NODE_MODULES} \ 30 | --attn_level_modules ${CUSTOMIZED_ATTENTION_MODULES} \ 31 | --attn_mask_modules ${CUSTOMIZED_MASK_MODULES} \ 32 | --use_gnn_layers ${True/False} \ 33 | --gnn_insert_pos ${CUSTOMIZED_GNN_POSTION} \ 34 | --gnn_type ${CUSTOMIZED_GNN} \ 35 | --sampling_algo ${CUSTOMIZED_SAMPLING_ALGORITHMS} 36 | ``` 37 | - Example 1: Transformer with degree postional embedding, spatial encoding, shortest path edge encoding 38 | 39 | ``` 40 | $python run.py --seed 1024 \ 41 | --model_scale small \ 42 | --data_name ZINC \ 43 | --use_super_node True \ 44 | --node_level_modules degree \ 45 | --attn_level_modules spatial,spe \ 46 | ``` 47 | - Example 2: Transformer with 1hop attention mask 48 | ``` 49 | $python run.py --seed 1024 \ 50 | --model_scale middle \ 51 | --data_name flickr \ 52 | --use_super_node True \ 53 | --node_level_modules eig,svd \ 54 | --attn_mask_modules 1hop \ 55 | --sampling_algo shadowkhop \ 56 | --depth 2 \ 57 | --num_neighbors 10 58 | ``` 59 | - Example 3: Transformer with GIN layers before Transformer layers 60 | ``` 61 | $python run.py --seed 1024 \ 62 | --model_scale large \ 63 | --data_name ZINC \ 64 | --use_super_node True \ 65 | --use_gnn_layers True \ 66 | --gnn_insert_pos before \ 67 | --gnn_type GIN 68 | ``` 69 | 70 | 71 | 72 | 73 | ## Requirements 74 | - Python 3.x 75 | - pytorch >=1.5.0 76 | - torch-geometric >=2.0.3 77 | - transformers >= 4.8.2 78 | - tensorflow >= 2.3.1 79 | - scikit-learn >= 0.23.2 80 | - ogb >= 1.3.2 81 | - datasets >=1.8.0 82 | 83 | ## Results 84 | Please refer to our [paper](https://arxiv.org/pdf/2202.08455.pdf) 85 | 86 | ## Reference 87 | Please cite the paper whenever our graph transformer is used to produce published results or incorporated into other software: 88 | ``` 89 | @article{min2022transformer, 90 | title={Transformer for Graphs: An Overview from Architecture Perspective}, 91 | author={Min, Erxue and Chen, Runfa and Bian, Yatao and Xu, Tingyang and Zhao, Kangfei and Huang, Wenbing and Zhao, Peilin and Huang, Junzhou and Ananiadou, Sophia and Rong, Yu}, 92 | journal={arXiv preprint arXiv:2202.08455}, 93 | year={2022} 94 | } 95 | ``` 96 | 97 | 98 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | DATASET_REGISTRY = {} 2 | 3 | def register_dataset(name: str): 4 | def register_dataset_func(func): 5 | DATASET_REGISTRY[name] = func() 6 | return register_dataset_func 7 | -------------------------------------------------------------------------------- /data/algos.pyx: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | import cython 5 | from cython.parallel cimport prange, parallel 6 | cimport numpy 7 | import numpy 8 | 9 | def floyd_warshall(adjacency_matrix): 10 | 11 | (nrows, ncols) = adjacency_matrix.shape 12 | assert nrows == ncols 13 | cdef unsigned int n = nrows 14 | 15 | adj_mat_copy = adjacency_matrix.astype(long, order='C', casting='safe', copy=True) 16 | assert adj_mat_copy.flags['C_CONTIGUOUS'] 17 | cdef numpy.ndarray[long, ndim=2, mode='c'] M = adj_mat_copy 18 | cdef numpy.ndarray[long, ndim=2, mode='c'] path = numpy.zeros([n, n], dtype=numpy.int64) 19 | 20 | cdef unsigned int i, j, k 21 | cdef long M_ij, M_ik, cost_ikkj 22 | cdef long* M_ptr = &M[0,0] 23 | cdef long* M_i_ptr 24 | cdef long* M_k_ptr 25 | 26 | # set unreachable nodes distance to 510 27 | for i in range(n): 28 | for j in range(n): 29 | if i == j: 30 | M[i][j] = 0 31 | elif M[i][j] == 0: 32 | M[i][j] = 510 33 | 34 | # floyed algo 35 | for k in range(n): 36 | M_k_ptr = M_ptr + n*k 37 | for i in range(n): 38 | M_i_ptr = M_ptr + n*i 39 | M_ik = M_i_ptr[k] 40 | for j in range(n): 41 | cost_ikkj = M_ik + M_k_ptr[j] 42 | M_ij = M_i_ptr[j] 43 | if M_ij > cost_ikkj: 44 | M_i_ptr[j] = cost_ikkj 45 | path[i][j] = k 46 | 47 | # set unreachable path to 510 48 | for i in range(n): 49 | for j in range(n): 50 | if M[i][j] >= 510: 51 | path[i][j] = 510 52 | M[i][j] = 510 53 | 54 | return M, path 55 | 56 | 57 | def get_all_edges(path, i, j): 58 | cdef unsigned int k = path[i][j] 59 | if k == 0: 60 | return [] 61 | else: 62 | return get_all_edges(path, i, k) + [k] + get_all_edges(path, k, j) 63 | 64 | 65 | def gen_edge_input(max_dist, path, edge_feat): 66 | 67 | (nrows, ncols) = path.shape 68 | assert nrows == ncols 69 | cdef unsigned int n = nrows 70 | cdef unsigned int max_dist_copy = max_dist 71 | 72 | path_copy = path.astype(long, order='C', casting='safe', copy=True) 73 | edge_feat_copy = edge_feat.astype(long, order='C', casting='safe', copy=True) 74 | assert path_copy.flags['C_CONTIGUOUS'] 75 | assert edge_feat_copy.flags['C_CONTIGUOUS'] 76 | 77 | cdef numpy.ndarray[long, ndim=4, mode='c'] edge_fea_all = -1 * numpy.ones([n, n, max_dist_copy, edge_feat.shape[-1]], dtype=numpy.int64) 78 | cdef unsigned int i, j, k, num_path, cur 79 | 80 | for i in range(n): 81 | for j in range(n): 82 | if i == j: 83 | continue 84 | if path_copy[i][j] == 510: 85 | continue 86 | path = [i] + get_all_edges(path_copy, i, j) + [j] 87 | num_path = len(path) - 1 88 | for k in range(num_path): 89 | edge_fea_all[i, j, k, :] = edge_feat_copy[path[k], path[k+1], :] 90 | 91 | return edge_fea_all 92 | -------------------------------------------------------------------------------- /data/collator.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | import torch 5 | from torch_geometric.data import Batch,Data 6 | 7 | def pad_1d_unsqueeze(x, padlen): 8 | x = x + 1 # pad id = 0 9 | xlen = x.size(0) 10 | if xlen < padlen: 11 | new_x = x.new_zeros([padlen], dtype=x.dtype) 12 | new_x[:xlen] = x 13 | x = new_x 14 | return x.unsqueeze(0) 15 | 16 | 17 | def pad_2d_unsqueeze(x, padlen): 18 | x = x + 1 # pad id = 0 19 | xlen, xdim = x.size() 20 | if xlen < padlen: 21 | new_x = x.new_zeros([padlen, xdim], dtype=x.dtype) 22 | new_x[:xlen, :] = x 23 | x = new_x 24 | return x.unsqueeze(0) 25 | 26 | 27 | def pad_attn_bias_unsqueeze(x, padlen): 28 | xlen = x.size(0) 29 | if xlen < padlen: 30 | new_x = x.new_zeros([padlen, padlen], dtype=x.dtype).fill_(float("-inf")) 31 | new_x[:xlen, :xlen] = x 32 | new_x[xlen:, :xlen] = 0 33 | x = new_x 34 | return x.unsqueeze(0) 35 | 36 | 37 | def pad_edge_type_unsqueeze(x, padlen): 38 | xlen = x.size(0) 39 | if xlen < padlen: 40 | new_x = x.new_zeros([padlen, padlen, x.size(-1)], dtype=x.dtype) 41 | new_x[:xlen, :xlen, :] = x 42 | x = new_x 43 | return x.unsqueeze(0) 44 | 45 | def pad_pos_emb_unsqueeze(x, padlen): 46 | xlen, xdim = x.size() 47 | if xlen < padlen: 48 | new_x = x.new_zeros([padlen, xdim], dtype=x.dtype) 49 | new_x[:xlen, :] = x 50 | x = new_x 51 | return x.unsqueeze(0) 52 | 53 | 54 | def pad_spatial_pos_unsqueeze(x, padlen): 55 | x = x + 1 56 | xlen = x.size(0) 57 | if xlen < padlen: 58 | new_x = x.new_zeros([padlen, padlen], dtype=x.dtype) 59 | new_x[:xlen, :xlen] = x 60 | x = new_x 61 | return x.unsqueeze(0) 62 | 63 | def pad_adj_unsqueeze(x, padlen): 64 | xlen = x.size(0) 65 | if xlen < padlen: 66 | new_x = x.new_zeros([padlen, padlen], dtype=x.dtype) 67 | new_x[:xlen, :xlen] = x 68 | x = new_x 69 | return x.unsqueeze(0) 70 | 71 | 72 | def pad_3d_unsqueeze(x, padlen1, padlen2, padlen3): 73 | x = x + 1 74 | xlen1, xlen2, xlen3, xlen4 = x.size() 75 | if xlen1 < padlen1 or xlen2 < padlen2 or xlen3 < padlen3: 76 | new_x = x.new_zeros([padlen1, padlen2, padlen3, xlen4], dtype=x.dtype) 77 | new_x[:xlen1, :xlen2, :xlen3, :] = x 78 | x = new_x 79 | return x.unsqueeze(0) 80 | 81 | 82 | def collator(items, args): 83 | 84 | max_node = args.max_node 85 | multi_hop_max_dist = args.multi_hop_max_dist 86 | spatial_pos_max = args.spatial_pos_max 87 | 88 | items = [item for item in items if item is not None and item.x.size(0) <= max_node] 89 | items = [ 90 | ( 91 | item.idx, 92 | item.attn_bias, 93 | item.attn_edge_type, 94 | item.spatial_pos, 95 | item.in_degree, 96 | item.out_degree, 97 | item.x, 98 | item.edge_input, 99 | item.y, 100 | item.adj, 101 | item.adj_norm, 102 | item.edge_index, 103 | item.eig_pos_emb, 104 | item.svd_pos_emb, 105 | item.root_n_id 106 | ) 107 | for item in items 108 | ] 109 | ( 110 | idxs, 111 | attn_biases, 112 | attn_edge_types, 113 | spatial_poses, 114 | in_degrees, 115 | out_degrees, 116 | xs, 117 | edge_inputs, 118 | ys, 119 | adjs, 120 | adj_norms, 121 | edge_indexs, 122 | eig_pos_embs, 123 | svd_pos_embs, 124 | root_n_ids 125 | ) = zip(*items) 126 | 127 | for i, _ in enumerate(attn_biases): 128 | attn_biases[i][int(args.use_super_node):, int(args.use_super_node):][spatial_poses[i] >= spatial_pos_max] = float("-inf") 129 | max_node_num = max(i.size(0) for i in xs) 130 | ns = [x.size(0) for x in xs] 131 | x_mask = torch.zeros(len(xs),max_node_num) 132 | for i,n in enumerate(ns): 133 | x_mask[i,:n]=1 134 | 135 | 136 | 137 | y = torch.cat(ys) 138 | root_n_id = torch.tensor(root_n_ids) 139 | 140 | if args.node_feature_type=='cate': 141 | x = torch.cat([pad_2d_unsqueeze(i, max_node_num) for i in xs]) 142 | else: 143 | x = torch.cat([pad_pos_emb_unsqueeze(i, max_node_num) for i in xs]) 144 | 145 | 146 | if isinstance(edge_inputs[0],int): 147 | edge_input=None 148 | attn_edge_type=None 149 | else: 150 | max_dist = max(i.size(-2) for i in edge_inputs) 151 | edge_input = torch.cat( 152 | [pad_3d_unsqueeze(i[:, :, :multi_hop_max_dist, :], max_node_num, max_node_num, max_dist) for i in edge_inputs] 153 | ) 154 | attn_edge_type = torch.cat( 155 | [pad_edge_type_unsqueeze(i, max_node_num) for i in attn_edge_types] 156 | ) 157 | 158 | attn_bias = torch.cat( 159 | [pad_attn_bias_unsqueeze(i, max_node_num + int(args.use_super_node)) for i in attn_biases] 160 | ) 161 | 162 | 163 | in_degree = torch.cat([pad_1d_unsqueeze(i, max_node_num) for i in in_degrees]) if not isinstance(in_degrees[0],int) else None 164 | adj = torch.cat([pad_adj_unsqueeze(a, max_node_num) for a in adjs]) 165 | 166 | adj_norm = torch.cat([pad_adj_unsqueeze(a, max_node_num) for a in adj_norms]) if not isinstance(adj_norms[0], int) else None 167 | 168 | 169 | spatial_pos = torch.cat( 170 | [pad_spatial_pos_unsqueeze(i, max_node_num) for i in spatial_poses] 171 | ) if not isinstance(spatial_poses[i],int) else None 172 | 173 | 174 | batch_edge_index = Batch.from_data_list([Data(edge_index=ei, num_nodes=ns[i]) for i, ei in enumerate(edge_indexs)]).edge_index if args.use_gnn_layers else None 175 | 176 | 177 | eig_pos_embs = torch.cat([pad_pos_emb_unsqueeze(i, max_node_num) for i in eig_pos_embs]) if not isinstance(eig_pos_embs[0],int) else None 178 | 179 | svd_pos_embs = torch.cat([pad_pos_emb_unsqueeze(i, max_node_num) for i in svd_pos_embs]) if not isinstance(svd_pos_embs[0],int) else None 180 | 181 | 182 | 183 | return dict( 184 | idx=torch.LongTensor(idxs), 185 | attn_bias=attn_bias, 186 | attn_edge_type=attn_edge_type, 187 | spatial_pos=spatial_pos, 188 | in_degree=in_degree, 189 | out_degree=in_degree, # for undirected graph 190 | x=x, 191 | edge_input=edge_input, 192 | x_mask = x_mask, 193 | ns = torch.LongTensor(ns), #node number in each graph 194 | labels=y.squeeze(), 195 | adj = adj, 196 | adj_norm=adj_norm, 197 | edge_index = batch_edge_index, 198 | eig_pos_emb=eig_pos_embs, 199 | svd_pos_emb=svd_pos_embs, 200 | root_n_id=root_n_id 201 | ) 202 | -------------------------------------------------------------------------------- /data/dataset.py: -------------------------------------------------------------------------------- 1 | 2 | from typing import Optional, Union 3 | from torch_geometric.data import Data as PYGDataset 4 | from dgl.data import DGLDataset 5 | from .pyg_datasets import PYGDatasetLookupTable, GraphormerPYGDataset 6 | from .ogb_datasets import OGBDatasetLookupTable 7 | 8 | 9 | 10 | 11 | 12 | class GraphormerDataset: 13 | def __init__( 14 | self, 15 | dataset: Optional[Union[PYGDataset, DGLDataset]] = None, 16 | dataset_spec: Optional[str] = None, 17 | dataset_source: Optional[str] = None, 18 | seed: int = 0, 19 | train_idx = None, 20 | valid_idx = None, 21 | test_idx = None, 22 | ): 23 | super().__init__() 24 | if dataset is not None: 25 | self.dataset = GraphormerPYGDataset(dataset, train_idx, valid_idx, test_idx) 26 | 27 | elif dataset_source == "pyg": 28 | self.dataset = PYGDatasetLookupTable.GetPYGDataset(dataset_spec, seed) 29 | elif dataset_source == "ogb": 30 | self.dataset = OGBDatasetLookupTable.GetOGBDataset(dataset_spec, seed) 31 | self.setup() 32 | 33 | def setup(self): 34 | self.train_idx = self.dataset.train_idx 35 | self.valid_idx = self.dataset.valid_idx 36 | self.test_idx = self.dataset.test_idx 37 | 38 | self.dataset_train = self.dataset.train_data 39 | self.dataset_val = self.dataset.valid_data 40 | self.dataset_test = self.dataset.test_data 41 | -------------------------------------------------------------------------------- /data/ogb_datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | from .ogb_dataset_lookup_table import OGBDatasetLookupTable 5 | -------------------------------------------------------------------------------- /data/ogb_datasets/ogb_dataset_lookup_table.py: -------------------------------------------------------------------------------- 1 | 2 | from typing import Optional 3 | from ogb.lsc.pcqm4mv2_pyg import PygPCQM4Mv2Dataset 4 | from ogb.lsc.pcqm4m_pyg import PygPCQM4MDataset 5 | from ogb.graphproppred import PygGraphPropPredDataset 6 | from torch_geometric.data import Dataset 7 | from ..pyg_datasets import GraphormerPYGDataset 8 | import torch.distributed as dist 9 | import os 10 | 11 | 12 | 13 | 14 | 15 | class MyPygGraphPropPredDataset(PygGraphPropPredDataset): 16 | def download(self): 17 | if not dist.is_initialized() or dist.get_rank() == 0: 18 | super(MyPygGraphPropPredDataset, self).download() 19 | if dist.is_initialized(): 20 | dist.barrier() 21 | 22 | def process(self): 23 | if not dist.is_initialized() or dist.get_rank() == 0: 24 | super(MyPygGraphPropPredDataset, self).process() 25 | if dist.is_initialized(): 26 | dist.barrier() 27 | 28 | 29 | class OGBDatasetLookupTable: 30 | @staticmethod 31 | def GetOGBDataset(dataset_name: str, seed: int) -> Optional[Dataset]: 32 | inner_dataset = None 33 | train_idx = None 34 | valid_idx = None 35 | test_idx = None 36 | if dataset_name == "ogbg-molhiv": 37 | folder_name = dataset_name.replace("-", "_") 38 | os.system(f"mkdir -p dataset/{folder_name}/") 39 | os.system(f"touch dataset/{folder_name}/RELEASE_v1.txt") 40 | inner_dataset = MyPygGraphPropPredDataset(dataset_name) 41 | idx_split = inner_dataset.get_idx_split() 42 | train_idx = idx_split["train"] 43 | valid_idx = idx_split["valid"] 44 | test_idx = idx_split["test"] 45 | elif dataset_name == "ogbg-molpcba": 46 | folder_name = dataset_name.replace("-", "_") 47 | os.system(f"mkdir -p dataset/{folder_name}/") 48 | os.system(f"touch dataset/{folder_name}/RELEASE_v1.txt") 49 | inner_dataset = MyPygGraphPropPredDataset(dataset_name) 50 | idx_split = inner_dataset.get_idx_split() 51 | train_idx = idx_split["train"] 52 | valid_idx = idx_split["valid"] 53 | test_idx = idx_split["test"] 54 | elif dataset_name == "pcqm4mv2": 55 | os.system("mkdir -p dataset/pcqm4m-v2/") 56 | os.system("touch dataset/pcqm4m-v2/RELEASE_v1.txt") 57 | inner_dataset = MyPygPCQM4Mv2Dataset() 58 | idx_split = inner_dataset.get_idx_split() 59 | train_idx = idx_split["train"] 60 | valid_idx = idx_split["valid"] 61 | test_idx = idx_split["test-dev"] 62 | elif dataset_name == "pcqm4m": 63 | os.system("mkdir -p dataset/pcqm4m_kddcup2021/") 64 | os.system("touch dataset/pcqm4m_kddcup2021/RELEASE_v1.txt") 65 | inner_dataset = MyPygPCQM4MDataset() 66 | idx_split = inner_dataset.get_idx_split() 67 | train_idx = idx_split["train"] 68 | valid_idx = idx_split["valid"] 69 | test_idx = idx_split["test"] 70 | else: 71 | raise ValueError(f"Unknown dataset name {dataset_name} for ogb source.") 72 | return ( 73 | None 74 | if inner_dataset is None 75 | else GraphormerPYGDataset( 76 | inner_dataset, seed, train_idx, valid_idx, test_idx 77 | ) 78 | ) 79 | -------------------------------------------------------------------------------- /data/pyg_datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | from .pyg_dataset_lookup_table import PYGDatasetLookupTable 5 | from .pyg_dataset import GraphormerPYGDataset 6 | -------------------------------------------------------------------------------- /data/pyg_datasets/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qwerfdsaplking/Graph-Trans/e4f52d0bed92b6aea3812e86fe7de9f997550318/data/pyg_datasets/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /data/pyg_datasets/__pycache__/pyg_dataset.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qwerfdsaplking/Graph-Trans/e4f52d0bed92b6aea3812e86fe7de9f997550318/data/pyg_datasets/__pycache__/pyg_dataset.cpython-36.pyc -------------------------------------------------------------------------------- /data/pyg_datasets/__pycache__/pyg_dataset_lookup_table.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qwerfdsaplking/Graph-Trans/e4f52d0bed92b6aea3812e86fe7de9f997550318/data/pyg_datasets/__pycache__/pyg_dataset_lookup_table.cpython-36.pyc -------------------------------------------------------------------------------- /data/pyg_datasets/pyg_dataset.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from typing import Optional 3 | 4 | import torch 5 | from torch import Tensor 6 | from torch_sparse import SparseTensor 7 | from torch_geometric.data import Data, Batch 8 | from torch_geometric.data import Dataset 9 | from sklearn.model_selection import train_test_split 10 | from typing import List 11 | import torch 12 | import numpy as np 13 | 14 | from ..wrapper import preprocess_item 15 | from .. import algos 16 | 17 | import copy 18 | from functools import lru_cache 19 | 20 | from typing import Callable, List, NamedTuple, Optional, Tuple, Union 21 | 22 | import torch 23 | from torch import Tensor 24 | from torch_sparse import SparseTensor 25 | 26 | 27 | class EdgeIndex(NamedTuple): 28 | edge_index: Tensor 29 | e_id: Optional[Tensor] 30 | size: Tuple[int, int] 31 | 32 | def to(self, *args, **kwargs): 33 | edge_index = self.edge_index.to(*args, **kwargs) 34 | e_id = self.e_id.to(*args, **kwargs) if self.e_id is not None else None 35 | return EdgeIndex(edge_index, e_id, self.size) 36 | 37 | 38 | class Adj(NamedTuple): 39 | adj_t: SparseTensor 40 | e_id: Optional[Tensor] 41 | size: Tuple[int, int] 42 | 43 | def to(self, *args, **kwargs): 44 | adj_t = self.adj_t.to(*args, **kwargs) 45 | e_id = self.e_id.to(*args, **kwargs) if self.e_id is not None else None 46 | return Adj(adj_t, e_id, self.size) 47 | 48 | 49 | 50 | 51 | 52 | class GraphormerPYGDataset(Dataset): 53 | def __init__( 54 | self, 55 | dataset: Dataset, 56 | args, 57 | seed: int = 0, 58 | train_idx=None, 59 | valid_idx=None, 60 | test_idx=None, 61 | train_set=None, 62 | valid_set=None, 63 | test_set=None, 64 | x_norm_func=lambda x:x, 65 | ): 66 | self.args=args 67 | self.dataset = dataset 68 | if self.dataset is not None: 69 | self.num_data = len(self.dataset) 70 | self.seed = seed 71 | self.x_norm_func=x_norm_func 72 | if train_idx is None and train_set is None: 73 | train_valid_idx, test_idx = train_test_split( 74 | np.arange(self.num_data), 75 | test_size=self.num_data // 10, 76 | random_state=seed, 77 | ) 78 | train_idx, valid_idx = train_test_split( 79 | train_valid_idx, test_size=self.num_data // 5, random_state=seed 80 | ) 81 | self.train_idx = torch.from_numpy(train_idx) 82 | self.valid_idx = torch.from_numpy(valid_idx) 83 | self.test_idx = torch.from_numpy(test_idx) 84 | self.train_data = self.index_select(self.train_idx) 85 | self.valid_data = self.index_select(self.valid_idx) 86 | self.test_data = self.index_select(self.test_idx) 87 | elif train_set is not None: 88 | self.num_data = len(train_set) + len(valid_set) + len(test_set) 89 | self.train_data = self.create_subset(train_set) 90 | self.valid_data = self.create_subset(valid_set) 91 | self.test_data = self.create_subset(test_set) 92 | self.train_idx = None 93 | self.valid_idx = None 94 | self.test_idx = None 95 | else: 96 | self.num_data = len(train_idx) + len(valid_idx) + len(test_idx) 97 | self.train_idx = train_idx 98 | self.valid_idx = valid_idx 99 | self.test_idx = test_idx 100 | self.train_data = self.index_select(self.train_idx) 101 | self.valid_data = self.index_select(self.valid_idx) 102 | self.test_data = self.index_select(self.test_idx) 103 | self.__indices__ = None 104 | 105 | def index_select(self, idx): 106 | dataset = copy.copy(self) 107 | dataset.dataset = self.dataset.index_select(idx) 108 | if isinstance(idx, torch.Tensor): 109 | dataset.num_data = idx.size(0) 110 | else: 111 | dataset.num_data = idx.shape[0] 112 | dataset.__indices__ = idx 113 | dataset.train_data = None 114 | dataset.valid_data = None 115 | dataset.test_data = None 116 | dataset.train_idx = None 117 | dataset.valid_idx = None 118 | dataset.test_idx = None 119 | return dataset 120 | 121 | def create_subset(self, subset): 122 | dataset = GraphormerPYGDataset(subset,seed=self.seed,args=self.args) 123 | dataset.train_data = None 124 | dataset.valid_data = None 125 | dataset.test_data = None 126 | dataset.train_idx = None 127 | dataset.valid_idx = None 128 | dataset.test_idx = None 129 | return dataset 130 | 131 | 132 | @lru_cache(maxsize=16) 133 | def __getitem__(self, idx): 134 | if isinstance(idx, int): 135 | item = self.dataset[idx] 136 | item.idx = idx 137 | item.y=item.y.reshape(1, -1) if item.y.shape[-1] > 1 else item.y.reshape(-1) 138 | 139 | return preprocess_item(item, self.x_norm_func, args=self.args) 140 | else: 141 | raise TypeError("index to a GraphormerPYGDataset can only be an integer.") 142 | 143 | def __len__(self): 144 | return self.num_data 145 | 146 | 147 | 148 | class Graphtrans_Sampling_Dataset(Dataset):#shadowhop sampling 149 | def __init__(self, 150 | data, 151 | node_idx, 152 | depth: int, num_neighbors: int, 153 | replace: bool = False, 154 | x_norm_func = lambda x:x, 155 | args=None 156 | ): 157 | 158 | self.data = data#copy.copy(data) 159 | self.depth = depth 160 | self.num_neighbors = num_neighbors 161 | self.replace = replace 162 | self.x_norm_func = x_norm_func 163 | self.args=args 164 | 165 | if data.edge_index is not None: 166 | self.is_sparse_tensor = False 167 | row, col = data.edge_index.cpu() 168 | self.adj_t = SparseTensor( 169 | row=row, col=col, value=torch.arange(col.size(0)), 170 | sparse_sizes=(data.num_nodes, data.num_nodes)).t() 171 | else: 172 | self.is_sparse_tensor = True 173 | self.adj_t = data.adj_t.cpu() 174 | 175 | if node_idx is None: 176 | node_idx = torch.arange(self.adj_t.sparse_size(0)) 177 | elif node_idx.dtype == torch.bool: 178 | node_idx = node_idx.nonzero(as_tuple=False).view(-1) 179 | self.node_idx = node_idx 180 | self.num_data = len(self.node_idx) 181 | 182 | 183 | @lru_cache(maxsize=16) 184 | def __getitem__(self, idx): 185 | n_id = self.node_idx[idx] 186 | 187 | rowptr, col, value = self.adj_t.csr() 188 | out = torch.ops.torch_sparse.ego_k_hop_sample_adj( 189 | rowptr, col, n_id, self.depth, self.num_neighbors, self.replace) 190 | rowptr, col, n_id, e_id, ptr, root_n_id = out 191 | 192 | adj_t = SparseTensor(rowptr=rowptr, col=col, 193 | value=value[e_id] if value is not None else None, 194 | sparse_sizes=(n_id.numel(), n_id.numel()), 195 | is_sorted=True) 196 | 197 | batch = Batch(batch=torch.ops.torch_sparse.ptr2ind(ptr, n_id.numel()), 198 | ptr=ptr) 199 | batch.root_n_id = root_n_id 200 | 201 | if self.is_sparse_tensor: 202 | batch.adj_t = adj_t 203 | else: 204 | row, col, e_id = adj_t.t().coo() 205 | batch.edge_index = torch.stack([row, col], dim=0) 206 | 207 | for k, v in self.data: 208 | if k in ['edge_index', 'adj_t', 'num_nodes']: 209 | continue 210 | if k == 'y' and v.size(0) == self.data.num_nodes: 211 | batch[k] = v[n_id][root_n_id] 212 | elif isinstance(v, Tensor) and v.size(0) == self.data.num_nodes: 213 | batch[k] = v[n_id] 214 | elif isinstance(v, Tensor) and v.size(0) == self.data.num_edges: 215 | batch[k] = v[e_id] 216 | else: 217 | batch[k] = v 218 | 219 | item = batch 220 | item.idx = self.node_idx[idx] 221 | return preprocess_item(item,x_norm_func=self.x_norm_func,args=self.args) 222 | 223 | def __len__(self): 224 | return self.num_data 225 | 226 | 227 | 228 | 229 | class Graphtrans_Sampling_Dataset_v2(Dataset):#sage sampling +induced subgraph 230 | def __init__(self, 231 | data, 232 | node_idx, 233 | depth: int, 234 | num_neighbors, 235 | replace: bool = False, 236 | x_norm_func = lambda x:x, 237 | args=None 238 | ): 239 | 240 | self.data = copy.copy(data) 241 | self.depth = depth 242 | if isinstance(num_neighbors,int): 243 | self.num_neighbors = [num_neighbors]+(depth-1)*[1] 244 | self.replace = replace 245 | self.x_norm_func = x_norm_func 246 | self.args=args 247 | 248 | 249 | if data.edge_index is not None: 250 | self.is_sparse_tensor = False 251 | row, col = data.edge_index.cpu() 252 | self.adj_t = SparseTensor( 253 | row=row, col=col, value=torch.arange(col.size(0)), 254 | sparse_sizes=(data.num_nodes, data.num_nodes)).t() 255 | else: 256 | self.is_sparse_tensor = True 257 | self.adj_t = data.adj_t.cpu() 258 | 259 | 260 | if node_idx.dtype == torch.bool: 261 | node_idx = node_idx.nonzero(as_tuple=False).view(-1) 262 | self.node_idx = node_idx 263 | self.num_data = len(self.node_idx) 264 | 265 | 266 | def __getitem__(self, idx): 267 | 268 | n_id = self.node_idx[idx].reshape(1) 269 | root_n_id=0 270 | hop_node_nums = [n_id.shape[0]] 271 | 272 | for size in self.num_neighbors: 273 | adj_t, n_id = self.adj_t.sample_adj(n_id, size, replace=False) 274 | hop_node_nums.append(n_id.shape[0]) 275 | n_hops = torch.ones(n_id.shape)+len(self.num_neighbors) 276 | for i,hop_offset in enumerate(hop_node_nums): 277 | n_hops[:hop_offset]-=1 278 | 279 | adj_t,_ = self.adj_t.saint_subgraph(n_id) 280 | row, col, e_id = adj_t.t().coo() 281 | edge_index = torch.stack([row, col]) 282 | 283 | 284 | item = Data(x = self.data.x[n_id],edge_index=edge_index) 285 | for k, v in self.data: 286 | if k in ['edge_index', 'adj_t', 'num_nodes']: 287 | continue 288 | if k == 'y' and v.size(0) == self.data.num_nodes: 289 | item[k] = v[n_id][root_n_id].reshape(1) 290 | elif isinstance(v, Tensor) and v.size(0) == self.data.num_nodes: 291 | item[k] = v[n_id] 292 | elif isinstance(v, Tensor) and v.size(0) == self.data.num_edges: 293 | item[k] = v[e_id] 294 | else: 295 | item[k] = v 296 | item.root_n_id = root_n_id 297 | 298 | item.idx = self.node_idx[idx] 299 | item.n_hops = n_hops 300 | return preprocess_item(item,x_norm_func=self.x_norm_func,args=self.args) 301 | 302 | def __len__(self): 303 | return self.num_data 304 | 305 | 306 | -------------------------------------------------------------------------------- /data/pyg_datasets/pyg_dataset_lookup_table.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | from typing import Optional 5 | from torch_geometric.datasets import * 6 | from torch_geometric.data import Dataset 7 | from .pyg_dataset import GraphormerPYGDataset 8 | import torch.distributed as dist 9 | import os 10 | 11 | class MyQM7b(QM7b): 12 | def download(self): 13 | if not dist.is_initialized() or dist.get_rank() == 0: 14 | super(MyQM7b, self).download() 15 | if dist.is_initialized(): 16 | dist.barrier() 17 | 18 | def process(self): 19 | if not dist.is_initialized() or dist.get_rank() == 0: 20 | super(MyQM7b, self).process() 21 | if dist.is_initialized(): 22 | dist.barrier() 23 | 24 | 25 | class MyQM9(QM9): 26 | def download(self): 27 | if not dist.is_initialized() or dist.get_rank() == 0: 28 | super(MyQM9, self).download() 29 | if dist.is_initialized(): 30 | dist.barrier() 31 | 32 | def process(self): 33 | if not dist.is_initialized() or dist.get_rank() == 0: 34 | super(MyQM9, self).process() 35 | if dist.is_initialized(): 36 | dist.barrier() 37 | 38 | class MyZINC(ZINC): 39 | def download(self): 40 | if not dist.is_initialized() or dist.get_rank() == 0: 41 | super(MyZINC, self).download() 42 | if dist.is_initialized(): 43 | dist.barrier() 44 | 45 | def process(self): 46 | if not dist.is_initialized() or dist.get_rank() == 0: 47 | super(MyZINC, self).process() 48 | if dist.is_initialized(): 49 | dist.barrier() 50 | 51 | 52 | class MyMoleculeNet(MoleculeNet): 53 | def download(self): 54 | if not dist.is_initialized() or dist.get_rank() == 0: 55 | super(MyMoleculeNet, self).download() 56 | if dist.is_initialized(): 57 | dist.barrier() 58 | 59 | def process(self): 60 | if not dist.is_initialized() or dist.get_rank() == 0: 61 | super(MyMoleculeNet, self).process() 62 | if dist.is_initialized(): 63 | dist.barrier() 64 | 65 | 66 | 67 | class PYGDatasetLookupTable: 68 | @staticmethod 69 | def GetPYGDataset(dataset_spec: str, seed: int) -> Optional[Dataset]: 70 | split_result = dataset_spec.split(":") 71 | if len(split_result) == 2: 72 | name, params = split_result[0], split_result[1] 73 | params = params.split(",") 74 | elif len(split_result) == 1: 75 | name = dataset_spec 76 | params = [] 77 | inner_dataset = None 78 | num_class = 1 79 | 80 | train_set = None 81 | valid_set = None 82 | test_set = None 83 | 84 | 85 | folder_name = name.replace("-", "_") 86 | os.system(f"mkdir -p dataset/{folder_name}/") 87 | root = "dataset/"+folder_name 88 | 89 | 90 | if name == "qm7b": 91 | inner_dataset = MyQM7b(root=root) 92 | elif name == "qm9": 93 | inner_dataset = MyQM9(root=root) 94 | elif name == "zinc": 95 | inner_dataset = MyZINC(root=root) 96 | train_set = MyZINC(root=root, split="train") 97 | valid_set = MyZINC(root=root, split="val") 98 | test_set = MyZINC(root=root, split="test") 99 | elif name == 'zinc-subset': 100 | inner_dataset = MyZINC(root=root,subset=True) 101 | train_set = MyZINC(root=root,subset=True, split="train") 102 | valid_set = MyZINC(root=root,subset=True, split="val") 103 | test_set = MyZINC(root=root,subset=True, split="test") 104 | elif name == "moleculenet": 105 | nm = None 106 | for param in params: 107 | name, value = param.split("=") 108 | if name == "name": 109 | nm = value 110 | inner_dataset = MyMoleculeNet(root=root, name=nm) 111 | else: 112 | raise ValueError(f"Unknown dataset name {name} for pyg source.") 113 | if train_set is not None: 114 | return GraphormerPYGDataset( 115 | None, 116 | seed, 117 | None, 118 | None, 119 | None, 120 | train_set, 121 | valid_set, 122 | test_set, 123 | ) 124 | else: 125 | return ( 126 | None 127 | if inner_dataset is None 128 | else GraphormerPYGDataset(inner_dataset, seed) 129 | ) 130 | -------------------------------------------------------------------------------- /data/smiles/smiles_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | from sklearn.model_selection import train_test_split 5 | import torch 6 | import numpy as np 7 | 8 | from ..wrapper import preprocess_item 9 | from .. import algos 10 | from ..pyg_datasets import GraphormerPYGDataset 11 | 12 | from ogb.utils.mol import smiles2graph 13 | 14 | 15 | class GraphormerSMILESDataset(GraphormerPYGDataset): 16 | def __init__( 17 | self, 18 | dataset: str, 19 | num_class: int, 20 | max_node: int, 21 | multi_hop_max_dist: int, 22 | spatial_pos_max: int, 23 | ): 24 | self.dataset = np.genfromtxt(dataset, delimiter=",", dtype=str) 25 | num_data = len(self.dataset) 26 | self.num_class = num_class 27 | self.__get_graph_metainfo(max_node, multi_hop_max_dist, spatial_pos_max) 28 | train_valid_idx, test_idx = train_test_split(num_data // 10) 29 | train_idx, valid_idx = train_test_split(train_valid_idx, num_data // 5) 30 | self.train_idx = train_idx 31 | self.valid_idx = valid_idx 32 | self.test_idx = test_idx 33 | self.__indices__ = None 34 | self.train_data = self.index_select(train_idx) 35 | self.valid_data = self.index_select(valid_idx) 36 | self.test_data = self.index_select(test_idx) 37 | 38 | def __get_graph_metainfo( 39 | self, max_node: int, multi_hop_max_dist: int, spatial_pos_max: int 40 | ): 41 | self.max_node = min( 42 | max_node, 43 | torch.max(self.dataset[i][0].num_nodes() for i in range(len(self.dataset))), 44 | ) 45 | max_dist = 0 46 | for i in range(len(self.dataset)): 47 | pyg_graph = smiles2graph(self.dataset[i]) 48 | dense_adj = pyg_graph.adj().to_dense().type(torch.int) 49 | shortest_path_result, _ = algos.floyd_warshall(dense_adj.numpy()) 50 | max_dist = max(max_dist, np.amax(shortest_path_result)) 51 | self.multi_hop_max_dist = min(multi_hop_max_dist, max_dist) 52 | self.spatial_pos_max = min(spatial_pos_max, max_dist) 53 | 54 | def __getitem__(self, idx): 55 | if isinstance(idx, int): 56 | item = smiles2graph(self.dataset[idx]) 57 | item.idx = idx 58 | return preprocess_item(item) 59 | else: 60 | raise TypeError("index to a GraphormerPYGDataset can only be an integer.") 61 | -------------------------------------------------------------------------------- /data/wrapper.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | from utils import utils 4 | import torch 5 | import numpy as np 6 | from ogb.graphproppred import PygGraphPropPredDataset 7 | #from ogb.lsc.pcqm4mv2_pyg import PygPCQM4Mv2Dataset 8 | from functools import lru_cache 9 | import pyximport 10 | import torch.distributed as dist 11 | from torch_geometric.utils import to_undirected,add_self_loops 12 | from torch_geometric.data import Data 13 | 14 | pyximport.install(setup_args={"include_dirs": np.get_include()}) 15 | from . import algos 16 | from utils.utils import * 17 | from copy import deepcopy 18 | 19 | @torch.jit.script 20 | def convert_to_single_emb(x, offset: int = 512): 21 | feature_num = x.size(1) if len(x.size()) > 1 else 1 22 | feature_offset = 1 + torch.arange(0, feature_num * offset, offset, dtype=torch.long) 23 | x = x + feature_offset 24 | return x 25 | 26 | 27 | def preprocess_item(raw_item,x_norm_func,args): 28 | edge_attr, edge_index, x,y,idx = raw_item.edge_attr, raw_item.edge_index, raw_item.x,raw_item.y,raw_item.idx 29 | root_n_id=raw_item.root_n_id if 'root_n_id' in raw_item.to_dict().keys() else -1 30 | 31 | 32 | 33 | N = x.size(0) 34 | if args.node_feature_type=='cate': 35 | x = convert_to_single_emb(x) 36 | elif args.node_feature_type=='dense': 37 | x = x_norm_func(x) 38 | else: 39 | raise ValueError('node feature type error') 40 | 41 | # node adj matrix [N, N] bool 42 | try: 43 | edge_index = to_undirected(edge_index) 44 | except: 45 | print(edge_index) 46 | assert 1==2 47 | 48 | adj = torch.zeros([N, N], dtype=torch.bool) 49 | adj[edge_index[0, :], edge_index[1, :]] = True 50 | 51 | adj_w_sl = adj.clone()#adj with self loop 52 | adj_w_sl[torch.arange(N),torch.arange(N)]=1 53 | 54 | #positional bias 55 | if 'degree' in args.node_level_modules: 56 | in_degree = adj.long().sum(dim=1).view(-1) 57 | else: 58 | in_degree = 0 59 | 60 | if 'eig' in args.node_level_modules: 61 | if N0 23 | sign = 1 if sign else -1 24 | return self.embeddings(pos * sign) 25 | 26 | class SVD_Embedding(nn.Module): 27 | def __init__(self, svd_dim, hidden_dim): 28 | super(SVD_Embedding,self).__init__() 29 | self.svd_dim=svd_dim 30 | self.embeddings = nn.Linear(svd_dim*2,hidden_dim) 31 | def forward(self, batched_data): 32 | pos = batched_data['svd_pos_emb'] 33 | sign = torch.randn(1)[0]>0 34 | sign = 1 if sign else -1 35 | pos_u = pos[:,:,:self.svd_dim]*sign 36 | pos_v = pos[:,:,self.svd_dim:]*(-sign) 37 | pos = torch.cat([pos_u,pos_v],dim=-1) 38 | return self.embeddings(pos) 39 | 40 | 41 | class WL_Role_Embedding(nn.Module): 42 | def __init__(self, max_index, hidden_dim): 43 | super(WL_Role_Embedding,self).__init__() 44 | self.embeddings = nn.Linear(max_index,hidden_dim) 45 | def forward(self, batched_data): 46 | pos = batched_data['wl_role_ids'] 47 | return self.embeddings(pos) 48 | 49 | class Inti_Pos_Embedding(nn.Module): 50 | def __init__(self, max_index, hidden_dim): 51 | super(Inti_Pos_Embedding,self).__init__() 52 | self.embeddings = nn.Linear(max_index,hidden_dim) 53 | def forward(self, batched_data): 54 | pos = batched_data['init_pos_ids'] 55 | return self.embeddings(pos) 56 | 57 | class Hop_Dis_Embedding(nn.Module): 58 | def __init__(self, max_index, hidden_dim): 59 | super(Hop_Dis_Embedding,self).__init__() 60 | self.embeddings = nn.Linear(max_index,hidden_dim) 61 | def forward(self, batched_data): 62 | pos = batched_data['hop_dis_ids'] 63 | return self.embeddings(pos) 64 | 65 | class DegreeEncoder(nn.Module): 66 | def __init__(self, 67 | num_in_degree, 68 | num_out_degree, 69 | hidden_dim, 70 | n_layers #for parameter initialization 71 | ): 72 | super(DegreeEncoder, self).__init__() 73 | self.in_degree_encoder = nn.Embedding(num_in_degree, hidden_dim, padding_idx=0) 74 | self.out_degree_encoder = nn.Embedding(num_out_degree, hidden_dim, padding_idx=0) 75 | self.apply(lambda module: init_params(module, n_layers=n_layers)) 76 | 77 | def forward(self, batched_data): 78 | in_degree, out_degree = ( 79 | batched_data["in_degree"], 80 | batched_data["out_degree"], 81 | ) 82 | return self.in_degree_encoder(in_degree)+self.out_degree_encoder(out_degree) 83 | 84 | 85 | class AddSuperNode(nn.Module): 86 | def __init__(self, hidden_dim): 87 | super(AddSuperNode, self).__init__() 88 | self.graph_token = nn.Embedding(1, hidden_dim) 89 | 90 | def forward(self, node_feature): 91 | n_graph = node_feature.size()[0] 92 | graph_token_feature = self.graph_token.weight.unsqueeze(0).repeat(n_graph, 1, 1) 93 | graph_node_feature = torch.cat([graph_token_feature, node_feature], dim=1) 94 | 95 | return graph_node_feature 96 | 97 | 98 | 99 | 100 | 101 | class NodeFeatureEncoder(nn.Module): 102 | def __init__( 103 | self, 104 | feat_type, 105 | hidden_dim, 106 | n_layers, 107 | num_atoms=None, 108 | feat_dim=None 109 | ): 110 | super(NodeFeatureEncoder, self).__init__() 111 | 112 | self.feat_type = feat_type 113 | 114 | if feat_type=='dense' and feat_dim is not None:#dense feature 115 | self.feature_encoder = nn.Linear(feat_dim, hidden_dim) 116 | elif feat_type=='cate' and num_atoms is not None:#cate feature 117 | # 1 for graph token 118 | self.feature_encoder = nn.Embedding(num_atoms + 1, hidden_dim, padding_idx=0) 119 | else: 120 | raise ValueError('conflict feature type') 121 | 122 | self.apply(lambda module: init_params(module, n_layers=n_layers)) 123 | 124 | def forward(self, batched_data): 125 | x=batched_data["x"] 126 | if self.feat_type=='cate':# 127 | node_feature = self.feature_encoder(x).sum(dim=-2) # [n_graph, n_node, n_hidden] 128 | else: 129 | node_feature = self.feature_encoder(x) 130 | 131 | return node_feature 132 | 133 | 134 | def getAttnMasks(batched_data,attn_mask_modules,use_super_node,num_heads): 135 | adj = batched_data['adj'].bool().float() 136 | 137 | attn_mask = torch.ones(adj.shape[0], num_heads,adj.shape[1] + int(use_super_node), 138 | adj.shape[2] + int(use_super_node)).to(adj.device) 139 | if attn_mask_modules == '1hop': 140 | adjs = adj.unsqueeze(1).expand(-1,num_heads,-1,-1).bool().float() 141 | attn_mask[:,:,int(use_super_node):,int(use_super_node):] = adjs 142 | 143 | 144 | if attn_mask_modules == 'nhop': 145 | multi_hop_adjs = torch.cat([torch.matrix_power(adj, i + 1).unsqueeze(1) for i in range(num_heads)], 146 | dim=1).bool().float() 147 | attn_mask[:,:, int(use_super_node):, int(use_super_node):] = multi_hop_adjs 148 | 149 | return attn_mask 150 | 151 | 152 | class GraphAttnHopBias(nn.Module): 153 | def __init__( 154 | self, 155 | num_heads, 156 | n_hops, 157 | use_super_node 158 | ): 159 | super(GraphAttnHopBias, self).__init__() 160 | self.num_heads = num_heads 161 | self.use_super_node=use_super_node 162 | self.hop_bias = nn.Parameter(torch.randn(n_hops,num_heads)) 163 | self.n_hops = n_hops 164 | 165 | def forward(self, batched_data): 166 | x, adj, attn_bias = ( 167 | batched_data["x"], 168 | batched_data['adj_norm'], 169 | batched_data['attn_bias'] 170 | ) 171 | 172 | 173 | adj_n_hops_bias = torch.ones(adj.shape[0],adj.shape[1]+int(self.use_super_node), 174 | adj.shape[2]+int(self.use_super_node),self.n_hops).to(x.device) 175 | adj_list = [torch.matrix_power(adj,i+1).unsqueeze(-1) for i in range(self.n_hops)] 176 | adj_n_hops = torch.cat(adj_list,dim=-1)# n_graph, n_node, n_node, n_hops 177 | adj_n_hops_bias[:,int(self.use_super_node):,int(self.use_super_node):,:] = adj_n_hops 178 | adj_n_hops_bias = torch.matmul(adj_n_hops_bias,self.hop_bias).permute(0, 3, 1, 2) 179 | 180 | return adj_n_hops_bias# [n_graph, n_head, n_node+1, n_node+1] 181 | 182 | 183 | 184 | 185 | class GraphAttnSpatialBias(nn.Module):#refer to Graphormer 186 | def __init__( 187 | self, 188 | num_heads, 189 | num_spatial, 190 | n_layers, 191 | use_super_node 192 | ): 193 | super(GraphAttnSpatialBias, self).__init__() 194 | self.num_heads = num_heads 195 | self.use_super_node = use_super_node 196 | 197 | self.spatial_pos_encoder = nn.Embedding(num_spatial, num_heads, padding_idx=0) 198 | 199 | if use_super_node: 200 | self.graph_token_virtual_distance = nn.Embedding(1, num_heads) 201 | 202 | self.apply(lambda module: init_params(module, n_layers=n_layers)) 203 | 204 | def forward(self, batched_data): 205 | attn_bias, spatial_pos, x = ( 206 | batched_data["attn_bias"],#[n_graph, n_node+1, n_node+1] 207 | batched_data["spatial_pos"],#[n_graph, n_node, n_node] 208 | batched_data["x"], 209 | ) 210 | 211 | graph_attn_bias = attn_bias.clone() 212 | graph_attn_bias = graph_attn_bias.unsqueeze(1).repeat( 213 | 1, self.num_heads, 1, 1 214 | ) # [n_graph, n_head, n_node+1, n_node+1] 215 | 216 | # spatial pos 217 | # [n_graph, n_node, n_node, n_head] -> [n_graph, n_head, n_node, n_node] 218 | spatial_pos_bias = self.spatial_pos_encoder(spatial_pos).permute(0, 3, 1, 2) 219 | graph_attn_bias[:, :, int(self.use_super_node):, int(self.use_super_node):] = graph_attn_bias[:, :, int(self.use_super_node):, int(self.use_super_node):] + spatial_pos_bias 220 | 221 | # reset spatial pos here 222 | if self.use_super_node: 223 | t = self.graph_token_virtual_distance.weight.view(1, self.num_heads, 1) 224 | graph_attn_bias[:, :, 1:, 0] = graph_attn_bias[:, :, 1:, 0] + t 225 | graph_attn_bias[:, :, 0, :] = graph_attn_bias[:, :, 0, :] + t 226 | 227 | graph_attn_bias = graph_attn_bias + attn_bias.unsqueeze(1) # reset pad -inf 228 | 229 | return graph_attn_bias# [n_graph, n_head, n_node+1, n_node+1] 230 | 231 | 232 | 233 | class GraphAttnEdgeBias(nn.Module): #refer to Graphormer 234 | """ 235 | Compute attention bias for each head. We do not need to consider super node in this module. 236 | """ 237 | def __init__( 238 | self, 239 | num_heads, 240 | num_edges, 241 | num_edge_dis, 242 | edge_type, 243 | multi_hop_max_dist, 244 | n_layers, 245 | ): 246 | super(GraphAttnEdgeBias, self).__init__() 247 | self.num_heads = num_heads 248 | self.multi_hop_max_dist = multi_hop_max_dist 249 | #probably some issues here 250 | self.edge_encoder = nn.Embedding(num_edges + 1, num_heads, padding_idx=0) 251 | self.edge_type = edge_type 252 | if self.edge_type == "multi_hop": 253 | self.edge_dis_encoder = nn.Embedding( 254 | num_edge_dis * num_heads * num_heads, 1 255 | ) 256 | 257 | self.apply(lambda module: init_params(module, n_layers=n_layers)) 258 | 259 | def forward(self, batched_data): 260 | attn_bias, spatial_pos, x = ( 261 | batched_data["attn_bias"], 262 | batched_data["spatial_pos"], 263 | batched_data["x"], 264 | ) 265 | edge_input, attn_edge_type = ( 266 | batched_data["edge_input"], 267 | batched_data["attn_edge_type"], 268 | ) 269 | 270 | n_graph, n_node = x.size()[:2] 271 | 272 | 273 | if attn_edge_type is None: 274 | edge_input = torch.zeros(n_graph, self.num_heads, n_node, n_node).to(x.device) 275 | return edge_input 276 | 277 | # edge feature 278 | if self.edge_type == "multi_hop": 279 | spatial_pos_ = spatial_pos.clone() 280 | spatial_pos_[spatial_pos_ == 0] = 1 # set pad to 1 281 | # set 1 to 1, x > 1 to x - 1 282 | spatial_pos_ = torch.where(spatial_pos_ > 1, spatial_pos_ - 1, spatial_pos_) 283 | if self.multi_hop_max_dist > 0: 284 | spatial_pos_ = spatial_pos_.clamp(0, self.multi_hop_max_dist) 285 | edge_input = edge_input[:, :, :, : self.multi_hop_max_dist, :] 286 | # [n_graph, n_node, n_node, max_dist, n_head] 287 | edge_input = self.edge_encoder(edge_input).mean(-2) 288 | max_dist = edge_input.size(-2) 289 | edge_input_flat = edge_input.permute(3, 0, 1, 2, 4).reshape( 290 | max_dist, -1, self.num_heads 291 | ) 292 | edge_input_flat = torch.bmm( 293 | edge_input_flat, 294 | self.edge_dis_encoder.weight.reshape( 295 | -1, self.num_heads, self.num_heads 296 | )[:max_dist, :, :], 297 | ) 298 | edge_input = edge_input_flat.reshape( 299 | max_dist, n_graph, n_node, n_node, self.num_heads 300 | ).permute(1, 2, 3, 0, 4) 301 | edge_input = ( 302 | edge_input.sum(-2) / (spatial_pos_.float().unsqueeze(-1)) 303 | ).permute(0, 3, 1, 2) 304 | else: 305 | # [n_graph, n_node, n_node, n_head] -> [n_graph, n_head, n_node, n_node] 306 | edge_input = self.edge_encoder(attn_edge_type).mean(-2).permute(0, 3, 1, 2) 307 | 308 | 309 | return edge_input#[n_graph, n_head, n_node, n_node] 310 | 311 | 312 | 313 | -------------------------------------------------------------------------------- /graphtrasformer/gt_models.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Optional, Tuple 3 | from graphtrasformer.gnn_layers import * 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from torch_scatter import scatter 8 | from graphtrasformer.gt_layers import * 9 | from graphtrasformer.layers import * 10 | logger = logging.getLogger(__name__) 11 | 12 | 13 | def init_graphormer_params(module): 14 | """ 15 | Initialize the weights specific to the Graphormer Model. 16 | """ 17 | def normal_(data): 18 | data.copy_(data.cpu().normal_(mean=0.0, std=0.02).to(data.device)) 19 | 20 | if isinstance(module, nn.Linear): 21 | normal_(module.weight.data) 22 | if module.bias is not None: 23 | module.bias.data.zero_() 24 | if isinstance(module, nn.Embedding): 25 | normal_(module.weight.data) 26 | if module.padding_idx is not None: 27 | module.weight.data[module.padding_idx].zero_() 28 | if isinstance(module, MultiheadAttention): 29 | normal_(module.q_proj.weight.data) 30 | normal_(module.k_proj.weight.data) 31 | normal_(module.v_proj.weight.data) 32 | 33 | 34 | 35 | 36 | 37 | class GraphTransformer(nn.Module): 38 | def __init__( 39 | self, 40 | num_encoder_layers: int = 12, 41 | hidden_dim: int = 768, 42 | ffn_hidden_dim: int = 768*3, 43 | num_attn_heads: int = 32, 44 | emb_dropout: float = 0, 45 | dropout: float = 0.1, 46 | attn_dropout: float = 0.1, 47 | num_class: int =2 , 48 | encoder_normalize_before: bool = False, 49 | apply_graphormer_init: bool = False, 50 | activation_fn: str = "gelu", 51 | n_trans_layers_to_freeze: int = 0, 52 | traceable = False, 53 | 54 | use_super_node: bool = True, 55 | 56 | node_feature_type: str = 'cate', 57 | node_feature_dim: int = None, 58 | num_atoms: int = None, 59 | 60 | node_level_modules: tuple = ('degree'), 61 | attn_level_modules: tuple = ('spe','spatial'), 62 | attn_mask_modules: str = None, 63 | 64 | num_in_degree: int = None, 65 | num_out_degree: int = None, 66 | eig_pos_dim: int = None, 67 | svd_pos_dim: int = None, 68 | 69 | num_spatial: int = None, 70 | num_edges: int = None, 71 | num_edge_dis: int = None, 72 | edge_type: str = None, 73 | multi_hop_max_dist: int = None, 74 | num_hop_bias: int=None, 75 | 76 | use_gnn_layers: bool=False, 77 | gnn_insert_pos: str='before', 78 | num_gnn_layers: int=1, 79 | gnn_type: str='GAT', 80 | gnn_dropout: float=0.5 81 | ) -> None: 82 | 83 | super().__init__() 84 | self.emb_dropout = nn.Dropout(p=emb_dropout) 85 | self.hidden_dim= hidden_dim 86 | self.apply_graphormer_init = apply_graphormer_init 87 | self.traceable=traceable 88 | self.use_super_node = use_super_node 89 | self.use_gnn_layers = use_gnn_layers 90 | self.gnn_insert_pos = gnn_insert_pos 91 | self.num_attn_heads=num_attn_heads 92 | self.attn_mask_modules=attn_mask_modules 93 | 94 | if encoder_normalize_before: 95 | self.emb_layer_norm = nn.LayerNorm(self.hidden_dim) 96 | else: 97 | self.emb_layer_norm = None 98 | 99 | #node feature encoder 100 | self.node_feature_encoder = NodeFeatureEncoder(feat_type=node_feature_type, 101 | hidden_dim=hidden_dim, 102 | n_layers=num_encoder_layers, 103 | num_atoms=num_atoms, 104 | feat_dim=node_feature_dim 105 | ) 106 | 107 | if use_super_node: 108 | self.add_super_node = AddSuperNode(hidden_dim=hidden_dim) 109 | 110 | #node-level graph-structural feature encoder 111 | self.node_level_layers = nn.ModuleList([]) 112 | for module_name in node_level_modules: 113 | if module_name=='degree': 114 | layer = DegreeEncoder(num_in_degree=num_in_degree, 115 | num_out_degree=num_out_degree, 116 | hidden_dim=hidden_dim, 117 | n_layers=num_encoder_layers) 118 | elif module_name=='eig': 119 | layer = Eig_Embedding(eig_dim=eig_pos_dim,hidden_dim=hidden_dim) 120 | elif module_name=='svd': 121 | layer = SVD_Embedding(svd_dim=svd_pos_dim,hidden_dim=hidden_dim) 122 | else: 123 | raise ValueError('node level module error!') 124 | self.node_level_layers.append(layer) 125 | #attention-level graph-structural feature encoder 126 | self.attn_level_layers = nn.ModuleList([]) 127 | for module_name in attn_level_modules: 128 | if module_name=='spatial': 129 | layer = GraphAttnSpatialBias(num_heads=num_attn_heads, 130 | num_spatial=num_spatial, 131 | n_layers=num_encoder_layers, 132 | use_super_node=use_super_node) 133 | elif module_name=='spe': 134 | layer = GraphAttnEdgeBias(num_heads = num_attn_heads, 135 | num_edges = num_edges, 136 | num_edge_dis = num_edge_dis, 137 | edge_type=edge_type, 138 | multi_hop_max_dist=multi_hop_max_dist, 139 | n_layers=num_encoder_layers) 140 | elif module_name=='nhop': 141 | layer = GraphAttnHopBias(num_heads = num_attn_heads, 142 | n_hops = num_hop_bias, 143 | use_super_node=use_super_node) 144 | else: 145 | raise ValueError('attn level module error!') 146 | self.attn_level_layers.append(layer) 147 | 148 | 149 | 150 | #gnn layers 151 | if use_gnn_layers: 152 | if gnn_insert_pos=='before': 153 | self.gnn_layers = Geometric_GNN(gnn_type=gnn_type, 154 | hidden_dim=hidden_dim, 155 | gnn_dropout=gnn_dropout, 156 | n_layers=num_gnn_layers, 157 | use_super_node=use_super_node) 158 | elif gnn_insert_pos in ('alter','parallel'): 159 | self.gnn_layers = nn.ModuleList([Geometric_GNN(gnn_type=gnn_type, 160 | hidden_dim=hidden_dim, 161 | gnn_dropout=gnn_dropout, 162 | n_layers=num_gnn_layers, 163 | use_super_node=use_super_node) for _ in range(num_encoder_layers)]) 164 | 165 | 166 | #transformer layers 167 | self.transformer_layers =nn.ModuleList([ 168 | Transformer_Layer( 169 | num_heads=num_attn_heads, 170 | hidden_dim=hidden_dim, 171 | ffn_hidden_dim=ffn_hidden_dim, 172 | dropout=dropout, 173 | attn_dropout=attn_dropout, 174 | temperature=1, 175 | activation_fn=activation_fn 176 | ) for _ in range(num_encoder_layers) 177 | ]) 178 | 179 | 180 | self.output_layer_norm = nn.LayerNorm(hidden_dim) 181 | self.output_fc1 = nn.Linear(hidden_dim,hidden_dim) 182 | self.output_fc2 = nn.Linear(hidden_dim,num_class) 183 | self.out_act_fn = get_activation_function(activation_fn) 184 | 185 | 186 | # Apply initialization of model params after building the model 187 | if self.apply_graphormer_init: 188 | self.apply(init_graphormer_params) 189 | 190 | def freeze_module_params(m): 191 | if m is not None: 192 | for p in m.parameters(): 193 | p.requires_grad = False 194 | 195 | for layer in range(n_trans_layers_to_freeze): 196 | freeze_module_params(self.layers[layer]) 197 | 198 | 199 | 200 | def forward( 201 | self, 202 | batched_data, 203 | perturb=None, 204 | last_state_only: bool = False, 205 | ): 206 | 207 | #==============preparation========================== 208 | # compute padding mask. This is needed for multi-head attention 209 | data_x = batched_data["x"] 210 | n_graph, n_node = data_x.size()[:2] 211 | 212 | #calculate attention padding mask # B x T x T / Bx T+1 x T+1 213 | padding_mask = batched_data['x_mask'] 214 | if self.use_super_node: 215 | padding_mask_cls = torch.ones( 216 | n_graph, 1, device=padding_mask.device, dtype=padding_mask.dtype 217 | ) 218 | padding_mask = torch.cat((padding_mask_cls, padding_mask), dim=1).float() 219 | attn_mask = torch.matmul(padding_mask.unsqueeze(-1), padding_mask.unsqueeze(1)).long() 220 | self.attn_mask=attn_mask 221 | 222 | #x feature encode 223 | x = self.node_feature_encoder(batched_data)# B x T x C 224 | for nl_layer in self.node_level_layers: 225 | node_bias = nl_layer(batched_data) 226 | x += node_bias 227 | #add the super node 228 | if self.use_super_node: 229 | x = self.add_super_node(x)# B x T+1 x C 230 | 231 | 232 | 233 | # attention bias computation, B x H x (T+1) x (T+1) or B x H x T x T 234 | attn_bias = torch.zeros(n_graph,self.num_attn_heads,n_node+int(self.use_super_node),n_node+int(self.use_super_node)).to(data_x.device) 235 | for al_layer in self.attn_level_layers: 236 | bias = al_layer(batched_data) 237 | if bias.shape[-1]==attn_bias.shape[-1]: 238 | attn_bias+=bias 239 | elif bias.shape[-1]==attn_bias.shape[-1]-1: 240 | attn_bias[:, :, int(self.use_super_node):, int(self.use_super_node):] = attn_bias[:, :, int(self.use_super_node):, int(self.use_super_node):] + bias 241 | else: 242 | raise ValueError('attention calculation error') 243 | 244 | #attention mask 245 | if self.attn_mask_modules in ('1hop','nhop'): 246 | adj_mask = getAttnMasks(batched_data,self.attn_mask_modules,self.use_super_node,self.num_attn_heads) 247 | attn_mask = attn_mask.unsqueeze(1).expand(-1,self.num_attn_heads,-1,-1)*adj_mask 248 | 249 | 250 | #===================data flow=============== 251 | #input feature normalization and dropout 252 | if self.emb_layer_norm is not None: 253 | x = self.emb_layer_norm(x) 254 | x = self.emb_dropout(x) # B x T+1 x C 255 | 256 | #gnn layers before transformer 257 | if self.use_gnn_layers and self.gnn_insert_pos=='before': 258 | x = self.gnn_layers(batched_data,x) 259 | 260 | 261 | # graph transformer layers 262 | inner_states = [] 263 | if not last_state_only: 264 | inner_states.append(x) 265 | for i,layer in enumerate(self.transformer_layers): 266 | 267 | if self.use_gnn_layers and self.gnn_insert_pos=='parallel': 268 | x_graph = self.gnn_layers[i](batched_data, x) 269 | else: 270 | x_graph = 0 271 | 272 | #self-attention layer 273 | x, _ = layer.attention( 274 | x=x, 275 | mask=attn_mask, 276 | attn_bias=attn_bias, 277 | ) 278 | 279 | if self.use_gnn_layers and self.gnn_insert_pos=='alter':#by default, gnn after mhsa 280 | x = self.gnn_layers[i](batched_data, x) 281 | 282 | x = x + x_graph 283 | 284 | 285 | #FFN layer 286 | x = layer.ffn_layer(x) 287 | if not last_state_only: 288 | inner_states.append(x) 289 | 290 | 291 | 292 | #output layers 293 | if self.use_super_node: 294 | graph_rep = x[:, 0, :].squeeze()#B x 1 x C 295 | else: 296 | #center node 297 | root_n_id = batched_data['root_n_id'] 298 | root_idx = (torch.arange(n_graph,device=x.device)*n_node+root_n_id).long() 299 | graph_rep = x.reshape(-1,x.shape[-1])[root_idx].squeeze() 300 | #mean pooling, other readout methods to be implemented, e.g, center node 301 | #x = x.reshape(-1, self.hidden_dim) 302 | #padding_mask = padding_mask.reshape(-1).bool() 303 | #x[~padding_mask]=0 304 | #ns = batched_data['ns']#node number in each graph 305 | #graph_rep = x.reshape(-1,n_node,self.hidden_dim).sum(1) / ns.unsqueeze(1) 306 | 307 | #output transformation 308 | out = self.output_layer_norm(self.out_act_fn(self.output_fc1(graph_rep))) 309 | out = self.output_fc2(out).squeeze() 310 | 311 | return {'logits':out} 312 | 313 | 314 | -------------------------------------------------------------------------------- /graphtrasformer/layer_tests.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | 3 | from graphtrasformer.layers import * 4 | if __name__=='__main__': 5 | 6 | 7 | layer = Transformer_Layer( 8 | num_heads=4, 9 | hidden_dim=64, 10 | ffn_hidden_dim=128, 11 | dropout=0.1, 12 | attn_dropout=0.1, 13 | temperature=1, 14 | activation_fn='GELU') 15 | 16 | 17 | 18 | 19 | x = torch.randn(8, 100, 64) 20 | x[:,80:,:]=0 21 | mask = torch.zeros(8, 100,100) 22 | mask[:,:80,:80]=1 23 | 24 | 25 | 26 | out,attn = layer.attention(x,mask) 27 | 28 | 29 | -------------------------------------------------------------------------------- /graphtrasformer/layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | from torch.nn import init 4 | import json 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | 9 | 10 | 11 | def get_activation_function(activation: str='PReLU') -> nn.Module: 12 | if activation == 'ReLU': 13 | return nn.ReLU() 14 | elif activation == 'LeakyReLU': 15 | return nn.LeakyReLU(0.1) 16 | elif activation == 'PReLU': 17 | return nn.PReLU() 18 | elif activation == 'tanh': 19 | return nn.Tanh() 20 | elif activation == 'SELU': 21 | return nn.SELU() 22 | elif activation == 'ELU': 23 | return nn.ELU() 24 | elif activation == "Linear": 25 | return lambda x: x 26 | elif activation == 'GELU': 27 | return nn.GELU() 28 | else: 29 | raise ValueError(f'Activation "{activation}" not supported.') 30 | 31 | 32 | 33 | 34 | class PositionwiseFeedForward(nn.Module): 35 | """Implements FFN equation.""" 36 | 37 | def __init__(self,hidden_dim , ffn_hidden_dim, activation_fn="GELU", dropout=0.1): 38 | super(PositionwiseFeedForward, self).__init__() 39 | 40 | self.fc1 = nn.Linear(hidden_dim, ffn_hidden_dim) 41 | self.fc2 = nn.Linear(ffn_hidden_dim, hidden_dim) 42 | self.act_dropout = nn.Dropout(dropout) 43 | self.dropout = nn.Dropout(dropout) 44 | self.ffn_layer_norm = nn.LayerNorm(hidden_dim, eps=1e-6) 45 | self.ffn_act_func = get_activation_function(activation_fn) 46 | 47 | def forward(self, x): 48 | residual=x 49 | x = self.dropout(self.fc2(self.act_dropout(self.ffn_act_func(self.fc1(x))))) 50 | x+=residual 51 | x = self.ffn_layer_norm(x) 52 | return x 53 | 54 | 55 | 56 | class MultiheadAttention(nn.Module): 57 | """ 58 | Compute 'Scaled Dot Product SelfAttention 59 | """ 60 | def __init__(self, 61 | num_heads, 62 | hidden_dim, 63 | dropout=0.1, 64 | attn_dropout=0.1, 65 | temperature = 1): 66 | super().__init__() 67 | self.d_k = hidden_dim // num_heads 68 | self.num_heads = num_heads # number of heads 69 | self.temperature =temperature 70 | self.q_proj = nn.Linear(hidden_dim, hidden_dim) 71 | self.k_proj = nn.Linear(hidden_dim, hidden_dim) 72 | self.v_proj = nn.Linear(hidden_dim, hidden_dim) 73 | self.a_proj = nn.Linear(hidden_dim, hidden_dim) 74 | self.attn_dropout = nn.Dropout(attn_dropout) 75 | self.dropout=nn.Dropout(dropout) 76 | self.layer_norm = nn.LayerNorm(hidden_dim,eps=1e-6) 77 | self.reset_parameters() 78 | 79 | def reset_parameters(self): 80 | nn.init.xavier_uniform_(self.k_proj.weight) 81 | nn.init.xavier_uniform_(self.v_proj.weight) 82 | nn.init.xavier_uniform_(self.q_proj.weight) 83 | nn.init.xavier_uniform_(self.a_proj.weight) 84 | 85 | def forward(self, x, mask=None, attn_bias=None): 86 | residual = x 87 | batch_size = x.size(0) 88 | 89 | query = self.q_proj(x) 90 | key = self.k_proj(x) 91 | value = self.v_proj(x) 92 | 93 | query = query.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2) 94 | key = key.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2) 95 | value = value.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2) 96 | 97 | #ScaledDotProductAttention 98 | if mask is not None and len(mask.shape) == 3: 99 | mask = mask.unsqueeze(1) 100 | 101 | scores = torch.matmul(query/self.temperature, key.transpose(-2, -1)) \ 102 | / math.sqrt(query.size(-1)) 103 | 104 | if attn_bias is not None: 105 | scores = scores+attn_bias 106 | 107 | if mask is not None: 108 | if scores.shape==mask.shape:#different heads have different mask 109 | scores = scores * mask 110 | scores = scores.masked_fill(scores == 0, -1e12) 111 | else: 112 | scores = scores.masked_fill(mask == 0, -1e12) 113 | 114 | attn = self.attn_dropout(F.softmax(scores, dim=-1)) 115 | #ScaledDotProductAttention 116 | 117 | out = torch.matmul(attn, value) 118 | out = out.transpose(1, 2).contiguous().view(batch_size, -1, self.num_heads * self.d_k) 119 | out = self.dropout(self.a_proj(out)) 120 | out += residual 121 | out = self.layer_norm(out) 122 | 123 | return out, attn 124 | 125 | 126 | class Transformer_Layer(nn.Module): 127 | def __init__(self, 128 | num_heads, 129 | hidden_dim, 130 | ffn_hidden_dim, 131 | dropout=0.1, 132 | attn_dropout=0.1, 133 | temperature = 1, 134 | activation_fn='GELU'): 135 | super().__init__() 136 | assert hidden_dim % num_heads == 0 137 | 138 | self.attention = MultiheadAttention(num_heads, 139 | hidden_dim, 140 | dropout, 141 | attn_dropout, 142 | temperature) 143 | self.ffn_layer = PositionwiseFeedForward(hidden_dim,ffn_hidden_dim,activation_fn=activation_fn) 144 | 145 | 146 | def forward(self, x, attn_mask, attn_bias=None): 147 | x, attn = self.attention(x, mask=attn_mask, attn_bias=attn_bias) 148 | x = self.ffn_layer(x) 149 | 150 | return x, attn 151 | 152 | 153 | -------------------------------------------------------------------------------- /gt_dataset.py: -------------------------------------------------------------------------------- 1 | from torch_geometric.datasets import * 2 | import torch.nn as nn 3 | from tqdm import tqdm 4 | 5 | 6 | from ogb.graphproppred import PygGraphPropPredDataset 7 | from ogb.nodeproppred import PygNodePropPredDataset 8 | from ogb.graphproppred import Evaluator 9 | from datasets import load_dataset, load_metric 10 | from data.pyg_datasets.pyg_dataset import GraphormerPYGDataset, Graphtrans_Sampling_Dataset,Graphtrans_Sampling_Dataset_v2 11 | 12 | 13 | def get_loss_and_metric(data_name): 14 | 15 | if data_name in ['ZINC','pcqm4mv2','QM7','QM9','ZINC-full']: 16 | loss = nn.L1Loss(reduction='mean') 17 | metric = nn.L1Loss(reduction='mean') 18 | task_type='regression' 19 | metric_name = 'MAE' 20 | 21 | elif data_name in ['UPFD']: 22 | loss = nn.BCEWithLogitsLoss(reduction='mean') 23 | metric = load_metric("accuracy") 24 | task_type='binary_classification' 25 | metric_name='accuracy' 26 | 27 | elif data_name in ["ogbg-molhiv"]: 28 | loss = nn.BCEWithLogitsLoss(reduction='mean') 29 | metric = Evaluator(name=data_name) 30 | task_type='binary_classification' 31 | metric_name='ROC-AUC' 32 | 33 | elif data_name in ['flickr','ogbn-products','ogbn-arxiv']: 34 | loss = nn.CrossEntropyLoss(reduction='mean') 35 | metric = load_metric('accuracy') 36 | task_type='multi_classification' 37 | metric_name='accuracy' 38 | elif data_name in ["ogbg-molpcba"]: 39 | loss = nn.BCEWithLogitsLoss(reduction='mean') 40 | 41 | metric = Evaluator(name=data_name) 42 | task_type='multi_binary_classification' 43 | metric_name='AP' 44 | 45 | else: 46 | raise ValueError('no such dataset') 47 | 48 | return loss, metric, task_type,metric_name 49 | 50 | 51 | 52 | def normalization(data_list,mean,std): 53 | for i in tqdm(range(len(data_list))): 54 | data_list[i] = (data_list[i].x-mean)/std 55 | return data_list 56 | 57 | 58 | def get_graph_level_dataset(name,param=None,seed=1024,set_default_params=False,args=None): 59 | 60 | path = 'dataset/'+name 61 | print(path) 62 | train_set = None 63 | val_set = None 64 | test_set = None 65 | inner_dataset = None 66 | train_idx=None 67 | val_idx=None 68 | test_idx=None 69 | 70 | #graph regression 71 | if name=='ZINC':#250,000 molecular graphs with up to 38 heavy atoms 72 | train_set = ZINC(path,subset=True,split='train') 73 | val_set = ZINC(path,subset=True,split='val') 74 | test_set = ZINC(path,subset=True,split='test') 75 | args.node_feature_type='cate' 76 | args.num_class =1 77 | args.eval_steps=1000 78 | args.save_steps=1000 79 | args.greater_is_better = False 80 | args.warmup_steps=40000 81 | args.max_steps=400000 82 | 83 | elif name == 'ZINC-full': # 250,000 molecular graphs with up to 38 heavy atoms 84 | train_set = ZINC(path, subset=False, split='train') 85 | val_set = ZINC(path, subset=False, split='val') 86 | test_set = ZINC(path, subset=False, split='test') 87 | args.node_feature_type = 'cate' 88 | args.num_class = 1 89 | args.eval_steps = 1000 90 | args.save_steps = 1000 91 | args.greater_is_better = False 92 | args.warmup_steps = 40000 93 | args.max_steps = 400000 94 | 95 | elif name == "ogbg-molpcba": 96 | inner_dataset = PygGraphPropPredDataset(name) 97 | idx_split = inner_dataset.get_idx_split() 98 | train_idx = idx_split["train"] 99 | val_idx = idx_split["valid"] 100 | test_idx = idx_split["test"] 101 | args.node_feature_type = 'cate' 102 | args.num_class = 128 103 | args.eval_steps = 2000 104 | args.save_steps = 2000 105 | args.greater_is_better = True 106 | args.warmup_steps = 40000 107 | args.max_steps = 1000000 108 | 109 | 110 | 111 | elif name == "ogbg-molhiv": 112 | inner_dataset = PygGraphPropPredDataset(name) 113 | idx_split = inner_dataset.get_idx_split() 114 | train_idx = idx_split["train"] 115 | val_idx = idx_split["valid"] 116 | test_idx = idx_split["test"] 117 | args.node_feature_type = 'cate' 118 | args.num_class = 1 119 | args.eval_steps = 1000 120 | args.save_steps = 1000 121 | args.greater_is_better = True 122 | args.warmup_steps = 40000 123 | args.max_steps = 1200000 124 | 125 | 126 | 127 | elif name=='UPFD' and param in ('politifact', 'gossipcop'): 128 | train_set = UPFD(path,param,'bert',split='train') 129 | val_set = UPFD(path,param,'bert',split='val') 130 | test_set = UPFD(path,param,'bert',split='test') 131 | args.learning_rate=1e-5 132 | args.node_feature_type='dense' 133 | args.node_feature_dim=768 134 | args.greater_is_better = True 135 | 136 | 137 | 138 | else: 139 | raise ValueError('no such dataset') 140 | 141 | 142 | dataset = GraphormerPYGDataset( 143 | dataset=inner_dataset, 144 | train_idx=train_idx, 145 | valid_idx=val_idx, 146 | test_idx=test_idx, 147 | train_set=train_set, 148 | valid_set=val_set, 149 | test_set=test_set, 150 | seed=seed, 151 | args=args 152 | ) 153 | return dataset.train_data,dataset.valid_data,dataset.test_data, inner_dataset 154 | 155 | 156 | def get_node_level_dataset(name,param=None,args=None): 157 | path = 'dataset/' + name 158 | print(path) 159 | 160 | if args.sampling_algo=='shadowkhop': 161 | args.num_neighbors=10 162 | elif args.sampling_algo=='sage': 163 | args.num_neighbors=50 164 | 165 | if name in ['cora','citeseer','dblp','pubmed']: 166 | dataset = CitationFull(f'dataset/{name}',name) 167 | 168 | 169 | elif name =='flickr': 170 | dataset = Flickr(path) 171 | x_norm_func = lambda x:x # 172 | 173 | args.node_feature_dim=500 174 | args.node_feature_type='dense' 175 | args.num_class =7 176 | 177 | args.encoder_normalize_before =True 178 | args.apply_graphormer_init =True 179 | args.greater_is_better = True 180 | 181 | args.warmup_steps=2000 182 | args.max_steps=100000 183 | 184 | train_idx = dataset.data.train_mask.nonzero().squeeze() 185 | valid_idx = dataset.data.val_mask.nonzero().squeeze() 186 | test_idx = dataset.data.test_mask.nonzero().squeeze() 187 | 188 | 189 | elif name=='ogbn-products': 190 | dataset = PygNodePropPredDataset(name='ogbn-products') 191 | split_idx = dataset.get_idx_split() 192 | train_idx, valid_idx, test_idx = split_idx["train"], split_idx["valid"], split_idx["test"] 193 | 194 | x_norm_func = lambda x:x 195 | 196 | args.node_feature_dim=100 197 | args.node_feature_type='dense' 198 | args.num_class =47 199 | 200 | args.encoder_normalize_before =True 201 | args.apply_graphormer_init =True 202 | args.greater_is_better = True 203 | 204 | args.warmup_steps=10000 205 | args.max_steps=400000 206 | 207 | 208 | elif name =='ogbn-arxiv': 209 | dataset = PygNodePropPredDataset(name='ogbn-arxiv') 210 | split_idx = dataset.get_idx_split() 211 | train_idx, valid_idx, test_idx = split_idx["train"], split_idx["valid"], split_idx["test"] 212 | 213 | x_norm_func = lambda x:x 214 | 215 | args.node_feature_dim=128 216 | args.node_feature_type='dense' 217 | args.num_class =40 218 | 219 | args.encoder_normalize_before =True 220 | args.apply_graphormer_init =True 221 | args.greater_is_better = True 222 | 223 | args.warmup_steps=10000 224 | args.max_steps=800000 225 | 226 | 227 | else: 228 | raise ValueError('no such dataset') 229 | 230 | 231 | if args.sampling_algo=='shadowkhop': 232 | Sampling_Dataset = Graphtrans_Sampling_Dataset 233 | elif args.sampling_algo=='sage': 234 | Sampling_Dataset = Graphtrans_Sampling_Dataset_v2 235 | args.num_neighbors=50 236 | 237 | 238 | train_set = Sampling_Dataset(dataset.data, 239 | node_idx=train_idx, 240 | depth=args.depth, 241 | num_neighbors=args.num_neighbors, 242 | replace=False, 243 | x_norm_func=x_norm_func, 244 | args=args) 245 | valid_set = Sampling_Dataset(dataset.data, 246 | node_idx=valid_idx, 247 | depth=args.depth, 248 | num_neighbors=args.num_neighbors, 249 | replace=False, 250 | x_norm_func=x_norm_func, 251 | args=args) 252 | test_set = Sampling_Dataset(dataset.data, 253 | node_idx=test_idx, 254 | depth=args.depth, 255 | num_neighbors=args.num_neighbors, 256 | replace=False, 257 | x_norm_func=x_norm_func, 258 | args=args) 259 | 260 | return train_set,valid_set,test_set, dataset, args 261 | 262 | 263 | 264 | 265 | 266 | 267 | 268 | #just test 269 | if __name__=='__main__': 270 | pass 271 | 272 | 273 | 274 | -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | from tensorflow.python.util import deprecation 2 | deprecation._PRINT_DEPRECATION_WARNINGS = False 3 | from sklearn.metrics import roc_auc_score, average_precision_score 4 | import tensorflow.compat.v1 as tf 5 | tf.disable_v2_behavior() 6 | from argparse import ArgumentParser, Namespace 7 | from data.collator import * 8 | from gt_dataset import * 9 | 10 | import gc 11 | from graphtrasformer.architectures import * 12 | import transformers 13 | from transformers import ( 14 | AutoConfig, 15 | AutoModelForSequenceClassification, 16 | AutoTokenizer, 17 | DataCollatorWithPadding, 18 | EvalPrediction, 19 | HfArgumentParser, 20 | Trainer, 21 | TrainingArguments, 22 | default_data_collator, 23 | set_seed, 24 | ) 25 | 26 | 27 | from sklearn import metrics 28 | import h5py 29 | import numpy as np 30 | import pandas as pd 31 | from tqdm import tqdm 32 | import logging 33 | import time 34 | import torch.onnx 35 | import os 36 | 37 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 38 | tf.logging.set_verbosity(tf.logging.ERROR) 39 | import warnings 40 | warnings.filterwarnings('ignore') 41 | 42 | 43 | 44 | 45 | def str_tuple(string): 46 | return tuple(string.split(',')) 47 | 48 | def boolean_string(s): 49 | if s not in {'False', 'True'}: 50 | raise ValueError('Not a valid boolean string') 51 | return s == 'True' 52 | 53 | def set_model_scale(model_scale,args): 54 | if model_scale=='mini': 55 | args.num_encoder_layers = 3 56 | args.hidden_dim = 64 57 | args.ffn_hidden_dim = 64 58 | args.num_attn_heads = 4 59 | elif model_scale=='small': 60 | args.num_encoder_layers = 6 61 | args.hidden_dim = 80 62 | args.ffn_hidden_dim = 80 63 | args.num_attn_heads = 8 64 | 65 | elif model_scale=='middle': 66 | args.num_encoder_layers = 12 67 | args.hidden_dim = 80 68 | args.ffn_hidden_dim = 80 69 | args.num_attn_heads = 8 70 | elif model_scale=='large': 71 | args.num_encoder_layers = 12 72 | args.hidden_dim = 512 73 | args.ffn_hidden_dim = 512 74 | args.num_attn_heads = 32 75 | return args 76 | 77 | def parse_args(): 78 | parser = ArgumentParser() 79 | 80 | parser.add_argument('--disable_tqdm',type=boolean_string,default=True)#just for debug 81 | 82 | parser.add_argument('--model_scale', type=str, default='small')#('small','middle','large') 83 | parser.add_argument('--data_name',type=str,default='ZINC') 84 | parser.add_argument('--data_param',type=str,default=None) 85 | #basic Transformer parameters 86 | parser.add_argument('--max_node', type=int, default=512) 87 | parser.add_argument('--num_encoder_layers', type=int, default=12) 88 | parser.add_argument('--hidden_dim', type=int, default=768) 89 | parser.add_argument('--ffn_hidden_dim', type=int, default=768*3) 90 | parser.add_argument('--num_attn_heads', type=int, default=32) 91 | parser.add_argument('--emb_dropout',type=float,default=0.0) 92 | parser.add_argument('--dropout', type=float, default=0.1) 93 | parser.add_argument('--attn_dropout', type=float, default=0.1) 94 | parser.add_argument('--num_class', type=int, default=1) 95 | parser.add_argument('--encoder_normalize_before', type=boolean_string, default=True) 96 | parser.add_argument('--apply_graphormer_init', type=boolean_string, default=True) 97 | parser.add_argument('--activation_fn', type=str, default='GELU') 98 | parser.add_argument('--n_trans_layers_to_freeze', type=int, default=0) 99 | parser.add_argument('--traceable', type=boolean_string, default=False) 100 | 101 | 102 | 103 | #various positional embedding parameters 104 | parser.add_argument('--use_super_node', type=boolean_string, default=True) 105 | parser.add_argument('--node_feature_type', type=str, default=None)#or dense 106 | parser.add_argument('--node_feature_dim', type=int, default=None)# valid only for dense feature type 107 | parser.add_argument('--num_atoms', type=int, default=512*9)# valid only for cate feature type 108 | parser.add_argument('--node_level_modules', type=str_tuple, default=())#,'eig','svd'))'degree' 109 | parser.add_argument('--eig_pos_dim',type=int, default=3)#2 110 | parser.add_argument('--svd_pos_dim',type=int, default=3) 111 | parser.add_argument('--num_in_degree', type=int, default=512) 112 | parser.add_argument('--num_out_degree', type=int, default=512) 113 | 114 | #various attention bias/mask parameters 115 | parser.add_argument('--attn_level_modules', type=str_tuple, default=())#,'nhop'))'spatial',spe 116 | parser.add_argument('--attn_mask_modules',type=str, default=None)#'nhop' 117 | parser.add_argument('--num_edges', type=int, default=512*3) 118 | parser.add_argument('--num_spatial', type=int, default=512) 119 | parser.add_argument('--num_edge_dis', type=int, default=128) 120 | parser.add_argument('--spatial_pos_max', type=int, default=20) 121 | parser.add_argument('--edge_type', type=str, default=None) 122 | parser.add_argument('--multi_hop_max_dist', type=int, default=5) 123 | parser.add_argument('--num_hop_bias', type=int, default=3)#2/3/4 124 | 125 | #gnn layers parameters. Insert gnn layers before/alternate/parallel self-attention layers 126 | #gnn layers are implemented by pytorch geometric for simplicity, so we always require data transformation across gnn layer and self-attention layers 127 | parser.add_argument('--use_gnn_layers', type=boolean_string, default=False) 128 | parser.add_argument('--gnn_insert_pos', type=str, default='before')#'before'/'alter'/'parallel' gnn insert position 129 | parser.add_argument('--num_gnn_layers', type=int, default=1) # 130 | parser.add_argument('--gnn_type',type=str,default='GAT') #GCN,SAGE,GAT,RGCN ... any types of GNN supported by Geometric 131 | parser.add_argument('--gnn_dropout',type=float,default=0.5) 132 | 133 | 134 | #sampling parameters 135 | parser.add_argument('--depth',type=int,default=2) 136 | parser.add_argument('--num_neighbors', type=int,default=10) 137 | parser.add_argument('--sampling_algo',type=str,default='shadowkhop')# or sage 138 | 139 | 140 | # training parameters, we use Trainer class from Huggingface Transformer, which is highly optimized specifically for Transformer architecture 141 | parser.add_argument('--seed', type=int, default=1) 142 | parser.add_argument('--output_dir',type=str)#G#v2'./output' 143 | parser.add_argument('--per_device_train_batch_size', type=int, default=256) 144 | parser.add_argument('--per_device_eval_batch_size', type=int, default=256) 145 | parser.add_argument('--gradient_accumulation_steps',type=int,default=1) 146 | parser.add_argument('--learning_rate',type=float, default=2e-4) 147 | parser.add_argument('--weight_decay',type=float,default=0.01) 148 | parser.add_argument('--adam_beta1',type=float,default='0.9') 149 | parser.add_argument('--adam_beta2',type=float,default='0.999') 150 | parser.add_argument('--adam_epsilon',type=float,default=1e-8) 151 | parser.add_argument('--max_grad_norm',type=float,default=5.0) 152 | parser.add_argument('--num_train_epochs',type=int,default=300) 153 | parser.add_argument('--max_steps',type=int,default=400000)#1000000 154 | parser.add_argument('--lr_scheduler_type', type=str, default='linear') 155 | parser.add_argument('--warmup_steps',type=int,default=40000) 156 | parser.add_argument('--dataloader_num_workers',type=int,default=32) 157 | parser.add_argument('--evaluation_strategy',type=str,default='steps') 158 | parser.add_argument('--eval_steps',type=int,default=1000) 159 | parser.add_argument('--save_steps',type=int,default=1000) 160 | parser.add_argument('--greater_is_better',type=boolean_string,default=True) 161 | 162 | parser.add_argument('--rerun',type=boolean_string,default=False) 163 | 164 | args = parser.parse_args() 165 | set_model_scale(args.model_scale,args) 166 | 167 | 168 | return args 169 | 170 | 171 | 172 | 173 | 174 | 175 | 176 | def expand_graph_level_dataset(dataset,N): 177 | train_data = list(dataset.dataset) 178 | dataset.dataset = N*train_data 179 | dataset.num_data*=N 180 | return dataset 181 | def expand_node_level_dataset(dataset,N): 182 | dataset.node_idx = dataset.node_idx.expand(N,-1).reshape(-1) 183 | dataset.num_data*=N 184 | return dataset 185 | 186 | 187 | 188 | 189 | 190 | if __name__=='__main__': 191 | args = parse_args() 192 | 193 | 194 | #for efficiency in Trainer. 195 | # I don't know why the Trainer Class will suspend several seconds after each epoch of dataloader, 196 | # So I expand the training set manually for efficiency 197 | expand_num_dict={'flickr':20, 198 | 'ZINC':20, 199 | 'ogbn-products':5, 200 | 'ogbg-molpcba':1, 201 | 'ZINC-full':1, 202 | 'ogbn-arxiv':1, 203 | "ogbg-molhiv":5} 204 | 205 | 206 | if args.data_name in ('flickr','ogbn-products','ogbn-arxiv'): 207 | train_set, valid_set, test_set, odata, args = get_node_level_dataset(args.data_name, args=args) 208 | train_set = expand_node_level_dataset(train_set,expand_num_dict[args.data_name]) 209 | 210 | elif args.data_name in ('ZINC','UPFD','ogbg-molpcba','ZINC-full',"ogbg-molhiv"): 211 | train_set,valid_set,test_set,odata = get_graph_level_dataset(args.data_name,param=args.data_param,set_default_params=True,args=args) 212 | train_set = expand_graph_level_dataset(train_set, expand_num_dict[args.data_name]) 213 | 214 | else: 215 | raise ValueError('no dataset') 216 | 217 | 218 | criterion, metric, task_type,metric_name=get_loss_and_metric(args.data_name) 219 | 220 | #print parameters 221 | for k,v in vars(args).items(): 222 | print(k,v) 223 | #========================model=============================== 224 | model=get_model(args) 225 | 226 | 227 | 228 | log_file_param_list = (args.data_name, 229 | args.model_scale, 230 | args.use_super_node, 231 | args.node_level_modules, 232 | args.eig_pos_dim, 233 | args.svd_pos_dim, 234 | args.attn_level_modules, 235 | args.attn_mask_modules, 236 | args.num_hop_bias, 237 | args.use_gnn_layers, 238 | args.gnn_insert_pos, 239 | args.num_gnn_layers, 240 | args.gnn_type, 241 | args.gnn_dropout, 242 | args.sampling_algo, 243 | args.depth, 244 | args.num_neighbors, 245 | args.seed) 246 | 247 | log_file_param_list_p=[] 248 | for x in log_file_param_list: 249 | if isinstance(x, tuple): 250 | if len(x)==0: 251 | x='None' 252 | else: 253 | x='+'.join(x) 254 | else: 255 | x=str(x) 256 | 257 | log_file_param_list_p.append(x) 258 | 259 | output_dir ='./outputs/'+'_'.join(log_file_param_list_p) 260 | log_file_path = output_dir+'/logs.json' 261 | 262 | setattr(args,'log_file_path',log_file_path) 263 | setattr(args,'output_dir', output_dir) 264 | 265 | 266 | 267 | ##huggingface trainer============================ 268 | def compute_metrics(p: EvalPrediction): 269 | preds,labels = p 270 | 271 | gc.collect() 272 | if task_type=='multi_classification': 273 | preds = np.argmax(preds, axis=1) 274 | labels = labels.astype(np.long) 275 | return metric.compute(predictions=preds, references=labels) 276 | 277 | elif task_type=='multi_binary_classification' and metric_name=='AP': 278 | preds = torch.sigmoid(torch.tensor(preds)).numpy() 279 | return {metric_name:metric.eval({'y_true':labels,'y_pred':preds})['ap']}#输入的格式,输出的格式都要确认 #确认node edge特征是否正确处理 280 | 281 | elif task_type=='regression': 282 | return {metric_name:metric(torch.tensor(preds),torch.tensor(labels)).item()}#mae 283 | 284 | elif task_type=='binary_classification' and metric_name=='ROC-AUC': 285 | return {metric_name:roc_auc_score(y_true=labels,y_score=torch.sigmoid(torch.tensor(preds)).numpy())} 286 | elif task_type=='binary_classification' and metric_name=='accuracy': 287 | return metric.compute(predictions=torch.sigmoid(torch.tensor(preds)), references=labels) 288 | 289 | 290 | 291 | from transformers import TrainerCallback,TrainerState,TrainerControl,EarlyStoppingCallback 292 | class MyCallback(TrainerCallback): 293 | def on_log(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): 294 | print("save logs...") 295 | state.save_to_json(log_file_path) 296 | 297 | 298 | 299 | class MyTrainer(Trainer): 300 | def compute_loss(self, model, inputs, return_outputs=False): 301 | self.inputs=inputs 302 | labels = inputs['labels'] 303 | outputs = model(inputs) 304 | 305 | labels = labels.long() if task_type=='multi_classification' else labels.float() 306 | if task_type=='multi_binary_classification': 307 | labels = labels.reshape(-1) 308 | mask = ~torch.isnan(labels) 309 | loss = criterion(outputs['logits'].reshape(-1)[mask],labels[mask]) 310 | 311 | else: 312 | loss = criterion(outputs['logits'],labels) 313 | return (loss,outputs) if return_outputs else loss 314 | 315 | 316 | training_args = TrainingArguments( 317 | output_dir=args.output_dir, 318 | evaluation_strategy=args.evaluation_strategy, 319 | eval_steps=args.eval_steps, 320 | per_device_train_batch_size=args.per_device_train_batch_size, 321 | per_device_eval_batch_size=args.per_device_eval_batch_size, 322 | gradient_accumulation_steps=args.gradient_accumulation_steps, 323 | learning_rate=args.learning_rate, 324 | weight_decay=args.weight_decay, 325 | adam_beta1=args.adam_beta1, 326 | adam_beta2=args.adam_beta2, 327 | adam_epsilon=args.adam_epsilon, 328 | max_grad_norm=args.max_grad_norm, 329 | num_train_epochs=args.num_train_epochs, 330 | max_steps=args.max_steps, 331 | lr_scheduler_type=args.lr_scheduler_type, 332 | warmup_steps=args.warmup_steps, 333 | dataloader_num_workers=args.dataloader_num_workers,# sensitive 334 | load_best_model_at_end=True, 335 | metric_for_best_model=metric_name, 336 | greater_is_better=args.greater_is_better, 337 | save_steps=args.save_steps, 338 | save_total_limit=10, 339 | logging_steps=args.eval_steps, 340 | seed=args.seed 341 | 342 | ) 343 | 344 | 345 | training_args.disable_tqdm=args.disable_tqdm 346 | training_args.ignore_data_skip=True 347 | 348 | 349 | 350 | resume_from_checkpoint = True if (check_checkpoints(args.output_dir) and not args.rerun) else None 351 | 352 | trainer = MyTrainer( 353 | model=model, 354 | args=training_args, 355 | train_dataset=train_set, 356 | eval_dataset=valid_set, 357 | compute_metrics=compute_metrics, 358 | data_collator=lambda x:collator(x,args), 359 | callbacks=[MyCallback,EarlyStoppingCallback(early_stopping_patience=20)] 360 | ) 361 | trainer.args._n_gpu=1 362 | 363 | print(trainer.evaluate()) 364 | trainer.train(resume_from_checkpoint=resume_from_checkpoint) 365 | 366 | 367 | predictions, labels, test_metrics = trainer.predict(test_set, metric_key_prefix="predict") 368 | test_metrics['best_val_metric']=trainer.state.best_metric 369 | test_metrics['best_model_checkpoint']=trainer.state.best_model_checkpoint 370 | f=open(args.output_dir+'/test.txt','w') 371 | for k,v in test_metrics.items(): 372 | f.write(str(k)+':'+str(v)+'\n') 373 | f.write('\n') 374 | f.close() 375 | 376 | 377 | 378 | 379 | 380 | 381 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import scipy.sparse as sp 4 | from numpy.linalg import inv 5 | import pickle 6 | 7 | from torch_geometric.datasets import * 8 | 9 | import torch 10 | import numpy as np 11 | from torch_sparse.matmul import matmul 12 | from torch_sparse import SparseTensor 13 | 14 | 15 | c = 0.15 16 | k = 5 17 | 18 | 19 | def adj_normalize(mx): 20 | rowsum = np.array(mx.sum(1)) 21 | r_inv = np.power(rowsum, -0.5).flatten() 22 | r_inv[np.isinf(r_inv)] = 0. 23 | r_mat_inv = sp.diags(r_inv) 24 | mx = r_mat_inv.dot(mx).dot(r_mat_inv) 25 | return mx 26 | 27 | 28 | def get_intimacy_matrix(edges,n): 29 | edges= np.array(edges) 30 | adj = sp.coo_matrix((np.ones(edges.shape[0]), (edges[:, 0], edges[:, 1])), 31 | shape=(n,n), 32 | dtype=np.float32) 33 | print('normalize') 34 | adj_norm = adj_normalize(adj) 35 | print('inverse') 36 | eigen_adj = c * inv((sp.eye(adj.shape[0]) - (1 - c) * adj_norm).toarray()) 37 | 38 | return eigen_adj 39 | 40 | 41 | def adj_normalize_sparse(mx): 42 | mx=mx.to(device) 43 | rowsum = mx.sum(1) 44 | r_inv =rowsum.pow(-0.5).flatten() 45 | r_inv[torch.isinf(r_inv)] = 0. 46 | r_mat_inv = SparseTensor(row = torch.arange(n).to(device),col=torch.arange(n).to(device),value=r_inv, sparse_sizes=(n,n)) 47 | nr_mx = matmul(matmul(r_mat_inv,mx),r_mat_inv) 48 | return nr_mx 49 | 50 | def get_intimacy_matrix_sparse(edges,n): 51 | adj = SparseTensor(row=edges[0], col=edges[1], value=torch.ones(edges.shape[1]), sparse_sizes=(n, n)) 52 | adj_norm = adj_normalize_sparse(adj) 53 | return adj_norm 54 | 55 | def get_svd_dense(mx,q=3): 56 | mx = mx.float() 57 | u,s,v = torch.svd_lowrank(mx,q=q) 58 | s=torch.diag(s) 59 | pu = u@s.pow(0.5) 60 | pv = v@s.pow(0.5) 61 | return pu,pv 62 | 63 | 64 | def unweighted_adj_normalize_dense_batch(adj): 65 | adj = (adj+adj.transpose(-1,-2)).bool().float() 66 | adj = adj.float() 67 | rowsum = adj.sum(-1) 68 | r_inv = rowsum.pow(-0.5) 69 | r_mat_inv = torch.diag_embed(r_inv) 70 | nr_adj = torch.matmul(torch.matmul(r_mat_inv,adj),r_mat_inv) 71 | return nr_adj 72 | 73 | 74 | def get_eig_dense(adj): 75 | adj = adj.float() 76 | rowsum = adj.sum(1) 77 | r_inv =rowsum.pow(-0.5) 78 | r_mat_inv = torch.diag(r_inv) 79 | nr_adj = torch.matmul(torch.matmul(r_mat_inv,adj),r_mat_inv) 80 | graph_laplacian = torch.eye(adj.shape[0])-nr_adj 81 | L,V = torch.eig(graph_laplacian,eigenvectors=True) 82 | return L.T[0],V 83 | 84 | 85 | 86 | def check_checkpoints(output_dir): 87 | import os 88 | import shutil 89 | if os.path.exists(output_dir): 90 | files = os.listdir(output_dir) 91 | for file in files: 92 | if 'checkpoint' in file: 93 | 94 | return True 95 | print('remove ',output_dir) 96 | shutil.rmtree(output_dir) 97 | return False 98 | 99 | 100 | if __name__=='__main__': 101 | #just test 102 | 103 | device = torch.device('cuda',0) 104 | 105 | data = Flickr('dataset/flickr') 106 | 107 | edges= data.data.edge_index 108 | n=data.data.x.shape[0] 109 | 110 | 111 | adj = SparseTensor(row=edges[0], col=edges[1], value=torch.ones(edges.shape[1]), sparse_sizes=(n, n)) 112 | nr_adj = adj_normalize_sparse(adj) 113 | 114 | pu,pv= get_svd_dense(nr_adj.to_torch_sparse_coo_tensor(),q=10) 115 | 116 | 117 | adj= (torch.randn(10,10)>0).float() 118 | L,V = get_eig_dense(adj) 119 | --------------------------------------------------------------------------------