├── Graph
├── config
│ ├── hiv.json
│ ├── pcba.json
│ ├── pcqm.json
│ ├── pcqms.json
│ └── zinc.json
├── dgl_main.py
├── dgldataclass.py
├── get_dataset.py
├── large_model.py
├── master.csv
├── medium_model.py
├── small_model.py
└── zinc_model.py
├── LICENSE
├── Node
├── config.yaml
├── data
│ └── chameleon.pt.zip
├── main_node.py
├── master.csv
├── model_node.py
├── node_raw_data
│ ├── 2Dgrid.mat
│ ├── Penn94.mat
│ ├── actor
│ │ ├── out1_graph_edges.txt
│ │ └── out1_node_feature_label.txt
│ ├── amazon_electronics_photo.npz
│ ├── chameleon
│ │ ├── out1_graph_edges.txt
│ │ └── out1_node_feature_label.txt
│ ├── citeseer
│ │ ├── ind.citeseer.allx
│ │ ├── ind.citeseer.ally
│ │ ├── ind.citeseer.graph
│ │ ├── ind.citeseer.test.index
│ │ ├── ind.citeseer.tx
│ │ ├── ind.citeseer.ty
│ │ ├── ind.citeseer.x
│ │ └── ind.citeseer.y
│ ├── cora
│ │ ├── ind.cora.allx
│ │ ├── ind.cora.ally
│ │ ├── ind.cora.graph
│ │ ├── ind.cora.test.index
│ │ ├── ind.cora.tx
│ │ ├── ind.cora.ty
│ │ ├── ind.cora.x
│ │ └── ind.cora.y
│ ├── fb100-Penn94-splits.npy
│ └── squirrel
│ │ ├── out1_graph_edges.txt
│ │ └── out1_node_feature_label.txt
├── preprocess_node_data.py
└── utils.py
├── README.md
└── requirements.txt
/Graph/config/hiv.json:
--------------------------------------------------------------------------------
1 | {
2 | "nlayer": 8,
3 | "nheads": 4,
4 | "hidden_dim": 80,
5 | "trans_dropout": 0.1,
6 | "feat_dropout": 0.1,
7 | "adj_dropout": 0.3,
8 | "lr": 1e-4,
9 | "weight_decay": 1e-4,
10 | "epochs": 50,
11 | "warm_up_epoch": 5,
12 | "batch_size": 64
13 | }
14 |
--------------------------------------------------------------------------------
/Graph/config/pcba.json:
--------------------------------------------------------------------------------
1 | {
2 | "nlayer": 8,
3 | "nheads": 8,
4 | "hidden_dim": 272,
5 | "trans_dropout": 0.3,
6 | "feat_dropout": 0.1,
7 | "adj_dropout": 0.1,
8 | "lr": 5e-4,
9 | "weight_decay": 5e-3,
10 | "epochs": 30,
11 | "warm_up_epoch": 5,
12 | "batch_size": 64
13 | }
14 |
--------------------------------------------------------------------------------
/Graph/config/pcqm.json:
--------------------------------------------------------------------------------
1 | {
2 | "nlayer": 10,
3 | "nheads": 16,
4 | "hidden_dim": 400,
5 | "trans_dropout": 0.05,
6 | "feat_dropout": 0.05,
7 | "adj_dropout": 0.05,
8 | "lr": 2e-4,
9 | "weight_decay": 0.0,
10 | "epochs": 150,
11 | "warm_up_epoch": 10,
12 | "batch_size": 64
13 | }
14 |
--------------------------------------------------------------------------------
/Graph/config/pcqms.json:
--------------------------------------------------------------------------------
1 | {
2 | "nlayer": 6,
3 | "nheads": 8,
4 | "hidden_dim": 240,
5 | "filter_poly": 12,
6 | "trans_dropout": 0.3,
7 | "feat_dropout": 0.1,
8 | "adj_dropout": 0.1,
9 | "lr": 5e-4,
10 | "weight_decay": 5e-4,
11 | "epochs": 100,
12 | "warm_up_epoch": 5,
13 | "batch_size": 256
14 | }
15 |
--------------------------------------------------------------------------------
/Graph/config/zinc.json:
--------------------------------------------------------------------------------
1 | {
2 | "nlayer": 4,
3 | "nheads": 8,
4 | "hidden_dim": 160,
5 | "trans_dropout": 0.1,
6 | "feat_dropout": 0.05,
7 | "adj_dropout": 0.0,
8 | "lr": 1e-3,
9 | "weight_decay": 5e-4,
10 | "epochs": 1000,
11 | "warm_up_epoch": 50,
12 | "batch_size": 32
13 | }
14 |
--------------------------------------------------------------------------------
/Graph/dgl_main.py:
--------------------------------------------------------------------------------
1 | import time
2 | import math
3 | import copy
4 | import wandb
5 | import argparse
6 | import datetime
7 | import random, os
8 | import numpy as np
9 | import torch
10 | import torch.nn as nn
11 | import torch.nn.functional as F
12 | from torch.utils.data import DataLoader, Dataset, TensorDataset
13 | from torch.optim.lr_scheduler import LambdaLR
14 | import json5
15 | from easydict import EasyDict
16 |
17 | from ema_pytorch import EMA
18 | from zinc_model import SpecformerZINC
19 | from large_model import SpecformerLarge
20 | from medium_model import SpecformerMedium
21 | from small_model import SpecformerSmall
22 | from get_dataset import DynamicBatchSampler, RandomSampler, collate_pad, collate_dgl, get_dataset
23 |
24 |
25 | def init_params(module):
26 | if isinstance(module, nn.Linear):
27 | module.weight.data.normal_(mean=0.0, std=0.02)
28 | if module.bias is not None:
29 | module.bias.data.zero_()
30 | if isinstance(module, nn.Embedding):
31 | module.weight.data.normal_(mean=0.0, std=0.02)
32 |
33 |
34 | def get_config_from_json(json_file):
35 | with open('config/' + json_file + '.json', 'r') as config_file:
36 | config_dict = json5.load(config_file)
37 | config = EasyDict(config_dict)
38 |
39 | return config
40 |
41 |
42 | def seed_everything(seed):
43 | random.seed(seed)
44 | os.environ['PYTHONHASHSEED'] = str(seed)
45 | np.random.seed(seed)
46 | torch.manual_seed(seed)
47 | torch.cuda.manual_seed(seed)
48 | torch.backends.cudnn.deterministic = True
49 | torch.backends.cudnn.benchmark = True
50 | torch.backends.cudnn.allow_tf32 = False
51 |
52 |
53 | def count_parameters(model):
54 | return sum(p.numel() for p in model.parameters() if p.requires_grad)
55 |
56 |
57 | def train_epoch(dataset, model, device, dataloader, loss_fn, optimizer, wandb=None, wandb_item=None):
58 | model.train()
59 |
60 | for i, data in enumerate(dataloader):
61 | e, u, g, length, y = data
62 | e, u, g, length, y = e.to(device), u.to(device), g.to(device), length.to(device), y.to(device)
63 |
64 | logits = model(e, u, g, length)
65 | optimizer.zero_grad()
66 |
67 | y_idx = y == y
68 | loss = loss_fn(logits.to(torch.float32)[y_idx], y.to(torch.float32)[y_idx])
69 |
70 | loss.backward()
71 | optimizer.step()
72 |
73 | if wandb:
74 | wandb.log({wandb_item: loss.item()})
75 |
76 |
77 | def eval_epoch(dataset, model, device, dataloader, evaluator, metric):
78 | model.eval()
79 |
80 | y_true = []
81 | y_pred = []
82 |
83 | with torch.no_grad():
84 | for i, data in enumerate(dataloader):
85 | e, u, g, length, y = data
86 | e, u, g, length, y = e.to(device), u.to(device), g.to(device), length.to(device), y.to(device)
87 |
88 | logits = model(e, u, g, length)
89 |
90 | y_true.append(y.view(logits.shape).detach().cpu())
91 | y_pred.append(logits.detach().cpu())
92 |
93 | y_true = torch.cat(y_true, dim=0).numpy()
94 | y_pred = torch.cat(y_pred, dim=0).numpy()
95 |
96 | return evaluator.eval({'y_true': y_true, 'y_pred': y_pred})[metric]
97 |
98 |
99 | def main_worker(args):
100 | seed_everything(args.seed)
101 | rank = 'cuda:{}'.format(args.cuda)
102 | print(args)
103 |
104 | datainfo = get_dataset(args.dataset)
105 | nclass = datainfo['num_class']
106 | loss_fn = datainfo['loss_fn']
107 | evaluator = datainfo['evaluator']
108 | train = datainfo['train_dataset']
109 | valid = datainfo['valid_dataset']
110 | test = datainfo['test_dataset']
111 | metric = datainfo['metric']
112 | metric_mode = datainfo['metric_mode']
113 |
114 | # dataloader
115 | '''
116 | train_batch_sampler = DynamicBatchSampler(RandomSampler(train), [data.num_nodes for data in train],
117 | batch_size=32, max_nodes=50, drop_last=False)
118 | valid_batch_sampler = DynamicBatchSampler(RandomSampler(valid), [data.num_nodes for data in valid],
119 | batch_size=32, max_nodes=50, drop_last=False)
120 | test_batch_sampler = DynamicBatchSampler(RandomSampler(test), [data.num_nodes for data in test],
121 | batch_size=32, max_nodes=50, drop_last=False)
122 | train_dataloader = DataLoader(train, batch_sampler=train_batch_sampler, collate_fn=collate_pad)
123 | valid_dataloader = DataLoader(valid, batch_sampler=valid_batch_sampler, collate_fn=collate_pad)
124 | test_dataloader = DataLoader(test, batch_sampler=test_batch_sampler, collate_fn=collate_pad)
125 | '''
126 |
127 | train_dataloader = DataLoader(train, batch_size = args.batch_size, num_workers=4, collate_fn=collate_dgl, shuffle = True)
128 | valid_dataloader = DataLoader(valid, batch_size = args.batch_size // 2, num_workers=4, collate_fn=collate_dgl, shuffle = False)
129 | test_dataloader = DataLoader(test, batch_size = args.batch_size // 2, num_workers=4, collate_fn=collate_dgl, shuffle = False)
130 |
131 | if args.dataset == 'zinc':
132 | print('zinc')
133 | model = SpecformerZINC(nclass, args.nlayer, args.hidden_dim, args.nheads,
134 | args.feat_dropout, args.trans_dropout, args.adj_dropout).to(rank)
135 |
136 | elif args.dataset == 'pcqm' or args.dataset == 'pcqms':
137 | print('pcqm')
138 | model = SpecformerLarge(nclass, args.nlayer, args.hidden_dim, args.nheads,
139 | args.feat_dropout, args.trans_dropout, args.adj_dropout).to(rank)
140 | print('init')
141 | model.apply(init_params)
142 |
143 | elif args.dataset == 'pcba':
144 | print('pcba')
145 | model = SpecformerMedium(nclass, args.nlayer, args.hidden_dim, args.nheads,
146 | args.feat_dropout, args.trans_dropout, args.adj_dropout).to(rank)
147 | model.apply(init_params)
148 |
149 | else:
150 | print('hiv')
151 | model = SpecformerSmall(nclass, args.nlayer, args.hidden_dim, args.nheads,
152 | args.feat_dropout, args.trans_dropout, args.adj_dropout).to(rank)
153 |
154 | print(count_parameters(model))
155 |
156 | optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, betas=(0.9, 0.999), eps=1e-8, weight_decay=args.weight_decay)
157 | # warm_up + cosine weight decay
158 | lr_plan = lambda cur_epoch: (cur_epoch+1) / args.warm_up_epoch if cur_epoch < args.warm_up_epoch else \
159 | (0.5 * (1.0 + math.cos(math.pi * (cur_epoch - args.warm_up_epoch) / (args.epochs - args.warm_up_epoch))))
160 | scheduler = LambdaLR(optimizer, lr_lambda=lr_plan)
161 |
162 | results = []
163 | for epoch in range(args.epochs):
164 |
165 | train_epoch(args.dataset, model, rank, train_dataloader, loss_fn, optimizer, wandb=None, wandb_item='loss')
166 | scheduler.step()
167 |
168 | torch.save(model.state_dict(), 'checkpoint/{}_{}.pth'.format(args.project_name, epoch))
169 |
170 | if epoch % 1 == 0:
171 |
172 | val_res = eval_epoch(args.dataset, model, rank, valid_dataloader, evaluator, metric)
173 | test_res = eval_epoch(args.dataset, model, rank, test_dataloader, evaluator, metric)
174 |
175 | results.append([val_res, test_res])
176 |
177 | if metric_mode == 'min':
178 | best_res = sorted(results, key = lambda x: x[0], reverse=False)[0][1]
179 | else:
180 | best_res = sorted(results, key = lambda x: x[0], reverse=True)[0][1]
181 |
182 | print(epoch, 'valid: {:.4f}'.format(val_res), 'test: {:.4f}'.format(test_res), 'best: {:.4f}'.format(best_res))
183 |
184 | # wandb.log({'val': val_res, 'test': test_res})
185 |
186 | torch.save(model.state_dict(), 'checkpoint/{}.pth'.format(args.project_name))
187 |
188 |
189 | if __name__ == '__main__':
190 | parser = argparse.ArgumentParser()
191 | parser.add_argument('--seed', type=int, default=0)
192 | parser.add_argument('--cuda', type=int, default=0)
193 | parser.add_argument('--dataset', default='zinc')
194 |
195 | args = parser.parse_args()
196 | args.project_name = datetime.datetime.now().strftime('%m-%d-%X')
197 |
198 | config = get_config_from_json(args.dataset)
199 |
200 | for key in config.keys():
201 | setattr(args, key, config[key])
202 |
203 | main_worker(args)
204 |
205 |
--------------------------------------------------------------------------------
/Graph/dgldataclass.py:
--------------------------------------------------------------------------------
1 | import pickle
2 | import pandas as pd
3 | import shutil, os
4 | import os.path as osp
5 | import torch
6 | import numpy as np
7 | import dgl
8 | from dgl.data.utils import load_graphs, save_graphs, Subset
9 | from ogb.utils import smiles2graph
10 | from ogb.utils.torch_util import replace_numpy_with_torchtensor
11 | from ogb.utils.url import decide_download, download_url, extract_zip
12 | from ogb.io.read_graph_raw import read_csv_graph_raw, read_csv_heterograph_raw, read_binary_graph_raw, read_binary_heterograph_raw
13 | from tqdm import tqdm
14 | from torch_geometric.utils import to_dense_adj, remove_isolated_nodes
15 | from torch_geometric.data import InMemoryDataset
16 |
17 |
18 | def read_graph_dgl(raw_dir, add_inverse_edge = False, additional_node_files = [], additional_edge_files = [], binary=False):
19 |
20 | if binary:
21 | # npz
22 | graph_list = read_binary_graph_raw(raw_dir, add_inverse_edge)
23 | else:
24 | # csv
25 | graph_list = read_csv_graph_raw(raw_dir, add_inverse_edge, additional_node_files = additional_node_files, additional_edge_files = additional_edge_files)
26 |
27 | dgl_graph_list = []
28 |
29 | print('Converting graphs into DGL objects...')
30 |
31 | for graph in tqdm(graph_list):
32 |
33 | src, dst = torch.from_numpy(graph['edge_index'])
34 | num_nodes = graph['num_nodes']
35 |
36 | if num_nodes == 1: # some graphs have one node
37 | A_ = torch.tensor(1.).view(1, 1)
38 | else:
39 | A = torch.zeros([num_nodes, num_nodes], dtype=torch.float)
40 | A[src, dst] = 1.0
41 | for i in range(num_nodes):
42 | A[i, i] = 1.0
43 | deg = torch.sum(A, axis=0).squeeze() ** -0.5
44 | D = torch.diag(deg)
45 | A_ = D @ A @ D
46 | e, u = torch.linalg.eigh(A_)
47 |
48 | fully_connected = torch.ones([num_nodes, num_nodes], dtype=torch.float).nonzero(as_tuple=True)
49 | g = dgl.graph(fully_connected, num_nodes = num_nodes)
50 |
51 | g.ndata['e'] = e
52 | g.ndata['u'] = u
53 |
54 | if graph['node_feat'] is not None:
55 | g.ndata['feat'] = torch.from_numpy(graph['node_feat'])
56 |
57 | if graph['edge_feat'] is not None:
58 | edge_idx = torch.stack([src, dst], dim=0)
59 | edge_attr = torch.from_numpy(graph['edge_feat']) + 1 # for padding
60 |
61 | if len(edge_attr.shape) == 1:
62 | edge_attr_dense = to_dense_adj(edge_idx, edge_attr=edge_attr.unsqueeze(-1)).squeeze(0).squeeze(-1).view(-1)
63 | else:
64 | if edge_attr.size(0) == 0:
65 | edge_attr_dense = torch.zeros([num_nodes ** 2, edge_attr.size(1)]).long() # for graphs without edge
66 | else:
67 | edge_attr_dense = to_dense_adj(edge_idx, edge_attr=edge_attr, max_num_nodes=num_nodes).squeeze(0).view(-1, edge_attr.shape[-1])
68 |
69 | g.edata['feat'] = edge_attr_dense
70 |
71 | dgl_graph_list.append(g)
72 |
73 | return dgl_graph_list
74 |
75 |
76 | class DglGraphPropPredDataset(object):
77 | '''Adapted from https://docs.dgl.ai/en/latest/_modules/dgl/data/chem/csv_dataset.html#CSVDataset'''
78 | def __init__(self, name, root = 'dataset', meta_dict = None):
79 | '''
80 | - name (str): name of the dataset
81 | - root (str): root directory to store the dataset folder
82 | - meta_dict: dictionary that stores all the meta-information about data. Default is None,
83 | but when something is passed, it uses its information. Useful for debugging for external contributers.
84 | '''
85 |
86 | self.name = name ## original name, e.g., ogbg-molhiv
87 |
88 | if meta_dict is None:
89 | self.dir_name = '_'.join(name.split('-'))
90 |
91 | # check if previously-downloaded folder exists.
92 | # If so, use that one.
93 | if osp.exists(osp.join(root, self.dir_name + '_dgl')):
94 | self.dir_name = self.dir_name + '_dgl'
95 |
96 | self.original_root = root
97 | self.root = osp.join(root, self.dir_name)
98 |
99 | master = pd.read_csv(os.path.join(os.path.dirname(__file__), 'master.csv'), index_col = 0)
100 | if not self.name in master:
101 | error_mssg = 'Invalid dataset name {}.\n'.format(self.name)
102 | error_mssg += 'Available datasets are as follows:\n'
103 | error_mssg += '\n'.join(master.keys())
104 | raise ValueError(error_mssg)
105 | self.meta_info = master[self.name]
106 |
107 | else:
108 | self.dir_name = meta_dict['dir_path']
109 | self.original_root = ''
110 | self.root = meta_dict['dir_path']
111 | self.meta_info = meta_dict
112 |
113 | # check version
114 | # First check whether the dataset has been already downloaded or not.
115 | # If so, check whether the dataset version is the newest or not.
116 | # If the dataset is not the newest version, notify this to the user.
117 | if osp.isdir(self.root) and (not osp.exists(osp.join(self.root, 'RELEASE_v' + str(self.meta_info['version']) + '.txt'))):
118 | print(self.name + ' has been updated.')
119 | if input('Will you update the dataset now? (y/N)\n').lower() == 'y':
120 | shutil.rmtree(self.root)
121 |
122 | self.download_name = self.meta_info['download_name'] ## name of downloaded file, e.g., tox21
123 |
124 | self.num_tasks = int(self.meta_info['num tasks'])
125 | self.eval_metric = self.meta_info['eval metric']
126 | self.task_type = self.meta_info['task type']
127 | self.num_classes = self.meta_info['num classes']
128 | self.binary = self.meta_info['binary'] == 'True'
129 |
130 | super(DglGraphPropPredDataset, self).__init__()
131 |
132 | self.pre_process()
133 |
134 | def pre_process(self):
135 | processed_dir = osp.join(self.root, 'processed')
136 | raw_dir = osp.join(self.root, 'raw')
137 | pre_processed_file_path = osp.join(processed_dir, 'dgl_data_processed')
138 |
139 | if self.task_type == 'subtoken prediction':
140 | target_sequence_file_path = osp.join(processed_dir, 'target_sequence')
141 |
142 | if os.path.exists(pre_processed_file_path):
143 |
144 | if self.task_type == 'subtoken prediction':
145 | self.graphs, _ = load_graphs(pre_processed_file_path)
146 | self.labels = torch.load(target_sequence_file_path)
147 |
148 | else:
149 | self.graphs, label_dict = load_graphs(pre_processed_file_path)
150 | self.labels = label_dict['labels']
151 |
152 | else:
153 | ### check download
154 | if self.binary:
155 | # npz format
156 | has_necessary_file = osp.exists(osp.join(self.root, 'raw', 'data.npz'))
157 | else:
158 | # csv file
159 | has_necessary_file = osp.exists(osp.join(self.root, 'raw', 'edge.csv.gz'))
160 |
161 | ### download
162 | if not has_necessary_file:
163 | url = self.meta_info['url']
164 | if decide_download(url):
165 | path = download_url(url, self.original_root)
166 | extract_zip(path, self.original_root)
167 | os.unlink(path)
168 | # delete folder if there exists
169 | try:
170 | shutil.rmtree(self.root)
171 | except:
172 | pass
173 | shutil.move(osp.join(self.original_root, self.download_name), self.root)
174 | else:
175 | print('Stop download.')
176 | exit(-1)
177 |
178 | ### preprocess
179 | add_inverse_edge = self.meta_info['add_inverse_edge'] == 'True'
180 |
181 | if self.meta_info['additional node files'] == 'None':
182 | additional_node_files = []
183 | else:
184 | additional_node_files = self.meta_info['additional node files'].split(',')
185 |
186 | if self.meta_info['additional edge files'] == 'None':
187 | additional_edge_files = []
188 | else:
189 | additional_edge_files = self.meta_info['additional edge files'].split(',')
190 |
191 | graphs = read_graph_dgl(raw_dir, add_inverse_edge = add_inverse_edge, additional_node_files = additional_node_files, additional_edge_files = additional_edge_files, binary=self.binary)
192 |
193 | if self.task_type == 'subtoken prediction':
194 | # the downloaded labels are initially joined by ' '
195 | labels_joined = pd.read_csv(osp.join(raw_dir, 'graph-label.csv.gz'), compression='gzip', header = None).values
196 | # need to split each element into subtokens
197 | labels = [str(labels_joined[i][0]).split(' ') for i in range(len(labels_joined))]
198 |
199 | print('Saving...')
200 | save_graphs(pre_processed_file_path, graphs)
201 | torch.save(labels, target_sequence_file_path)
202 |
203 | ### load preprocessed files
204 | self.graphs, _ = load_graphs(pre_processed_file_path)
205 | self.labels = torch.load(target_sequence_file_path)
206 |
207 | else:
208 | if self.binary:
209 | labels = np.load(osp.join(raw_dir, 'graph-label.npz'))['graph_label']
210 | else:
211 | labels = pd.read_csv(osp.join(raw_dir, 'graph-label.csv.gz'), compression='gzip', header = None).values
212 |
213 | has_nan = np.isnan(labels).any()
214 |
215 | if 'classification' in self.task_type:
216 | if has_nan:
217 | labels = torch.from_numpy(labels).to(torch.float32)
218 | else:
219 | labels = torch.from_numpy(labels).to(torch.long)
220 | else:
221 | labels = torch.from_numpy(labels).to(torch.float32)
222 |
223 |
224 | print('Saving...')
225 | save_graphs(pre_processed_file_path, graphs, labels={'labels': labels})
226 |
227 | ### load preprocessed files
228 | self.graphs, label_dict = load_graphs(pre_processed_file_path)
229 | self.labels = label_dict['labels']
230 |
231 |
232 | def get_idx_split(self, split_type = None):
233 | if split_type is None:
234 | split_type = self.meta_info['split']
235 |
236 | path = osp.join(self.root, 'split', split_type)
237 |
238 | # short-cut if split_dict.pt exists
239 | if os.path.isfile(os.path.join(path, 'split_dict.pt')):
240 | return torch.load(os.path.join(path, 'split_dict.pt'))
241 |
242 | train_idx = pd.read_csv(osp.join(path, 'train.csv.gz'), compression='gzip', header = None).values.T[0]
243 | valid_idx = pd.read_csv(osp.join(path, 'valid.csv.gz'), compression='gzip', header = None).values.T[0]
244 | test_idx = pd.read_csv(osp.join(path, 'test.csv.gz'), compression='gzip', header = None).values.T[0]
245 |
246 | return {'train': torch.tensor(train_idx, dtype = torch.long), 'valid': torch.tensor(valid_idx, dtype = torch.long), 'test': torch.tensor(test_idx, dtype = torch.long)}
247 |
248 | def __getitem__(self, idx):
249 | '''Get datapoint with index'''
250 |
251 | if isinstance(idx, int):
252 | return self.graphs[idx], self.labels[idx]
253 | elif torch.is_tensor(idx) and idx.dtype == torch.long:
254 | if idx.dim() == 0:
255 | return self.graphs[idx], self.labels[idx]
256 | elif idx.dim() == 1:
257 | return Subset(self, idx.cpu())
258 |
259 | raise IndexError(
260 | 'Only integers and long are valid '
261 | 'indices (got {}).'.format(type(idx).__name__))
262 |
263 | def __len__(self):
264 | '''Length of the dataset
265 | Returns
266 | -------
267 | int
268 | Length of Dataset
269 | '''
270 | return len(self.graphs)
271 |
272 | def __repr__(self): # pragma: no cover
273 | return '{}({})'.format(self.__class__.__name__, len(self))
274 |
275 |
276 | class DglPCQM4Mv2Dataset(object):
277 | def __init__(self, root = 'dataset', smiles2graph = smiles2graph):
278 | '''
279 | DGL PCQM4Mv2 dataset object
280 | - root (str): the dataset folder will be located at root/pcqm4m_kddcup2021
281 | - smiles2graph (callable): A callable function that converts a SMILES string into a graph object
282 | * The default smiles2graph requires rdkit to be installed
283 | '''
284 |
285 | self.original_root = root
286 | self.smiles2graph = smiles2graph
287 | self.folder = osp.join(root, 'pcqm4m-v2')
288 | self.version = 1
289 |
290 | # Old url hosted at Stanford
291 | # md5sum: 65b742bafca5670be4497499db7d361b
292 | # self.url = f'http://ogb-data.stanford.edu/data/lsc/pcqm4m-v2.zip'
293 | # New url hosted by DGL team at AWS--much faster to download
294 | self.url = 'https://dgl-data.s3-accelerate.amazonaws.com/dataset/OGB-LSC/pcqm4m-v2.zip'
295 |
296 | # check version and update if necessary
297 | if osp.isdir(self.folder) and (not osp.exists(osp.join(self.folder, f'RELEASE_v{self.version}.txt'))):
298 | print('PCQM4Mv2 dataset has been updated.')
299 | if input('Will you update the dataset now? (y/N)\n').lower() == 'y':
300 | shutil.rmtree(self.folder)
301 |
302 | super(DglPCQM4Mv2Dataset, self).__init__()
303 |
304 | # Prepare everything.
305 | # download if there is no raw file
306 | # preprocess if there is no processed file
307 | # load data if processed file is found.
308 | self.prepare_graph()
309 |
310 | def download(self):
311 | if decide_download(self.url):
312 | path = download_url(self.url, self.original_root)
313 | extract_zip(path, self.original_root)
314 | os.unlink(path)
315 | else:
316 | print('Stop download.')
317 | exit(-1)
318 |
319 | def prepare_graph(self):
320 | processed_dir = osp.join(self.folder, 'processed')
321 | raw_dir = osp.join(self.folder, 'raw')
322 | pre_processed_file_path = osp.join(processed_dir, 'dgl_data_processed')
323 |
324 | if osp.exists(pre_processed_file_path):
325 | # if pre-processed file already exists
326 | self.graphs, label_dict = load_graphs(pre_processed_file_path)
327 | self.labels = label_dict['labels']
328 | else:
329 | # if pre-processed file does not exist
330 |
331 | if not osp.exists(osp.join(raw_dir, 'data.csv.gz')):
332 | # if the raw file does not exist, then download it.
333 | self.download()
334 |
335 | data_df = pd.read_csv(osp.join(raw_dir, 'data.csv.gz'))
336 | smiles_list = data_df['smiles']
337 | homolumogap_list = data_df['homolumogap']
338 |
339 | print('Converting SMILES strings into graphs...')
340 | self.graphs = []
341 | self.labels = []
342 | for i in tqdm(range(len(smiles_list))):
343 |
344 | smiles = smiles_list[i]
345 | homolumogap = homolumogap_list[i]
346 | graph = self.smiles2graph(smiles)
347 |
348 | assert(len(graph['edge_feat']) == graph['edge_index'].shape[1])
349 | assert(len(graph['node_feat']) == graph['num_nodes'])
350 |
351 | src, dst = torch.from_numpy(graph['edge_index'])
352 | num_nodes = graph['num_nodes']
353 |
354 | if num_nodes == 1: # some graphs have one node
355 | A_ = torch.tensor(1.).view(1, 1)
356 | else:
357 | A = torch.zeros([num_nodes, num_nodes], dtype=torch.float)
358 | A[src, dst] = 1.0
359 | for i in range(num_nodes):
360 | A[i, i] = 1.0
361 | deg = torch.sum(A, axis=0).squeeze() ** -0.5
362 | D = torch.diag(deg)
363 | A_ = D @ A @ D
364 | e, u = torch.linalg.eigh(A_)
365 |
366 | fully_connected = torch.ones([num_nodes, num_nodes], dtype=torch.float).nonzero(as_tuple=True)
367 | g = dgl.graph(fully_connected, num_nodes = num_nodes)
368 |
369 | g.ndata['e'] = e
370 | g.ndata['u'] = u
371 |
372 | if graph['node_feat'] is not None:
373 | g.ndata['feat'] = torch.from_numpy(graph['node_feat']).long()
374 |
375 | if graph['edge_feat'] is not None:
376 | edge_idx = torch.stack([src, dst], dim=0)
377 | edge_attr = torch.from_numpy(graph['edge_feat']).long() + 1 # for padding
378 |
379 | if len(edge_attr.shape) == 1:
380 | edge_attr_dense = to_dense_adj(edge_idx, edge_attr=edge_attr.unsqueeze(-1)).squeeze(0).squeeze(-1).view(-1)
381 | else:
382 | if edge_attr.size(0) == 0: # for graphs without edge
383 | edge_attr_dense = torch.zeros([num_nodes ** 2, edge_attr.size(1)]).long()
384 | else:
385 | edge_attr_dense = to_dense_adj(edge_idx, edge_attr=edge_attr, max_num_nodes=num_nodes).squeeze(0).view(-1, edge_attr.shape[-1])
386 |
387 | g.edata['feat'] = edge_attr_dense
388 |
389 | self.graphs.append(g)
390 | self.labels.append(homolumogap)
391 |
392 | self.labels = torch.tensor(self.labels, dtype=torch.float32)
393 |
394 | # double-check prediction target
395 | split_dict = self.get_idx_split()
396 | assert(all([not torch.isnan(self.labels[i]) for i in split_dict['train']]))
397 | assert(all([not torch.isnan(self.labels[i]) for i in split_dict['valid']]))
398 | assert(all([torch.isnan(self.labels[i]) for i in split_dict['test-dev']]))
399 | assert(all([torch.isnan(self.labels[i]) for i in split_dict['test-challenge']]))
400 |
401 | print('Saving...')
402 | save_graphs(pre_processed_file_path, self.graphs, labels={'labels': self.labels})
403 |
404 |
405 | def get_idx_split(self):
406 | split_dict = replace_numpy_with_torchtensor(torch.load(osp.join(self.folder, 'split_dict.pt')))
407 | return split_dict
408 |
409 | def __getitem__(self, idx):
410 | '''Get datapoint with index'''
411 |
412 | if isinstance(idx, int):
413 | return self.graphs[idx], self.labels[idx]
414 | elif torch.is_tensor(idx) and idx.dtype == torch.long:
415 | if idx.dim() == 0:
416 | return self.graphs[idx], self.labels[idx]
417 | elif idx.dim() == 1:
418 | return Subset(self, idx.cpu())
419 |
420 | raise IndexError(
421 | 'Only integers and long are valid '
422 | 'indices (got {}).'.format(type(idx).__name__))
423 |
424 | def __len__(self):
425 | '''Length of the dataset
426 | Returns
427 | -------
428 | int
429 | Length of Dataset
430 | '''
431 | return len(self.graphs)
432 |
433 | def __repr__(self): # pragma: no cover
434 | return '{}({})'.format(self.__class__.__name__, len(self))
435 |
436 |
437 | class DglZincDataset(InMemoryDataset):
438 |
439 | url = 'https://www.dropbox.com/s/feo9qle74kg48gy/molecules.zip?dl=1'
440 | split_url = ('https://raw.githubusercontent.com/graphdeeplearning/'
441 | 'benchmarking-gnns/master/data/molecules/{}.index')
442 |
443 | def __init__(self, root, subset=False, split='train', transform=None,
444 | pre_transform=None, pre_filter=None):
445 | self.subset = subset
446 | assert split in ['train', 'val', 'test']
447 | super().__init__(root, transform, pre_transform, pre_filter)
448 | path = osp.join(self.processed_dir, f'dgl_{split}.pt')
449 | print(path)
450 | self.graphs, label_dict = load_graphs(path)
451 | self.labels = label_dict['labels']
452 |
453 | @property
454 | def raw_file_names(self):
455 | return [
456 | 'train.pickle', 'val.pickle', 'test.pickle', 'train.index',
457 | 'val.index', 'test.index'
458 | ]
459 |
460 | @property
461 | def processed_dir(self):
462 | name = 'subset' if self.subset else 'full'
463 | return osp.join(self.root, name, 'processed')
464 |
465 | @property
466 | def processed_file_names(self):
467 | return ['dgl_train.pt', 'dgl_val.pt', 'dgl_test.pt']
468 |
469 | def download(self):
470 | shutil.rmtree(self.raw_dir)
471 | path = download_url(self.url, self.root)
472 | extract_zip(path, self.root)
473 | os.rename(osp.join(self.root, 'molecules'), self.raw_dir)
474 | os.unlink(path)
475 |
476 | for split in ['train', 'val', 'test']:
477 | download_url(self.split_url.format(split), self.raw_dir)
478 |
479 | def process(self):
480 | for split in ['train', 'val', 'test']:
481 | with open(osp.join(self.raw_dir, f'{split}.pickle'), 'rb') as f:
482 | mols = pickle.load(f)
483 |
484 | indices = range(len(mols))
485 |
486 | if self.subset:
487 | with open(osp.join(self.raw_dir, f'{split}.index'), 'r') as f:
488 | indices = [int(x) for x in f.read()[:-1].split(',')]
489 |
490 | pbar = tqdm(total=len(indices))
491 | pbar.set_description(f'Processing {split} dataset')
492 |
493 | graphs = []
494 | labels = []
495 | for idx in indices:
496 | mol = mols[idx]
497 |
498 | x = mol['atom_type'].to(torch.long).view(-1, 1)
499 | y = mol['logP_SA_cycle_normalized'].to(torch.float)
500 |
501 | adj = mol['bond_type']
502 | edge_idx = adj.nonzero(as_tuple=False).t().contiguous()
503 | edge_attr = adj[edge_idx[0], edge_idx[1]].to(torch.long) + 1 # for padding
504 |
505 | src, dst = edge_idx
506 | num_nodes = x.size(0)
507 |
508 | if num_nodes == 1: # some graphs have one node
509 | A_ = torch.tensor(1.).view(1, 1)
510 | else:
511 | A = torch.zeros([num_nodes, num_nodes], dtype=torch.float)
512 | A[src, dst] = 1.0
513 | for i in range(num_nodes):
514 | A[i, i] = 1.0
515 | deg = torch.sum(A, axis=0).squeeze() ** -0.5
516 | D = torch.diag(deg)
517 | A_ = D @ A @ D
518 | e, u = torch.linalg.eigh(A_)
519 |
520 | fully_connected = torch.ones([num_nodes, num_nodes], dtype=torch.float).nonzero(as_tuple=True)
521 | g = dgl.graph(fully_connected, num_nodes = num_nodes)
522 |
523 | g.ndata['e'] = e
524 | g.ndata['u'] = u
525 |
526 | g.ndata['feat'] = x
527 | g.edata['feat'] = to_dense_adj(edge_idx, edge_attr=edge_attr.unsqueeze(-1)).squeeze(0).squeeze(-1).view(-1)
528 |
529 | if self.pre_filter is not None and not self.pre_filter(data):
530 | continue
531 |
532 | if self.pre_transform is not None:
533 | data = self.pre_transform(data)
534 |
535 | graphs.append(g)
536 | labels.append(y)
537 |
538 | pbar.update(1)
539 |
540 | pbar.close()
541 |
542 | labels = torch.tensor(labels, dtype=torch.float32)
543 | save_graphs(osp.join(self.processed_dir, f'dgl_{split}.pt'), graphs, labels={'labels': labels})
544 |
545 | def __getitem__(self, idx):
546 | '''Get datapoint with index'''
547 |
548 | if isinstance(idx, int):
549 | return self.graphs[idx], self.labels[idx]
550 | elif torch.is_tensor(idx) and idx.dtype == torch.long:
551 | if idx.dim() == 0:
552 | return self.graphs[idx], self.labels[idx]
553 | elif idx.dim() == 1:
554 | return Subset(self, idx.cpu())
555 |
556 | raise IndexError(
557 | 'Only integers and long are valid '
558 | 'indices (got {}).'.format(type(idx).__name__))
559 |
560 | def __len__(self):
561 | '''Length of the dataset
562 | Returns
563 | -------
564 | int
565 | Length of Dataset
566 | '''
567 | return len(self.graphs)
568 |
569 |
--------------------------------------------------------------------------------
/Graph/get_dataset.py:
--------------------------------------------------------------------------------
1 | from dataclass import *
2 | from dgldataclass import DglGraphPropPredDataset, DglPCQM4Mv2Dataset, DglZincDataset
3 | from pygdataclass import PygGraphPropPredDataset
4 | import dgl
5 | from dgl.data.utils import load_graphs, save_graphs, Subset
6 | import torch
7 | from torch.nn import functional as F
8 | from torch.utils.data import DataLoader, Sampler, RandomSampler
9 | from torch_geometric.data import InMemoryDataset, Data
10 | from ogb.graphproppred import Evaluator
11 | from torch_geometric.utils import to_dense_adj
12 |
13 |
14 | class PCQM4Mv2Evaluator:
15 | def __init__(self):
16 | '''
17 | Evaluator for the PCQM4Mv2 dataset
18 | Metric is Mean Absolute Error
19 | '''
20 | pass
21 |
22 | def eval(self, input_dict):
23 | '''
24 | y_true: numpy.ndarray or torch.Tensor of shape (num_graphs,)
25 | y_pred: numpy.ndarray or torch.Tensor of shape (num_graphs,)
26 | y_true and y_pred need to be of the same type (either numpy.ndarray or torch.Tensor)
27 | '''
28 | assert('y_pred' in input_dict)
29 | assert('y_true' in input_dict)
30 |
31 | y_pred, y_true = input_dict['y_pred'].reshape(-1), input_dict['y_true'].reshape(-1)
32 |
33 | assert((isinstance(y_true, np.ndarray) and isinstance(y_pred, np.ndarray))
34 | or
35 | (isinstance(y_true, torch.Tensor) and isinstance(y_pred, torch.Tensor)))
36 | assert(y_true.shape == y_pred.shape)
37 | assert(len(y_true.shape) == 1)
38 |
39 | if isinstance(y_true, torch.Tensor):
40 | return {'mae': torch.mean(torch.abs(y_pred - y_true)).cpu().item()}
41 | else:
42 | return {'mae': float(np.mean(np.absolute(y_pred - y_true)))}
43 |
44 |
45 | class DynamicBatchSampler(Sampler):
46 | def __init__(self, sampler, num_nodes_list, batch_size=32, max_nodes=200, drop_last=False):
47 |
48 | super(DynamicBatchSampler, self).__init__(sampler)
49 | self.sampler = sampler
50 | self.num_nodes_list = num_nodes_list
51 | self.batch_size = batch_size
52 | self.max_nodes = max_nodes
53 | self.drop_last = drop_last
54 |
55 | def __iter__(self):
56 |
57 | batch = []
58 | total_nodes = 0
59 | memory = self.max_nodes * self.max_nodes * self.batch_size
60 |
61 | for idx in self.sampler:
62 | cur_nodes = self.num_nodes_list[idx]
63 |
64 | # beyond memory, truncate batch
65 | # squre for Transformer
66 | if total_nodes + cur_nodes ** 2 > memory:
67 | yield batch
68 | batch = [idx]
69 | total_nodes = cur_nodes ** 2
70 | else:
71 | batch.append(idx)
72 | total_nodes += cur_nodes ** 2
73 |
74 | if len(batch) == self.batch_size:
75 | yield batch
76 | batch = []
77 | total_nodes = 0
78 |
79 | if len(batch) > 0 and not self.drop_last:
80 | yield batch
81 |
82 | def __len__(self):
83 | # we do not know the exactly batch size, so do not call len(dataloader)
84 | pass
85 |
86 |
87 | def collate_dgl(samples):
88 | graphs, labels = map(list, zip(*samples))
89 |
90 | graph_list = []
91 | length = []
92 | E = []
93 | U = []
94 |
95 | max_nodes = max([g.num_nodes() for g in graphs])
96 |
97 | for i, g in enumerate(graphs):
98 | num_nodes = g.num_nodes()
99 |
100 | e = g.ndata['e']
101 | u = g.ndata['u']
102 |
103 | pad_e = e.new_zeros([max_nodes])
104 | pad_e[:num_nodes] = e
105 |
106 | pad_u = u.new_zeros([max_nodes, max_nodes])
107 | pad_u[:num_nodes, :num_nodes] = u
108 |
109 | E.append(pad_e)
110 | U.append(pad_u)
111 | graph_list.append(g)
112 | length.append(num_nodes)
113 |
114 | E = torch.stack(E, 0)
115 | U = torch.stack(U, 0)
116 | length = torch.LongTensor(length)
117 | batched_graph = dgl.batch(graphs, ndata=['feat'], edata=['feat'])
118 |
119 | if isinstance(labels[0], torch.Tensor):
120 | return E, U, batched_graph, length, torch.stack(labels)
121 | else:
122 | return E, U, batched_graph, length, labels
123 |
124 |
125 | def collate_pad(batch):
126 | E = []
127 | U = []
128 | X = []
129 | F = []
130 | Y = []
131 |
132 | max_nodes = min(max([data.num_nodes for data in batch]), 128)
133 |
134 | for data in batch:
135 | length = data.num_nodes
136 | e = data.e
137 | u = data.u.view(length, length)
138 | x = data.x
139 | f = data.edge_attr
140 |
141 | if length > max_nodes:
142 | src, dst = data.edge_index
143 | A = torch.zeros([length, length], dtype=torch.float)
144 | A[src, dst] = 1.0
145 | A = A[:max_nodes, :max_nodes]
146 | deg = torch.sum(A, axis=0).squeeze()
147 | deg = torch.clamp(deg, min=1.0) ** -0.5
148 | D = torch.diag(deg)
149 | A_ = D @ A @ D
150 |
151 | pad_e, pad_u = torch.linalg.eigh(A_)
152 | pad_x = x[:max_nodes, :] + 1
153 |
154 | fdim = f.size(-1)
155 | pad_f = torch.zeros([length, length, fdim], dtype=torch.long)
156 | pad_f[src, dst] = f + 1
157 | pad_f = pad_f[:max_nodes, :max_nodes]
158 | else:
159 | pad_e = e.new_zeros([max_nodes])
160 | pad_e[:length] = e
161 |
162 | pad_u = u.new_zeros([max_nodes, max_nodes])
163 | pad_u[:length, :length] = u
164 |
165 | xdim = x.size(-1)
166 | pad_x = x.new_zeros([max_nodes, xdim])
167 | pad_x[:length, :] = x + 1
168 |
169 | fdim = f.size(-1)
170 | src, dst = data.edge_index
171 | pad_f = f.new_zeros([max_nodes, max_nodes, fdim])
172 | pad_f[src, dst, :] = f + 1
173 |
174 | E.append(pad_e)
175 | U.append(pad_u)
176 | X.append(pad_x)
177 | F.append(pad_f)
178 | Y.append(data.y.squeeze())
179 |
180 | return torch.stack(E, 0), torch.stack(U, 0), torch.stack(X, 0), torch.stack(F, 0), torch.stack(Y, 0)
181 |
182 |
183 | def get_dataset(dataset_name='abaaba'):
184 |
185 | if dataset_name == 'zinc':
186 | data_info = {
187 | 'num_class': 1,
188 | 'loss_fn': F.l1_loss,
189 | 'metric': 'mae',
190 | 'metric_mode': 'min',
191 | 'evaluator': PCQM4Mv2Evaluator(),
192 | 'train_dataset': DglZincDataset('dataset/zinc', subset=True, split='train'),
193 | 'valid_dataset': DglZincDataset('dataset/zinc', subset=True, split='val'),
194 | 'test_dataset': DglZincDataset('dataset/zinc', subset=True, split='test'),
195 | }
196 | elif dataset_name == 'pcqm':
197 | dataset = DglPCQM4Mv2Dataset()
198 | split_idx = dataset.get_idx_split()
199 | idx = split_idx['train']
200 | rand_idx = torch.randperm(idx.size(0))
201 | train_idx = idx[rand_idx[150000:]]
202 | valid_idx = idx[rand_idx[:150000]]
203 | test_idx = split_idx['valid']
204 |
205 | data_info = {
206 | 'num_class': 1,
207 | 'loss_fn': F.l1_loss,
208 | 'metric': 'mae',
209 | 'metric_mode': 'min',
210 | 'evaluator': PCQM4Mv2Evaluator(),
211 | 'train_dataset': dataset[train_idx],
212 | 'valid_dataset': dataset[valid_idx],
213 | 'test_dataset': dataset[test_idx],
214 | }
215 | elif dataset_name == 'pcqms':
216 | train_g, train_dict = load_graphs('dataset/pcqm_subset_train.pt')
217 | valid_g, valid_dict = load_graphs('dataset/pcqm_subset_valid.pt')
218 | test_g, test_dict = load_graphs('dataset/pcqm_subset_test.pt')
219 |
220 | data_info = {
221 | 'num_class': 1,
222 | 'loss_fn': F.l1_loss,
223 | 'metric': 'mae',
224 | 'metric_mode': 'min',
225 | 'evaluator': PCQM4Mv2Evaluator(),
226 | 'train_dataset': list(zip(train_g, train_dict['labels'])),
227 | 'valid_dataset': list(zip(valid_g, valid_dict['labels'])),
228 | 'test_dataset': list(zip(test_g, test_dict['labels'])),
229 | }
230 | elif dataset_name == 'hiv':
231 | dataset = DglGraphPropPredDataset('ogbg-molhiv')
232 | split_idx = dataset.get_idx_split()
233 | data_info = {
234 | 'num_class': 1,
235 | 'loss_fn': F.binary_cross_entropy_with_logits,
236 | 'metric': 'rocauc',
237 | 'metric_mode': 'max',
238 | 'evaluator': Evaluator('ogbg-molhiv'),
239 | 'train_dataset': dataset[split_idx['train']],
240 | 'valid_dataset': dataset[split_idx['valid']],
241 | 'test_dataset': dataset[split_idx['test']],
242 | }
243 | elif dataset_name == 'pcba':
244 | dataset = DglGraphPropPredDataset(name = 'ogbg-molpcba')
245 | split_idx = dataset.get_idx_split()
246 | data_info = {
247 | 'num_class': 128,
248 | 'loss_fn': F.binary_cross_entropy_with_logits,
249 | 'metric': 'ap',
250 | 'metric_mode': 'max',
251 | 'evaluator': Evaluator('ogbg-molpcba'),
252 | 'train_dataset': dataset[split_idx['train']],
253 | 'valid_dataset': dataset[split_idx['valid']],
254 | 'test_dataset': dataset[split_idx['test']],
255 | }
256 | elif dataset_name == 'ppa':
257 | dataset = PygGraphPropPredDataset(name = 'ogbg-ppa')
258 | split_idx = dataset.get_idx_split()
259 | data_info = {
260 | 'num_class': 37,
261 | 'loss_fn': F.cross_entropy,
262 | 'metric': 'acc',
263 | 'metric_mode': 'max',
264 | 'evaluator': Evaluator('ogbg-ppa'),
265 | 'train_dataset': dataset[split_idx['train']],
266 | 'valid_dataset': dataset[split_idx['valid']],
267 | 'test_dataset': dataset[split_idx['test']],
268 | }
269 | else:
270 | raise NotImplementedError
271 |
272 | return data_info
273 |
274 |
--------------------------------------------------------------------------------
/Graph/large_model.py:
--------------------------------------------------------------------------------
1 | import time
2 | import math
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | import numpy as np
7 |
8 | from ogb.utils.features import get_atom_feature_dims, get_bond_feature_dims
9 | from dgl.nn.pytorch.glob import AvgPooling
10 | from dgl import function as fn
11 | from dgl.ops.edge_softmax import edge_softmax
12 | from torch.nn.init import xavier_uniform_, xavier_normal_, constant_
13 |
14 |
15 | class AtomEncoder(torch.nn.Module):
16 |
17 | def __init__(self, emb_dim):
18 | super(AtomEncoder, self).__init__()
19 |
20 | self.atom_embedding_list = torch.nn.ModuleList()
21 |
22 | for _, dim in enumerate(get_atom_feature_dims()):
23 | emb = torch.nn.Embedding(dim, emb_dim)
24 | torch.nn.init.xavier_uniform_(emb.weight.data)
25 | self.atom_embedding_list.append(emb)
26 |
27 | def forward(self, x):
28 | x_embedding = 0
29 | for i in range(x.shape[1]):
30 | x_embedding += self.atom_embedding_list[i](x[:, i])
31 |
32 | return x_embedding
33 |
34 |
35 | class BondEncoder(torch.nn.Module):
36 |
37 | def __init__(self, emb_dim):
38 | super(BondEncoder, self).__init__()
39 |
40 | self.bond_embedding_list = torch.nn.ModuleList()
41 |
42 | for _, dim in enumerate(get_bond_feature_dims()):
43 | emb = torch.nn.Embedding(dim + 1, emb_dim, padding_idx=0) # for padding
44 | torch.nn.init.xavier_uniform_(emb.weight.data)
45 | self.bond_embedding_list.append(emb)
46 |
47 | def forward(self, edge_attr):
48 | bond_embedding = 0
49 | for i in range(edge_attr.shape[1]):
50 | bond_embedding += self.bond_embedding_list[i](edge_attr[:, i])
51 |
52 | return bond_embedding
53 |
54 |
55 | class SineEncoding(nn.Module):
56 | def __init__(self, hidden_dim=128):
57 | super(SineEncoding, self).__init__()
58 | self.constant = 100
59 | self.hidden_dim = hidden_dim
60 | self.eig_w = nn.Linear(hidden_dim + 1, hidden_dim)
61 |
62 | def forward(self, e):
63 | # input: [B, N]
64 | # output: [B, N, d]
65 |
66 | ee = e * self.constant
67 | div = torch.exp(torch.arange(0, self.hidden_dim, 2) * (-math.log(10000)/self.hidden_dim)).to(e.device)
68 | pe = ee.unsqueeze(2) * div
69 | eeig = torch.cat((e.unsqueeze(2), torch.sin(pe), torch.cos(pe)), dim=2)
70 |
71 | return self.eig_w(eeig)
72 |
73 |
74 | class FeedForwardNetwork(nn.Module):
75 |
76 | def __init__(self, input_dim, hidden_dim, output_dim):
77 | super(FeedForwardNetwork, self).__init__()
78 | self.layer1 = nn.Linear(input_dim, hidden_dim)
79 | self.gelu = nn.GELU()
80 | self.layer2 = nn.Linear(hidden_dim, output_dim)
81 |
82 | def forward(self, x):
83 | x = self.layer1(x)
84 | x = self.gelu(x)
85 | x = self.layer2(x)
86 | return x
87 |
88 |
89 | class Conv(nn.Module):
90 | def __init__(self, hidden_dim, nheads, trans_dropout, feat_dropout, adj_dropout):
91 | super(Conv, self).__init__()
92 | self.nheads = nheads
93 |
94 | self.mha_norm = nn.LayerNorm(hidden_dim)
95 | self.ffn_norm = nn.LayerNorm(hidden_dim)
96 | self.mha_dropout = nn.Dropout(trans_dropout)
97 | self.ffn_dropout = nn.Dropout(trans_dropout)
98 | self.mha = nn.MultiheadAttention(hidden_dim, nheads, trans_dropout, batch_first=True)
99 | self.ffn = FeedForwardNetwork(hidden_dim, hidden_dim, hidden_dim)
100 | self.decoder = nn.Linear(hidden_dim, nheads)
101 |
102 | self.adj_dropout = nn.Dropout(adj_dropout)
103 | self.filter_encoder = nn.Sequential(
104 | nn.Linear(nheads + 1, hidden_dim),
105 | nn.BatchNorm1d(hidden_dim),
106 | nn.GELU(),
107 | nn.Linear(hidden_dim, hidden_dim),
108 | nn.BatchNorm1d(hidden_dim),
109 | nn.GELU(),
110 | )
111 |
112 | self.pre_ffn = nn.Sequential(
113 | nn.Linear(hidden_dim, hidden_dim),
114 | nn.GELU()
115 | )
116 |
117 | self.preffn_dropout = nn.Dropout(feat_dropout)
118 | self.x_ffn_dropout = nn.Dropout(feat_dropout)
119 |
120 | self.x_ffn = nn.Sequential(
121 | nn.Linear(hidden_dim, hidden_dim),
122 | nn.BatchNorm1d(hidden_dim),
123 | nn.ReLU(),
124 | nn.Linear(hidden_dim, hidden_dim),
125 | nn.BatchNorm1d(hidden_dim),
126 | nn.ReLU()
127 | )
128 |
129 | def forward(self, eig, u, ut, graph, x_feat, edge_attr, eig_mask, edge_idx):
130 | B, N = eig.size()[:2]
131 |
132 | mha_eig = self.mha_norm(eig)
133 | mha_eig, attn = self.mha(mha_eig, mha_eig, mha_eig, key_padding_mask=eig_mask)
134 | eig = eig + self.mha_dropout(mha_eig)
135 |
136 | ffn_eig = self.ffn_norm(eig)
137 | ffn_eig = self.ffn(ffn_eig)
138 | eig = eig + self.ffn_dropout(ffn_eig) # [B, N, d]
139 |
140 | new_e = self.decoder(eig).transpose(1, 2) # [B, m, N]
141 | diag_e = torch.diag_embed(new_e) # [B, m, N, N]
142 |
143 | identity = torch.diag_embed(torch.ones(B, N)).to(u.device)
144 | bases = [identity]
145 | for i in range(self.nheads):
146 | filters = u @ diag_e[:, i, :, :] @ ut
147 | bases.append(filters)
148 |
149 | bases = torch.stack(bases, axis=-1) # [B, N, N, H]
150 | bases = bases[edge_idx]
151 | bases = self.adj_dropout(self.filter_encoder(bases))
152 | bases = edge_softmax(graph, bases)
153 |
154 | with graph.local_scope():
155 | graph.ndata['x'] = x_feat
156 | graph.apply_edges(fn.copy_u('x', '_x'))
157 | xee = self.pre_ffn(graph.edata['_x'] + edge_attr) * bases
158 | graph.edata['v'] = xee
159 | graph.update_all(fn.copy_e('v', '_aggr_e'), fn.sum('_aggr_e', 'aggr_e'))
160 | y = graph.ndata['aggr_e']
161 | y = self.preffn_dropout(y)
162 | x = x_feat + y
163 | y = self.x_ffn(x)
164 | y = self.x_ffn_dropout(y)
165 | x = x + y
166 |
167 | return eig, x
168 |
169 |
170 | class SpecformerLarge(nn.Module):
171 |
172 | def __init__(self, nclass, nlayer, hidden_dim=128, nheads=4, feat_dropout=0.1, trans_dropout=0.1, adj_dropout=0.1):
173 | super(SpecformerLarge, self).__init__()
174 |
175 | self.nlayer = nlayer
176 | self.nclass = nclass
177 | self.hidden_dim = hidden_dim
178 | self.nheads = nheads
179 |
180 | self.atom_encoder = AtomEncoder(hidden_dim)
181 | self.bond_encoder = BondEncoder(hidden_dim)
182 | self.eig_encoder = SineEncoding(hidden_dim)
183 |
184 | self.convs = nn.ModuleList([Conv(hidden_dim, nheads, trans_dropout, feat_dropout, adj_dropout) for _ in range(nlayer)])
185 | self.pool = AvgPooling()
186 | self.linear = nn.Linear(hidden_dim, nclass)
187 |
188 | def forward(self, e, u, g, length):
189 |
190 | # e: [B, N] eigenvalues
191 | # u: [B, N, N] eigenvectors
192 | # x: [B, N, d] node features
193 | # f: [B, N, N, d] edge features
194 |
195 | B, N = e.size()
196 | ut = u.transpose(1, 2)
197 |
198 | node_feat = g.ndata['feat']
199 | edge_feat = g.edata['feat']
200 |
201 | # do not use u to generate edge_idx because of the connected components
202 | e_mask, edge_idx = self.length_to_mask(length)
203 |
204 | node_feat = self.atom_encoder(node_feat)
205 | edge_feat = self.bond_encoder(edge_feat)
206 | eig = self.eig_encoder(e)
207 |
208 | for conv in self.convs:
209 | eig, node_feat = conv(eig, u, ut, g, node_feat, edge_feat, e_mask, edge_idx)
210 |
211 | h = self.pool(g, node_feat)
212 | h = self.linear(h)
213 |
214 | return h
215 |
216 |
217 | def length_to_mask(self, length):
218 | '''
219 | length: [B]
220 | return: [B, max_len].
221 | '''
222 | B = len(length)
223 | N = length.max().item()
224 | mask1d = torch.arange(N, device=length.device).expand(B, N) >= length.unsqueeze(1)
225 | mask2d = (~mask1d).float().unsqueeze(2) @ (~mask1d).float().unsqueeze(1)
226 | mask2d = mask2d.bool()
227 |
228 | # Example
229 | # length=[1, 2, 3], B=3, N=3,
230 |
231 | # mask1d for key_padding_mask of MultiheadAttention [B, N]
232 | # [False, True, True ]
233 | # [False, False, True ]
234 | # [False, False, False]
235 |
236 | # mask2d for edge indexing [B, N, N]
237 | # [[1, 0, 0], | [1, 1, 0], | [1, 1, 1],
238 | # [0, 0, 0], | [1, 1, 0], | [1, 1, 1],
239 | # [0, 0, 0], | [0, 0, 0], | [1, 1, 1],]
240 |
241 | return mask1d, mask2d
242 |
243 |
244 | '''
245 | def length_to_mask(self, length):
246 | '''
247 | length: [B]
248 | return: [B, max_len].
249 | '''
250 | B = len(length)
251 | N = length.max().item()
252 | mask1d = torch.arange(N, device=length.device).expand(B, N) >= length.unsqueeze(1)
253 |
254 | mask2d = torch.zeros(B, N, N, device=length.device)
255 | for i in range(B):
256 | mask2d[i, :length[i], :length[i]] = 1.0
257 |
258 | # mask1d for key_padding_mask, a True value indicates that the corresponding key value will be ignored for the purpose of attention
259 | # mask2d for edge selection from padding
260 | return mask1d, mask2d.bool()
261 | '''
262 |
--------------------------------------------------------------------------------
/Graph/master.csv:
--------------------------------------------------------------------------------
1 | ,ogbg-molbace,ogbg-molbbbp,ogbg-molclintox,ogbg-molmuv,ogbg-molpcba,ogbg-molsider,ogbg-moltox21,ogbg-moltoxcast,ogbg-molhiv,ogbg-molesol,ogbg-molfreesolv,ogbg-mollipo,ogbg-molchembl,ogbg-ppa,ogbg-code2
2 | num tasks,1,1,2,17,128,27,12,617,1,1,1,1,1310,1,1
3 | eval metric,rocauc,rocauc,rocauc,ap,ap,rocauc,rocauc,rocauc,rocauc,rmse,rmse,rmse,rocauc,acc,F1
4 | download_name,bace,bbbp,clintox,muv,pcba,sider,tox21,toxcast,hiv,esol,freesolv,lipophilicity,chembl,ogbg_ppi_medium,code2
5 | version,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1
6 | url,http://snap.stanford.edu/ogb/data/graphproppred/csv_mol_download/bace.zip,http://snap.stanford.edu/ogb/data/graphproppred/csv_mol_download/bbbp.zip,http://snap.stanford.edu/ogb/data/graphproppred/csv_mol_download/clintox.zip,http://snap.stanford.edu/ogb/data/graphproppred/csv_mol_download/muv.zip,http://snap.stanford.edu/ogb/data/graphproppred/csv_mol_download/pcba.zip,http://snap.stanford.edu/ogb/data/graphproppred/csv_mol_download/sider.zip,http://snap.stanford.edu/ogb/data/graphproppred/csv_mol_download/tox21.zip,http://snap.stanford.edu/ogb/data/graphproppred/csv_mol_download/toxcast.zip,http://snap.stanford.edu/ogb/data/graphproppred/csv_mol_download/hiv.zip,http://snap.stanford.edu/ogb/data/graphproppred/csv_mol_download/esol.zip,http://snap.stanford.edu/ogb/data/graphproppred/csv_mol_download/freesolv.zip,http://snap.stanford.edu/ogb/data/graphproppred/csv_mol_download/lipophilicity.zip,http://snap.stanford.edu/ogb/data/graphproppred/csv_mol_download/chembl.zip,http://snap.stanford.edu/ogb/data/graphproppred/ogbg_ppi_medium.zip,http://snap.stanford.edu/ogb/data/graphproppred/code2.zip
7 | add_inverse_edge,True,True,True,True,True,True,True,True,True,True,True,True,True,True,False
8 | data type,mol,mol,mol,mol,mol,mol,mol,mol,mol,mol,mol,mol,mol,,
9 | has_node_attr,True,True,True,True,True,True,True,True,True,True,True,True,True,False,True
10 | has_edge_attr,True,True,True,True,True,True,True,True,True,True,True,True,True,True,False
11 | task type,binary classification,binary classification,binary classification,binary classification,binary classification,binary classification,binary classification,binary classification,binary classification,regression,regression,regression,binary classification,multiclass classification,subtoken prediction
12 | num classes,2,2,2,2,2,2,2,2,2,-1,-1,-1,2,37,-1
13 | split,scaffold,scaffold,scaffold,scaffold,scaffold,scaffold,scaffold,scaffold,scaffold,scaffold,scaffold,scaffold,scaffold,species,project
14 | additional node files,None,None,None,None,None,None,None,None,None,None,None,None,None,None,"node_is_attributed,node_dfs_order,node_depth"
15 | additional edge files,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None
16 | binary,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False
17 |
--------------------------------------------------------------------------------
/Graph/medium_model.py:
--------------------------------------------------------------------------------
1 | import time
2 | import math
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | import numpy as np
7 |
8 | from ogb.utils.features import get_atom_feature_dims, get_bond_feature_dims
9 | from dgl.nn.pytorch.glob import AvgPooling
10 | from dgl import function as fn
11 | from dgl.ops.edge_softmax import edge_softmax
12 | from torch.nn.init import xavier_uniform_, xavier_normal_, constant_
13 |
14 |
15 | class AtomEncoder(torch.nn.Module):
16 |
17 | def __init__(self, emb_dim):
18 | super(AtomEncoder, self).__init__()
19 |
20 | self.atom_embedding_list = torch.nn.ModuleList()
21 |
22 | for _, dim in enumerate(get_atom_feature_dims()):
23 | emb = torch.nn.Embedding(dim, emb_dim)
24 | torch.nn.init.xavier_uniform_(emb.weight.data)
25 | self.atom_embedding_list.append(emb)
26 |
27 | def forward(self, x):
28 | x_embedding = 0
29 | for i in range(x.shape[1]):
30 | x_embedding += self.atom_embedding_list[i](x[:, i])
31 |
32 | return x_embedding
33 |
34 |
35 | class BondEncoder(torch.nn.Module):
36 |
37 | def __init__(self, emb_dim):
38 | super(BondEncoder, self).__init__()
39 |
40 | self.bond_embedding_list = torch.nn.ModuleList()
41 |
42 | for _, dim in enumerate(get_bond_feature_dims()):
43 | emb = torch.nn.Embedding(dim + 1, emb_dim, padding_idx=0) # for padding
44 | torch.nn.init.xavier_uniform_(emb.weight.data)
45 | self.bond_embedding_list.append(emb)
46 |
47 | def forward(self, edge_attr):
48 | bond_embedding = 0
49 | for i in range(edge_attr.shape[1]):
50 | bond_embedding += self.bond_embedding_list[i](edge_attr[:, i])
51 |
52 | return bond_embedding
53 |
54 |
55 | class SineEncoding(nn.Module):
56 | def __init__(self, hidden_dim=128):
57 | super(SineEncoding, self).__init__()
58 | self.constant = 100
59 | self.hidden_dim = hidden_dim
60 | self.eig_w = nn.Linear(hidden_dim + 1, hidden_dim)
61 |
62 | def forward(self, e):
63 | # input: [B, N]
64 | # output: [B, N, d]
65 |
66 | ee = e * self.constant
67 | div = torch.exp(torch.arange(0, self.hidden_dim, 2) * (-math.log(10000)/self.hidden_dim)).to(e.device)
68 | pe = ee.unsqueeze(2) * div
69 | eeig = torch.cat((e.unsqueeze(2), torch.sin(pe), torch.cos(pe)), dim=2)
70 |
71 | return self.eig_w(eeig)
72 |
73 |
74 | class FeedForwardNetwork(nn.Module):
75 |
76 | def __init__(self, input_dim, hidden_dim, output_dim):
77 | super(FeedForwardNetwork, self).__init__()
78 | self.layer1 = nn.Linear(input_dim, hidden_dim)
79 | self.gelu = nn.GELU()
80 | self.layer2 = nn.Linear(hidden_dim, output_dim)
81 |
82 | def forward(self, x):
83 | x = self.layer1(x)
84 | x = self.gelu(x)
85 | x = self.layer2(x)
86 | return x
87 |
88 |
89 | class Conv(nn.Module):
90 | def __init__(self, nheads, hidden_size, feat_dropout, adj_dropout):
91 | super(Conv, self).__init__()
92 |
93 | self.adj_dropout = nn.Dropout(adj_dropout)
94 | self.filter_encoder = nn.Sequential(
95 | nn.Linear(nheads + 1, hidden_size),
96 | nn.BatchNorm1d(hidden_size),
97 | nn.GELU(),
98 | nn.Linear(hidden_size, hidden_size),
99 | nn.BatchNorm1d(hidden_size),
100 | nn.GELU(),
101 | )
102 |
103 | self.pre_ffn = nn.Sequential(
104 | nn.Linear(hidden_size, hidden_size),
105 | # nn.BatchNorm1d(hidden_size),
106 | nn.GELU()
107 | )
108 |
109 | self.preffn_dropout = nn.Dropout(feat_dropout)
110 | self.ffn_dropout = nn.Dropout(feat_dropout)
111 |
112 | self.ffn = nn.Sequential(
113 | nn.Linear(hidden_size, hidden_size),
114 | nn.BatchNorm1d(hidden_size),
115 | nn.ReLU(),
116 | nn.Linear(hidden_size, hidden_size),
117 | nn.BatchNorm1d(hidden_size),
118 | nn.ReLU()
119 | )
120 |
121 | def forward(self, graph, x_feat, edge_attr, bases):
122 | bases = self.adj_dropout(self.filter_encoder(bases))
123 | bases = edge_softmax(graph, bases)
124 |
125 | with graph.local_scope():
126 | graph.ndata['x'] = x_feat
127 | graph.apply_edges(fn.copy_u('x', '_x'))
128 | xee = self.pre_ffn(graph.edata['_x'] + edge_attr) * bases
129 | graph.edata['v'] = xee
130 | graph.update_all(fn.copy_e('v', '_aggr_e'), fn.sum('_aggr_e', 'aggr_e'))
131 | y = graph.ndata['aggr_e']
132 | y = self.preffn_dropout(y)
133 | x = x_feat + y
134 | y = self.ffn(x)
135 | y = self.ffn_dropout(y)
136 | x = x + y
137 | return x
138 |
139 |
140 | class SpecformerMedium(nn.Module):
141 |
142 | def __init__(self, nclass, nlayer, hidden_dim=128, nheads=4, feat_dropout=0.1, trans_dropout=0.1, adj_dropout=0.1):
143 | super(SpecformerMedium, self).__init__()
144 |
145 | print('medium model')
146 | self.nlayer = nlayer
147 | self.nclass = nclass
148 | self.hidden_dim = hidden_dim
149 | self.nheads = nheads
150 |
151 | self.atom_encoder = AtomEncoder(hidden_dim)
152 | self.bond_encoder = BondEncoder(hidden_dim)
153 |
154 | self.eig_encoder = SineEncoding(hidden_dim)
155 | #self.eig_encoder = nn.Linear(1, hidden_dim) # ablation
156 | self.decoder = nn.Linear(hidden_dim, nheads)
157 |
158 | self.mha_norm = nn.LayerNorm(hidden_dim)
159 | self.ffn_norm = nn.LayerNorm(hidden_dim)
160 | self.mha_dropout = nn.Dropout(trans_dropout)
161 | self.ffn_dropout = nn.Dropout(trans_dropout)
162 | self.mha = nn.MultiheadAttention(hidden_dim, nheads, trans_dropout, batch_first=True)
163 | self.ffn = FeedForwardNetwork(hidden_dim, hidden_dim, hidden_dim)
164 |
165 | self.convs = nn.ModuleList([Conv(nheads, hidden_dim, feat_dropout, adj_dropout) for _ in range(nlayer)])
166 | self.pool = AvgPooling()
167 | self.linear = nn.Linear(hidden_dim, nclass)
168 |
169 |
170 | def forward(self, e, u, g, length):
171 |
172 | # e: [B, N] eigenvalues
173 | # u: [B, N, N] eigenvectors
174 | # x: [B, N, d] node features
175 | # f: [B, N, N, d] edge features
176 |
177 | B, N = e.size()
178 | ut = u.transpose(1, 2)
179 |
180 | node_feat = g.ndata['feat']
181 | edge_feat = g.edata['feat']
182 |
183 | # do not use u to generate edge_idx because of the connected components
184 | e_mask, edge_idx = self.length_to_mask(length)
185 |
186 | node_feat = self.atom_encoder(node_feat)
187 | edge_feat = self.bond_encoder(edge_feat)
188 |
189 | eig = self.eig_encoder(e)
190 |
191 | mha_eig = self.mha_norm(eig)
192 | mha_eig, attn = self.mha(mha_eig, mha_eig, mha_eig, key_padding_mask=e_mask)
193 | eig = eig + self.mha_dropout(mha_eig)
194 |
195 | ffn_eig = self.ffn_norm(eig)
196 | ffn_eig = self.ffn(ffn_eig)
197 | eig = eig + self.ffn_dropout(ffn_eig)
198 |
199 | new_e = self.decoder(eig).transpose(2, 1) # [B, m, N]
200 | diag_e = torch.diag_embed(new_e) # [B, m, N, N]
201 |
202 | identity = torch.diag_embed(torch.ones_like(e))
203 | bases = [identity]
204 | for i in range(self.nheads):
205 | filters = u @ diag_e[:, i, :, :] @ ut
206 | bases.append(filters)
207 |
208 | bases = torch.stack(bases, axis=-1) # [B, N, N, H]
209 | bases = bases[edge_idx]
210 |
211 | for conv in self.convs:
212 | node_feat = conv(g, node_feat, edge_feat, bases)
213 |
214 | h = self.pool(g, node_feat)
215 | h = self.linear(h)
216 |
217 | return h
218 |
219 |
220 | def length_to_mask(self, length):
221 | '''
222 | length: [B]
223 | return: [B, max_len].
224 | '''
225 | B = len(length)
226 | N = length.max().item()
227 | mask1d = torch.arange(N, device=length.device).expand(B, N) >= length.unsqueeze(1)
228 | mask2d = (~mask1d).float().unsqueeze(2) @ (~mask1d).float().unsqueeze(1)
229 | mask2d = mask2d.bool()
230 |
231 | # Example
232 | # length=[1, 2, 3], B=3, N=3,
233 |
234 | # mask1d for key_padding_mask of MultiheadAttention [B, N]
235 | # [False, True, True ]
236 | # [False, False, True ]
237 | # [False, False, False]
238 |
239 | # mask2d for edge indexing [B, N, N]
240 | # [[1, 0, 0], | [1, 1, 0], | [1, 1, 1],
241 | # [0, 0, 0], | [1, 1, 0], | [1, 1, 1],
242 | # [0, 0, 0], | [0, 0, 0], | [1, 1, 1],]
243 |
244 | return mask1d, mask2d
245 |
246 |
247 | '''
248 | def length_to_mask(self, length):
249 | '''
250 | length: [B]
251 | return: [B, max_len].
252 | '''
253 | B = len(length)
254 | N = length.max().item()
255 | mask1d = torch.arange(N, device=length.device).expand(B, N) >= length.unsqueeze(1)
256 |
257 | mask2d = torch.zeros(B, N, N, device=length.device)
258 | for i in range(B):
259 | mask2d[i, :length[i], :length[i]] = 1.0
260 |
261 | # mask1d for key_padding_mask, a True value indicates that the corresponding key value will be ignored for the purpose of attention
262 | # mask2d for edge selection from padding
263 | return mask1d, mask2d.bool()
264 | '''
265 |
--------------------------------------------------------------------------------
/Graph/small_model.py:
--------------------------------------------------------------------------------
1 | import time
2 | import math
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | import numpy as np
7 |
8 | from ogb.utils.features import get_atom_feature_dims, get_bond_feature_dims
9 | from dgl.nn.pytorch.glob import AvgPooling
10 | from dgl import function as fn
11 | from dgl.ops.edge_softmax import edge_softmax
12 | from torch.nn.init import xavier_uniform_, xavier_normal_, constant_
13 |
14 |
15 | class AtomEncoder(torch.nn.Module):
16 |
17 | def __init__(self, emb_dim):
18 | super(AtomEncoder, self).__init__()
19 |
20 | self.atom_embedding_list = torch.nn.ModuleList()
21 |
22 | for _, dim in enumerate(get_atom_feature_dims()):
23 | emb = torch.nn.Embedding(dim, emb_dim)
24 | torch.nn.init.xavier_uniform_(emb.weight.data)
25 | self.atom_embedding_list.append(emb)
26 |
27 | def forward(self, x):
28 | x_embedding = 0
29 | for i in range(x.shape[1]):
30 | x_embedding += self.atom_embedding_list[i](x[:, i])
31 |
32 | return x_embedding
33 |
34 |
35 | class BondEncoder(torch.nn.Module):
36 |
37 | def __init__(self, emb_dim):
38 | super(BondEncoder, self).__init__()
39 |
40 | self.bond_embedding_list = torch.nn.ModuleList()
41 |
42 | for _, dim in enumerate(get_bond_feature_dims()):
43 | emb = torch.nn.Embedding(dim + 1, emb_dim, padding_idx=0) # for padding
44 | torch.nn.init.xavier_uniform_(emb.weight.data)
45 | self.bond_embedding_list.append(emb)
46 |
47 | def forward(self, edge_attr):
48 | bond_embedding = 0
49 | for i in range(edge_attr.shape[1]):
50 | bond_embedding += self.bond_embedding_list[i](edge_attr[:, i])
51 |
52 | return bond_embedding
53 |
54 |
55 | class SineEncoding(nn.Module):
56 | def __init__(self, hidden_dim=128):
57 | super(SineEncoding, self).__init__()
58 | self.constant = 100
59 | self.hidden_dim = hidden_dim
60 | self.eig_w = nn.Linear(hidden_dim + 1, hidden_dim)
61 |
62 | def forward(self, e):
63 | # input: [B, N]
64 | # output: [B, N, d]
65 |
66 | ee = e * self.constant
67 | div = torch.exp(torch.arange(0, self.hidden_dim, 2) * (-math.log(10000)/self.hidden_dim)).to(e.device)
68 | pe = ee.unsqueeze(2) * div
69 | eeig = torch.cat((e.unsqueeze(2), torch.sin(pe), torch.cos(pe)), dim=2)
70 |
71 | return self.eig_w(eeig)
72 |
73 |
74 | class FeedForwardNetwork(nn.Module):
75 |
76 | def __init__(self, input_dim, hidden_dim, output_dim):
77 | super(FeedForwardNetwork, self).__init__()
78 | self.layer1 = nn.Linear(input_dim, hidden_dim)
79 | self.gelu = nn.GELU()
80 | self.layer2 = nn.Linear(hidden_dim, output_dim)
81 |
82 | def forward(self, x):
83 | x = self.layer1(x)
84 | x = self.gelu(x)
85 | x = self.layer2(x)
86 | return x
87 |
88 |
89 | class Conv(nn.Module):
90 | def __init__(self, hidden_size, dropout_rate):
91 | super(Conv, self).__init__()
92 |
93 | self.pre_ffn = nn.Sequential(
94 | nn.Linear(hidden_size, hidden_size),
95 | # nn.BatchNorm1d(hidden_size),
96 | nn.GELU()
97 | )
98 |
99 | self.preffn_dropout = nn.Dropout(dropout_rate)
100 | self.ffn_dropout = nn.Dropout(dropout_rate)
101 |
102 | self.ffn = nn.Sequential(
103 | nn.Linear(hidden_size, hidden_size),
104 | nn.BatchNorm1d(hidden_size),
105 | nn.ReLU(),
106 | nn.Linear(hidden_size, hidden_size),
107 | nn.BatchNorm1d(hidden_size),
108 | nn.ReLU()
109 | )
110 |
111 | def forward(self, graph, x_feat, edge_attr, bases):
112 | with graph.local_scope():
113 | graph.ndata['x'] = x_feat
114 | graph.apply_edges(fn.copy_u('x', '_x'))
115 | xee = self.pre_ffn(graph.edata['_x'] + edge_attr) * bases
116 | graph.edata['v'] = xee
117 | graph.update_all(fn.copy_e('v', '_aggr_e'), fn.sum('_aggr_e', 'aggr_e'))
118 | y = graph.ndata['aggr_e']
119 | y = self.preffn_dropout(y)
120 | x = x_feat + y
121 | y = self.ffn(x)
122 | y = self.ffn_dropout(y)
123 | x = x + y
124 | return x
125 |
126 |
127 | class SpecformerSmall(nn.Module):
128 |
129 | def __init__(self, nclass, nlayer, hidden_dim=128, nheads=4, feat_dropout=0.1, trans_dropout=0.1, adj_dropout=0.1):
130 | super(SpecformerSmall, self).__init__()
131 |
132 | print('small model')
133 | self.nlayer = nlayer
134 | self.nclass = nclass
135 | self.hidden_dim = hidden_dim
136 | self.nheads = nheads
137 |
138 | self.atom_encoder = AtomEncoder(hidden_dim)
139 | self.bond_encoder = BondEncoder(hidden_dim)
140 |
141 | self.eig_encoder = SineEncoding(hidden_dim)
142 | self.decoder = nn.Linear(hidden_dim, nheads)
143 |
144 | self.mha_norm = nn.LayerNorm(hidden_dim)
145 | self.ffn_norm = nn.LayerNorm(hidden_dim)
146 | self.mha_dropout = nn.Dropout(trans_dropout)
147 | self.ffn_dropout = nn.Dropout(trans_dropout)
148 | self.mha = nn.MultiheadAttention(hidden_dim, nheads, trans_dropout, batch_first=True)
149 | self.ffn = FeedForwardNetwork(hidden_dim, hidden_dim, hidden_dim)
150 |
151 | self.adj_dropout = nn.Dropout(adj_dropout)
152 | self.filter_encoder = nn.Sequential(
153 | nn.Linear(nheads + 1, hidden_dim),
154 | nn.BatchNorm1d(hidden_dim),
155 | nn.GELU(),
156 | nn.Linear(hidden_dim, hidden_dim),
157 | nn.BatchNorm1d(hidden_dim),
158 | nn.GELU(),
159 | )
160 |
161 | self.convs = nn.ModuleList([Conv(hidden_dim, feat_dropout) for _ in range(nlayer)])
162 | self.pool = AvgPooling()
163 | self.linear = nn.Linear(hidden_dim, nclass)
164 |
165 |
166 | def forward(self, e, u, g, length):
167 |
168 | # e: [B, N] eigenvalues
169 | # u: [B, N, N] eigenvectors
170 | # x: [B, N, d] node features
171 | # f: [B, N, N, d] edge features
172 |
173 | B, N = e.size()
174 | ut = u.transpose(1, 2)
175 |
176 | node_feat = g.ndata['feat']
177 | edge_feat = g.edata['feat']
178 |
179 | # do not use u to generate edge_idx because of the connected components
180 | e_mask, edge_idx = self.length_to_mask(length)
181 |
182 | node_feat = self.atom_encoder(node_feat)
183 | edge_feat = self.bond_encoder(edge_feat)
184 | eig = self.eig_encoder(e)
185 |
186 | mha_eig = self.mha_norm(eig)
187 | mha_eig, attn = self.mha(mha_eig, mha_eig, mha_eig, key_padding_mask=e_mask)
188 | eig = eig + self.mha_dropout(mha_eig)
189 |
190 | ffn_eig = self.ffn_norm(eig)
191 | ffn_eig = self.ffn(ffn_eig)
192 | eig = eig + self.ffn_dropout(ffn_eig)
193 |
194 | new_e = self.decoder(eig).transpose(2, 1) # [B, m, N]
195 | diag_e = torch.diag_embed(new_e) # [B, m, N, N]
196 |
197 | identity = torch.diag_embed(torch.ones_like(e))
198 | bases = [identity]
199 | for i in range(self.nheads):
200 | filters = u @ diag_e[:, i, :, :] @ ut
201 | bases.append(filters)
202 |
203 | bases = torch.stack(bases, axis=-1) # [B, N, N, H]
204 | bases = bases[edge_idx]
205 | bases = self.adj_dropout(self.filter_encoder(bases))
206 | bases = edge_softmax(g, bases)
207 |
208 | for conv in self.convs:
209 | node_feat = conv(g, node_feat, edge_feat, bases)
210 |
211 | h = self.pool(g, node_feat)
212 | h = self.linear(h)
213 |
214 | return h
215 |
216 |
217 | def length_to_mask(self, length):
218 | '''
219 | length: [B]
220 | return: [B, max_len].
221 | '''
222 | B = len(length)
223 | N = length.max().item()
224 | mask1d = torch.arange(N, device=length.device).expand(B, N) >= length.unsqueeze(1)
225 | mask2d = (~mask1d).float().unsqueeze(2) @ (~mask1d).float().unsqueeze(1)
226 | mask2d = mask2d.bool()
227 |
228 | # Example
229 | # length=[1, 2, 3], B=3, N=3,
230 |
231 | # mask1d for key_padding_mask of MultiheadAttention [B, N]
232 | # [False, True, True ]
233 | # [False, False, True ]
234 | # [False, False, False]
235 |
236 | # mask2d for edge indexing [B, N, N]
237 | # [[1, 0, 0], | [1, 1, 0], | [1, 1, 1],
238 | # [0, 0, 0], | [1, 1, 0], | [1, 1, 1],
239 | # [0, 0, 0], | [0, 0, 0], | [1, 1, 1],]
240 |
241 | return mask1d, mask2d
242 |
243 |
244 | '''
245 | def length_to_mask(self, length):
246 | '''
247 | length: [B]
248 | return: [B, max_len].
249 | '''
250 | B = len(length)
251 | N = length.max().item()
252 | mask1d = torch.arange(N, device=length.device).expand(B, N) >= length.unsqueeze(1)
253 |
254 | mask2d = torch.zeros(B, N, N, device=length.device)
255 | for i in range(B):
256 | mask2d[i, :length[i], :length[i]] = 1.0
257 |
258 | # mask1d for key_padding_mask, a True value indicates that the corresponding key value will be ignored for the purpose of attention
259 | # mask2d for edge selection from padding
260 | return mask1d, mask2d.bool()
261 | '''
262 |
--------------------------------------------------------------------------------
/Graph/zinc_model.py:
--------------------------------------------------------------------------------
1 | import time
2 | import math
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | import numpy as np
7 |
8 | from dgl.ops.edge_softmax import edge_softmax
9 | from dgl.nn.pytorch.glob import AvgPooling
10 | from dgl import function as fn
11 | from torch.nn.init import xavier_uniform_, xavier_normal_, constant_
12 |
13 |
14 | class SineEncoding(nn.Module):
15 | def __init__(self, hidden_dim=128):
16 | super(SineEncoding, self).__init__()
17 | self.constant = 100
18 | self.hidden_dim = hidden_dim
19 | self.eig_w = nn.Linear(hidden_dim + 1, hidden_dim)
20 |
21 | def forward(self, e):
22 | # input: [B, N]
23 | # output: [B, N, d]
24 |
25 | ee = e * self.constant
26 | div = torch.exp(torch.arange(0, self.hidden_dim, 2) * (-math.log(10000)/self.hidden_dim)).to(e.device)
27 | pe = ee.unsqueeze(2) * div
28 | eeig = torch.cat((e.unsqueeze(2), torch.sin(pe), torch.cos(pe)), dim=2)
29 |
30 | return self.eig_w(eeig)
31 |
32 |
33 | class FeedForwardNetwork(nn.Module):
34 |
35 | def __init__(self, input_dim, hidden_dim, output_dim):
36 | super(FeedForwardNetwork, self).__init__()
37 | self.layer1 = nn.Linear(input_dim, hidden_dim)
38 | self.gelu = nn.GELU()
39 | self.layer2 = nn.Linear(hidden_dim, output_dim)
40 |
41 | def forward(self, x):
42 | x = self.layer1(x)
43 | x = self.gelu(x)
44 | x = self.layer2(x)
45 | return x
46 |
47 |
48 | class Conv(nn.Module):
49 | def __init__(self, hidden_size, dropout_rate):
50 | super(Conv, self).__init__()
51 |
52 | self.pre_ffn = nn.Sequential(
53 | nn.Linear(hidden_size, hidden_size),
54 | # nn.BatchNorm1d(hidden_size),
55 | nn.GELU()
56 | )
57 |
58 | self.preffn_dropout = nn.Dropout(dropout_rate)
59 | self.ffn_dropout = nn.Dropout(dropout_rate)
60 |
61 | self.ffn = nn.Sequential(
62 | nn.Linear(hidden_size, hidden_size),
63 | nn.BatchNorm1d(hidden_size),
64 | nn.ReLU(),
65 | nn.Linear(hidden_size, hidden_size),
66 | nn.BatchNorm1d(hidden_size),
67 | nn.ReLU()
68 | )
69 |
70 | def forward(self, graph, x_feat, edge_attr, bases):
71 | with graph.local_scope():
72 | graph.ndata['x'] = x_feat
73 | graph.apply_edges(fn.copy_u('x', '_x'))
74 | xee = self.pre_ffn(graph.edata['_x'] + edge_attr) * bases
75 | graph.edata['v'] = xee
76 | graph.update_all(fn.copy_e('v', '_aggr_e'), fn.sum('_aggr_e', 'aggr_e'))
77 | y = graph.ndata['aggr_e']
78 | y = self.preffn_dropout(y)
79 | x = x_feat + y
80 | y = self.ffn(x)
81 | y = self.ffn_dropout(y)
82 | x = x + y
83 | return x
84 |
85 |
86 | class SpecformerZINC(nn.Module):
87 |
88 | def __init__(self, nclass, nlayer, hidden_dim=128, nheads=4, feat_dropout=0.1, trans_dropout=0.1, adj_dropout=0.1):
89 | super(SpecformerZINC, self).__init__()
90 |
91 | self.nlayer = nlayer
92 | self.nclass = nclass
93 | self.hidden_dim = hidden_dim
94 | self.nheads = nheads
95 |
96 | self.atom_encoder = nn.Embedding(40, hidden_dim)
97 | self.bond_encoder = nn.Embedding(10, hidden_dim, padding_idx=0)
98 |
99 | self.eig_encoder = SineEncoding(hidden_dim)
100 | self.decoder = nn.Linear(hidden_dim, nheads)
101 |
102 | self.mha_norm = nn.LayerNorm(hidden_dim)
103 | self.ffn_norm = nn.LayerNorm(hidden_dim)
104 | self.mha_dropout = nn.Dropout(trans_dropout)
105 | self.ffn_dropout = nn.Dropout(trans_dropout)
106 | self.mha = nn.MultiheadAttention(hidden_dim, nheads, trans_dropout, batch_first=True)
107 | self.ffn = FeedForwardNetwork(hidden_dim, hidden_dim, hidden_dim)
108 |
109 | self.adj_dropout = nn.Dropout(adj_dropout)
110 | self.filter_encoder = nn.Sequential(
111 | nn.Linear(nheads + 1, hidden_dim),
112 | nn.GELU(),
113 | nn.Linear(hidden_dim, hidden_dim),
114 | nn.GELU(),
115 | )
116 |
117 | self.convs = nn.ModuleList([Conv(hidden_dim, feat_dropout) for _ in range(nlayer)])
118 | self.pool = AvgPooling()
119 | self.linear = nn.Linear(hidden_dim, nclass)
120 |
121 | def forward(self, e, u, g, length):
122 |
123 | # e: [B, N] eigenvalues
124 | # u: [B, N, N] eigenvectors
125 | # x: [B, N, d] node features
126 | # f: [B, N, N, d] edge features
127 | # do not use u to generate edge_idx because of the existing of connected components
128 |
129 | B, N = e.size()
130 | ut = u.transpose(1, 2)
131 |
132 | node_feat = g.ndata['feat']
133 | edge_feat = g.edata['feat']
134 |
135 | eig_mask, edge_idx = self.length_to_mask(length)
136 |
137 | node_feat = self.atom_encoder(node_feat).squeeze(-2)
138 | edge_feat = self.bond_encoder(edge_feat).squeeze(-2)
139 |
140 | eig = self.eig_encoder(e)
141 | mha_eig = self.mha_norm(eig)
142 | mha_eig, attn = self.mha(mha_eig, mha_eig, mha_eig, key_padding_mask=eig_mask, average_attn_weights=False)
143 | eig = eig + self.mha_dropout(mha_eig)
144 |
145 | ffn_eig = self.ffn_norm(eig)
146 | ffn_eig = self.ffn(ffn_eig)
147 | eig = eig + self.ffn_dropout(ffn_eig)
148 |
149 | new_e = self.decoder(eig).transpose(2, 1) # [B, m, N]
150 | diag_e = torch.diag_embed(new_e) # [B, m, N, N]
151 |
152 | bases = [torch.diag_embed(torch.ones_like(e))]
153 | for i in range(self.nheads):
154 | filters = u @ diag_e[:, i, :, :] @ ut
155 | bases.append(filters)
156 |
157 | bases = torch.stack(bases, axis=-1) # [B, N, N, H]
158 | bases = bases[edge_idx]
159 | bases = self.adj_dropout(self.filter_encoder(bases))
160 |
161 | for conv in self.convs:
162 | node_feat = conv(g, node_feat, edge_feat, bases)
163 |
164 | h = self.pool(g, node_feat)
165 | h = self.linear(h)
166 |
167 | return h, new_e, attn
168 |
169 |
170 | def length_to_mask(self, length):
171 | '''
172 | length: [B]
173 | return: [B, max_len].
174 | '''
175 | B = len(length)
176 | N = length.max().item()
177 | mask1d = torch.arange(N, device=length.device).expand(B, N) >= length.unsqueeze(1)
178 |
179 | mask2d = torch.zeros(B, N, N, device=length.device)
180 | for i in range(B):
181 | mask2d[i, :length[i], :length[i]] = 1.0
182 |
183 | # mask1d for key_padding_mask, a True value indicates that the corresponding key value will be ignored for the purpose of attention
184 | # mask2d for edge selection from padding
185 | return mask1d, mask2d.bool()
186 |
187 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | GNU GENERAL PUBLIC LICENSE
2 | Version 3, 29 June 2007
3 |
4 | Copyright (C) 2007 Free Software Foundation, Inc.
5 | Everyone is permitted to copy and distribute verbatim copies
6 | of this license document, but changing it is not allowed.
7 |
8 | Preamble
9 |
10 | The GNU General Public License is a free, copyleft license for
11 | software and other kinds of works.
12 |
13 | The licenses for most software and other practical works are designed
14 | to take away your freedom to share and change the works. By contrast,
15 | the GNU General Public License is intended to guarantee your freedom to
16 | share and change all versions of a program--to make sure it remains free
17 | software for all its users. We, the Free Software Foundation, use the
18 | GNU General Public License for most of our software; it applies also to
19 | any other work released this way by its authors. You can apply it to
20 | your programs, too.
21 |
22 | When we speak of free software, we are referring to freedom, not
23 | price. Our General Public Licenses are designed to make sure that you
24 | have the freedom to distribute copies of free software (and charge for
25 | them if you wish), that you receive source code or can get it if you
26 | want it, that you can change the software or use pieces of it in new
27 | free programs, and that you know you can do these things.
28 |
29 | To protect your rights, we need to prevent others from denying you
30 | these rights or asking you to surrender the rights. Therefore, you have
31 | certain responsibilities if you distribute copies of the software, or if
32 | you modify it: responsibilities to respect the freedom of others.
33 |
34 | For example, if you distribute copies of such a program, whether
35 | gratis or for a fee, you must pass on to the recipients the same
36 | freedoms that you received. You must make sure that they, too, receive
37 | or can get the source code. And you must show them these terms so they
38 | know their rights.
39 |
40 | Developers that use the GNU GPL protect your rights with two steps:
41 | (1) assert copyright on the software, and (2) offer you this License
42 | giving you legal permission to copy, distribute and/or modify it.
43 |
44 | For the developers' and authors' protection, the GPL clearly explains
45 | that there is no warranty for this free software. For both users' and
46 | authors' sake, the GPL requires that modified versions be marked as
47 | changed, so that their problems will not be attributed erroneously to
48 | authors of previous versions.
49 |
50 | Some devices are designed to deny users access to install or run
51 | modified versions of the software inside them, although the manufacturer
52 | can do so. This is fundamentally incompatible with the aim of
53 | protecting users' freedom to change the software. The systematic
54 | pattern of such abuse occurs in the area of products for individuals to
55 | use, which is precisely where it is most unacceptable. Therefore, we
56 | have designed this version of the GPL to prohibit the practice for those
57 | products. If such problems arise substantially in other domains, we
58 | stand ready to extend this provision to those domains in future versions
59 | of the GPL, as needed to protect the freedom of users.
60 |
61 | Finally, every program is threatened constantly by software patents.
62 | States should not allow patents to restrict development and use of
63 | software on general-purpose computers, but in those that do, we wish to
64 | avoid the special danger that patents applied to a free program could
65 | make it effectively proprietary. To prevent this, the GPL assures that
66 | patents cannot be used to render the program non-free.
67 |
68 | The precise terms and conditions for copying, distribution and
69 | modification follow.
70 |
71 | TERMS AND CONDITIONS
72 |
73 | 0. Definitions.
74 |
75 | "This License" refers to version 3 of the GNU General Public License.
76 |
77 | "Copyright" also means copyright-like laws that apply to other kinds of
78 | works, such as semiconductor masks.
79 |
80 | "The Program" refers to any copyrightable work licensed under this
81 | License. Each licensee is addressed as "you". "Licensees" and
82 | "recipients" may be individuals or organizations.
83 |
84 | To "modify" a work means to copy from or adapt all or part of the work
85 | in a fashion requiring copyright permission, other than the making of an
86 | exact copy. The resulting work is called a "modified version" of the
87 | earlier work or a work "based on" the earlier work.
88 |
89 | A "covered work" means either the unmodified Program or a work based
90 | on the Program.
91 |
92 | To "propagate" a work means to do anything with it that, without
93 | permission, would make you directly or secondarily liable for
94 | infringement under applicable copyright law, except executing it on a
95 | computer or modifying a private copy. Propagation includes copying,
96 | distribution (with or without modification), making available to the
97 | public, and in some countries other activities as well.
98 |
99 | To "convey" a work means any kind of propagation that enables other
100 | parties to make or receive copies. Mere interaction with a user through
101 | a computer network, with no transfer of a copy, is not conveying.
102 |
103 | An interactive user interface displays "Appropriate Legal Notices"
104 | to the extent that it includes a convenient and prominently visible
105 | feature that (1) displays an appropriate copyright notice, and (2)
106 | tells the user that there is no warranty for the work (except to the
107 | extent that warranties are provided), that licensees may convey the
108 | work under this License, and how to view a copy of this License. If
109 | the interface presents a list of user commands or options, such as a
110 | menu, a prominent item in the list meets this criterion.
111 |
112 | 1. Source Code.
113 |
114 | The "source code" for a work means the preferred form of the work
115 | for making modifications to it. "Object code" means any non-source
116 | form of a work.
117 |
118 | A "Standard Interface" means an interface that either is an official
119 | standard defined by a recognized standards body, or, in the case of
120 | interfaces specified for a particular programming language, one that
121 | is widely used among developers working in that language.
122 |
123 | The "System Libraries" of an executable work include anything, other
124 | than the work as a whole, that (a) is included in the normal form of
125 | packaging a Major Component, but which is not part of that Major
126 | Component, and (b) serves only to enable use of the work with that
127 | Major Component, or to implement a Standard Interface for which an
128 | implementation is available to the public in source code form. A
129 | "Major Component", in this context, means a major essential component
130 | (kernel, window system, and so on) of the specific operating system
131 | (if any) on which the executable work runs, or a compiler used to
132 | produce the work, or an object code interpreter used to run it.
133 |
134 | The "Corresponding Source" for a work in object code form means all
135 | the source code needed to generate, install, and (for an executable
136 | work) run the object code and to modify the work, including scripts to
137 | control those activities. However, it does not include the work's
138 | System Libraries, or general-purpose tools or generally available free
139 | programs which are used unmodified in performing those activities but
140 | which are not part of the work. For example, Corresponding Source
141 | includes interface definition files associated with source files for
142 | the work, and the source code for shared libraries and dynamically
143 | linked subprograms that the work is specifically designed to require,
144 | such as by intimate data communication or control flow between those
145 | subprograms and other parts of the work.
146 |
147 | The Corresponding Source need not include anything that users
148 | can regenerate automatically from other parts of the Corresponding
149 | Source.
150 |
151 | The Corresponding Source for a work in source code form is that
152 | same work.
153 |
154 | 2. Basic Permissions.
155 |
156 | All rights granted under this License are granted for the term of
157 | copyright on the Program, and are irrevocable provided the stated
158 | conditions are met. This License explicitly affirms your unlimited
159 | permission to run the unmodified Program. The output from running a
160 | covered work is covered by this License only if the output, given its
161 | content, constitutes a covered work. This License acknowledges your
162 | rights of fair use or other equivalent, as provided by copyright law.
163 |
164 | You may make, run and propagate covered works that you do not
165 | convey, without conditions so long as your license otherwise remains
166 | in force. You may convey covered works to others for the sole purpose
167 | of having them make modifications exclusively for you, or provide you
168 | with facilities for running those works, provided that you comply with
169 | the terms of this License in conveying all material for which you do
170 | not control copyright. Those thus making or running the covered works
171 | for you must do so exclusively on your behalf, under your direction
172 | and control, on terms that prohibit them from making any copies of
173 | your copyrighted material outside their relationship with you.
174 |
175 | Conveying under any other circumstances is permitted solely under
176 | the conditions stated below. Sublicensing is not allowed; section 10
177 | makes it unnecessary.
178 |
179 | 3. Protecting Users' Legal Rights From Anti-Circumvention Law.
180 |
181 | No covered work shall be deemed part of an effective technological
182 | measure under any applicable law fulfilling obligations under article
183 | 11 of the WIPO copyright treaty adopted on 20 December 1996, or
184 | similar laws prohibiting or restricting circumvention of such
185 | measures.
186 |
187 | When you convey a covered work, you waive any legal power to forbid
188 | circumvention of technological measures to the extent such circumvention
189 | is effected by exercising rights under this License with respect to
190 | the covered work, and you disclaim any intention to limit operation or
191 | modification of the work as a means of enforcing, against the work's
192 | users, your or third parties' legal rights to forbid circumvention of
193 | technological measures.
194 |
195 | 4. Conveying Verbatim Copies.
196 |
197 | You may convey verbatim copies of the Program's source code as you
198 | receive it, in any medium, provided that you conspicuously and
199 | appropriately publish on each copy an appropriate copyright notice;
200 | keep intact all notices stating that this License and any
201 | non-permissive terms added in accord with section 7 apply to the code;
202 | keep intact all notices of the absence of any warranty; and give all
203 | recipients a copy of this License along with the Program.
204 |
205 | You may charge any price or no price for each copy that you convey,
206 | and you may offer support or warranty protection for a fee.
207 |
208 | 5. Conveying Modified Source Versions.
209 |
210 | You may convey a work based on the Program, or the modifications to
211 | produce it from the Program, in the form of source code under the
212 | terms of section 4, provided that you also meet all of these conditions:
213 |
214 | a) The work must carry prominent notices stating that you modified
215 | it, and giving a relevant date.
216 |
217 | b) The work must carry prominent notices stating that it is
218 | released under this License and any conditions added under section
219 | 7. This requirement modifies the requirement in section 4 to
220 | "keep intact all notices".
221 |
222 | c) You must license the entire work, as a whole, under this
223 | License to anyone who comes into possession of a copy. This
224 | License will therefore apply, along with any applicable section 7
225 | additional terms, to the whole of the work, and all its parts,
226 | regardless of how they are packaged. This License gives no
227 | permission to license the work in any other way, but it does not
228 | invalidate such permission if you have separately received it.
229 |
230 | d) If the work has interactive user interfaces, each must display
231 | Appropriate Legal Notices; however, if the Program has interactive
232 | interfaces that do not display Appropriate Legal Notices, your
233 | work need not make them do so.
234 |
235 | A compilation of a covered work with other separate and independent
236 | works, which are not by their nature extensions of the covered work,
237 | and which are not combined with it such as to form a larger program,
238 | in or on a volume of a storage or distribution medium, is called an
239 | "aggregate" if the compilation and its resulting copyright are not
240 | used to limit the access or legal rights of the compilation's users
241 | beyond what the individual works permit. Inclusion of a covered work
242 | in an aggregate does not cause this License to apply to the other
243 | parts of the aggregate.
244 |
245 | 6. Conveying Non-Source Forms.
246 |
247 | You may convey a covered work in object code form under the terms
248 | of sections 4 and 5, provided that you also convey the
249 | machine-readable Corresponding Source under the terms of this License,
250 | in one of these ways:
251 |
252 | a) Convey the object code in, or embodied in, a physical product
253 | (including a physical distribution medium), accompanied by the
254 | Corresponding Source fixed on a durable physical medium
255 | customarily used for software interchange.
256 |
257 | b) Convey the object code in, or embodied in, a physical product
258 | (including a physical distribution medium), accompanied by a
259 | written offer, valid for at least three years and valid for as
260 | long as you offer spare parts or customer support for that product
261 | model, to give anyone who possesses the object code either (1) a
262 | copy of the Corresponding Source for all the software in the
263 | product that is covered by this License, on a durable physical
264 | medium customarily used for software interchange, for a price no
265 | more than your reasonable cost of physically performing this
266 | conveying of source, or (2) access to copy the
267 | Corresponding Source from a network server at no charge.
268 |
269 | c) Convey individual copies of the object code with a copy of the
270 | written offer to provide the Corresponding Source. This
271 | alternative is allowed only occasionally and noncommercially, and
272 | only if you received the object code with such an offer, in accord
273 | with subsection 6b.
274 |
275 | d) Convey the object code by offering access from a designated
276 | place (gratis or for a charge), and offer equivalent access to the
277 | Corresponding Source in the same way through the same place at no
278 | further charge. You need not require recipients to copy the
279 | Corresponding Source along with the object code. If the place to
280 | copy the object code is a network server, the Corresponding Source
281 | may be on a different server (operated by you or a third party)
282 | that supports equivalent copying facilities, provided you maintain
283 | clear directions next to the object code saying where to find the
284 | Corresponding Source. Regardless of what server hosts the
285 | Corresponding Source, you remain obligated to ensure that it is
286 | available for as long as needed to satisfy these requirements.
287 |
288 | e) Convey the object code using peer-to-peer transmission, provided
289 | you inform other peers where the object code and Corresponding
290 | Source of the work are being offered to the general public at no
291 | charge under subsection 6d.
292 |
293 | A separable portion of the object code, whose source code is excluded
294 | from the Corresponding Source as a System Library, need not be
295 | included in conveying the object code work.
296 |
297 | A "User Product" is either (1) a "consumer product", which means any
298 | tangible personal property which is normally used for personal, family,
299 | or household purposes, or (2) anything designed or sold for incorporation
300 | into a dwelling. In determining whether a product is a consumer product,
301 | doubtful cases shall be resolved in favor of coverage. For a particular
302 | product received by a particular user, "normally used" refers to a
303 | typical or common use of that class of product, regardless of the status
304 | of the particular user or of the way in which the particular user
305 | actually uses, or expects or is expected to use, the product. A product
306 | is a consumer product regardless of whether the product has substantial
307 | commercial, industrial or non-consumer uses, unless such uses represent
308 | the only significant mode of use of the product.
309 |
310 | "Installation Information" for a User Product means any methods,
311 | procedures, authorization keys, or other information required to install
312 | and execute modified versions of a covered work in that User Product from
313 | a modified version of its Corresponding Source. The information must
314 | suffice to ensure that the continued functioning of the modified object
315 | code is in no case prevented or interfered with solely because
316 | modification has been made.
317 |
318 | If you convey an object code work under this section in, or with, or
319 | specifically for use in, a User Product, and the conveying occurs as
320 | part of a transaction in which the right of possession and use of the
321 | User Product is transferred to the recipient in perpetuity or for a
322 | fixed term (regardless of how the transaction is characterized), the
323 | Corresponding Source conveyed under this section must be accompanied
324 | by the Installation Information. But this requirement does not apply
325 | if neither you nor any third party retains the ability to install
326 | modified object code on the User Product (for example, the work has
327 | been installed in ROM).
328 |
329 | The requirement to provide Installation Information does not include a
330 | requirement to continue to provide support service, warranty, or updates
331 | for a work that has been modified or installed by the recipient, or for
332 | the User Product in which it has been modified or installed. Access to a
333 | network may be denied when the modification itself materially and
334 | adversely affects the operation of the network or violates the rules and
335 | protocols for communication across the network.
336 |
337 | Corresponding Source conveyed, and Installation Information provided,
338 | in accord with this section must be in a format that is publicly
339 | documented (and with an implementation available to the public in
340 | source code form), and must require no special password or key for
341 | unpacking, reading or copying.
342 |
343 | 7. Additional Terms.
344 |
345 | "Additional permissions" are terms that supplement the terms of this
346 | License by making exceptions from one or more of its conditions.
347 | Additional permissions that are applicable to the entire Program shall
348 | be treated as though they were included in this License, to the extent
349 | that they are valid under applicable law. If additional permissions
350 | apply only to part of the Program, that part may be used separately
351 | under those permissions, but the entire Program remains governed by
352 | this License without regard to the additional permissions.
353 |
354 | When you convey a copy of a covered work, you may at your option
355 | remove any additional permissions from that copy, or from any part of
356 | it. (Additional permissions may be written to require their own
357 | removal in certain cases when you modify the work.) You may place
358 | additional permissions on material, added by you to a covered work,
359 | for which you have or can give appropriate copyright permission.
360 |
361 | Notwithstanding any other provision of this License, for material you
362 | add to a covered work, you may (if authorized by the copyright holders of
363 | that material) supplement the terms of this License with terms:
364 |
365 | a) Disclaiming warranty or limiting liability differently from the
366 | terms of sections 15 and 16 of this License; or
367 |
368 | b) Requiring preservation of specified reasonable legal notices or
369 | author attributions in that material or in the Appropriate Legal
370 | Notices displayed by works containing it; or
371 |
372 | c) Prohibiting misrepresentation of the origin of that material, or
373 | requiring that modified versions of such material be marked in
374 | reasonable ways as different from the original version; or
375 |
376 | d) Limiting the use for publicity purposes of names of licensors or
377 | authors of the material; or
378 |
379 | e) Declining to grant rights under trademark law for use of some
380 | trade names, trademarks, or service marks; or
381 |
382 | f) Requiring indemnification of licensors and authors of that
383 | material by anyone who conveys the material (or modified versions of
384 | it) with contractual assumptions of liability to the recipient, for
385 | any liability that these contractual assumptions directly impose on
386 | those licensors and authors.
387 |
388 | All other non-permissive additional terms are considered "further
389 | restrictions" within the meaning of section 10. If the Program as you
390 | received it, or any part of it, contains a notice stating that it is
391 | governed by this License along with a term that is a further
392 | restriction, you may remove that term. If a license document contains
393 | a further restriction but permits relicensing or conveying under this
394 | License, you may add to a covered work material governed by the terms
395 | of that license document, provided that the further restriction does
396 | not survive such relicensing or conveying.
397 |
398 | If you add terms to a covered work in accord with this section, you
399 | must place, in the relevant source files, a statement of the
400 | additional terms that apply to those files, or a notice indicating
401 | where to find the applicable terms.
402 |
403 | Additional terms, permissive or non-permissive, may be stated in the
404 | form of a separately written license, or stated as exceptions;
405 | the above requirements apply either way.
406 |
407 | 8. Termination.
408 |
409 | You may not propagate or modify a covered work except as expressly
410 | provided under this License. Any attempt otherwise to propagate or
411 | modify it is void, and will automatically terminate your rights under
412 | this License (including any patent licenses granted under the third
413 | paragraph of section 11).
414 |
415 | However, if you cease all violation of this License, then your
416 | license from a particular copyright holder is reinstated (a)
417 | provisionally, unless and until the copyright holder explicitly and
418 | finally terminates your license, and (b) permanently, if the copyright
419 | holder fails to notify you of the violation by some reasonable means
420 | prior to 60 days after the cessation.
421 |
422 | Moreover, your license from a particular copyright holder is
423 | reinstated permanently if the copyright holder notifies you of the
424 | violation by some reasonable means, this is the first time you have
425 | received notice of violation of this License (for any work) from that
426 | copyright holder, and you cure the violation prior to 30 days after
427 | your receipt of the notice.
428 |
429 | Termination of your rights under this section does not terminate the
430 | licenses of parties who have received copies or rights from you under
431 | this License. If your rights have been terminated and not permanently
432 | reinstated, you do not qualify to receive new licenses for the same
433 | material under section 10.
434 |
435 | 9. Acceptance Not Required for Having Copies.
436 |
437 | You are not required to accept this License in order to receive or
438 | run a copy of the Program. Ancillary propagation of a covered work
439 | occurring solely as a consequence of using peer-to-peer transmission
440 | to receive a copy likewise does not require acceptance. However,
441 | nothing other than this License grants you permission to propagate or
442 | modify any covered work. These actions infringe copyright if you do
443 | not accept this License. Therefore, by modifying or propagating a
444 | covered work, you indicate your acceptance of this License to do so.
445 |
446 | 10. Automatic Licensing of Downstream Recipients.
447 |
448 | Each time you convey a covered work, the recipient automatically
449 | receives a license from the original licensors, to run, modify and
450 | propagate that work, subject to this License. You are not responsible
451 | for enforcing compliance by third parties with this License.
452 |
453 | An "entity transaction" is a transaction transferring control of an
454 | organization, or substantially all assets of one, or subdividing an
455 | organization, or merging organizations. If propagation of a covered
456 | work results from an entity transaction, each party to that
457 | transaction who receives a copy of the work also receives whatever
458 | licenses to the work the party's predecessor in interest had or could
459 | give under the previous paragraph, plus a right to possession of the
460 | Corresponding Source of the work from the predecessor in interest, if
461 | the predecessor has it or can get it with reasonable efforts.
462 |
463 | You may not impose any further restrictions on the exercise of the
464 | rights granted or affirmed under this License. For example, you may
465 | not impose a license fee, royalty, or other charge for exercise of
466 | rights granted under this License, and you may not initiate litigation
467 | (including a cross-claim or counterclaim in a lawsuit) alleging that
468 | any patent claim is infringed by making, using, selling, offering for
469 | sale, or importing the Program or any portion of it.
470 |
471 | 11. Patents.
472 |
473 | A "contributor" is a copyright holder who authorizes use under this
474 | License of the Program or a work on which the Program is based. The
475 | work thus licensed is called the contributor's "contributor version".
476 |
477 | A contributor's "essential patent claims" are all patent claims
478 | owned or controlled by the contributor, whether already acquired or
479 | hereafter acquired, that would be infringed by some manner, permitted
480 | by this License, of making, using, or selling its contributor version,
481 | but do not include claims that would be infringed only as a
482 | consequence of further modification of the contributor version. For
483 | purposes of this definition, "control" includes the right to grant
484 | patent sublicenses in a manner consistent with the requirements of
485 | this License.
486 |
487 | Each contributor grants you a non-exclusive, worldwide, royalty-free
488 | patent license under the contributor's essential patent claims, to
489 | make, use, sell, offer for sale, import and otherwise run, modify and
490 | propagate the contents of its contributor version.
491 |
492 | In the following three paragraphs, a "patent license" is any express
493 | agreement or commitment, however denominated, not to enforce a patent
494 | (such as an express permission to practice a patent or covenant not to
495 | sue for patent infringement). To "grant" such a patent license to a
496 | party means to make such an agreement or commitment not to enforce a
497 | patent against the party.
498 |
499 | If you convey a covered work, knowingly relying on a patent license,
500 | and the Corresponding Source of the work is not available for anyone
501 | to copy, free of charge and under the terms of this License, through a
502 | publicly available network server or other readily accessible means,
503 | then you must either (1) cause the Corresponding Source to be so
504 | available, or (2) arrange to deprive yourself of the benefit of the
505 | patent license for this particular work, or (3) arrange, in a manner
506 | consistent with the requirements of this License, to extend the patent
507 | license to downstream recipients. "Knowingly relying" means you have
508 | actual knowledge that, but for the patent license, your conveying the
509 | covered work in a country, or your recipient's use of the covered work
510 | in a country, would infringe one or more identifiable patents in that
511 | country that you have reason to believe are valid.
512 |
513 | If, pursuant to or in connection with a single transaction or
514 | arrangement, you convey, or propagate by procuring conveyance of, a
515 | covered work, and grant a patent license to some of the parties
516 | receiving the covered work authorizing them to use, propagate, modify
517 | or convey a specific copy of the covered work, then the patent license
518 | you grant is automatically extended to all recipients of the covered
519 | work and works based on it.
520 |
521 | A patent license is "discriminatory" if it does not include within
522 | the scope of its coverage, prohibits the exercise of, or is
523 | conditioned on the non-exercise of one or more of the rights that are
524 | specifically granted under this License. You may not convey a covered
525 | work if you are a party to an arrangement with a third party that is
526 | in the business of distributing software, under which you make payment
527 | to the third party based on the extent of your activity of conveying
528 | the work, and under which the third party grants, to any of the
529 | parties who would receive the covered work from you, a discriminatory
530 | patent license (a) in connection with copies of the covered work
531 | conveyed by you (or copies made from those copies), or (b) primarily
532 | for and in connection with specific products or compilations that
533 | contain the covered work, unless you entered into that arrangement,
534 | or that patent license was granted, prior to 28 March 2007.
535 |
536 | Nothing in this License shall be construed as excluding or limiting
537 | any implied license or other defenses to infringement that may
538 | otherwise be available to you under applicable patent law.
539 |
540 | 12. No Surrender of Others' Freedom.
541 |
542 | If conditions are imposed on you (whether by court order, agreement or
543 | otherwise) that contradict the conditions of this License, they do not
544 | excuse you from the conditions of this License. If you cannot convey a
545 | covered work so as to satisfy simultaneously your obligations under this
546 | License and any other pertinent obligations, then as a consequence you may
547 | not convey it at all. For example, if you agree to terms that obligate you
548 | to collect a royalty for further conveying from those to whom you convey
549 | the Program, the only way you could satisfy both those terms and this
550 | License would be to refrain entirely from conveying the Program.
551 |
552 | 13. Use with the GNU Affero General Public License.
553 |
554 | Notwithstanding any other provision of this License, you have
555 | permission to link or combine any covered work with a work licensed
556 | under version 3 of the GNU Affero General Public License into a single
557 | combined work, and to convey the resulting work. The terms of this
558 | License will continue to apply to the part which is the covered work,
559 | but the special requirements of the GNU Affero General Public License,
560 | section 13, concerning interaction through a network will apply to the
561 | combination as such.
562 |
563 | 14. Revised Versions of this License.
564 |
565 | The Free Software Foundation may publish revised and/or new versions of
566 | the GNU General Public License from time to time. Such new versions will
567 | be similar in spirit to the present version, but may differ in detail to
568 | address new problems or concerns.
569 |
570 | Each version is given a distinguishing version number. If the
571 | Program specifies that a certain numbered version of the GNU General
572 | Public License "or any later version" applies to it, you have the
573 | option of following the terms and conditions either of that numbered
574 | version or of any later version published by the Free Software
575 | Foundation. If the Program does not specify a version number of the
576 | GNU General Public License, you may choose any version ever published
577 | by the Free Software Foundation.
578 |
579 | If the Program specifies that a proxy can decide which future
580 | versions of the GNU General Public License can be used, that proxy's
581 | public statement of acceptance of a version permanently authorizes you
582 | to choose that version for the Program.
583 |
584 | Later license versions may give you additional or different
585 | permissions. However, no additional obligations are imposed on any
586 | author or copyright holder as a result of your choosing to follow a
587 | later version.
588 |
589 | 15. Disclaimer of Warranty.
590 |
591 | THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
592 | APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
593 | HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
594 | OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
595 | THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
596 | PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
597 | IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
598 | ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
599 |
600 | 16. Limitation of Liability.
601 |
602 | IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
603 | WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
604 | THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
605 | GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
606 | USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
607 | DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
608 | PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
609 | EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
610 | SUCH DAMAGES.
611 |
612 | 17. Interpretation of Sections 15 and 16.
613 |
614 | If the disclaimer of warranty and limitation of liability provided
615 | above cannot be given local legal effect according to their terms,
616 | reviewing courts shall apply local law that most closely approximates
617 | an absolute waiver of all civil liability in connection with the
618 | Program, unless a warranty or assumption of liability accompanies a
619 | copy of the Program in return for a fee.
620 |
621 | END OF TERMS AND CONDITIONS
622 |
623 | How to Apply These Terms to Your New Programs
624 |
625 | If you develop a new program, and you want it to be of the greatest
626 | possible use to the public, the best way to achieve this is to make it
627 | free software which everyone can redistribute and change under these terms.
628 |
629 | To do so, attach the following notices to the program. It is safest
630 | to attach them to the start of each source file to most effectively
631 | state the exclusion of warranty; and each file should have at least
632 | the "copyright" line and a pointer to where the full notice is found.
633 |
634 |
635 | Copyright (C)
636 |
637 | This program is free software: you can redistribute it and/or modify
638 | it under the terms of the GNU General Public License as published by
639 | the Free Software Foundation, either version 3 of the License, or
640 | (at your option) any later version.
641 |
642 | This program is distributed in the hope that it will be useful,
643 | but WITHOUT ANY WARRANTY; without even the implied warranty of
644 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
645 | GNU General Public License for more details.
646 |
647 | You should have received a copy of the GNU General Public License
648 | along with this program. If not, see .
649 |
650 | Also add information on how to contact you by electronic and paper mail.
651 |
652 | If the program does terminal interaction, make it output a short
653 | notice like this when it starts in an interactive mode:
654 |
655 | Copyright (C)
656 | This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'.
657 | This is free software, and you are welcome to redistribute it
658 | under certain conditions; type `show c' for details.
659 |
660 | The hypothetical commands `show w' and `show c' should show the appropriate
661 | parts of the General Public License. Of course, your program's commands
662 | might be different; for a GUI interface, you would use an "about box".
663 |
664 | You should also get your employer (if you work as a programmer) or school,
665 | if any, to sign a "copyright disclaimer" for the program, if necessary.
666 | For more information on this, and how to apply and follow the GNU GPL, see
667 | .
668 |
669 | The GNU General Public License does not permit incorporating your program
670 | into proprietary programs. If your program is a subroutine library, you
671 | may consider it more useful to permit linking proprietary applications with
672 | the library. If this is what you want to do, use the GNU Lesser General
673 | Public License instead of this License. But first, please read
674 | .
675 |
--------------------------------------------------------------------------------
/Node/config.yaml:
--------------------------------------------------------------------------------
1 | signal:
2 | nclass: 1
3 | nlayer: 1
4 | num_heads: 1
5 | hidden_dim: 16
6 | epoch: 2000
7 | lr: 0.01
8 | weight_decay: 0.0
9 | tran_dropout: 0.0
10 | feat_dropout: 0.0
11 | prop_dropout: 0.0
12 | norm: 'none'
13 | cora:
14 | nclass: 7
15 | nlayer: 2
16 | num_heads: 2
17 | hidden_dim: 32
18 | epoch: 2000
19 | lr: 0.0002
20 | weight_decay: 0.0001
21 | tran_dropout: 0.2
22 | feat_dropout: 0.6
23 | prop_dropout: 0.2
24 | norm: 'none'
25 | citeseer:
26 | nclass: 6
27 | nlayer: 2
28 | num_heads: 2
29 | hidden_dim: 32
30 | epoch: 2000
31 | lr: 0.0002
32 | weight_decay: 0.001
33 | tran_dropout: 0.0
34 | feat_dropout: 0.7
35 | prop_dropout: 0.5
36 | norm: 'none'
37 | photo:
38 | nclass: 8
39 | nlayer: 2
40 | num_heads: 4
41 | hidden_dim: 32
42 | epoch: 2000
43 | lr: 0.0002
44 | weight_decay: 0.0001
45 | tran_dropout: 0.2
46 | feat_dropout: 0.3
47 | prop_dropout: 0.2
48 | norm: 'none'
49 | arxiv:
50 | nclass: 40
51 | nlayer: 1
52 | num_heads: 1
53 | hidden_dim: 512
54 | epoch: 2000
55 | lr: 0.001
56 | weight_decay: 0.0
57 | tran_dropout: 0.1
58 | feat_dropout: 0.1
59 | prop_dropout: 0.1
60 | norm: 'layer'
61 | chameleon:
62 | nclass: 5
63 | nlayer: 2
64 | num_heads: 4
65 | hidden_dim: 32
66 | epoch: 2000
67 | lr: 0.001
68 | weight_decay: 0.0005
69 | tran_dropout: 0.2
70 | feat_dropout: 0.4
71 | prop_dropout: 0.5
72 | norm: 'none'
73 | squirrel:
74 | nclass: 5
75 | nlayer: 2
76 | num_heads: 2
77 | hidden_dim: 32
78 | epoch: 2000
79 | lr: 0.001
80 | weight_decay: 0.001
81 | tran_dropout: 0.1
82 | feat_dropout: 0.4
83 | prop_dropout: 0.4
84 | norm: 'none'
85 | actor:
86 | nclass: 5
87 | nlayer: 2
88 | num_heads: 1
89 | hidden_dim: 32
90 | epoch: 2000
91 | lr: 0.0002
92 | weight_decay: 0.0001
93 | tran_dropout: 0.5
94 | feat_dropout: 0.8
95 | prop_dropout: 0.5
96 | norm: 'none'
97 | penn:
98 | nclass: 2
99 | nlayer: 1
100 | num_heads: 1
101 | hidden_dim: 64
102 | epoch: 2000
103 | lr: 0.001
104 | weight_decay: 0.001
105 | tran_dropout: 0.0
106 | feat_dropout: 0.4
107 | prop_dropout: 0.4
108 | norm: 'batch'
109 |
--------------------------------------------------------------------------------
/Node/data/chameleon.pt.zip:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DSL-Lab/Specformer/40a35d0c4db0839c9d5e17f45ea7b4618e8fce71/Node/data/chameleon.pt.zip
--------------------------------------------------------------------------------
/Node/main_node.py:
--------------------------------------------------------------------------------
1 | import time
2 | import yaml
3 | import copy
4 | import math
5 | import random
6 | import argparse
7 | import numpy as np
8 | import torch
9 | import torch.nn as nn
10 | import torch.nn.functional as F
11 | import torchmetrics
12 | from sklearn.metrics import roc_auc_score, mean_absolute_error, accuracy_score, r2_score
13 | from model_node import Specformer
14 | from utils import count_parameters, init_params, seed_everything, get_split
15 |
16 |
17 | def main_worker(args, config):
18 | print(args, config)
19 | seed_everything(args.seed)
20 | device = 'cuda:{}'.format(args.cuda)
21 | torch.cuda.set_device(args.seed)
22 |
23 | epoch = config['epoch']
24 | lr = config['lr']
25 | weight_decay = config['weight_decay']
26 | nclass = config['nclass']
27 | nlayer = config['nlayer']
28 | hidden_dim = config['hidden_dim']
29 | num_heads = config['num_heads']
30 | tran_dropout = config['tran_dropout']
31 | feat_dropout = config['feat_dropout']
32 | prop_dropout = config['prop_dropout']
33 | norm = config['norm']
34 |
35 | if 'signal' in args.dataset:
36 | e, u, x, y, m = torch.load('data/{}.pt'.format(args.dataset))
37 | e, u, x, y, m = e.cuda(), u.cuda(), x.cuda(), y.cuda(), m.cuda()
38 | mask = torch.where(m == 1)
39 | x = x[:, args.image].unsqueeze(1)
40 | y = y[:, args.image]
41 | else:
42 | e, u, x, y = torch.load('data/{}.pt'.format(args.dataset))
43 | e, u, x, y = e.cuda(), u.cuda(), x.cuda(), y.cuda()
44 |
45 | if len(y.size()) > 1:
46 | if y.size(1) > 1:
47 | y = torch.argmax(y, dim=1)
48 | else:
49 | y = y.view(-1)
50 |
51 | train, valid, test = get_split(args.dataset, y, nclass, args.seed)
52 | train, valid, test = map(torch.LongTensor, (train, valid, test))
53 | train, valid, test = train.cuda(), valid.cuda(), test.cuda()
54 |
55 | nfeat = x.size(1)
56 | net = Specformer(nclass, nfeat, nlayer, hidden_dim, num_heads, tran_dropout, feat_dropout, prop_dropout, norm).cuda()
57 | net.apply(init_params)
58 | optimizer = torch.optim.Adam(net.parameters(), lr=lr, weight_decay=weight_decay)
59 | print(count_parameters(net))
60 |
61 | res = []
62 | min_loss = 100.0
63 | max_acc = 0
64 | counter = 0
65 | evaluation = torchmetrics.Accuracy(task='multiclass', num_classes=nclass)
66 |
67 | for idx in range(epoch):
68 |
69 | net.train()
70 | optimizer.zero_grad()
71 | logits = net(e, u, x)
72 |
73 | if 'signal' in args.dataset:
74 | logits = logits.view(y.size())
75 | loss = torch.square((logits[mask] - y[mask])).sum()
76 | else:
77 | loss = F.cross_entropy(logits[train], y[train])
78 |
79 | loss.backward()
80 | optimizer.step()
81 |
82 | net.eval()
83 | logits = net(e, u, x)
84 |
85 | if 'signal' in args.dataset:
86 | logits = logits.view(y.size())
87 | r2 = r2_score(y[mask].data.cpu().numpy(), logits[mask].data.cpu().numpy())
88 | sse = torch.square(logits[mask] - y[mask]).sum().item()
89 | print(r2, sse)
90 | else:
91 | val_loss = F.cross_entropy(logits[valid], y[valid]).item()
92 |
93 | val_acc = evaluation(logits[valid].cpu(), y[valid].cpu()).item()
94 | test_acc = evaluation(logits[test].cpu(), y[test].cpu()).item()
95 | res.append([val_loss, val_acc, test_acc])
96 |
97 | print(idx, val_loss, val_acc, test_acc)
98 |
99 | if val_loss < min_loss:
100 | min_loss = val_loss
101 | counter = 0
102 | else:
103 | counter += 1
104 |
105 | if counter == 200:
106 | max_acc1 = sorted(res, key=lambda x: x[0], reverse=False)[0][-1]
107 | max_acc2 = sorted(res, key=lambda x: x[1], reverse=True)[0][-1]
108 | print(max_acc1, max_acc2)
109 | break
110 |
111 |
112 | if __name__ == '__main__':
113 | parser = argparse.ArgumentParser()
114 | parser.add_argument('--seed', type=int, default=1)
115 | parser.add_argument('--cuda', type=int, default=0)
116 | parser.add_argument('--dataset', default='cora')
117 | parser.add_argument('--image', type=int, default=0)
118 |
119 | args = parser.parse_args()
120 |
121 | if 'signal' in args.dataset:
122 | config = yaml.load(open('config.yaml'), Loader=yaml.SafeLoader)['signal']
123 | else:
124 | config = yaml.load(open('config.yaml'), Loader=yaml.SafeLoader)[args.dataset]
125 |
126 | main_worker(args, config)
127 |
128 |
--------------------------------------------------------------------------------
/Node/master.csv:
--------------------------------------------------------------------------------
1 | ,ogbn-proteins,ogbn-products,ogbn-arxiv,ogbn-mag,ogbn-papers100M
2 | num tasks,112,1,1,1,1
3 | num classes,2,47,40,349,172
4 | eval metric,rocauc,acc,acc,acc,acc
5 | task type,binary classification,multiclass classification,multiclass classification,multiclass classification,multiclass classification
6 | download_name,proteins,products,arxiv,mag,papers100M-bin
7 | version,1,1,1,2,1
8 | url,http://snap.stanford.edu/ogb/data/nodeproppred/proteins.zip,http://snap.stanford.edu/ogb/data/nodeproppred/products.zip,http://snap.stanford.edu/ogb/data/nodeproppred/arxiv.zip,http://snap.stanford.edu/ogb/data/nodeproppred/mag.zip,http://snap.stanford.edu/ogb/data/nodeproppred/papers100M-bin.zip
9 | add_inverse_edge,True,True,False,False,False
10 | has_node_attr,False,True,True,True,True
11 | has_edge_attr,True,False,False,False,False
12 | split,species,sales_ranking,time,time,time
13 | additional node files,node_species,None,node_year,node_year,node_year
14 | additional edge files,None,None,None,edge_reltype,None
15 | is hetero,False,False,False,True,False
16 | binary,False,False,False,False,True
17 |
--------------------------------------------------------------------------------
/Node/model_node.py:
--------------------------------------------------------------------------------
1 | import time
2 | import math
3 | import random
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | import numpy as np
8 | from torch.nn.init import xavier_uniform_, xavier_normal_, constant_
9 |
10 |
11 | class SineEncoding(nn.Module):
12 | def __init__(self, hidden_dim=128):
13 | super(SineEncoding, self).__init__()
14 | self.constant = 100
15 | self.hidden_dim = hidden_dim
16 | self.eig_w = nn.Linear(hidden_dim + 1, hidden_dim)
17 |
18 | def forward(self, e):
19 | # input: [N]
20 | # output: [N, d]
21 |
22 | ee = e * self.constant
23 | div = torch.exp(torch.arange(0, self.hidden_dim, 2) * (-math.log(10000)/self.hidden_dim)).to(e.device)
24 | pe = ee.unsqueeze(1) * div
25 | eeig = torch.cat((e.unsqueeze(1), torch.sin(pe), torch.cos(pe)), dim=1)
26 |
27 | return self.eig_w(eeig)
28 |
29 |
30 | class FeedForwardNetwork(nn.Module):
31 |
32 | def __init__(self, input_dim, hidden_dim, output_dim):
33 | super(FeedForwardNetwork, self).__init__()
34 | self.layer1 = nn.Linear(input_dim, hidden_dim)
35 | self.gelu = nn.GELU()
36 | self.layer2 = nn.Linear(hidden_dim, output_dim)
37 |
38 | def forward(self, x):
39 | x = self.layer1(x)
40 | x = self.gelu(x)
41 | x = self.layer2(x)
42 | return x
43 |
44 |
45 | class SpecLayer(nn.Module):
46 |
47 | def __init__(self, nbases, ncombines, prop_dropout=0.0, norm='none'):
48 | super(SpecLayer, self).__init__()
49 | self.prop_dropout = nn.Dropout(prop_dropout)
50 |
51 | if norm == 'none':
52 | self.weight = nn.Parameter(torch.ones((1, nbases, ncombines)))
53 | else:
54 | self.weight = nn.Parameter(torch.empty((1, nbases, ncombines)))
55 | nn.init.normal_(self.weight, mean=0.0, std=0.01)
56 |
57 | if norm == 'layer': # Arxiv
58 | self.norm = nn.LayerNorm(ncombines)
59 | elif norm == 'batch': # Penn
60 | self.norm = nn.BatchNorm1d(ncombines)
61 | else: # Others
62 | self.norm = None
63 |
64 | def forward(self, x):
65 | x = self.prop_dropout(x) * self.weight # [N, m, d] * [1, m, d]
66 | x = torch.sum(x, dim=1)
67 |
68 | if self.norm is not None:
69 | x = self.norm(x)
70 | x = F.relu(x)
71 |
72 | return x
73 |
74 |
75 | class Specformer(nn.Module):
76 |
77 | def __init__(self, nclass, nfeat, nlayer=1, hidden_dim=128, nheads=1,
78 | tran_dropout=0.0, feat_dropout=0.0, prop_dropout=0.0, norm='none'):
79 | super(Specformer, self).__init__()
80 |
81 | self.norm = norm
82 | self.nfeat = nfeat
83 | self.nlayer = nlayer
84 | self.nheads = nheads
85 | self.hidden_dim = hidden_dim
86 |
87 | self.feat_encoder = nn.Sequential(
88 | nn.Linear(nfeat, hidden_dim),
89 | nn.ReLU(),
90 | nn.Linear(hidden_dim, nclass),
91 | )
92 |
93 | # for arxiv & penn
94 | self.linear_encoder = nn.Linear(nfeat, hidden_dim)
95 | self.classify = nn.Linear(hidden_dim, nclass)
96 |
97 | self.eig_encoder = SineEncoding(hidden_dim)
98 | self.decoder = nn.Linear(hidden_dim, nheads)
99 |
100 | self.mha_norm = nn.LayerNorm(hidden_dim)
101 | self.ffn_norm = nn.LayerNorm(hidden_dim)
102 | self.mha_dropout = nn.Dropout(tran_dropout)
103 | self.ffn_dropout = nn.Dropout(tran_dropout)
104 | self.mha = nn.MultiheadAttention(hidden_dim, nheads, tran_dropout)
105 | self.ffn = FeedForwardNetwork(hidden_dim, hidden_dim, hidden_dim)
106 |
107 | self.feat_dp1 = nn.Dropout(feat_dropout)
108 | self.feat_dp2 = nn.Dropout(feat_dropout)
109 | if norm == 'none':
110 | self.layers = nn.ModuleList([SpecLayer(nheads+1, nclass, prop_dropout, norm=norm) for i in range(nlayer)])
111 | else:
112 | self.layers = nn.ModuleList([SpecLayer(nheads+1, hidden_dim, prop_dropout, norm=norm) for i in range(nlayer)])
113 |
114 |
115 | def forward(self, e, u, x):
116 | N = e.size(0)
117 | ut = u.permute(1, 0)
118 |
119 | if self.norm == 'none':
120 | h = self.feat_dp1(x)
121 | h = self.feat_encoder(h)
122 | h = self.feat_dp2(h)
123 | else:
124 | h = self.feat_dp1(x)
125 | h = self.linear_encoder(h)
126 |
127 | eig = self.eig_encoder(e) # [N, d]
128 |
129 | mha_eig = self.mha_norm(eig)
130 | mha_eig, attn = self.mha(mha_eig, mha_eig, mha_eig)
131 | eig = eig + self.mha_dropout(mha_eig)
132 |
133 | ffn_eig = self.ffn_norm(eig)
134 | ffn_eig = self.ffn(ffn_eig)
135 | eig = eig + self.ffn_dropout(ffn_eig)
136 |
137 | new_e = self.decoder(eig) # [N, m]
138 |
139 | for conv in self.layers:
140 | basic_feats = [h]
141 | utx = ut @ h
142 | for i in range(self.nheads):
143 | basic_feats.append(u @ (new_e[:, i].unsqueeze(1) * utx)) # [N, d]
144 | basic_feats = torch.stack(basic_feats, axis=1) # [N, m, d]
145 | h = conv(basic_feats)
146 |
147 | if self.norm == 'none':
148 | return h
149 | else:
150 | h = self.feat_dp2(h)
151 | h = self.classify(h)
152 | return h
153 |
154 |
--------------------------------------------------------------------------------
/Node/node_raw_data/2Dgrid.mat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DSL-Lab/Specformer/40a35d0c4db0839c9d5e17f45ea7b4618e8fce71/Node/node_raw_data/2Dgrid.mat
--------------------------------------------------------------------------------
/Node/node_raw_data/Penn94.mat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DSL-Lab/Specformer/40a35d0c4db0839c9d5e17f45ea7b4618e8fce71/Node/node_raw_data/Penn94.mat
--------------------------------------------------------------------------------
/Node/node_raw_data/amazon_electronics_photo.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DSL-Lab/Specformer/40a35d0c4db0839c9d5e17f45ea7b4618e8fce71/Node/node_raw_data/amazon_electronics_photo.npz
--------------------------------------------------------------------------------
/Node/node_raw_data/citeseer/ind.citeseer.allx:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DSL-Lab/Specformer/40a35d0c4db0839c9d5e17f45ea7b4618e8fce71/Node/node_raw_data/citeseer/ind.citeseer.allx
--------------------------------------------------------------------------------
/Node/node_raw_data/citeseer/ind.citeseer.ally:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DSL-Lab/Specformer/40a35d0c4db0839c9d5e17f45ea7b4618e8fce71/Node/node_raw_data/citeseer/ind.citeseer.ally
--------------------------------------------------------------------------------
/Node/node_raw_data/citeseer/ind.citeseer.graph:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DSL-Lab/Specformer/40a35d0c4db0839c9d5e17f45ea7b4618e8fce71/Node/node_raw_data/citeseer/ind.citeseer.graph
--------------------------------------------------------------------------------
/Node/node_raw_data/citeseer/ind.citeseer.test.index:
--------------------------------------------------------------------------------
1 | 2488
2 | 2644
3 | 3261
4 | 2804
5 | 3176
6 | 2432
7 | 3310
8 | 2410
9 | 2812
10 | 2520
11 | 2994
12 | 3282
13 | 2680
14 | 2848
15 | 2670
16 | 3005
17 | 2977
18 | 2592
19 | 2967
20 | 2461
21 | 3184
22 | 2852
23 | 2768
24 | 2905
25 | 2851
26 | 3129
27 | 3164
28 | 2438
29 | 2793
30 | 2763
31 | 2528
32 | 2954
33 | 2347
34 | 2640
35 | 3265
36 | 2874
37 | 2446
38 | 2856
39 | 3149
40 | 2374
41 | 3097
42 | 3301
43 | 2664
44 | 2418
45 | 2655
46 | 2464
47 | 2596
48 | 3262
49 | 3278
50 | 2320
51 | 2612
52 | 2614
53 | 2550
54 | 2626
55 | 2772
56 | 3007
57 | 2733
58 | 2516
59 | 2476
60 | 2798
61 | 2561
62 | 2839
63 | 2685
64 | 2391
65 | 2705
66 | 3098
67 | 2754
68 | 3251
69 | 2767
70 | 2630
71 | 2727
72 | 2513
73 | 2701
74 | 3264
75 | 2792
76 | 2821
77 | 3260
78 | 2462
79 | 3307
80 | 2639
81 | 2900
82 | 3060
83 | 2672
84 | 3116
85 | 2731
86 | 3316
87 | 2386
88 | 2425
89 | 2518
90 | 3151
91 | 2586
92 | 2797
93 | 2479
94 | 3117
95 | 2580
96 | 3182
97 | 2459
98 | 2508
99 | 3052
100 | 3230
101 | 3215
102 | 2803
103 | 2969
104 | 2562
105 | 2398
106 | 3325
107 | 2343
108 | 3030
109 | 2414
110 | 2776
111 | 2383
112 | 3173
113 | 2850
114 | 2499
115 | 3312
116 | 2648
117 | 2784
118 | 2898
119 | 3056
120 | 2484
121 | 3179
122 | 3132
123 | 2577
124 | 2563
125 | 2867
126 | 3317
127 | 2355
128 | 3207
129 | 3178
130 | 2968
131 | 3319
132 | 2358
133 | 2764
134 | 3001
135 | 2683
136 | 3271
137 | 2321
138 | 2567
139 | 2502
140 | 3246
141 | 2715
142 | 3066
143 | 2390
144 | 2381
145 | 3162
146 | 2741
147 | 2498
148 | 2790
149 | 3038
150 | 3321
151 | 2481
152 | 3050
153 | 3161
154 | 3122
155 | 2801
156 | 2957
157 | 3177
158 | 2965
159 | 2621
160 | 3208
161 | 2921
162 | 2802
163 | 2357
164 | 2677
165 | 2519
166 | 2860
167 | 2696
168 | 2368
169 | 3241
170 | 2858
171 | 2419
172 | 2762
173 | 2875
174 | 3222
175 | 3064
176 | 2827
177 | 3044
178 | 2471
179 | 3062
180 | 2982
181 | 2736
182 | 2322
183 | 2709
184 | 2766
185 | 2424
186 | 2602
187 | 2970
188 | 2675
189 | 3299
190 | 2554
191 | 2964
192 | 2597
193 | 2753
194 | 2979
195 | 2523
196 | 2912
197 | 2896
198 | 2317
199 | 3167
200 | 2813
201 | 2482
202 | 2557
203 | 3043
204 | 3244
205 | 2985
206 | 2460
207 | 2363
208 | 3272
209 | 3045
210 | 3192
211 | 2453
212 | 2656
213 | 2834
214 | 2443
215 | 3202
216 | 2926
217 | 2711
218 | 2633
219 | 2384
220 | 2752
221 | 3285
222 | 2817
223 | 2483
224 | 2919
225 | 2924
226 | 2661
227 | 2698
228 | 2361
229 | 2662
230 | 2819
231 | 3143
232 | 2316
233 | 3196
234 | 2739
235 | 2345
236 | 2578
237 | 2822
238 | 3229
239 | 2908
240 | 2917
241 | 2692
242 | 3200
243 | 2324
244 | 2522
245 | 3322
246 | 2697
247 | 3163
248 | 3093
249 | 3233
250 | 2774
251 | 2371
252 | 2835
253 | 2652
254 | 2539
255 | 2843
256 | 3231
257 | 2976
258 | 2429
259 | 2367
260 | 3144
261 | 2564
262 | 3283
263 | 3217
264 | 3035
265 | 2962
266 | 2433
267 | 2415
268 | 2387
269 | 3021
270 | 2595
271 | 2517
272 | 2468
273 | 3061
274 | 2673
275 | 2348
276 | 3027
277 | 2467
278 | 3318
279 | 2959
280 | 3273
281 | 2392
282 | 2779
283 | 2678
284 | 3004
285 | 2634
286 | 2974
287 | 3198
288 | 2342
289 | 2376
290 | 3249
291 | 2868
292 | 2952
293 | 2710
294 | 2838
295 | 2335
296 | 2524
297 | 2650
298 | 3186
299 | 2743
300 | 2545
301 | 2841
302 | 2515
303 | 2505
304 | 3181
305 | 2945
306 | 2738
307 | 2933
308 | 3303
309 | 2611
310 | 3090
311 | 2328
312 | 3010
313 | 3016
314 | 2504
315 | 2936
316 | 3266
317 | 3253
318 | 2840
319 | 3034
320 | 2581
321 | 2344
322 | 2452
323 | 2654
324 | 3199
325 | 3137
326 | 2514
327 | 2394
328 | 2544
329 | 2641
330 | 2613
331 | 2618
332 | 2558
333 | 2593
334 | 2532
335 | 2512
336 | 2975
337 | 3267
338 | 2566
339 | 2951
340 | 3300
341 | 2869
342 | 2629
343 | 2747
344 | 3055
345 | 2831
346 | 3105
347 | 3168
348 | 3100
349 | 2431
350 | 2828
351 | 2684
352 | 3269
353 | 2910
354 | 2865
355 | 2693
356 | 2884
357 | 3228
358 | 2783
359 | 3247
360 | 2770
361 | 3157
362 | 2421
363 | 2382
364 | 2331
365 | 3203
366 | 3240
367 | 2351
368 | 3114
369 | 2986
370 | 2688
371 | 2439
372 | 2996
373 | 3079
374 | 3103
375 | 3296
376 | 2349
377 | 2372
378 | 3096
379 | 2422
380 | 2551
381 | 3069
382 | 2737
383 | 3084
384 | 3304
385 | 3022
386 | 2542
387 | 3204
388 | 2949
389 | 2318
390 | 2450
391 | 3140
392 | 2734
393 | 2881
394 | 2576
395 | 3054
396 | 3089
397 | 3125
398 | 2761
399 | 3136
400 | 3111
401 | 2427
402 | 2466
403 | 3101
404 | 3104
405 | 3259
406 | 2534
407 | 2961
408 | 3191
409 | 3000
410 | 3036
411 | 2356
412 | 2800
413 | 3155
414 | 3224
415 | 2646
416 | 2735
417 | 3020
418 | 2866
419 | 2426
420 | 2448
421 | 3226
422 | 3219
423 | 2749
424 | 3183
425 | 2906
426 | 2360
427 | 2440
428 | 2946
429 | 2313
430 | 2859
431 | 2340
432 | 3008
433 | 2719
434 | 3058
435 | 2653
436 | 3023
437 | 2888
438 | 3243
439 | 2913
440 | 3242
441 | 3067
442 | 2409
443 | 3227
444 | 2380
445 | 2353
446 | 2686
447 | 2971
448 | 2847
449 | 2947
450 | 2857
451 | 3263
452 | 3218
453 | 2861
454 | 3323
455 | 2635
456 | 2966
457 | 2604
458 | 2456
459 | 2832
460 | 2694
461 | 3245
462 | 3119
463 | 2942
464 | 3153
465 | 2894
466 | 2555
467 | 3128
468 | 2703
469 | 2323
470 | 2631
471 | 2732
472 | 2699
473 | 2314
474 | 2590
475 | 3127
476 | 2891
477 | 2873
478 | 2814
479 | 2326
480 | 3026
481 | 3288
482 | 3095
483 | 2706
484 | 2457
485 | 2377
486 | 2620
487 | 2526
488 | 2674
489 | 3190
490 | 2923
491 | 3032
492 | 2334
493 | 3254
494 | 2991
495 | 3277
496 | 2973
497 | 2599
498 | 2658
499 | 2636
500 | 2826
501 | 3148
502 | 2958
503 | 3258
504 | 2990
505 | 3180
506 | 2538
507 | 2748
508 | 2625
509 | 2565
510 | 3011
511 | 3057
512 | 2354
513 | 3158
514 | 2622
515 | 3308
516 | 2983
517 | 2560
518 | 3169
519 | 3059
520 | 2480
521 | 3194
522 | 3291
523 | 3216
524 | 2643
525 | 3172
526 | 2352
527 | 2724
528 | 2485
529 | 2411
530 | 2948
531 | 2445
532 | 2362
533 | 2668
534 | 3275
535 | 3107
536 | 2496
537 | 2529
538 | 2700
539 | 2541
540 | 3028
541 | 2879
542 | 2660
543 | 3324
544 | 2755
545 | 2436
546 | 3048
547 | 2623
548 | 2920
549 | 3040
550 | 2568
551 | 3221
552 | 3003
553 | 3295
554 | 2473
555 | 3232
556 | 3213
557 | 2823
558 | 2897
559 | 2573
560 | 2645
561 | 3018
562 | 3326
563 | 2795
564 | 2915
565 | 3109
566 | 3086
567 | 2463
568 | 3118
569 | 2671
570 | 2909
571 | 2393
572 | 2325
573 | 3029
574 | 2972
575 | 3110
576 | 2870
577 | 3284
578 | 2816
579 | 2647
580 | 2667
581 | 2955
582 | 2333
583 | 2960
584 | 2864
585 | 2893
586 | 2458
587 | 2441
588 | 2359
589 | 2327
590 | 3256
591 | 3099
592 | 3073
593 | 3138
594 | 2511
595 | 2666
596 | 2548
597 | 2364
598 | 2451
599 | 2911
600 | 3237
601 | 3206
602 | 3080
603 | 3279
604 | 2934
605 | 2981
606 | 2878
607 | 3130
608 | 2830
609 | 3091
610 | 2659
611 | 2449
612 | 3152
613 | 2413
614 | 2722
615 | 2796
616 | 3220
617 | 2751
618 | 2935
619 | 3238
620 | 2491
621 | 2730
622 | 2842
623 | 3223
624 | 2492
625 | 3074
626 | 3094
627 | 2833
628 | 2521
629 | 2883
630 | 3315
631 | 2845
632 | 2907
633 | 3083
634 | 2572
635 | 3092
636 | 2903
637 | 2918
638 | 3039
639 | 3286
640 | 2587
641 | 3068
642 | 2338
643 | 3166
644 | 3134
645 | 2455
646 | 2497
647 | 2992
648 | 2775
649 | 2681
650 | 2430
651 | 2932
652 | 2931
653 | 2434
654 | 3154
655 | 3046
656 | 2598
657 | 2366
658 | 3015
659 | 3147
660 | 2944
661 | 2582
662 | 3274
663 | 2987
664 | 2642
665 | 2547
666 | 2420
667 | 2930
668 | 2750
669 | 2417
670 | 2808
671 | 3141
672 | 2997
673 | 2995
674 | 2584
675 | 2312
676 | 3033
677 | 3070
678 | 3065
679 | 2509
680 | 3314
681 | 2396
682 | 2543
683 | 2423
684 | 3170
685 | 2389
686 | 3289
687 | 2728
688 | 2540
689 | 2437
690 | 2486
691 | 2895
692 | 3017
693 | 2853
694 | 2406
695 | 2346
696 | 2877
697 | 2472
698 | 3210
699 | 2637
700 | 2927
701 | 2789
702 | 2330
703 | 3088
704 | 3102
705 | 2616
706 | 3081
707 | 2902
708 | 3205
709 | 3320
710 | 3165
711 | 2984
712 | 3185
713 | 2707
714 | 3255
715 | 2583
716 | 2773
717 | 2742
718 | 3024
719 | 2402
720 | 2718
721 | 2882
722 | 2575
723 | 3281
724 | 2786
725 | 2855
726 | 3014
727 | 2401
728 | 2535
729 | 2687
730 | 2495
731 | 3113
732 | 2609
733 | 2559
734 | 2665
735 | 2530
736 | 3293
737 | 2399
738 | 2605
739 | 2690
740 | 3133
741 | 2799
742 | 2533
743 | 2695
744 | 2713
745 | 2886
746 | 2691
747 | 2549
748 | 3077
749 | 3002
750 | 3049
751 | 3051
752 | 3087
753 | 2444
754 | 3085
755 | 3135
756 | 2702
757 | 3211
758 | 3108
759 | 2501
760 | 2769
761 | 3290
762 | 2465
763 | 3025
764 | 3019
765 | 2385
766 | 2940
767 | 2657
768 | 2610
769 | 2525
770 | 2941
771 | 3078
772 | 2341
773 | 2916
774 | 2956
775 | 2375
776 | 2880
777 | 3009
778 | 2780
779 | 2370
780 | 2925
781 | 2332
782 | 3146
783 | 2315
784 | 2809
785 | 3145
786 | 3106
787 | 2782
788 | 2760
789 | 2493
790 | 2765
791 | 2556
792 | 2890
793 | 2400
794 | 2339
795 | 3201
796 | 2818
797 | 3248
798 | 3280
799 | 2570
800 | 2569
801 | 2937
802 | 3174
803 | 2836
804 | 2708
805 | 2820
806 | 3195
807 | 2617
808 | 3197
809 | 2319
810 | 2744
811 | 2615
812 | 2825
813 | 2603
814 | 2914
815 | 2531
816 | 3193
817 | 2624
818 | 2365
819 | 2810
820 | 3239
821 | 3159
822 | 2537
823 | 2844
824 | 2758
825 | 2938
826 | 3037
827 | 2503
828 | 3297
829 | 2885
830 | 2608
831 | 2494
832 | 2712
833 | 2408
834 | 2901
835 | 2704
836 | 2536
837 | 2373
838 | 2478
839 | 2723
840 | 3076
841 | 2627
842 | 2369
843 | 2669
844 | 3006
845 | 2628
846 | 2788
847 | 3276
848 | 2435
849 | 3139
850 | 3235
851 | 2527
852 | 2571
853 | 2815
854 | 2442
855 | 2892
856 | 2978
857 | 2746
858 | 3150
859 | 2574
860 | 2725
861 | 3188
862 | 2601
863 | 2378
864 | 3075
865 | 2632
866 | 2794
867 | 3270
868 | 3071
869 | 2506
870 | 3126
871 | 3236
872 | 3257
873 | 2824
874 | 2989
875 | 2950
876 | 2428
877 | 2405
878 | 3156
879 | 2447
880 | 2787
881 | 2805
882 | 2720
883 | 2403
884 | 2811
885 | 2329
886 | 2474
887 | 2785
888 | 2350
889 | 2507
890 | 2416
891 | 3112
892 | 2475
893 | 2876
894 | 2585
895 | 2487
896 | 3072
897 | 3082
898 | 2943
899 | 2757
900 | 2388
901 | 2600
902 | 3294
903 | 2756
904 | 3142
905 | 3041
906 | 2594
907 | 2998
908 | 3047
909 | 2379
910 | 2980
911 | 2454
912 | 2862
913 | 3175
914 | 2588
915 | 3031
916 | 3012
917 | 2889
918 | 2500
919 | 2791
920 | 2854
921 | 2619
922 | 2395
923 | 2807
924 | 2740
925 | 2412
926 | 3131
927 | 3013
928 | 2939
929 | 2651
930 | 2490
931 | 2988
932 | 2863
933 | 3225
934 | 2745
935 | 2714
936 | 3160
937 | 3124
938 | 2849
939 | 2676
940 | 2872
941 | 3287
942 | 3189
943 | 2716
944 | 3115
945 | 2928
946 | 2871
947 | 2591
948 | 2717
949 | 2546
950 | 2777
951 | 3298
952 | 2397
953 | 3187
954 | 2726
955 | 2336
956 | 3268
957 | 2477
958 | 2904
959 | 2846
960 | 3121
961 | 2899
962 | 2510
963 | 2806
964 | 2963
965 | 3313
966 | 2679
967 | 3302
968 | 2663
969 | 3053
970 | 2469
971 | 2999
972 | 3311
973 | 2470
974 | 2638
975 | 3120
976 | 3171
977 | 2689
978 | 2922
979 | 2607
980 | 2721
981 | 2993
982 | 2887
983 | 2837
984 | 2929
985 | 2829
986 | 3234
987 | 2649
988 | 2337
989 | 2759
990 | 2778
991 | 2771
992 | 2404
993 | 2589
994 | 3123
995 | 3209
996 | 2729
997 | 3252
998 | 2606
999 | 2579
1000 | 2552
1001 |
--------------------------------------------------------------------------------
/Node/node_raw_data/citeseer/ind.citeseer.tx:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DSL-Lab/Specformer/40a35d0c4db0839c9d5e17f45ea7b4618e8fce71/Node/node_raw_data/citeseer/ind.citeseer.tx
--------------------------------------------------------------------------------
/Node/node_raw_data/citeseer/ind.citeseer.ty:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DSL-Lab/Specformer/40a35d0c4db0839c9d5e17f45ea7b4618e8fce71/Node/node_raw_data/citeseer/ind.citeseer.ty
--------------------------------------------------------------------------------
/Node/node_raw_data/citeseer/ind.citeseer.x:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DSL-Lab/Specformer/40a35d0c4db0839c9d5e17f45ea7b4618e8fce71/Node/node_raw_data/citeseer/ind.citeseer.x
--------------------------------------------------------------------------------
/Node/node_raw_data/citeseer/ind.citeseer.y:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DSL-Lab/Specformer/40a35d0c4db0839c9d5e17f45ea7b4618e8fce71/Node/node_raw_data/citeseer/ind.citeseer.y
--------------------------------------------------------------------------------
/Node/node_raw_data/cora/ind.cora.allx:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DSL-Lab/Specformer/40a35d0c4db0839c9d5e17f45ea7b4618e8fce71/Node/node_raw_data/cora/ind.cora.allx
--------------------------------------------------------------------------------
/Node/node_raw_data/cora/ind.cora.ally:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DSL-Lab/Specformer/40a35d0c4db0839c9d5e17f45ea7b4618e8fce71/Node/node_raw_data/cora/ind.cora.ally
--------------------------------------------------------------------------------
/Node/node_raw_data/cora/ind.cora.graph:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DSL-Lab/Specformer/40a35d0c4db0839c9d5e17f45ea7b4618e8fce71/Node/node_raw_data/cora/ind.cora.graph
--------------------------------------------------------------------------------
/Node/node_raw_data/cora/ind.cora.test.index:
--------------------------------------------------------------------------------
1 | 2692
2 | 2532
3 | 2050
4 | 1715
5 | 2362
6 | 2609
7 | 2622
8 | 1975
9 | 2081
10 | 1767
11 | 2263
12 | 1725
13 | 2588
14 | 2259
15 | 2357
16 | 1998
17 | 2574
18 | 2179
19 | 2291
20 | 2382
21 | 1812
22 | 1751
23 | 2422
24 | 1937
25 | 2631
26 | 2510
27 | 2378
28 | 2589
29 | 2345
30 | 1943
31 | 1850
32 | 2298
33 | 1825
34 | 2035
35 | 2507
36 | 2313
37 | 1906
38 | 1797
39 | 2023
40 | 2159
41 | 2495
42 | 1886
43 | 2122
44 | 2369
45 | 2461
46 | 1925
47 | 2565
48 | 1858
49 | 2234
50 | 2000
51 | 1846
52 | 2318
53 | 1723
54 | 2559
55 | 2258
56 | 1763
57 | 1991
58 | 1922
59 | 2003
60 | 2662
61 | 2250
62 | 2064
63 | 2529
64 | 1888
65 | 2499
66 | 2454
67 | 2320
68 | 2287
69 | 2203
70 | 2018
71 | 2002
72 | 2632
73 | 2554
74 | 2314
75 | 2537
76 | 1760
77 | 2088
78 | 2086
79 | 2218
80 | 2605
81 | 1953
82 | 2403
83 | 1920
84 | 2015
85 | 2335
86 | 2535
87 | 1837
88 | 2009
89 | 1905
90 | 2636
91 | 1942
92 | 2193
93 | 2576
94 | 2373
95 | 1873
96 | 2463
97 | 2509
98 | 1954
99 | 2656
100 | 2455
101 | 2494
102 | 2295
103 | 2114
104 | 2561
105 | 2176
106 | 2275
107 | 2635
108 | 2442
109 | 2704
110 | 2127
111 | 2085
112 | 2214
113 | 2487
114 | 1739
115 | 2543
116 | 1783
117 | 2485
118 | 2262
119 | 2472
120 | 2326
121 | 1738
122 | 2170
123 | 2100
124 | 2384
125 | 2152
126 | 2647
127 | 2693
128 | 2376
129 | 1775
130 | 1726
131 | 2476
132 | 2195
133 | 1773
134 | 1793
135 | 2194
136 | 2581
137 | 1854
138 | 2524
139 | 1945
140 | 1781
141 | 1987
142 | 2599
143 | 1744
144 | 2225
145 | 2300
146 | 1928
147 | 2042
148 | 2202
149 | 1958
150 | 1816
151 | 1916
152 | 2679
153 | 2190
154 | 1733
155 | 2034
156 | 2643
157 | 2177
158 | 1883
159 | 1917
160 | 1996
161 | 2491
162 | 2268
163 | 2231
164 | 2471
165 | 1919
166 | 1909
167 | 2012
168 | 2522
169 | 1865
170 | 2466
171 | 2469
172 | 2087
173 | 2584
174 | 2563
175 | 1924
176 | 2143
177 | 1736
178 | 1966
179 | 2533
180 | 2490
181 | 2630
182 | 1973
183 | 2568
184 | 1978
185 | 2664
186 | 2633
187 | 2312
188 | 2178
189 | 1754
190 | 2307
191 | 2480
192 | 1960
193 | 1742
194 | 1962
195 | 2160
196 | 2070
197 | 2553
198 | 2433
199 | 1768
200 | 2659
201 | 2379
202 | 2271
203 | 1776
204 | 2153
205 | 1877
206 | 2027
207 | 2028
208 | 2155
209 | 2196
210 | 2483
211 | 2026
212 | 2158
213 | 2407
214 | 1821
215 | 2131
216 | 2676
217 | 2277
218 | 2489
219 | 2424
220 | 1963
221 | 1808
222 | 1859
223 | 2597
224 | 2548
225 | 2368
226 | 1817
227 | 2405
228 | 2413
229 | 2603
230 | 2350
231 | 2118
232 | 2329
233 | 1969
234 | 2577
235 | 2475
236 | 2467
237 | 2425
238 | 1769
239 | 2092
240 | 2044
241 | 2586
242 | 2608
243 | 1983
244 | 2109
245 | 2649
246 | 1964
247 | 2144
248 | 1902
249 | 2411
250 | 2508
251 | 2360
252 | 1721
253 | 2005
254 | 2014
255 | 2308
256 | 2646
257 | 1949
258 | 1830
259 | 2212
260 | 2596
261 | 1832
262 | 1735
263 | 1866
264 | 2695
265 | 1941
266 | 2546
267 | 2498
268 | 2686
269 | 2665
270 | 1784
271 | 2613
272 | 1970
273 | 2021
274 | 2211
275 | 2516
276 | 2185
277 | 2479
278 | 2699
279 | 2150
280 | 1990
281 | 2063
282 | 2075
283 | 1979
284 | 2094
285 | 1787
286 | 2571
287 | 2690
288 | 1926
289 | 2341
290 | 2566
291 | 1957
292 | 1709
293 | 1955
294 | 2570
295 | 2387
296 | 1811
297 | 2025
298 | 2447
299 | 2696
300 | 2052
301 | 2366
302 | 1857
303 | 2273
304 | 2245
305 | 2672
306 | 2133
307 | 2421
308 | 1929
309 | 2125
310 | 2319
311 | 2641
312 | 2167
313 | 2418
314 | 1765
315 | 1761
316 | 1828
317 | 2188
318 | 1972
319 | 1997
320 | 2419
321 | 2289
322 | 2296
323 | 2587
324 | 2051
325 | 2440
326 | 2053
327 | 2191
328 | 1923
329 | 2164
330 | 1861
331 | 2339
332 | 2333
333 | 2523
334 | 2670
335 | 2121
336 | 1921
337 | 1724
338 | 2253
339 | 2374
340 | 1940
341 | 2545
342 | 2301
343 | 2244
344 | 2156
345 | 1849
346 | 2551
347 | 2011
348 | 2279
349 | 2572
350 | 1757
351 | 2400
352 | 2569
353 | 2072
354 | 2526
355 | 2173
356 | 2069
357 | 2036
358 | 1819
359 | 1734
360 | 1880
361 | 2137
362 | 2408
363 | 2226
364 | 2604
365 | 1771
366 | 2698
367 | 2187
368 | 2060
369 | 1756
370 | 2201
371 | 2066
372 | 2439
373 | 1844
374 | 1772
375 | 2383
376 | 2398
377 | 1708
378 | 1992
379 | 1959
380 | 1794
381 | 2426
382 | 2702
383 | 2444
384 | 1944
385 | 1829
386 | 2660
387 | 2497
388 | 2607
389 | 2343
390 | 1730
391 | 2624
392 | 1790
393 | 1935
394 | 1967
395 | 2401
396 | 2255
397 | 2355
398 | 2348
399 | 1931
400 | 2183
401 | 2161
402 | 2701
403 | 1948
404 | 2501
405 | 2192
406 | 2404
407 | 2209
408 | 2331
409 | 1810
410 | 2363
411 | 2334
412 | 1887
413 | 2393
414 | 2557
415 | 1719
416 | 1732
417 | 1986
418 | 2037
419 | 2056
420 | 1867
421 | 2126
422 | 1932
423 | 2117
424 | 1807
425 | 1801
426 | 1743
427 | 2041
428 | 1843
429 | 2388
430 | 2221
431 | 1833
432 | 2677
433 | 1778
434 | 2661
435 | 2306
436 | 2394
437 | 2106
438 | 2430
439 | 2371
440 | 2606
441 | 2353
442 | 2269
443 | 2317
444 | 2645
445 | 2372
446 | 2550
447 | 2043
448 | 1968
449 | 2165
450 | 2310
451 | 1985
452 | 2446
453 | 1982
454 | 2377
455 | 2207
456 | 1818
457 | 1913
458 | 1766
459 | 1722
460 | 1894
461 | 2020
462 | 1881
463 | 2621
464 | 2409
465 | 2261
466 | 2458
467 | 2096
468 | 1712
469 | 2594
470 | 2293
471 | 2048
472 | 2359
473 | 1839
474 | 2392
475 | 2254
476 | 1911
477 | 2101
478 | 2367
479 | 1889
480 | 1753
481 | 2555
482 | 2246
483 | 2264
484 | 2010
485 | 2336
486 | 2651
487 | 2017
488 | 2140
489 | 1842
490 | 2019
491 | 1890
492 | 2525
493 | 2134
494 | 2492
495 | 2652
496 | 2040
497 | 2145
498 | 2575
499 | 2166
500 | 1999
501 | 2434
502 | 1711
503 | 2276
504 | 2450
505 | 2389
506 | 2669
507 | 2595
508 | 1814
509 | 2039
510 | 2502
511 | 1896
512 | 2168
513 | 2344
514 | 2637
515 | 2031
516 | 1977
517 | 2380
518 | 1936
519 | 2047
520 | 2460
521 | 2102
522 | 1745
523 | 2650
524 | 2046
525 | 2514
526 | 1980
527 | 2352
528 | 2113
529 | 1713
530 | 2058
531 | 2558
532 | 1718
533 | 1864
534 | 1876
535 | 2338
536 | 1879
537 | 1891
538 | 2186
539 | 2451
540 | 2181
541 | 2638
542 | 2644
543 | 2103
544 | 2591
545 | 2266
546 | 2468
547 | 1869
548 | 2582
549 | 2674
550 | 2361
551 | 2462
552 | 1748
553 | 2215
554 | 2615
555 | 2236
556 | 2248
557 | 2493
558 | 2342
559 | 2449
560 | 2274
561 | 1824
562 | 1852
563 | 1870
564 | 2441
565 | 2356
566 | 1835
567 | 2694
568 | 2602
569 | 2685
570 | 1893
571 | 2544
572 | 2536
573 | 1994
574 | 1853
575 | 1838
576 | 1786
577 | 1930
578 | 2539
579 | 1892
580 | 2265
581 | 2618
582 | 2486
583 | 2583
584 | 2061
585 | 1796
586 | 1806
587 | 2084
588 | 1933
589 | 2095
590 | 2136
591 | 2078
592 | 1884
593 | 2438
594 | 2286
595 | 2138
596 | 1750
597 | 2184
598 | 1799
599 | 2278
600 | 2410
601 | 2642
602 | 2435
603 | 1956
604 | 2399
605 | 1774
606 | 2129
607 | 1898
608 | 1823
609 | 1938
610 | 2299
611 | 1862
612 | 2420
613 | 2673
614 | 1984
615 | 2204
616 | 1717
617 | 2074
618 | 2213
619 | 2436
620 | 2297
621 | 2592
622 | 2667
623 | 2703
624 | 2511
625 | 1779
626 | 1782
627 | 2625
628 | 2365
629 | 2315
630 | 2381
631 | 1788
632 | 1714
633 | 2302
634 | 1927
635 | 2325
636 | 2506
637 | 2169
638 | 2328
639 | 2629
640 | 2128
641 | 2655
642 | 2282
643 | 2073
644 | 2395
645 | 2247
646 | 2521
647 | 2260
648 | 1868
649 | 1988
650 | 2324
651 | 2705
652 | 2541
653 | 1731
654 | 2681
655 | 2707
656 | 2465
657 | 1785
658 | 2149
659 | 2045
660 | 2505
661 | 2611
662 | 2217
663 | 2180
664 | 1904
665 | 2453
666 | 2484
667 | 1871
668 | 2309
669 | 2349
670 | 2482
671 | 2004
672 | 1965
673 | 2406
674 | 2162
675 | 1805
676 | 2654
677 | 2007
678 | 1947
679 | 1981
680 | 2112
681 | 2141
682 | 1720
683 | 1758
684 | 2080
685 | 2330
686 | 2030
687 | 2432
688 | 2089
689 | 2547
690 | 1820
691 | 1815
692 | 2675
693 | 1840
694 | 2658
695 | 2370
696 | 2251
697 | 1908
698 | 2029
699 | 2068
700 | 2513
701 | 2549
702 | 2267
703 | 2580
704 | 2327
705 | 2351
706 | 2111
707 | 2022
708 | 2321
709 | 2614
710 | 2252
711 | 2104
712 | 1822
713 | 2552
714 | 2243
715 | 1798
716 | 2396
717 | 2663
718 | 2564
719 | 2148
720 | 2562
721 | 2684
722 | 2001
723 | 2151
724 | 2706
725 | 2240
726 | 2474
727 | 2303
728 | 2634
729 | 2680
730 | 2055
731 | 2090
732 | 2503
733 | 2347
734 | 2402
735 | 2238
736 | 1950
737 | 2054
738 | 2016
739 | 1872
740 | 2233
741 | 1710
742 | 2032
743 | 2540
744 | 2628
745 | 1795
746 | 2616
747 | 1903
748 | 2531
749 | 2567
750 | 1946
751 | 1897
752 | 2222
753 | 2227
754 | 2627
755 | 1856
756 | 2464
757 | 2241
758 | 2481
759 | 2130
760 | 2311
761 | 2083
762 | 2223
763 | 2284
764 | 2235
765 | 2097
766 | 1752
767 | 2515
768 | 2527
769 | 2385
770 | 2189
771 | 2283
772 | 2182
773 | 2079
774 | 2375
775 | 2174
776 | 2437
777 | 1993
778 | 2517
779 | 2443
780 | 2224
781 | 2648
782 | 2171
783 | 2290
784 | 2542
785 | 2038
786 | 1855
787 | 1831
788 | 1759
789 | 1848
790 | 2445
791 | 1827
792 | 2429
793 | 2205
794 | 2598
795 | 2657
796 | 1728
797 | 2065
798 | 1918
799 | 2427
800 | 2573
801 | 2620
802 | 2292
803 | 1777
804 | 2008
805 | 1875
806 | 2288
807 | 2256
808 | 2033
809 | 2470
810 | 2585
811 | 2610
812 | 2082
813 | 2230
814 | 1915
815 | 1847
816 | 2337
817 | 2512
818 | 2386
819 | 2006
820 | 2653
821 | 2346
822 | 1951
823 | 2110
824 | 2639
825 | 2520
826 | 1939
827 | 2683
828 | 2139
829 | 2220
830 | 1910
831 | 2237
832 | 1900
833 | 1836
834 | 2197
835 | 1716
836 | 1860
837 | 2077
838 | 2519
839 | 2538
840 | 2323
841 | 1914
842 | 1971
843 | 1845
844 | 2132
845 | 1802
846 | 1907
847 | 2640
848 | 2496
849 | 2281
850 | 2198
851 | 2416
852 | 2285
853 | 1755
854 | 2431
855 | 2071
856 | 2249
857 | 2123
858 | 1727
859 | 2459
860 | 2304
861 | 2199
862 | 1791
863 | 1809
864 | 1780
865 | 2210
866 | 2417
867 | 1874
868 | 1878
869 | 2116
870 | 1961
871 | 1863
872 | 2579
873 | 2477
874 | 2228
875 | 2332
876 | 2578
877 | 2457
878 | 2024
879 | 1934
880 | 2316
881 | 1841
882 | 1764
883 | 1737
884 | 2322
885 | 2239
886 | 2294
887 | 1729
888 | 2488
889 | 1974
890 | 2473
891 | 2098
892 | 2612
893 | 1834
894 | 2340
895 | 2423
896 | 2175
897 | 2280
898 | 2617
899 | 2208
900 | 2560
901 | 1741
902 | 2600
903 | 2059
904 | 1747
905 | 2242
906 | 2700
907 | 2232
908 | 2057
909 | 2147
910 | 2682
911 | 1792
912 | 1826
913 | 2120
914 | 1895
915 | 2364
916 | 2163
917 | 1851
918 | 2391
919 | 2414
920 | 2452
921 | 1803
922 | 1989
923 | 2623
924 | 2200
925 | 2528
926 | 2415
927 | 1804
928 | 2146
929 | 2619
930 | 2687
931 | 1762
932 | 2172
933 | 2270
934 | 2678
935 | 2593
936 | 2448
937 | 1882
938 | 2257
939 | 2500
940 | 1899
941 | 2478
942 | 2412
943 | 2107
944 | 1746
945 | 2428
946 | 2115
947 | 1800
948 | 1901
949 | 2397
950 | 2530
951 | 1912
952 | 2108
953 | 2206
954 | 2091
955 | 1740
956 | 2219
957 | 1976
958 | 2099
959 | 2142
960 | 2671
961 | 2668
962 | 2216
963 | 2272
964 | 2229
965 | 2666
966 | 2456
967 | 2534
968 | 2697
969 | 2688
970 | 2062
971 | 2691
972 | 2689
973 | 2154
974 | 2590
975 | 2626
976 | 2390
977 | 1813
978 | 2067
979 | 1952
980 | 2518
981 | 2358
982 | 1789
983 | 2076
984 | 2049
985 | 2119
986 | 2013
987 | 2124
988 | 2556
989 | 2105
990 | 2093
991 | 1885
992 | 2305
993 | 2354
994 | 2135
995 | 2601
996 | 1770
997 | 1995
998 | 2504
999 | 1749
1000 | 2157
1001 |
--------------------------------------------------------------------------------
/Node/node_raw_data/cora/ind.cora.tx:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DSL-Lab/Specformer/40a35d0c4db0839c9d5e17f45ea7b4618e8fce71/Node/node_raw_data/cora/ind.cora.tx
--------------------------------------------------------------------------------
/Node/node_raw_data/cora/ind.cora.ty:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DSL-Lab/Specformer/40a35d0c4db0839c9d5e17f45ea7b4618e8fce71/Node/node_raw_data/cora/ind.cora.ty
--------------------------------------------------------------------------------
/Node/node_raw_data/cora/ind.cora.x:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DSL-Lab/Specformer/40a35d0c4db0839c9d5e17f45ea7b4618e8fce71/Node/node_raw_data/cora/ind.cora.x
--------------------------------------------------------------------------------
/Node/node_raw_data/cora/ind.cora.y:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DSL-Lab/Specformer/40a35d0c4db0839c9d5e17f45ea7b4618e8fce71/Node/node_raw_data/cora/ind.cora.y
--------------------------------------------------------------------------------
/Node/node_raw_data/fb100-Penn94-splits.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DSL-Lab/Specformer/40a35d0c4db0839c9d5e17f45ea7b4618e8fce71/Node/node_raw_data/fb100-Penn94-splits.npy
--------------------------------------------------------------------------------
/Node/preprocess_node_data.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 | import math
4 | import time
5 | import pickle as pkl
6 | import scipy as sp
7 | from scipy import io
8 | import numpy as np
9 | import pandas as pd
10 | import networkx as nx
11 | import dgl
12 | import torch
13 | from sklearn.preprocessing import label_binarize
14 | from ogb.nodeproppred.dataset_dgl import DglNodePropPredDataset
15 | from numpy.linalg import eig, eigh
16 |
17 |
18 | def generate_signal_data():
19 | data = io.loadmat('node_raw_data/2Dgrid.mat')
20 | A = data['A']
21 | x = data['F'].astype(np.float32)
22 | m = data['mask']
23 |
24 | A = sp.sparse.coo_matrix(A).todense()
25 |
26 | D_vec = np.sum(A, axis=1).A1
27 | D_vec_invsqrt_corr = 1 / np.sqrt(D_vec)
28 | D_invsqrt_corr = np.diag(D_vec_invsqrt_corr)
29 | L = np.eye(10000) - D_invsqrt_corr @ A @ D_invsqrt_corr
30 |
31 | e, u = eigh(L)
32 |
33 | y_low = u @ np.diag(np.array([math.exp(-10*(ee-0)**2) for ee in e])) @ u.T @ x
34 | y_high = u @ np.diag(np.array([1 - math.exp(-10*(ee-0)**2) for ee in e])) @ u.T @ x
35 | y_band = u @ np.diag(np.array([math.exp(-10*(ee-1)**2) for ee in e])) @ u.T @ x
36 | y_rej = u @ np.diag(np.array([1 - math.exp(-10*(ee-1)**2) for ee in e])) @ u.T @ x
37 | y_comb = u @ np.diag(np.array([abs(np.sin(ee*math.pi)) for ee in e])) @ u.T @ x
38 |
39 | e = torch.FloatTensor(e)
40 | u = torch.FloatTensor(u)
41 | x = torch.FloatTensor(x)
42 | m = torch.LongTensor(m).squeeze()
43 | y_low = torch.FloatTensor(y_low)
44 | y_high = torch.FloatTensor(y_high)
45 | y_band = torch.FloatTensor(y_band)
46 | y_rej = torch.FloatTensor(y_rej)
47 | y_comb = torch.FloatTensor(y_comb)
48 |
49 | torch.save([e, u, x, y_low, m], 'data/signal_low.pt')
50 | torch.save([e, u, x, y_high, m], 'data/signal_high.pt')
51 | torch.save([e, u, x, y_band, m], 'data/signal_band.pt')
52 | torch.save([e, u, x, y_rej, m], 'data/signal_rej.pt')
53 | torch.save([e, u, x, y_comb, m], 'data/signal_comb.pt')
54 |
55 |
56 | def normalize_graph(g):
57 | g = np.array(g)
58 | g = g + g.T
59 | g[g > 0.] = 1.0
60 | deg = g.sum(axis=1).reshape(-1)
61 | deg[deg == 0.] = 1.0
62 | deg = np.diag(deg ** -0.5)
63 | adj = np.dot(np.dot(deg, g), deg)
64 | L = np.eye(g.shape[0]) - adj
65 | return L
66 |
67 |
68 | def eigen_decompositon(g):
69 | "The normalized (unit “length”) eigenvectors, "
70 | "such that the column v[:,i] is the eigenvector corresponding to the eigenvalue w[i]."
71 | g = normalize_graph(g)
72 | e, u = eigh(g)
73 | return e, u
74 |
75 |
76 | def parse_index_file(filename):
77 | """Parse index file."""
78 | index = []
79 | for line in open(filename):
80 | index.append(int(line.strip()))
81 | return index
82 |
83 |
84 | def feature_normalize(x):
85 | x = np.array(x)
86 | rowsum = x.sum(axis=1, keepdims=True)
87 | rowsum = np.clip(rowsum, 1, 1e10)
88 | return x / rowsum
89 |
90 |
91 | def load_data(dataset_str):
92 | names = ['x', 'y', 'tx', 'ty', 'allx', 'ally', 'graph']
93 | objects = []
94 | for i in range(len(names)):
95 | with open("node_raw_data/{}/ind.{}.{}".format(dataset_str, dataset_str, names[i]), 'rb') as f:
96 | if sys.version_info > (3, 0):
97 | objects.append(pkl.load(f, encoding='latin1'))
98 | else:
99 | objects.append(pkl.load(f))
100 |
101 | x, y, tx, ty, allx, ally, graph = tuple(objects)
102 | test_idx_reorder = parse_index_file("node_raw_data/{}/ind.{}.test.index".format(dataset_str, dataset_str))
103 | test_idx_range = np.sort(test_idx_reorder)
104 |
105 | if dataset_str == 'citeseer':
106 | # Fix citeseer dataset (there are some isolated nodes in the graph)
107 | # Find isolated nodes, add them as zero-vecs into the right position
108 | test_idx_range_full = range(min(test_idx_reorder), max(test_idx_reorder)+1)
109 | tx_extended = sp.sparse.lil_matrix((len(test_idx_range_full), x.shape[1]))
110 | tx_extended[test_idx_range-min(test_idx_range), :] = tx
111 | tx = tx_extended
112 | ty_extended = np.zeros((len(test_idx_range_full), y.shape[1]))
113 | ty_extended[test_idx_range-min(test_idx_range), :] = ty
114 | ty = ty_extended
115 |
116 | features = sp.sparse.vstack((allx, tx)).tolil()
117 | features[test_idx_reorder, :] = features[test_idx_range, :]
118 | adj = nx.adjacency_matrix(nx.from_dict_of_lists(graph))
119 |
120 | labels = np.vstack((ally, ty))
121 | labels[test_idx_reorder, :] = labels[test_idx_range, :]
122 |
123 | return adj, features, labels
124 |
125 |
126 | def eig_dgl_adj_sparse(g, sm=0, lm=0):
127 | A = g.adj(scipy_fmt='csr')
128 | deg = np.array(A.sum(axis=0)).flatten()
129 | D_ = sp.sparse.diags(deg ** -0.5)
130 |
131 | A_ = D_.dot(A.dot(D_))
132 | L_ = sp.sparse.eye(g.num_nodes()) - A_
133 |
134 | if sm > 0:
135 | e1, u1 = sp.sparse.linalg.eigsh(L_, k=sm, which='SM', tol=1e-5)
136 | e1, u1 = map(torch.FloatTensor, (e1, u1))
137 |
138 | if lm > 0:
139 | e2, u2 = sp.sparse.linalg.eigsh(L_, k=lm, which='LM', tol=1e-5)
140 | e2, u2 = map(torch.FloatTensor, (e2, u2))
141 |
142 | if sm > 0 and lm > 0:
143 | return torch.cat((e1, e2), dim=0), torch.cat((u1, u2), dim=1)
144 | elif sm > 0:
145 | return e1, u1
146 | elif lm > 0:
147 | return e2, u2
148 | else:
149 | pass
150 |
151 |
152 | def load_fb100_dataset():
153 | mat = io.loadmat('node_raw_data/Penn94.mat')
154 | A = mat['A']
155 | metadata = mat['local_info']
156 |
157 | edge_index = A.nonzero()
158 | metadata = metadata.astype(int)
159 | label = metadata[:, 1] - 1 # gender label, -1 means unlabeled
160 |
161 | # make features into one-hot encodings
162 | feature_vals = np.hstack((np.expand_dims(metadata[:, 0], 1), metadata[:, 2:]))
163 | features = np.empty((A.shape[0], 0))
164 | for col in range(feature_vals.shape[1]):
165 | feat_col = feature_vals[:, col]
166 | feat_onehot = label_binarize(feat_col, classes=np.unique(feat_col))
167 | features = np.hstack((features, feat_onehot))
168 |
169 | node_feat = torch.tensor(features, dtype=torch.float)
170 | num_nodes = metadata.shape[0]
171 | label = torch.LongTensor(label)
172 |
173 | g = dgl.graph((edge_index[0], edge_index[1]), num_nodes=num_nodes)
174 |
175 | return g, node_feat, label
176 |
177 |
178 | def generate_node_data(dataset):
179 |
180 | if dataset in ['cora', 'citeseer']:
181 |
182 | adj, x, y = load_data(dataset)
183 | adj = adj.todense()
184 | x = x.todense()
185 | x = feature_normalize(x)
186 | e, u = eigen_decompositon(adj)
187 |
188 | e = torch.FloatTensor(e)
189 | u = torch.FloatTensor(u)
190 | x = torch.FloatTensor(x)
191 | y = torch.LongTensor(y)
192 |
193 | torch.save([e, u, x, y], 'data/{}.pt'.format(dataset))
194 |
195 | elif dataset in ['photo']:
196 | data = np.load('node_raw_data/amazon_electronics_photo.npz', allow_pickle=True)
197 | adj = sp.sparse.csr_matrix((data['adj_data'], data['adj_indices'], data['adj_indptr']),
198 | shape=data['adj_shape']).toarray()
199 | feat = sp.sparse.csr_matrix((data['attr_data'], data['attr_indices'], data['attr_indptr']),
200 | shape=data['attr_shape']).toarray()
201 | x = feature_normalize(feat)
202 | y = data['labels']
203 | e, u = eigen_decompositon(adj)
204 |
205 | e = torch.FloatTensor(e)
206 | u = torch.FloatTensor(u)
207 | x = torch.FloatTensor(x)
208 | y = torch.LongTensor(y)
209 |
210 | torch.save([e, u, x, y], 'data/{}.pt'.format(dataset))
211 |
212 | elif dataset in ['arxiv']:
213 | data = DglNodePropPredDataset('ogbn-arxiv')
214 | g = data[0][0]
215 | g = dgl.add_reverse_edges(g)
216 | g = dgl.to_simple(g)
217 |
218 | e, u = eig_dgl_adj_sparse(g, sm=5000)
219 | x = g.ndata['feat']
220 | y = data[0][1]
221 |
222 | torch.save([e, u, x, y], 'data/arxiv.pt')
223 |
224 | elif dataset in ['penn']:
225 | g, x, y = load_fb100_dataset()
226 | g = dgl.add_reverse_edges(g)
227 | g = dgl.to_simple(g)
228 |
229 | e, u = eig_dgl_adj_sparse(g, sm=3000, lm=3000)
230 |
231 | torch.save([e, u, x, y], 'data/penn.pt')
232 |
233 | elif dataset in ['chameleon', 'squirrel', 'actor']:
234 | edge_df = pd.read_csv('node_raw_data/{}/'.format(dataset) + 'out1_graph_edges.txt', sep='\t')
235 | node_df = pd.read_csv('node_raw_data/{}/'.format(dataset) + 'out1_node_feature_label.txt', sep='\t')
236 | feature = node_df[node_df.columns[1]]
237 | y = node_df[node_df.columns[2]]
238 |
239 | num_nodes = len(y)
240 | adj = np.zeros((num_nodes, num_nodes))
241 |
242 | source = list(edge_df[edge_df.columns[0]])
243 | target = list(edge_df[edge_df.columns[1]])
244 |
245 | for i in range(len(source)):
246 | adj[source[i], target[i]] = 1.
247 | adj[target[i], source[i]] = 1.
248 |
249 | if dataset == 'actor':
250 | # for sparse features
251 | nfeat = 932
252 | x = np.zeros((len(y), nfeat))
253 |
254 | feature = list(feature)
255 | feature = [feat.split(',') for feat in feature]
256 | for ind, feat in enumerate(feature):
257 | for ff in feat:
258 | x[ind, int(ff)] = 1.
259 |
260 | x = feature_normalize(x)
261 | else:
262 | feature = list(feature)
263 | feature = [feat.split(',') for feat in feature]
264 | new_feat = []
265 |
266 | for feat in feature:
267 | new_feat.append([int(f) for f in feat])
268 | x = np.array(new_feat)
269 | x = feature_normalize(x)
270 |
271 | e, u = eigen_decompositon(adj)
272 |
273 | e = torch.FloatTensor(e)
274 | u = torch.FloatTensor(u)
275 | x = torch.FloatTensor(x)
276 | y = torch.LongTensor(y)
277 |
278 | torch.save([e, u, x, y], 'data/{}.pt'.format(dataset))
279 |
280 |
281 | if __name__ == '__main__':
282 | #generate_node_data('cora')
283 | #generate_node_data('citeseer')
284 | #generate_node_data('photo')
285 | #generate_node_data('chameleon')
286 | #generate_node_data('squirrel')
287 | #generate_node_data('actor')
288 | generate_node_data('penn')
289 |
290 |
--------------------------------------------------------------------------------
/Node/utils.py:
--------------------------------------------------------------------------------
1 | import time
2 | import math
3 | import random
4 | import numpy as np
5 | import scipy as sp
6 | import os
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 | from ogb.nodeproppred.dataset_dgl import DglNodePropPredDataset
11 |
12 |
13 | def count_parameters(model):
14 | return sum(p.numel() for p in model.parameters() if p.requires_grad)
15 |
16 |
17 | def init_params(module):
18 | if isinstance(module, nn.Linear):
19 | module.weight.data.normal_(mean=0.0, std=0.01)
20 | if module.bias is not None:
21 | module.bias.data.zero_()
22 |
23 |
24 | def seed_everything(seed):
25 | random.seed(seed)
26 | os.environ['PYTHONHASHSEED'] = str(seed)
27 | np.random.seed(seed)
28 | torch.manual_seed(seed)
29 | torch.cuda.manual_seed(seed)
30 | torch.backends.cudnn.deterministic = True
31 | torch.backends.cudnn.benchmark = True
32 | torch.backends.cudnn.allow_tf32 = False
33 |
34 |
35 | def get_split(dataset, y, nclass, seed=0):
36 |
37 | if dataset == 'arxiv':
38 | dataset = DglNodePropPredDataset('ogbn-arxiv')
39 | split = dataset.get_idx_split()
40 | train, valid, test = split['train'], split['valid'], split['test']
41 | return train, valid, test
42 |
43 | elif dataset == 'penn':
44 | split = np.load('node_raw_data/fb100-Penn94-splits.npy', allow_pickle=True)[0]
45 | train, valid, test = split['train'], split['valid'], split['test']
46 | return train, valid, test
47 |
48 | else:
49 | y = y.cpu()
50 |
51 | percls_trn = int(round(0.6 * len(y) / nclass))
52 | val_lb = int(round(0.2 * len(y)))
53 |
54 | indices = []
55 | for i in range(nclass):
56 | index = (y == i).nonzero().view(-1)
57 | index = index[torch.randperm(index.size(0), device=index.device)]
58 | indices.append(index)
59 |
60 | train_index = torch.cat([i[:percls_trn] for i in indices], dim=0)
61 | rest_index = torch.cat([i[percls_trn:] for i in indices], dim=0)
62 | rest_index = rest_index[torch.randperm(rest_index.size(0))]
63 | valid_index = rest_index[:val_lb]
64 | test_index = rest_index[val_lb:]
65 |
66 | return train_index, valid_index, test_index
67 |
68 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Specformer
2 | Code of [Specformer: Spectral Graph Neural Networks Meet Transformers](https://openreview.net/forum?id=0pdSt3oyJa1)
3 |
4 | # How to run
5 | - For node-level task, e.g., signal regression and node classification, you should first run preprocess_node_data.py to generate .pt files for each dataset.
6 | - For graph-level taks, you can direcly run dgl_main.py.
7 |
8 | # Q&A
9 | Any suggestion/question is welcome.
10 |
11 | # Reference
12 | If you make advantage of Specformer in your research, please cite the following in your manuscript:
13 |
14 | ```
15 | @inproceedings{specformer2023,
16 | author={Deyu Bo and
17 | Chuan Shi and
18 | Lele Wang and
19 | Renjie Liao},
20 | title={Specformer: Spectral Graph Neural Networks Meet Transformers},
21 | booktitle = {{ICLR}},
22 | publisher = {OpenReview.net},
23 | year = {2023}
24 | }
25 | ```
26 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | dgl==1.0.0
2 | easydict==1.10
3 | ema_pytorch==0.2.1
4 | json5==0.9.11
5 | networkx==3.0
6 | numpy==1.24.2
7 | ogb==1.3.5
8 | pandas==1.5.3
9 | PyYAML==6.0
10 | scikit_learn==1.2.1
11 | scipy==1.10.1
12 | torch==1.13.1
13 | torch_geometric==2.2.0
14 | torchmetrics==0.11.1
15 | tqdm==4.64.1
16 | wandb==0.13.10
17 |
--------------------------------------------------------------------------------