├── README.md ├── convertor └── ENZYMES.py ├── data.zip ├── graphdownstream ├── __pycache__ │ ├── basemodel.cpython-36.pyc │ ├── basemodel.cpython-38.pyc │ ├── dataset.cpython-36.pyc │ ├── dataset.cpython-38.pyc │ ├── embedding.cpython-36.pyc │ ├── embedding.cpython-38.pyc │ ├── filternet.cpython-36.pyc │ ├── filternet.cpython-38.pyc │ ├── gin.cpython-36.pyc │ ├── gin.cpython-38.pyc │ ├── graph_finetuning_layer.cpython-36.pyc │ ├── graph_prompt_layer.cpython-36.pyc │ ├── predictnet.cpython-36.pyc │ ├── predictnet.cpython-38.pyc │ ├── utils.cpython-36.pyc │ └── utils.cpython-38.pyc ├── basemodel.py ├── dataset.py ├── embedding.py ├── epoch_loss.png ├── epoch_loss_enzymes1.png ├── filternet.py ├── gin.py ├── gina ├── graph_prompt_layer.py ├── pre_train_before.py ├── predictnet.py ├── prompt_fewshot.py ├── prompt_fewshot_before.py ├── test.ipynb └── utils.py └── nodedownstream ├── ENZYMES2ONE_Graph.py ├── __pycache__ ├── ENZYMES2ONE_Graph.cpython-36.pyc ├── basemodel.cpython-36.pyc ├── dataset.cpython-36.pyc ├── embedding.cpython-36.pyc ├── filternet.cpython-36.pyc ├── gin.cpython-36.pyc ├── node_finetuning_layer.cpython-36.pyc ├── node_prompt_layer.cpython-36.pyc ├── predictnet.cpython-36.pyc ├── split.cpython-36.pyc └── utils.cpython-36.pyc ├── basemodel.py ├── dataset.py ├── datasetInfo.py ├── dataset_flickr.py ├── embedding.py ├── filternet.py ├── flikcrtaskchoose.py ├── gin.py ├── node_prompt_layer.py ├── pre_train.py ├── pre_train_flickr.py ├── predictnet.py ├── prompt_fewshot.py ├── prompt_fewshot_flickr.py ├── run.py ├── run_ablation.py ├── run_before.py ├── split.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | We provide the code (in pytorch) and datasets for our paper [**"GraphPrompt: Unifying Pre-Training and Downstream Tasks 2 | for Graph Neural Networks"**](https://arxiv.org/pdf/2302.08043.pdf), 3 | which is accepted by WWW2023. 4 | 5 | We Further extend GraphPrompt to GraphPrompt+ by enhancing the pre-training and prompting stages [**"Generalized Graph Prompt: Toward a Unification of Pre-Training and Downstream Tasks on Graphs"**](https://arxiv.org/pdf/2311.15317.pdf) which is accepted by IEEE TKDE, the code and datasets are publicly available (https://github.com/gmcmt/graph_prompt_extension). 6 | 7 | ## Description 8 | The repository is organised as follows: 9 | - **data/**: contains data we use. 10 | - **graphdownstream/**: implements pre-training and downstream tasks at the graph level. 11 | - **nodedownstream/**: implements downstream tasks at the node level. 12 | - **convertor/**: generate raw data. 13 | 14 | ## Package Dependencies 15 | * cuda 11.3 16 | * dgl0.9.0-cu113 17 | * dgllife 18 | 19 | ## Running experiments 20 | ### Graph Classification 21 | Default dataset is ENZYMES. You need to change the corresponding parameters in *pre_train.py* and *prompt_fewshot.py* to train and evaluate on other datasets. 22 | 23 | Pretrain: 24 | - python pre_train.py 25 | 26 | Prompt tune and test: 27 | - python prompt_fewshot.py 28 | 29 | ### Node Classification 30 | 31 | Default dataset is ENZYMES. You need to change the corresponding parameters in *prompt_fewshot.py* to train and evaluate on other datasets. 32 | 33 | Prompt tune and test: 34 | - python run.py 35 | 36 | 37 | ## Citation 38 | @inproceedings{liu2023graphprompt,\ 39 | title={GraphPrompt: Unifying Pre-Training and Downstream Tasks for Graph Neural Networks},\ 40 | author={Liu, Zemin and Yu, Xingtong and Fang, Yuan and Zhang, Xinming},\ 41 | booktitle={Proceedings of the ACM Web Conference 2023},\ 42 | year={2023}\ 43 | } 44 | -------------------------------------------------------------------------------- /convertor/ENZYMES.py: -------------------------------------------------------------------------------- 1 | import igraph as ig 2 | import os 3 | import sys 4 | import torch 5 | from tqdm import tqdm 6 | import numpy 7 | 8 | if __name__ == "__main__": 9 | assert len(sys.argv) == 2 10 | nci1_data_path = sys.argv[1] 11 | nodes = list() 12 | node2graph = list() 13 | with open(os.path.join(nci1_data_path, "ENZYMES_graph_indicator.txt"), "r") as f: 14 | for n_id, g_id in enumerate(f): 15 | g_id = int(g_id) - 1 16 | node2graph.append(g_id) 17 | if g_id == len(nodes): 18 | nodes.append(list()) 19 | nodes[-1].append(n_id) 20 | 21 | nodelabels = [list() for _ in range(len(nodes))] 22 | with open(os.path.join(nci1_data_path, "ENZYMES_node_labels.txt"), "r") as f: 23 | _nodelabels = list() 24 | for nl in f: 25 | nl = int(nl) 26 | _nodelabels.append(nl) 27 | n_idx = 0 28 | for g_idx in range(len(nodes)): 29 | for _ in range(len(nodes[g_idx])): 30 | nodelabels[g_idx].append(_nodelabels[n_idx]) 31 | n_idx += 1 32 | del _nodelabels 33 | 34 | edges = [list() for _ in range(len(nodes))] 35 | with open(os.path.join(nci1_data_path, "ENZYMES_A.txt"), "r") as f: 36 | for e in f: 37 | e = [int(v) - 1 for v in e.split(",")] 38 | g_id = node2graph[e[0]] 39 | edges[g_id].append((e[0] - nodes[g_id][0], e[1] - nodes[g_id][0])) 40 | 41 | graphlabels = list() 42 | with open(os.path.join(nci1_data_path, "ENZYMES_graph_labels.txt"), "r") as f: 43 | for nl in f: 44 | nl = int(nl) 45 | graphlabels.append(nl) 46 | 47 | #由于煞笔igraph无法保存list作为图的特征,所以只能以string的形式存储特征矩阵,到dgl中再转化成tensor 48 | #注释掉的部分是将string读做float list的形式 49 | # nodeattr = [list() for _ in range(len(nodes))] 50 | # with open(os.path.join(nci1_data_path, "ENZYMES_node_attributes.txt"), "r") as f: 51 | # _nodeattr = list() 52 | # data=f.readlines() 53 | # for line in data: 54 | # numbers = line[:-1].split(',') # 将数据分隔 55 | # numbers_float = [] # 转化为浮点数 56 | # for num in numbers: 57 | # numbers_float.append(float(num)) 58 | # _nodeattr.append(numbers_float) 59 | # '''for nl in f: 60 | # nl = float(nl) 61 | # _nodeattr.append(nl)''' 62 | # n_idx = 0 63 | # for g_idx in range(len(nodes)): 64 | # for _ in range(len(nodes[g_idx])): 65 | # nodeattr[g_idx].append(_nodeattr[n_idx]) 66 | # n_idx += 1 67 | # del _nodeattr 68 | 69 | 70 | nodeattr = [list() for _ in range(len(nodes))] 71 | with open(os.path.join(nci1_data_path, "ENZYMES_node_attributes.txt"), "r") as f: 72 | _nodeattr = list() 73 | data=f.readlines() 74 | for line in data: 75 | _nodeattr.append(line) 76 | '''for nl in f: 77 | nl = float(nl) 78 | _nodeattr.append(nl)''' 79 | n_idx = 0 80 | for g_idx in range(len(nodes)): 81 | for _ in range(len(nodes[g_idx])): 82 | nodeattr[g_idx].append(_nodeattr[n_idx]) 83 | n_idx += 1 84 | del _nodeattr 85 | 86 | 87 | 88 | os.makedirs(os.path.join(nci1_data_path, "raw"), exist_ok=True) 89 | max_n_num = 0 90 | max_e_num = 0 91 | max_nlabel_num = 0 92 | classnum = 2 93 | node_feature_dim=18 94 | num_per_class = torch.zeros(classnum) 95 | ######################################################################### 96 | ####注意NCI1的nodelabel是从1开始而非0开始的!!!!!!!!!!!!!!!!!!!!! 97 | ####需要注意对于这种没有elabel的图需要设置elabel都为0,否则后续读数据时候会出问题#### 98 | ######################################################################### 99 | count=torch.zeros(6) 100 | for g_id in tqdm(range(len(nodes))): 101 | graph = ig.Graph(directed=True) 102 | vcount = len(nodes[g_id]) 103 | vlabels = nodelabels[g_id] 104 | vfeature=nodeattr[g_id] 105 | vlabels = numpy.array(vlabels)-1 106 | vlabels = vlabels.tolist() 107 | # glabels = graphlabels[g_id] 108 | graph.add_vertices(vcount) 109 | graph.add_edges(edges[g_id]) 110 | enum=graph.ecount() 111 | elabels=torch.zeros(enum,dtype=int) 112 | elabels=elabels.numpy().tolist() 113 | graph.vs["label"] = vlabels 114 | graph["feature"]=vfeature 115 | graph.es["label"] = elabels 116 | graph.es["key"] = [0] * len(edges[g_id]) 117 | graph["label"]=graphlabels[g_id]-1 118 | count[graph["label"]]+=1 119 | graph_id = "G_N%d_E%d_NL%d_GL%d_%d" % ( 120 | vcount, len(edges[g_id]), max(vlabels) + 1, graph["label"], g_id) 121 | if vcount > max_n_num: 122 | max_n_num = vcount 123 | if len(edges[g_id]) > max_e_num: 124 | max_e_num = len(edges[g_id]) 125 | if max(vlabels) + 1 > max_nlabel_num: 126 | max_nlabel_num = max(vlabels) + 1 127 | filename = os.path.join(nci1_data_path, "raw", graph_id) 128 | # nx.nx_pydot.write_dot(pattern, filename + ".dot") 129 | graph.write(filename + ".gml") 130 | print("max_n_num: ", max_n_num) 131 | print("max_e_num: ", max_e_num) 132 | print("max_nlabel_num: ", max_nlabel_num) 133 | print("graph_label_num:",count) -------------------------------------------------------------------------------- /data.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Starlien95/GraphPrompt/f27e8328a61a86c3d8f5d0fba4edb890968866be/data.zip -------------------------------------------------------------------------------- /graphdownstream/__pycache__/basemodel.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Starlien95/GraphPrompt/f27e8328a61a86c3d8f5d0fba4edb890968866be/graphdownstream/__pycache__/basemodel.cpython-36.pyc -------------------------------------------------------------------------------- /graphdownstream/__pycache__/basemodel.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Starlien95/GraphPrompt/f27e8328a61a86c3d8f5d0fba4edb890968866be/graphdownstream/__pycache__/basemodel.cpython-38.pyc -------------------------------------------------------------------------------- /graphdownstream/__pycache__/dataset.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Starlien95/GraphPrompt/f27e8328a61a86c3d8f5d0fba4edb890968866be/graphdownstream/__pycache__/dataset.cpython-36.pyc -------------------------------------------------------------------------------- /graphdownstream/__pycache__/dataset.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Starlien95/GraphPrompt/f27e8328a61a86c3d8f5d0fba4edb890968866be/graphdownstream/__pycache__/dataset.cpython-38.pyc -------------------------------------------------------------------------------- /graphdownstream/__pycache__/embedding.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Starlien95/GraphPrompt/f27e8328a61a86c3d8f5d0fba4edb890968866be/graphdownstream/__pycache__/embedding.cpython-36.pyc -------------------------------------------------------------------------------- /graphdownstream/__pycache__/embedding.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Starlien95/GraphPrompt/f27e8328a61a86c3d8f5d0fba4edb890968866be/graphdownstream/__pycache__/embedding.cpython-38.pyc -------------------------------------------------------------------------------- /graphdownstream/__pycache__/filternet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Starlien95/GraphPrompt/f27e8328a61a86c3d8f5d0fba4edb890968866be/graphdownstream/__pycache__/filternet.cpython-36.pyc -------------------------------------------------------------------------------- /graphdownstream/__pycache__/filternet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Starlien95/GraphPrompt/f27e8328a61a86c3d8f5d0fba4edb890968866be/graphdownstream/__pycache__/filternet.cpython-38.pyc -------------------------------------------------------------------------------- /graphdownstream/__pycache__/gin.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Starlien95/GraphPrompt/f27e8328a61a86c3d8f5d0fba4edb890968866be/graphdownstream/__pycache__/gin.cpython-36.pyc -------------------------------------------------------------------------------- /graphdownstream/__pycache__/gin.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Starlien95/GraphPrompt/f27e8328a61a86c3d8f5d0fba4edb890968866be/graphdownstream/__pycache__/gin.cpython-38.pyc -------------------------------------------------------------------------------- /graphdownstream/__pycache__/graph_finetuning_layer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Starlien95/GraphPrompt/f27e8328a61a86c3d8f5d0fba4edb890968866be/graphdownstream/__pycache__/graph_finetuning_layer.cpython-36.pyc -------------------------------------------------------------------------------- /graphdownstream/__pycache__/graph_prompt_layer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Starlien95/GraphPrompt/f27e8328a61a86c3d8f5d0fba4edb890968866be/graphdownstream/__pycache__/graph_prompt_layer.cpython-36.pyc -------------------------------------------------------------------------------- /graphdownstream/__pycache__/predictnet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Starlien95/GraphPrompt/f27e8328a61a86c3d8f5d0fba4edb890968866be/graphdownstream/__pycache__/predictnet.cpython-36.pyc -------------------------------------------------------------------------------- /graphdownstream/__pycache__/predictnet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Starlien95/GraphPrompt/f27e8328a61a86c3d8f5d0fba4edb890968866be/graphdownstream/__pycache__/predictnet.cpython-38.pyc -------------------------------------------------------------------------------- /graphdownstream/__pycache__/utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Starlien95/GraphPrompt/f27e8328a61a86c3d8f5d0fba4edb890968866be/graphdownstream/__pycache__/utils.cpython-36.pyc -------------------------------------------------------------------------------- /graphdownstream/__pycache__/utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Starlien95/GraphPrompt/f27e8328a61a86c3d8f5d0fba4edb890968866be/graphdownstream/__pycache__/utils.cpython-38.pyc -------------------------------------------------------------------------------- /graphdownstream/basemodel.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | import copy 6 | import numpy as np 7 | from utils import int2onehot 8 | from utils import get_enc_len, split_and_batchify_graph_feats, batch_convert_len_to_mask 9 | from embedding import OrthogonalEmbedding, NormalEmbedding, EquivariantEmbedding 10 | from filternet import MaxGatedFilterNet 11 | from predictnet import MeanPredictNet, SumPredictNet, MaxPredictNet, \ 12 | MeanAttnPredictNet, SumAttnPredictNet, MaxAttnPredictNet, \ 13 | MeanMemAttnPredictNet, SumMemAttnPredictNet, MaxMemAttnPredictNet, \ 14 | DIAMNet 15 | 16 | class BaseModel(nn.Module): 17 | def __init__(self, config): 18 | super(BaseModel, self).__init__() 19 | 20 | self.act_func = config["activation_function"] 21 | self.init_emb = config["init_emb"] 22 | self.share_emb = config["share_emb"] 23 | self.share_arch = config["share_arch"] 24 | self.base = config["base"] 25 | self.max_ngv = config["max_ngv"] 26 | self.max_ngvl = config["max_ngvl"] 27 | self.max_nge = config["max_nge"] 28 | self.max_ngel = config["max_ngel"] 29 | self.max_npv = config["max_npv"] 30 | self.max_npvl = config["max_npvl"] 31 | self.max_npe = config["max_npe"] 32 | self.max_npel = config["max_npel"] 33 | 34 | self.emb_dim = config["emb_dim"] 35 | self.dropout = config["dropout"] 36 | self.dropatt = config["dropatt"] 37 | self.add_enc = config["predict_net_add_enc"] 38 | 39 | # create encoding layer 40 | # create filter layers 41 | # create embedding layers 42 | # create networks 43 | #self.p_net, self.g_net = None, None 44 | self.g_net = None 45 | 46 | # create predict layers 47 | self.predict_net = None 48 | 49 | def get_emb_dim(self): 50 | if self.init_emb == "None": 51 | return self.get_enc_dim() 52 | else: 53 | return self.emb_dim 54 | 55 | def get_enc(self, graph, graph_len): 56 | raise NotImplementedError 57 | 58 | def get_emb(self, graph, graph_len): 59 | raise NotImplementedError 60 | 61 | def get_filter_gate(self, graph, graph_len): 62 | raise NotImplementedError 63 | 64 | def create_filter(self, filter_type): 65 | if filter_type == "None": 66 | filter_net = None 67 | elif filter_type == "MaxGatedFilterNet": 68 | filter_net = MaxGatedFilterNet() 69 | else: 70 | raise NotImplementedError("Currently, %s is not supported!" % (filter_type)) 71 | return filter_net 72 | 73 | def create_enc(self, max_n, base): 74 | enc_len = get_enc_len(max_n-1, base) 75 | enc_dim = enc_len * base 76 | enc = nn.Embedding(max_n, enc_dim) 77 | enc.weight.data.copy_(torch.from_numpy(int2onehot(np.arange(0, max_n), enc_len, base))) 78 | enc.weight.requires_grad=False 79 | return enc 80 | 81 | def create_emb(self, input_dim, emb_dim, init_emb="Orthogonal"): 82 | if init_emb == "None": 83 | emb = None 84 | elif init_emb == "Orthogonal": 85 | emb = OrthogonalEmbedding(input_dim, emb_dim) 86 | elif init_emb == "Normal": 87 | emb = NormalEmbedding(input_dim, emb_dim) 88 | elif init_emb == "Equivariant": 89 | emb = EquivariantEmbedding(input_dim, emb_dim) 90 | else: 91 | raise NotImplementedError 92 | return emb 93 | 94 | def create_net(self, name, input_dim, **kw): 95 | raise NotImplementedError 96 | 97 | def create_predict_net(self, predict_type, pattern_dim, graph_dim, **kw): 98 | if predict_type == "None": 99 | predict_net = None 100 | elif predict_type == "MeanPredictNet": 101 | hidden_dim = kw.get("hidden_dim", 64) 102 | predict_net = MeanPredictNet(pattern_dim, graph_dim, hidden_dim, 103 | act_func=self.act_func, dropout=self.dropout) 104 | elif predict_type == "SumPredictNet": 105 | hidden_dim = kw.get("hidden_dim", 64) 106 | predict_net = SumPredictNet(pattern_dim, graph_dim, hidden_dim, 107 | act_func=self.act_func, dropout=self.dropout) 108 | elif predict_type == "MaxPredictNet": 109 | hidden_dim = kw.get("hidden_dim", 64) 110 | predict_net = MaxPredictNet(pattern_dim, graph_dim, hidden_dim, 111 | act_func=self.act_func, dropout=self.dropout) 112 | elif predict_type == "MeanAttnPredictNet": 113 | hidden_dim = kw.get("hidden_dim", 64) 114 | recurrent_steps = kw.get("recurrent_steps", 1) 115 | num_heads = kw.get("num_heads", 1) 116 | predict_net = MeanAttnPredictNet(pattern_dim, graph_dim, hidden_dim, 117 | act_func=self.act_func, 118 | num_heads=num_heads, recurrent_steps=recurrent_steps, 119 | dropout=self.dropout, dropatt=self.dropatt) 120 | elif predict_type == "SumAttnPredictNet": 121 | hidden_dim = kw.get("hidden_dim", 64) 122 | recurrent_steps = kw.get("recurrent_steps", 1) 123 | num_heads = kw.get("num_heads", 1) 124 | predict_net = SumAttnPredictNet(pattern_dim, graph_dim, hidden_dim, 125 | act_func=self.act_func, 126 | num_heads=num_heads, recurrent_steps=recurrent_steps, 127 | dropout=self.dropout, dropatt=self.dropatt) 128 | elif predict_type == "MaxAttnPredictNet": 129 | hidden_dim = kw.get("hidden_dim", 64) 130 | recurrent_steps = kw.get("recurrent_steps", 1) 131 | num_heads = kw.get("num_heads", 1) 132 | predict_net = MaxAttnPredictNet(pattern_dim, graph_dim, hidden_dim, 133 | act_func=self.act_func, 134 | num_heads=num_heads, recurrent_steps=recurrent_steps, 135 | dropout=self.dropout, dropatt=self.dropatt) 136 | elif predict_type == "MeanMemAttnPredictNet": 137 | hidden_dim = kw.get("hidden_dim", 64) 138 | recurrent_steps = kw.get("recurrent_steps", 1) 139 | num_heads = kw.get("num_heads", 1) 140 | mem_len = kw.get("mem_len", 4) 141 | predict_net = MeanMemAttnPredictNet(pattern_dim, graph_dim, hidden_dim, 142 | act_func=self.act_func, 143 | num_heads=num_heads, recurrent_steps=recurrent_steps, 144 | mem_len=mem_len, 145 | dropout=self.dropout, dropatt=self.dropatt) 146 | elif predict_type == "SumMemAttnPredictNet": 147 | hidden_dim = kw.get("hidden_dim", 64) 148 | recurrent_steps = kw.get("recurrent_steps", 1) 149 | num_heads = kw.get("num_heads", 1) 150 | mem_len = kw.get("mem_len", 4) 151 | predict_net = SumMemAttnPredictNet(pattern_dim, graph_dim, hidden_dim, 152 | act_func=self.act_func, 153 | num_heads=num_heads, recurrent_steps=recurrent_steps, 154 | mem_len=mem_len, 155 | dropout=self.dropout, dropatt=self.dropatt) 156 | elif predict_type == "MaxMemAttnPredictNet": 157 | hidden_dim = kw.get("hidden_dim", 64) 158 | recurrent_steps = kw.get("recurrent_steps", 1) 159 | num_heads = kw.get("num_heads", 1) 160 | mem_len = kw.get("mem_len", 4) 161 | predict_net = MaxMemAttnPredictNet(pattern_dim, graph_dim, hidden_dim, 162 | act_func=self.act_func, 163 | num_heads=num_heads, recurrent_steps=recurrent_steps, 164 | mem_len=mem_len, 165 | dropout=self.dropout, dropatt=self.dropatt) 166 | elif predict_type == "DIAMNet": 167 | hidden_dim = kw.get("hidden_dim", 64) 168 | recurrent_steps = kw.get("recurrent_steps", 1) 169 | num_heads = kw.get("num_heads", 1) 170 | mem_len = kw.get("mem_len", 4) 171 | mem_init = kw.get("mem_init", "mean") 172 | predict_net = DIAMNet(pattern_dim, graph_dim, hidden_dim, 173 | act_func=self.act_func, 174 | num_heads=num_heads, recurrent_steps=recurrent_steps, 175 | mem_len=mem_len, mem_init=mem_init, 176 | dropout=self.dropout, dropatt=self.dropatt) 177 | else: 178 | raise NotImplementedError("Currently, %s is not supported!" % (predict_type)) 179 | return predict_net 180 | 181 | def increase_input_size(self, config): 182 | assert config["base"] == self.base 183 | assert config["max_npv"] >= self.max_npv 184 | assert config["max_npvl"] >= self.max_npvl 185 | assert config["max_npe"] >= self.max_npe 186 | assert config["max_npel"] >= self.max_npel 187 | assert config["max_ngv"] >= self.max_ngv 188 | assert config["max_ngvl"] >= self.max_ngvl 189 | assert config["max_nge"] >= self.max_nge 190 | assert config["max_ngel"] >= self.max_ngel 191 | assert config["predict_net_add_enc"] or not self.add_enc 192 | assert config["predict_net_add_degree"] or not self.add_degree 193 | 194 | # create encoding layers 195 | # increase embedding layers 196 | # increase predict network 197 | # set new parameters 198 | 199 | def increase_net(self, config): 200 | raise NotImplementedError 201 | 202 | 203 | class EdgeSeqModel(BaseModel): 204 | def __init__(self, config): 205 | super(EdgeSeqModel, self).__init__(config) 206 | # create encoding layer 207 | self.g_v_enc, self.g_vl_enc, self.g_el_enc = \ 208 | [self.create_enc(max_n, self.base) for max_n in [self.max_ngv, self.max_ngvl, self.max_ngel]] 209 | self.g_u_enc, self.g_ul_enc = self.g_v_enc, self.g_vl_enc 210 | if self.share_emb: 211 | self.p_v_enc, self.p_vl_enc, self.p_el_enc = \ 212 | self.g_v_enc, self.g_vl_enc, self.g_el_enc 213 | else: 214 | self.p_v_enc, self.p_vl_enc, self.p_el_enc = \ 215 | [self.create_enc(max_n, self.base) for max_n in [self.max_npv, self.max_npvl, self.max_npel]] 216 | self.p_u_enc, self.p_ul_enc = self.p_v_enc, self.p_vl_enc 217 | 218 | # create filter layers 219 | self.ul_flt, self.el_flt, self.vl_flt = [self.create_filter(config["filter_net"]) for _ in range(3)] 220 | 221 | # create embedding layers 222 | self.g_u_emb, self.g_v_emb, self.g_ul_emb, self.g_el_emb, self.g_vl_emb = \ 223 | [self.create_emb(enc.embedding_dim, self.emb_dim, init_emb=self.init_emb) \ 224 | for enc in [self.g_u_enc, self.g_v_enc, self.g_ul_enc, self.g_el_enc, self.g_vl_enc]] 225 | if self.share_emb: 226 | self.p_u_emb, self.p_v_emb, self.p_ul_emb, self.p_el_emb, self.p_vl_emb = \ 227 | self.g_u_emb, self.g_v_emb, self.g_ul_emb, self.g_el_emb, self.g_vl_emb 228 | else: 229 | self.p_u_emb, self.p_v_emb, self.p_ul_emb, self.p_el_emb, self.p_vl_emb = \ 230 | [self.create_emb(enc.embedding_dim, self.emb_dim, init_emb=self.init_emb) \ 231 | for enc in [self.p_u_enc, self.p_v_enc, self.p_ul_enc, self.p_el_enc, self.p_vl_enc]] 232 | 233 | # create networks 234 | # create predict layers 235 | 236 | def get_enc_dim(self): 237 | g_dim = self.base * (get_enc_len(self.max_ngv-1, self.base) * 2 + \ 238 | get_enc_len(self.max_ngvl-1, self.base) * 2 + \ 239 | get_enc_len(self.max_ngel-1, self.base)) 240 | if self.share_emb: 241 | return g_dim, g_dim 242 | else: 243 | p_dim = self.base * (get_enc_len(self.max_npv-1, self.base) * 2 + \ 244 | get_enc_len(self.max_npvl-1, self.base) * 2 + \ 245 | get_enc_len(self.max_npel-1, self.base)) 246 | return p_dim, g_dim 247 | 248 | def get_emb_dim(self): 249 | if self.init_emb == "None": 250 | return self.get_enc_dim() 251 | else: 252 | return self.emb_dim, self.emb_dim 253 | 254 | def get_enc(self, pattern, pattern_len, graph, graph_len): 255 | pattern_u, pattern_v, pattern_ul, pattern_el, pattern_vl = \ 256 | self.p_u_enc(pattern.u), self.p_v_enc(pattern.v), self.p_ul_enc(pattern.ul), self.p_el_enc(pattern.el), self.p_vl_enc(pattern.vl) 257 | graph_u, graph_v, graph_ul, graph_el, graph_vl = \ 258 | self.g_u_enc(graph.u), self.g_v_enc(graph.v), self.g_ul_enc(graph.ul), self.g_el_enc(graph.el), self.g_vl_enc(graph.vl) 259 | 260 | p_enc = torch.cat([ 261 | pattern_u, 262 | pattern_v, 263 | pattern_ul, 264 | pattern_el, 265 | pattern_vl], dim=2) 266 | g_enc = torch.cat([ 267 | graph_u, 268 | graph_v, 269 | graph_ul, 270 | graph_el, 271 | graph_vl], dim=2) 272 | return p_enc, g_enc 273 | 274 | def get_emb(self, pattern, pattern_len, graph, graph_len): 275 | bsz = pattern_len.size(0) 276 | pattern_u, pattern_v, pattern_ul, pattern_el, pattern_vl = \ 277 | self.p_u_enc(pattern.u), self.p_v_enc(pattern.v), self.p_ul_enc(pattern.ul), self.p_el_enc(pattern.el), self.p_vl_enc(pattern.vl) 278 | graph_u, graph_v, graph_ul, graph_el, graph_vl = \ 279 | self.g_u_enc(graph.u), self.g_v_enc(graph.v), self.g_ul_enc(graph.ul), self.g_el_enc(graph.el), self.g_vl_enc(graph.vl) 280 | 281 | if self.init_emb == "None": 282 | p_emb = torch.cat([pattern_u, pattern_v, pattern_ul, pattern_el, pattern_vl], dim=2) 283 | g_emb = torch.cat([graph_u, graph_v, graph_ul, graph_el, graph_vl], dim=2) 284 | else: 285 | p_emb = self.p_u_emb(pattern_u) + \ 286 | self.p_v_emb(pattern_v) + \ 287 | self.p_ul_emb(pattern_ul) + \ 288 | self.p_el_emb(pattern_el) + \ 289 | self.p_vl_emb(pattern_vl) 290 | g_emb = self.g_u_emb(graph_u) + \ 291 | self.g_v_emb(graph_v) + \ 292 | self.g_ul_emb(graph_ul) + \ 293 | self.g_el_emb(graph_el) + \ 294 | self.g_vl_emb(graph_vl) 295 | return p_emb, g_emb 296 | 297 | def get_filter_gate(self, pattern, pattern_len, graph, graph_len): 298 | gate = None 299 | if self.ul_flt is not None: 300 | if gate is not None: 301 | gate &= self.ul_flt(pattern.ul, graph.ul) 302 | else: 303 | gate = self.ul_flt(pattern.ul, graph.ul) 304 | if self.el_flt is not None: 305 | if gate is not None: 306 | gate &= self.el_flt(pattern.el, graph.el) 307 | else: 308 | gate = self.el_flt(pattern.el, graph.el) 309 | if self.vl_flt is not None: 310 | if gate is not None: 311 | gate &= self.vl_flt(pattern.vl, graph.vl) 312 | else: 313 | gate = self.vl_flt(pattern.vl, graph.vl) 314 | return gate 315 | 316 | def increase_input_size(self, config): 317 | super(EdgeSeqModel, self).increase_input_size(config) 318 | 319 | # create encoding layers 320 | new_g_v_enc, new_g_vl_enc, new_g_el_enc = \ 321 | [self.create_enc(max_n, self.base) for max_n in [config["max_ngv"], config["max_ngvl"], config["max_ngel"]]] 322 | if self.share_emb: 323 | new_p_v_enc, new_p_vl_enc, new_p_el_enc = \ 324 | new_g_v_enc, new_g_vl_enc, new_g_el_enc 325 | else: 326 | new_p_v_enc, new_p_vl_enc, new_p_el_enc = \ 327 | [self.create_enc(max_n, self.base) for max_n in [config["max_npv"], config["max_npvl"], config["max_npel"]]] 328 | del self.g_v_enc, self.g_vl_enc, self.g_el_enc, self.g_u_enc, self.g_ul_enc 329 | del self.p_v_enc, self.p_vl_enc, self.p_el_enc, self.p_u_enc, self.p_ul_enc 330 | self.g_v_enc, self.g_vl_enc, self.g_el_enc = new_g_v_enc, new_g_vl_enc, new_g_el_enc 331 | self.g_u_enc, self.g_ul_enc = self.g_v_enc, self.g_vl_enc 332 | self.p_v_enc, self.p_vl_enc, self.p_el_enc = new_p_v_enc, new_p_vl_enc, new_p_el_enc 333 | self.p_u_enc, self.p_ul_enc = self.p_v_enc, self.p_vl_enc 334 | 335 | # increase embedding layers 336 | self.g_u_emb.increase_input_size(self.g_u_enc.embedding_dim) 337 | self.g_v_emb.increase_input_size(self.g_v_enc.embedding_dim) 338 | self.g_ul_emb.increase_input_size(self.g_ul_enc.embedding_dim) 339 | self.g_vl_emb.increase_input_size(self.g_vl_enc.embedding_dim) 340 | self.g_el_emb.increase_input_size(self.g_el_enc.embedding_dim) 341 | if not self.share_emb: 342 | self.p_u_emb.increase_input_size(self.p_u_enc.embedding_dim) 343 | self.p_v_emb.increase_input_size(self.p_v_enc.embedding_dim) 344 | self.p_ul_emb.increase_input_size(self.p_ul_enc.embedding_dim) 345 | self.p_vl_emb.increase_input_size(self.p_vl_enc.embedding_dim) 346 | self.p_el_emb.increase_input_size(self.p_el_enc.embedding_dim) 347 | 348 | # increase predict network 349 | 350 | # set new parameters 351 | self.max_npv = config["max_npv"] 352 | self.max_npvl = config["max_npvl"] 353 | self.max_npe = config["max_npe"] 354 | self.max_npel = config["max_npel"] 355 | self.max_ngv = config["max_ngv"] 356 | self.max_ngvl = config["max_ngvl"] 357 | self.max_nge = config["max_nge"] 358 | self.max_ngel = config["max_ngel"] 359 | 360 | 361 | 362 | class GraphAdjModel(BaseModel): 363 | def __init__(self, config): 364 | super(GraphAdjModel, self).__init__(config) 365 | 366 | self.add_degree = config["predict_net_add_degree"] 367 | 368 | # create encoding layer 369 | self.g_v_enc, self.g_vl_enc = \ 370 | [self.create_enc(max_n, self.base) for max_n in [self.max_ngv, self.max_ngvl]] 371 | '''if self.share_emb: 372 | self.p_v_enc, self.p_vl_enc = \ 373 | self.g_v_enc, self.g_vl_enc 374 | else: 375 | self.p_v_enc, self.p_vl_enc = \ 376 | [self.create_enc(max_n, self.base) for max_n in [self.max_npv, self.max_npvl]]''' 377 | 378 | # create filter layers 379 | self.vl_flt = self.create_filter(config["filter_net"]) 380 | 381 | # create embedding layers 382 | self.g_vl_emb = self.create_emb(self.g_vl_enc.embedding_dim, self.emb_dim, init_emb=self.init_emb) 383 | '''if self.share_emb: 384 | self.p_vl_emb = self.g_vl_emb 385 | else: 386 | self.p_vl_emb = self.create_emb(self.p_vl_enc.embedding_dim, self.emb_dim, init_emb=self.init_emb)''' 387 | 388 | # create networks 389 | # create predict layers 390 | 391 | def get_enc_dim(self): 392 | g_dim = self.base * (get_enc_len(self.max_ngv-1, self.base) + get_enc_len(self.max_ngvl-1, self.base)) 393 | '''if self.share_emb: 394 | return g_dim, g_dim 395 | else: 396 | p_dim = self.base * (get_enc_len(self.max_npv-1, self.base) + get_enc_len(self.max_npvl-1, self.base)) 397 | return p_dim, g_dim''' 398 | return g_dim 399 | 400 | #def get_enc(self, pattern, pattern_len, graph, graph_len): 401 | def get_enc(self, graph, graph_len): 402 | graph_v, graph_vl = self.g_v_enc(graph.ndata["id"]), self.g_vl_enc(graph.ndata["label"]) 403 | g_enc = torch.cat([graph_v, graph_vl], dim=1) 404 | return g_enc 405 | 406 | def get_emb(self, graph, graph_len): 407 | 408 | graph_v, graph_vl = self.g_v_enc(graph.ndata["id"]), self.g_vl_enc(graph.ndata["label"]) 409 | 410 | if self.init_emb == "None": 411 | g_emb = graph_vl 412 | else: 413 | g_emb = self.g_vl_emb(graph_vl) 414 | return g_emb 415 | 416 | def get_filter_gate(self, graph, graph_len): 417 | 418 | gate = None 419 | if self.vl_flt is not None: 420 | '''gate = self.vl_flt( 421 | split_and_batchify_graph_feats(pattern.ndata["label"].unsqueeze(-1), pattern_len)[0], 422 | split_and_batchify_graph_feats(graph.ndata["label"].unsqueeze(-1), graph_len)[0])''' 423 | gate = self.vl_flt(split_and_batchify_graph_feats(graph.ndata["label"].unsqueeze(-1), graph_len)[0]) 424 | 425 | if gate is not None: 426 | bsz = graph_len.size(0) 427 | max_g_len = graph_len.max() 428 | if bsz * max_g_len != graph.number_of_nodes(): 429 | graph_mask = batch_convert_len_to_mask(graph_len) # bsz x max_len 430 | gate = gate.masked_select(graph_mask.unsqueeze(-1)).view(-1, 1) 431 | else: 432 | gate = gate.view(-1, 1) 433 | return gate 434 | 435 | def increase_input_size(self, config): 436 | super(GraphAdjModel, self).increase_input_size(config) 437 | 438 | # create encoding layers 439 | new_g_v_enc, new_g_vl_enc = \ 440 | [self.create_enc(max_n, self.base) for max_n in [config["max_ngv"], config["max_ngvl"]]] 441 | '''if self.share_emb: 442 | new_p_v_enc, new_p_vl_enc = \ 443 | new_g_v_enc, new_g_vl_enc 444 | else: 445 | new_p_v_enc, new_p_vl_enc = \ 446 | [self.create_enc(max_n, self.base) for max_n in [config["max_npv"], config["max_npvl"]]]''' 447 | del self.g_v_enc, self.g_vl_enc 448 | #del self.p_v_enc, self.p_vl_enc 449 | self.g_v_enc, self.g_vl_enc = new_g_v_enc, new_g_vl_enc 450 | #self.p_v_enc, self.p_vl_enc = new_p_v_enc, new_p_vl_enc 451 | 452 | # increase embedding layers 453 | self.g_vl_emb.increase_input_size(self.g_vl_enc.embedding_dim) 454 | '''if not self.share_emb: 455 | self.p_vl_emb.increase_input_size(self.p_vl_enc.embedding_dim)''' 456 | 457 | # increase networks 458 | 459 | # increase predict network 460 | 461 | # set new parameters 462 | #npv:pattern vertex 463 | #npvl:pattern vertex label 464 | #npe:pattern edge 465 | #npel:pattern edge label 466 | '''self.max_npv = config["max_npv"] 467 | self.max_npvl = config["max_npvl"] 468 | self.max_npe = config["max_npe"] 469 | self.max_npel = config["max_npel"]''' 470 | self.max_ngv = config["max_ngv"] 471 | self.max_ngvl = config["max_ngvl"] 472 | self.max_nge = config["max_nge"] 473 | self.max_ngel = config["max_ngel"] 474 | -------------------------------------------------------------------------------- /graphdownstream/dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import dgl 4 | import os 5 | import math 6 | import pickle 7 | import json 8 | import copy 9 | import torch.utils.data as data 10 | import random 11 | from collections import defaultdict, Counter 12 | from tqdm import tqdm 13 | from utils import get_enc_len, int2onehot, \ 14 | batch_convert_tensor_to_tensor, batch_convert_array_to_array,label2onehot 15 | 16 | INF = float("inf") 17 | 18 | ############################################## 19 | ################ Sampler Part ################ 20 | ############################################## 21 | class Sampler(data.Sampler): 22 | _type_map = { 23 | int: np.int32, 24 | float: np.float32} 25 | 26 | def __init__(self, dataset, group_by, batch_size, shuffle, drop_last): 27 | super(Sampler, self).__init__(dataset) 28 | if isinstance(group_by, str): 29 | group_by = [group_by] 30 | for attr in group_by: 31 | setattr(self, attr, list()) 32 | self.data_size = len(dataset.data) 33 | for x in dataset.data: 34 | for attr in group_by: 35 | value = x[attr] 36 | if isinstance(value, dgl.DGLGraph): 37 | getattr(self, attr).append(value.number_of_nodes()) 38 | elif hasattr(value, "__len__"): 39 | getattr(self, attr).append(len(value)) 40 | else: 41 | getattr(self, attr).append(value) 42 | self.order = copy.copy(group_by) 43 | self.order.append("rand") 44 | self.batch_size = batch_size 45 | self.shuffle = shuffle 46 | self.drop_last = drop_last 47 | 48 | def make_array(self): 49 | self.rand = np.random.rand(self.data_size).astype(np.float32) 50 | if self.data_size == 0: 51 | types = [np.float32] * len(self.order) 52 | else: 53 | types = [type(getattr(self, attr)[0]) for attr in self.order] 54 | types = [Sampler._type_map.get(t, t) for t in types] 55 | dtype = list(zip(self.order, types)) 56 | array = np.array( 57 | list(zip(*[getattr(self, attr) for attr in self.order])), 58 | dtype=dtype) 59 | return array 60 | 61 | def __iter__(self): 62 | array = self.make_array() 63 | indices = np.argsort(array, axis=0, order=self.order) 64 | batches = [indices[i:i + self.batch_size] for i in range(0, len(indices), self.batch_size)] 65 | if self.shuffle: 66 | np.random.shuffle(batches) 67 | batch_idx = 0 68 | while batch_idx < len(batches)-1: 69 | yield batches[batch_idx] 70 | batch_idx += 1 71 | if len(batches) > 0 and (len(batches[batch_idx]) == self.batch_size or not self.drop_last): 72 | yield batches[batch_idx] 73 | 74 | def __len__(self): 75 | if self.drop_last: 76 | return math.floor(self.data_size/self.batch_size) 77 | else: 78 | return math.ceil(self.data_size/self.batch_size) 79 | 80 | 81 | ############################################## 82 | ############# EdgeSeq Data Part ############## 83 | ############################################## 84 | class EdgeSeq: 85 | def __init__(self, code): 86 | self.u = code[:,0] 87 | self.v = code[:,1] 88 | self.ul = code[:,2] 89 | self.el = code[:,3] 90 | self.vl = code[:,4] 91 | 92 | def __len__(self): 93 | if len(self.u.shape) == 2: # single code 94 | return self.u.shape[0] 95 | else: # batch code 96 | return self.u.shape[0] * self.u.shape[1] 97 | 98 | @staticmethod 99 | def batch(data): 100 | b = EdgeSeq(torch.empty((0,5), dtype=torch.long)) 101 | b.u = batch_convert_tensor_to_tensor([x.u for x in data]) 102 | b.v = batch_convert_tensor_to_tensor([x.v for x in data]) 103 | b.ul = batch_convert_tensor_to_tensor([x.ul for x in data]) 104 | b.el = batch_convert_tensor_to_tensor([x.el for x in data]) 105 | b.vl = batch_convert_tensor_to_tensor([x.vl for x in data]) 106 | return b 107 | 108 | def to(self, device): 109 | self.u = self.u.to(device) 110 | self.v = self.v.to(device) 111 | self.ul = self.ul.to(device) 112 | self.el = self.el.to(device) 113 | self.vl = self.vl.to(device) 114 | 115 | 116 | ############################################## 117 | ############# EdgeSeq Data Part ############## 118 | ############################################## 119 | class EdgeSeqDataset(data.Dataset): 120 | def __init__(self, data=None): 121 | super(EdgeSeqDataset, self).__init__() 122 | 123 | if data: 124 | self.data = EdgeSeqDataset.preprocess_batch(data, use_tqdm=True) 125 | else: 126 | self.data = list() 127 | self._to_tensor() 128 | 129 | def _to_tensor(self): 130 | for x in self.data: 131 | for k in ["pattern", "graph", "subisomorphisms"]: 132 | if isinstance(x[k], np.ndarray): 133 | x[k] = torch.from_numpy(x[k]) 134 | 135 | def __len__(self): 136 | return len(self.data) 137 | 138 | def __getitem__(self, idx): 139 | return self.data[idx] 140 | 141 | def save(self, filename): 142 | cache = defaultdict(list) 143 | for x in self.data: 144 | for k in list(x.keys()): 145 | if k.startswith("_"): 146 | cache[k].append(x.pop(k)) 147 | with open(filename, "wb") as f: 148 | torch.save(self.data, f, pickle_protocol=pickle.HIGHEST_PROTOCOL) 149 | if len(cache) > 0: 150 | keys = cache.keys() 151 | for i in range(len(self.data)): 152 | for k in keys: 153 | self.data[i][k] = cache[k][i] 154 | 155 | def load(self, filename): 156 | with open(filename, "rb") as f: 157 | data = torch.load(f) 158 | del self.data 159 | self.data = data 160 | 161 | return self 162 | 163 | @staticmethod 164 | def graph2edgeseq(graph): 165 | labels = graph.vs["label"] 166 | graph_code = list() 167 | 168 | for edge in graph.es: 169 | v, u = edge.tuple 170 | graph_code.append((v, u, labels[v], edge["label"], labels[u])) 171 | graph_code = np.array(graph_code, dtype=np.int64) 172 | graph_code.view( 173 | [("v", "int64"), ("u", "int64"), ("vl", "int64"), ("el", "int64"), ("ul", "int64")]).sort( 174 | axis=0, order=["v", "u", "el"]) 175 | return graph_code 176 | 177 | @staticmethod 178 | def preprocess(x): 179 | pattern_code = EdgeSeqDataset.graph2edgeseq(x["pattern"]) 180 | graph_code = EdgeSeqDataset.graph2edgeseq(x["graph"]) 181 | subisomorphisms = np.array(x["subisomorphisms"], dtype=np.int32).reshape(-1, x["pattern"].vcount()) 182 | 183 | x = { 184 | "id": x["id"], 185 | "pattern": pattern_code, 186 | "graph": graph_code, 187 | "counts": x["counts"], 188 | "subisomorphisms": subisomorphisms} 189 | return x 190 | 191 | @staticmethod 192 | def preprocess_batch(data, use_tqdm=False): 193 | d = list() 194 | if use_tqdm: 195 | data = tqdm(data) 196 | for x in data: 197 | d.append(EdgeSeqDataset.preprocess(x)) 198 | return d 199 | 200 | @staticmethod 201 | def batchify(batch): 202 | _id = [x["id"] for x in batch] 203 | pattern = EdgeSeq.batch([EdgeSeq(x["pattern"]) for x in batch]) 204 | pattern_len = torch.tensor([x["pattern"].shape[0] for x in batch], dtype=torch.int32).view(-1, 1) 205 | graph = EdgeSeq.batch([EdgeSeq(x["graph"]) for x in batch]) 206 | graph_len = torch.tensor([x["graph"].shape[0] for x in batch], dtype=torch.int32).view(-1, 1) 207 | counts = torch.tensor([x["counts"] for x in batch], dtype=torch.float32).view(-1, 1) 208 | return _id, pattern, pattern_len, graph, graph_len, counts 209 | 210 | 211 | ############################################## 212 | ######### GraphAdj Data Part ########### 213 | ############################################## 214 | class GraphAdjDataset_DGL_Input(data.Dataset): 215 | def __init__(self, data=None): 216 | super(GraphAdjDataset_DGL_Input, self).__init__() 217 | 218 | self.data = GraphAdjDataset_DGL_Input.preprocess_batch(data, use_tqdm=True) 219 | #self._to_tensor() 220 | 221 | def _to_tensor(self): 222 | for x in self.data: 223 | for k in ["graph"]: 224 | y = x[k] 225 | for k, v in y.ndata.items(): 226 | if isinstance(v, np.ndarray): 227 | y.ndata[k] = torch.from_numpy(v) 228 | for k, v in y.edata.items(): 229 | if isinstance(v, np.ndarray): 230 | y.edata[k] = torch.from_numpy(v) 231 | if isinstance(x["subisomorphisms"], np.ndarray): 232 | x["subisomorphisms"] = torch.from_numpy(x["subisomorphisms"]) 233 | 234 | def __len__(self): 235 | return len(self.data) 236 | 237 | def __getitem__(self, idx): 238 | return self.data[idx] 239 | 240 | def save(self, filename): 241 | cache = defaultdict(list) 242 | for x in self.data: 243 | for k in list(x.keys()): 244 | if k.startswith("_"): 245 | cache[k].append(x.pop(k)) 246 | with open(filename, "wb") as f: 247 | torch.save(self.data, f, pickle_protocol=pickle.HIGHEST_PROTOCOL) 248 | if len(cache) > 0: 249 | keys = cache.keys() 250 | for i in range(len(self.data)): 251 | for k in keys: 252 | self.data[i][k] = cache[k][i] 253 | 254 | def load(self, filename): 255 | with open(filename, "rb") as f: 256 | data = torch.load(f) 257 | del self.data 258 | self.data = data 259 | return self 260 | 261 | @staticmethod 262 | def comp_indeg_norm(graph): 263 | import igraph as ig 264 | if isinstance(graph, ig.Graph): 265 | # 10x faster 266 | in_deg = np.array(graph.indegree(), dtype=np.float32) 267 | elif isinstance(graph, dgl.DGLGraph): 268 | in_deg = graph.in_degrees(range(graph.number_of_nodes())).float().numpy() 269 | else: 270 | raise NotImplementedError 271 | norm = 1.0 / in_deg 272 | norm[np.isinf(norm)] = 0 273 | return norm 274 | 275 | @staticmethod 276 | def graph2dglgraph(graph): 277 | dglgraph = dgl.DGLGraph(multigraph=True) 278 | dglgraph.add_nodes(graph.vcount()) 279 | edges = graph.get_edgelist() 280 | dglgraph.add_edges([e[0] for e in edges], [e[1] for e in edges]) 281 | dglgraph.readonly(True) 282 | return dglgraph 283 | 284 | @staticmethod 285 | #打乱了遍历所有节点找到不相邻节点的顺序,而不是像8.5前的代码一样直接按序遍历 286 | #从而应当可以提高预训练模型的效果 287 | #同时这里还可以考虑增加要找到的负样本数量,应当会有更好的预训练效果 288 | def find_no_connection_node(graph,node): 289 | numnode=graph.number_of_nodes() 290 | rand=list(range(numnode)) 291 | random.shuffle(rand) 292 | for i in range(numnode): 293 | if graph.has_edges_between(node,rand[i]): 294 | continue 295 | else: 296 | return i 297 | 298 | @staticmethod 299 | def findsample(graph): 300 | nodenum=graph.number_of_nodes() 301 | result=torch.ones(nodenum,3) 302 | adj=graph.adjacency_matrix() 303 | src=adj._indices()[1].tolist() 304 | dst=adj._indices()[0].tolist() 305 | #----------------------------------------------------------------------------------------- 306 | #这里的处理方式针对所有节点皆有邻居的情况且皆以之为起点,且存在不与之相连的节点,典型的数据为双向图且无孤立节点, 307 | #对于其他类型数据则不适用,需要考虑无邻居节点或无不相连节点要怎么处理 308 | #一个处理方式就是在图上选择出符合要求的节点来构建图 309 | for i in range(nodenum): 310 | result[i,0]=i 311 | # NCI1存在着某些节点是孤立的情况,这里将孤立节点的正样本设为其自身 312 | if i not in src: 313 | result[i,1]=i 314 | else: 315 | index_i=src.index(i) 316 | i_point_to=dst[index_i] 317 | result[i,1]=i_point_to 318 | result[i,2]=GraphAdjDataset.find_no_connection_node(graph,i) 319 | #------------------------------------------------------------------------------------------- 320 | return torch.tensor(result,dtype=int) 321 | 322 | @staticmethod 323 | def preprocess(x): 324 | graph = x["graph"] 325 | '''graph_dglgraph = GraphAdjDataset.graph2dglgraph(graph) 326 | graph_dglgraph.ndata["indeg"] = torch.tensor(np.array(graph.indegree(), dtype=np.float32)) 327 | graph_dglgraph.ndata["label"] = torch.tensor(np.array(graph.vs["label"], dtype=np.int64)) 328 | graph_dglgraph.ndata["id"] = torch.tensor(np.arange(0, graph.vcount(), dtype=np.int64)) 329 | graph_dglgraph.ndata["sample"] = GraphAdjDataset.findsample(graph_dglgraph)''' 330 | x = { 331 | "id": x["id"], 332 | "graph": graph, 333 | "label": x["label"]} 334 | return x 335 | 336 | 337 | @staticmethod 338 | def preprocess_batch(data, use_tqdm=False): 339 | d = list() 340 | if use_tqdm: 341 | data = tqdm(data) 342 | for x in data: 343 | d.append(GraphAdjDataset_DGL_Input.preprocess(x)) 344 | return d 345 | 346 | @staticmethod 347 | def batchify(batch): 348 | _id = [x["id"] for x in batch] 349 | graph_label = torch.tensor([x["label"] for x in batch], dtype=torch.float64).view(-1, 1) 350 | graph = dgl.batch([x["graph"] for x in batch]) 351 | graph_len = torch.tensor([x["graph"].number_of_nodes() for x in batch], dtype=torch.int32).view(-1, 1) 352 | return _id, graph_label, graph, graph_len 353 | 354 | 355 | class GraphAdjDataset(data.Dataset): 356 | def __init__(self, data=None): 357 | super(GraphAdjDataset, self).__init__() 358 | 359 | if data: 360 | self.data = GraphAdjDataset.preprocess_batch(data, use_tqdm=True) 361 | else: 362 | self.data = list() 363 | # self._to_tensor() 364 | 365 | def _to_tensor(self): 366 | for x in self.data: 367 | for k in ["graph"]: 368 | y = x[k] 369 | for k, v in y.ndata.items(): 370 | if isinstance(v, np.ndarray): 371 | y.ndata[k] = torch.from_numpy(v) 372 | for k, v in y.edata.items(): 373 | if isinstance(v, np.ndarray): 374 | y.edata[k] = torch.from_numpy(v) 375 | if isinstance(x["subisomorphisms"], np.ndarray): 376 | x["subisomorphisms"] = torch.from_numpy(x["subisomorphisms"]) 377 | 378 | def __len__(self): 379 | return len(self.data) 380 | 381 | def __getitem__(self, idx): 382 | return self.data[idx] 383 | 384 | def save(self, filename): 385 | cache = defaultdict(list) 386 | for x in self.data: 387 | for k in list(x.keys()): 388 | if k.startswith("_"): 389 | cache[k].append(x.pop(k)) 390 | with open(filename, "wb") as f: 391 | torch.save(self.data, f, pickle_protocol=pickle.HIGHEST_PROTOCOL) 392 | if len(cache) > 0: 393 | keys = cache.keys() 394 | for i in range(len(self.data)): 395 | for k in keys: 396 | self.data[i][k] = cache[k][i] 397 | 398 | def load(self, filename): 399 | with open(filename, "rb") as f: 400 | data = torch.load(f) 401 | del self.data 402 | self.data = data 403 | return self 404 | 405 | @staticmethod 406 | def comp_indeg_norm(graph): 407 | import igraph as ig 408 | if isinstance(graph, ig.Graph): 409 | # 10x faster 410 | in_deg = np.array(graph.indegree(), dtype=np.float32) 411 | elif isinstance(graph, dgl.DGLGraph): 412 | in_deg = graph.in_degrees(range(graph.number_of_nodes())).float().numpy() 413 | else: 414 | raise NotImplementedError 415 | norm = 1.0 / in_deg 416 | norm[np.isinf(norm)] = 0 417 | return norm 418 | 419 | @staticmethod 420 | def graph2dglgraph(graph): 421 | dglgraph = dgl.DGLGraph(multigraph=True) 422 | dglgraph.add_nodes(graph.vcount()) 423 | edges = graph.get_edgelist() 424 | dglgraph.add_edges([e[0] for e in edges], [e[1] for e in edges]) 425 | dglgraph.readonly(True) 426 | return dglgraph 427 | 428 | @staticmethod 429 | # 打乱了遍历所有节点找到不相邻节点的顺序,而不是像8.5前的代码一样直接按序遍历 430 | # 从而应当可以提高预训练模型的效果 431 | # 同时这里还可以考虑增加要找到的负样本数量,应当会有更好的预训练效果 432 | def find_no_connection_node(graph, node): 433 | numnode = graph.number_of_nodes() 434 | rand = list(range(numnode)) 435 | random.shuffle(rand) 436 | for i in range(numnode): 437 | if graph.has_edges_between(node, rand[i]): 438 | continue 439 | else: 440 | return i 441 | 442 | @staticmethod 443 | def findsample(graph): 444 | nodenum = graph.number_of_nodes() 445 | result = torch.ones(nodenum, 3) 446 | adj = graph.adjacency_matrix() 447 | #当前版本的dgl的adj indeices将src和dst改成符合直觉的顺序了 448 | '''src = adj._indices()[1].tolist() 449 | dst = adj._indices()[0].tolist()''' 450 | src = adj._indices()[0].tolist() 451 | dst = adj._indices()[1].tolist() 452 | 453 | # ----------------------------------------------------------------------------------------- 454 | # 这里的处理方式针对所有节点皆有邻居的情况且皆以之为起点,且存在不与之相连的节点,典型的数据为双向图且无孤立节点, 455 | # 对于其他类型数据则不适用,需要考虑无邻居节点或无不相连节点要怎么处理 456 | # 一个处理方式就是在图上选择出符合要求的节点来构建图 457 | for i in range(nodenum): 458 | result[i, 0] = i 459 | # NCI1存在着某些节点是孤立的情况,这里将孤立节点的正样本设为其自身 460 | if i not in src: 461 | result[i, 1] = i 462 | else: 463 | index_i = src.index(i) 464 | i_point_to = dst[index_i] 465 | result[i, 1] = i_point_to 466 | result[i, 2] = GraphAdjDataset.find_no_connection_node(graph, i) 467 | # ------------------------------------------------------------------------------------------- 468 | return torch.tensor(result, dtype=int) 469 | 470 | # @staticmethod 471 | # def igraph_node_feature2dgl_node_feature(input): 472 | # a = input 473 | # b = a.split('[')[1] 474 | # c = b.split(']')[0] 475 | # d = c.split('\', ') 476 | # nodefeaturestring = [] 477 | # for nodefeature in d: 478 | # temp = nodefeature.split('\'')[1] 479 | # temp = temp.split(',') 480 | # nodefeaturestring.append(temp) 481 | # 482 | # numbers_float = [] # 转化为浮点数 483 | # for num in nodefeaturestring: 484 | # temp = [] 485 | # for data in num: 486 | # temp.append(float(data)) 487 | # numbers_float.append(temp) 488 | # return numbers_float 489 | 490 | @staticmethod 491 | def preprocess(x): 492 | graph = x["graph"] 493 | graph_dglgraph = GraphAdjDataset.graph2dglgraph(graph) 494 | '''graph_dglgraph.ndata["indeg"] = np.array(graph.indegree(), dtype=np.float32) 495 | graph_dglgraph.ndata["label"] = np.array(graph.vs["label"], dtype=np.int64) 496 | graph_dglgraph.ndata["id"] = np.arange(0, graph.vcount(), dtype=np.int64)''' 497 | # graph_dglgraph.edata["label"] = np.array(graph.es["label"], dtype=np.int64) 498 | graph_dglgraph.ndata["indeg"] = torch.tensor(np.array(graph.indegree(), dtype=np.float32)) 499 | graph_dglgraph.ndata["label"] = torch.tensor(np.array(graph.vs["label"], dtype=np.int64)) 500 | graph_dglgraph.ndata["id"] = torch.tensor(np.arange(0, graph.vcount(), dtype=np.int64)) 501 | graph_dglgraph.ndata["sample"] = GraphAdjDataset.findsample(graph_dglgraph) 502 | nodefeature=graph["feature"] 503 | label=graph["label"] 504 | graph_dglgraph.ndata["feature"]=torch.tensor(np.array(nodefeature, dtype=np.float32)) 505 | x = { 506 | "id": x["id"], 507 | "graph": graph_dglgraph, 508 | "label": label} 509 | return x 510 | 511 | @staticmethod 512 | def preprocess_batch(data, use_tqdm=False): 513 | d = list() 514 | if use_tqdm: 515 | data = tqdm(data) 516 | for x in data: 517 | d.append(GraphAdjDataset.preprocess(x)) 518 | return d 519 | 520 | @staticmethod 521 | def batchify(batch): 522 | _id = [x["id"] for x in batch] 523 | graph_label = torch.tensor([x["label"] for x in batch], dtype=torch.float64).view(-1, 1) 524 | graph = dgl.batch([x["graph"] for x in batch]) 525 | graph_len = torch.tensor([x["graph"].number_of_nodes() for x in batch], dtype=torch.int32).view(-1, 1) 526 | return _id, graph_label, graph, graph_len 527 | -------------------------------------------------------------------------------- /graphdownstream/embedding.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from utils import extend_dimensions 5 | 6 | 7 | class NormalEmbedding(nn.Module): 8 | def __init__(self, input_dim, emb_dim): 9 | super(NormalEmbedding, self).__init__() 10 | self.input_dim = input_dim 11 | self.emb_dim = emb_dim 12 | self.emb_layer = nn.Linear(input_dim, emb_dim, bias=False) 13 | 14 | # init 15 | nn.init.normal_(self.emb_layer.weight, 0.0, 1.0) 16 | 17 | def increase_input_size(self, new_input_dim): 18 | assert new_input_dim >= self.input_dim 19 | if new_input_dim != self.input_dim: 20 | new_emb_layer = extend_dimensions(self.emb_layer, new_input_dim=new_input_dim, upper=False) 21 | del self.emb_layer 22 | self.emb_layer = new_emb_layer 23 | self.input_dim = new_input_dim 24 | 25 | def forward(self, x): 26 | emb = self.emb_layer(x) 27 | return emb 28 | 29 | class OrthogonalEmbedding(nn.Module): 30 | def __init__(self, input_dim, emb_dim): 31 | super(OrthogonalEmbedding, self).__init__() 32 | self.input_dim = input_dim 33 | self.emb_dim = emb_dim 34 | self.emb_layer = nn.Linear(input_dim, emb_dim, bias=False) 35 | 36 | # init 37 | nn.init.orthogonal_(self.emb_layer.weight) 38 | 39 | def increase_input_size(self, new_input_dim): 40 | assert new_input_dim >= self.input_dim 41 | if new_input_dim != self.input_dim: 42 | new_emb_layer = extend_dimensions(self.emb_layer, new_input_dim=new_input_dim, upper=False) 43 | del self.emb_layer 44 | self.emb_layer = new_emb_layer 45 | self.input_dim = new_input_dim 46 | 47 | def forward(self, x): 48 | emb = self.emb_layer(x) 49 | return emb 50 | 51 | class EquivariantEmbedding(nn.Module): 52 | def __init__(self, input_dim, emb_dim): 53 | super(EquivariantEmbedding, self).__init__() 54 | self.input_dim = input_dim 55 | self.emb_dim = emb_dim 56 | self.emb_layer = nn.Linear(input_dim, emb_dim, bias=False) 57 | 58 | # init 59 | nn.init.normal_(self.emb_layer.weight[:,0], 0.0, 1.0) 60 | emb_column = self.emb_layer.weight[:,0] 61 | with torch.no_grad(): 62 | for i in range(1, self.input_dim): 63 | self.emb_layer.weight[:,i].data.copy_(torch.roll(emb_column, i, 0)) 64 | 65 | def increase_input_size(self, new_input_dim): 66 | assert new_input_dim >= self.input_dim 67 | if new_input_dim != self.input_dim: 68 | new_emb_layer = extend_dimensions(self.emb_layer, new_input_dim=new_input_dim, upper=False) 69 | del self.emb_layer 70 | self.emb_layer = new_emb_layer 71 | self.input_dim = new_input_dim 72 | 73 | def forward(self, x): 74 | emb = self.emb_layer(x) 75 | return emb -------------------------------------------------------------------------------- /graphdownstream/epoch_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Starlien95/GraphPrompt/f27e8328a61a86c3d8f5d0fba4edb890968866be/graphdownstream/epoch_loss.png -------------------------------------------------------------------------------- /graphdownstream/epoch_loss_enzymes1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Starlien95/GraphPrompt/f27e8328a61a86c3d8f5d0fba4edb890968866be/graphdownstream/epoch_loss_enzymes1.png -------------------------------------------------------------------------------- /graphdownstream/filternet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | # class MaxGatedFilterNet(nn.Module): 6 | # def __init__(self, pattern_dim, graph_dim): 7 | # super(MaxGatedFilterNet, self).__init__() 8 | # self.g_layer = nn.Linear(graph_dim, pattern_dim) 9 | # self.f_layer = nn.Linear(pattern_dim, 1) 10 | 11 | # # init 12 | # scale = (1/pattern_dim)**0.5 13 | # nn.init.normal_(self.g_layer.weight, 0.0, scale) 14 | # nn.init.zeros_(self.g_layer.bias) 15 | # nn.init.normal_(self.f_layer.weight, 0.0, scale) 16 | # nn.init.ones_(self.f_layer.bias) 17 | 18 | # def forward(self, p_x, g_x): 19 | # max_x = torch.max(p_x, dim=1, keepdim=True)[0].float() 20 | # g_x = self.g_layer(g_x.float()) 21 | # f = self.f_layer(g_x * max_x) 22 | # return F.sigmoid(f) 23 | 24 | class MaxGatedFilterNet(nn.Module): 25 | def __init__(self): 26 | super(MaxGatedFilterNet, self).__init__() 27 | 28 | def forward(self, p_x, g_x): 29 | max_x = torch.max(p_x, dim=1, keepdim=True)[0] 30 | if max_x.dim() == 2: 31 | return g_x <= max_x 32 | else: 33 | return (g_x <= max_x).all(keepdim=True, dim=2) 34 | 35 | 36 | -------------------------------------------------------------------------------- /graphdownstream/gin.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import dgl 5 | import dgl.function as fn 6 | import copy 7 | from functools import partial 8 | from dgl.nn.pytorch.conv import RelGraphConv 9 | from basemodel import GraphAdjModel 10 | from utils import map_activation_str_to_layer, split_and_batchify_graph_feats,GetAdj 11 | 12 | 13 | class GIN(torch.nn.Module): 14 | def __init__(self, config): 15 | super(GIN, self).__init__() 16 | 17 | self.act=torch.nn.ReLU() 18 | self.g_net, self.bns, g_dim = self.create_net( 19 | name="graph", input_dim=config["node_feature_dim"], hidden_dim=config["gcn_hidden_dim"], 20 | num_layers=config["gcn_graph_num_layers"], num_bases=config["gcn_num_bases"], regularizer=config["gcn_regularizer"]) 21 | self.num_layers_num=config["gcn_graph_num_layers"] 22 | self.dropout=torch.nn.Dropout(p=config["dropout"]) 23 | 24 | def create_net(self, name, input_dim, **kw): 25 | num_layers = kw.get("num_layers", 1) 26 | hidden_dim = kw.get("hidden_dim", 64) 27 | num_rels = kw.get("num_rels", 1) 28 | num_bases = kw.get("num_bases", 8) 29 | regularizer = kw.get("regularizer", "basis") 30 | dropout = kw.get("dropout", 0.5) 31 | 32 | 33 | self.convs = torch.nn.ModuleList() 34 | self.bns = torch.nn.ModuleList() 35 | 36 | for i in range(num_layers): 37 | 38 | if i: 39 | nn = torch.nn.Sequential(torch.nn.Linear(hidden_dim, hidden_dim), self.act, torch.nn.Linear(hidden_dim, hidden_dim)) 40 | else: 41 | nn = torch.nn.Sequential(torch.nn.Linear(input_dim, hidden_dim), self.act, torch.nn.Linear(hidden_dim, hidden_dim)) 42 | conv = dgl.nn.pytorch.conv.GINConv(apply_func=nn,aggregator_type='sum') 43 | bn = torch.nn.BatchNorm1d(hidden_dim) 44 | 45 | self.convs.append(conv) 46 | self.bns.append(bn) 47 | 48 | return self.convs, self.bns, hidden_dim 49 | 50 | 51 | def forward(self, graph, graph_len): 52 | graph_output = graph.ndata["feature"] 53 | xs = [] 54 | for i in range(self.num_layers_num): 55 | graph_output = F.relu(self.convs[i](graph,graph_output)) 56 | graph_output = self.bns[i](graph_output) 57 | graph_output = self.dropout(graph_output) 58 | xs.append(graph_output) 59 | xpool= [] 60 | for x in xs: 61 | graph_embedding = split_and_batchify_graph_feats(x, graph_len)[0] 62 | graph_embedding = torch.sum(graph_embedding, dim=1) 63 | xpool.append(graph_embedding) 64 | x = torch.cat(xpool, -1) 65 | return x,torch.cat(xs, -1) 66 | -------------------------------------------------------------------------------- /graphdownstream/gina: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Starlien95/GraphPrompt/f27e8328a61a86c3d8f5d0fba4edb890968866be/graphdownstream/gina -------------------------------------------------------------------------------- /graphdownstream/graph_prompt_layer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import dgl 5 | import dgl.function as fn 6 | import copy 7 | from functools import partial 8 | from dgl.nn.pytorch.conv import RelGraphConv 9 | from basemodel import GraphAdjModel 10 | from utils import map_activation_str_to_layer, split_and_batchify_graph_feats,GetAdj 11 | 12 | class graph_prompt_layer_mean(nn.Module): 13 | def __init__(self): 14 | super(graph_prompt_layer_mean, self).__init__() 15 | self.weight= torch.nn.Parameter(torch.Tensor(2, 2)) 16 | def forward(self, graph_embedding, graph_len): 17 | graph_embedding=split_and_batchify_graph_feats(graph_embedding, graph_len)[0] 18 | graph_prompt_result=graph_embedding.mean(dim=1) 19 | return graph_prompt_result 20 | 21 | class graph_prompt_layer_linear_mean(nn.Module): 22 | def __init__(self,input_dim,output_dim): 23 | super(graph_prompt_layer_linear_mean, self).__init__() 24 | self.linear=torch.nn.Linear(input_dim,output_dim) 25 | 26 | def forward(self, graph_embedding, graph_len): 27 | graph_embedding=self.linear(graph_embedding) 28 | 29 | graph_embedding=split_and_batchify_graph_feats(graph_embedding, graph_len)[0] 30 | graph_prompt_result=graph_embedding.mean(dim=1) 31 | graph_prompt_result=torch.nn.functional.normalize(graph_prompt_result,dim=1) 32 | return graph_prompt_result 33 | 34 | class graph_prompt_layer_linear_sum(nn.Module): 35 | def __init__(self,input_dim,output_dim): 36 | super(graph_prompt_layer_linear_sum, self).__init__() 37 | self.linear=torch.nn.Linear(input_dim,output_dim) 38 | 39 | def forward(self, graph_embedding, graph_len): 40 | graph_embedding=self.linear(graph_embedding) 41 | 42 | graph_embedding=split_and_batchify_graph_feats(graph_embedding, graph_len)[0] 43 | graph_prompt_result=graph_embedding.sum(dim=1) 44 | graph_prompt_result=torch.nn.functional.normalize(graph_prompt_result,dim=1) 45 | return graph_prompt_result 46 | 47 | 48 | 49 | #sum result is same as mean result 50 | class graph_prompt_layer_sum(nn.Module): 51 | def __init__(self): 52 | super(graph_prompt_layer_sum, self).__init__() 53 | self.weight= torch.nn.Parameter(torch.Tensor(2, 2)) 54 | def forward(self, graph_embedding, graph_len): 55 | graph_embedding=split_and_batchify_graph_feats(graph_embedding, graph_len)[0] 56 | graph_prompt_result=graph_embedding.sum(dim=1) 57 | return graph_prompt_result 58 | 59 | 60 | 61 | class graph_prompt_layer_weighted(nn.Module): 62 | def __init__(self,max_n_num): 63 | super(graph_prompt_layer_weighted, self).__init__() 64 | self.weight= torch.nn.Parameter(torch.Tensor(1,max_n_num)) 65 | self.max_n_num=max_n_num 66 | self.reset_parameters() 67 | def reset_parameters(self): 68 | torch.nn.init.xavier_uniform_(self.weight) 69 | def forward(self, graph_embedding, graph_len): 70 | graph_embedding=split_and_batchify_graph_feats(graph_embedding, graph_len)[0] 71 | weight = self.weight[0][0:graph_embedding.size(1)] 72 | temp1 = torch.ones(graph_embedding.size(0), graph_embedding.size(2), graph_embedding.size(1)).to(graph_embedding.device) 73 | temp1 = weight * temp1 74 | temp1 = temp1.permute(0, 2, 1) 75 | graph_embedding=graph_embedding*temp1 76 | graph_prompt_result=graph_embedding.sum(dim=1) 77 | return graph_prompt_result 78 | 79 | class graph_prompt_layer_feature_weighted_mean(nn.Module): 80 | def __init__(self,input_dim): 81 | super(graph_prompt_layer_feature_weighted_mean, self).__init__() 82 | self.weight= torch.nn.Parameter(torch.Tensor(1,input_dim)) 83 | self.max_n_num=input_dim 84 | self.reset_parameters() 85 | def reset_parameters(self): 86 | torch.nn.init.xavier_uniform_(self.weight) 87 | def forward(self, graph_embedding, graph_len): 88 | graph_embedding=split_and_batchify_graph_feats(graph_embedding, graph_len)[0] 89 | graph_embedding=graph_embedding*self.weight 90 | graph_prompt_result=graph_embedding.mean(dim=1) 91 | return graph_prompt_result 92 | 93 | class graph_prompt_layer_feature_weighted_sum(nn.Module): 94 | def __init__(self,input_dim): 95 | super(graph_prompt_layer_feature_weighted_sum, self).__init__() 96 | self.weight= torch.nn.Parameter(torch.Tensor(1,input_dim)) 97 | self.max_n_num=input_dim 98 | self.reset_parameters() 99 | def reset_parameters(self): 100 | torch.nn.init.xavier_uniform_(self.weight) 101 | def forward(self, graph_embedding, graph_len): 102 | graph_embedding=split_and_batchify_graph_feats(graph_embedding, graph_len)[0] 103 | graph_embedding=graph_embedding*self.weight 104 | graph_prompt_result=graph_embedding.sum(dim=1) 105 | return graph_prompt_result 106 | 107 | class graph_prompt_layer_weighted_matrix(nn.Module): 108 | def __init__(self,max_n_num,input_dim): 109 | super(graph_prompt_layer_weighted_matrix, self).__init__() 110 | self.weight= torch.nn.Parameter(torch.Tensor(input_dim,max_n_num)) 111 | self.max_n_num=max_n_num 112 | self.reset_parameters() 113 | def reset_parameters(self): 114 | torch.nn.init.xavier_uniform_(self.weight) 115 | def forward(self, graph_embedding, graph_len): 116 | graph_embedding=split_and_batchify_graph_feats(graph_embedding, graph_len)[0] 117 | weight = self.weight.permute(1, 0)[0:graph_embedding.size(1)] 118 | weight = weight.expand(graph_embedding.size(0), weight.size(0), weight.size(1)) 119 | graph_embedding = graph_embedding * weight 120 | #prompt: mean 121 | graph_prompt_result=graph_embedding.sum(dim=1) 122 | return graph_prompt_result 123 | 124 | class graph_prompt_layer_weighted_linear(nn.Module): 125 | def __init__(self,max_n_num,input_dim,output_dim): 126 | super(graph_prompt_layer_weighted_linear, self).__init__() 127 | self.weight= torch.nn.Parameter(torch.Tensor(1,max_n_num)) 128 | self.linear=nn.Linear(input_dim,output_dim) 129 | self.max_n_num=max_n_num 130 | self.reset_parameters() 131 | def reset_parameters(self): 132 | torch.nn.init.xavier_uniform_(self.weight) 133 | def forward(self, graph_embedding, graph_len): 134 | graph_embedding=self.linear(graph_embedding) 135 | graph_embedding=split_and_batchify_graph_feats(graph_embedding, graph_len)[0] 136 | weight = self.weight[0][0:graph_embedding.size(1)] 137 | temp1 = torch.ones(graph_embedding.size(0), graph_embedding.size(2), graph_embedding.size(1)).to(graph_embedding.device) 138 | temp1 = weight * temp1 139 | temp1 = temp1.permute(0, 2, 1) 140 | graph_embedding=graph_embedding*temp1 141 | graph_prompt_result = graph_embedding.mean(dim=1) 142 | return graph_prompt_result 143 | 144 | class graph_prompt_layer_weighted_matrix_linear(nn.Module): 145 | def __init__(self,max_n_num,input_dim,output_dim): 146 | super(graph_prompt_layer_weighted_matrix_linear, self).__init__() 147 | self.weight= torch.nn.Parameter(torch.Tensor(output_dim,max_n_num)) 148 | self.linear=nn.Linear(input_dim,output_dim) 149 | self.max_n_num=max_n_num 150 | self.reset_parameters() 151 | def reset_parameters(self): 152 | torch.nn.init.xavier_uniform_(self.weight) 153 | def forward(self, graph_embedding, graph_len): 154 | graph_embedding=self.linear(graph_embedding) 155 | graph_embedding=split_and_batchify_graph_feats(graph_embedding, graph_len)[0] 156 | weight = self.weight.permute(1, 0)[0:graph_embedding.size(1)] 157 | weight = weight.expand(graph_embedding.size(0), weight.size(0), weight.size(1)) 158 | graph_embedding = graph_embedding * weight 159 | graph_prompt_result=graph_embedding.mean(dim=1) 160 | return graph_prompt_result 161 | -------------------------------------------------------------------------------- /graphdownstream/pre_train_before.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import numpy as np 4 | import dgl 5 | import logging 6 | import datetime 7 | import math 8 | import sys 9 | import gc 10 | import json 11 | import time 12 | import torch.nn.functional as F 13 | import warnings 14 | from functools import partial 15 | from collections import OrderedDict 16 | from torch.utils.data import DataLoader 17 | import matplotlib.pyplot as plt 18 | import random 19 | 20 | 21 | try: 22 | from torch.utils.tensorboard import SummaryWriter 23 | except BaseException as e: 24 | from tensorboardX import SummaryWriter 25 | from dataset import Sampler, EdgeSeqDataset, GraphAdjDataset 26 | from utils import anneal_fn, get_enc_len, pretrain_load_data, \ 27 | get_linear_schedule_with_warmup,bp_compute_abmae,compareloss,split_and_batchify_graph_feats 28 | '''from cnn import CNN 29 | from rnn import RNN 30 | from txl import TXL 31 | from rgcn import RGCN 32 | from rgin import RGIN 33 | 34 | from gin import GIN''' 35 | from gin import GIN 36 | 37 | warnings.filterwarnings("ignore") 38 | INF = float("inf") 39 | 40 | train_config = { 41 | "max_npv": 8, # max_number_pattern_vertices: 8, 16, 32 42 | "max_npe": 8, # max_number_pattern_edges: 8, 16, 32 43 | "max_npvl": 8, # max_number_pattern_vertex_labels: 8, 16, 32 44 | "max_npel": 8, # max_number_pattern_edge_labels: 8, 16, 32 45 | 46 | "max_ngv": 126, # max_number_graph_vertices: 64, 512,4096 47 | "max_nge": 298, # max_number_graph_edges: 256, 2048, 16384 48 | "max_ngvl": 7, # max_number_graph_vertex_labels: 16, 64, 256 49 | "max_ngel": 2, # max_number_graph_edge_labels: 16, 64, 256 50 | 51 | "base": 2, 52 | 53 | "gpu_id": 0, 54 | "num_workers": 0, 55 | 56 | "epochs": 400, 57 | "batch_size": 1024, 58 | "update_every": 1, # actual batch_sizer = batch_size * update_every 59 | "print_every": 100, 60 | "init_emb": "Equivariant", # None, Orthogonal, Normal, Equivariant 61 | "share_emb": True, # sharing embedding requires the same vector length 62 | "share_arch": True, # sharing architectures 63 | "dropout": 0.2, 64 | "dropatt": 0.2, 65 | 66 | "reg_loss": "MSE", # MAE, MSEl 67 | "bp_loss": "MSE", # MAE, MSE 68 | "bp_loss_slp": "anneal_cosine$1.0$0.01", # 0, 0.01, logistic$1.0$0.01, linear$1.0$0.01, cosine$1.0$0.01, 69 | # cyclical_logistic$1.0$0.01, cyclical_linear$1.0$0.01, cyclical_cosine$1.0$0.01 70 | # anneal_logistic$1.0$0.01, anneal_linear$1.0$0.01, anneal_cosine$1.0$0.01 71 | "lr": 0.1, 72 | "weight_decay": 0.00001, 73 | "max_grad_norm": 8, 74 | 75 | "model": "GIN", # CNN, RNN, TXL, RGCN, RGIN, RSIN 76 | 77 | "predict_net": "SumPredictNet", # MeanPredictNet, SumPredictNet, MaxPredictNet, 78 | # MeanAttnPredictNet, SumAttnPredictNet, MaxAttnPredictNet, 79 | # MeanMemAttnPredictNet, SumMemAttnPredictNet, MaxMemAttnPredictNet, 80 | # DIAMNet 81 | # "predict_net_add_enc": True, 82 | # "predict_net_add_degree": True, 83 | "predict_net_add_enc": True, 84 | "predict_net_add_degree": True, 85 | 86 | "predict_net_hidden_dim": 128, 87 | "predict_net_num_heads": 4, 88 | "predict_net_mem_len": 4, 89 | "predict_net_mem_init": "mean", 90 | # mean, sum, max, attn, circular_mean, circular_sum, circular_max, circular_attn, lstm 91 | "predict_net_recurrent_steps": 3, 92 | 93 | "emb_dim": 128, 94 | "activation_function": "leaky_relu", # sigmoid, softmax, tanh, relu, leaky_relu, prelu, gelu 95 | 96 | "filter_net": "MaxGatedFilterNet", # None, MaxGatedFilterNet 97 | # MeanAttnPredictNet, SumAttnPredictNet, MaxAttnPredictNet, 98 | # MeanMemAttnPredictNet, SumMemAttnPredictNet, MaxMemAttnPredictNet, 99 | # DIAMNet 100 | # "predict_net_add_enc": True, 101 | # "predict_net_add_degree": True, 102 | "txl_graph_num_layers": 3, 103 | "txl_pattern_num_layers": 3, 104 | "txl_d_model": 128, 105 | "txl_d_inner": 128, 106 | "txl_n_head": 4, 107 | "txl_d_head": 4, 108 | "txl_pre_lnorm": True, 109 | "txl_tgt_len": 64, 110 | "txl_ext_len": 0, # useless in current settings 111 | "txl_mem_len": 64, 112 | "txl_clamp_len": -1, # max positional embedding index 113 | "txl_attn_type": 0, # 0 for Dai et al, 1 for Shaw et al, 2 for Vaswani et al, 3 for Al Rfou et al. 114 | "txl_same_len": False, 115 | 116 | "gcn_num_bases": 8, 117 | "gcn_regularizer": "bdd", # basis, bdd 118 | "gcn_graph_num_layers": 3, 119 | "gcn_hidden_dim": 32, 120 | "gcn_ignore_norm": False, # ignorm=True -> RGCN-SUM 121 | 122 | "graph_dir": "../data/ENZYMES/ENZYMESPreTrain", 123 | "save_data_dir": "../data/ENZYMESPreTrain", 124 | "save_model_dir": "../dumps/ENZYMESPreTrain/TEST", 125 | "save_pretrain_model_dir": "../dumps/MUTAGPreTrain/GCN", 126 | "graphslabel_dir":"../data/ENZYMES/ENZYMES_graph_labels.txt", 127 | "downstream_graph_dir": "../data/debug/graphs", 128 | "downstream_save_data_dir": "../data/debug", 129 | "downstream_save_model_dir": "../dumps/debug", 130 | "downstream_graphslabel_dir":"../data/debug/graphs", 131 | "train_num_per_class": 3, 132 | "shot_num": 2, 133 | "temperature": 1, 134 | "graph_finetuning_input_dim": 8, 135 | "graph_finetuning_output_dim": 2, 136 | "graph_label_num": 6, 137 | "seed": 0, 138 | "model": "GIN", 139 | "dropout": 0.5, 140 | "node_feature_dim": 18, 141 | "pretrain_hop_num": 1 142 | } 143 | 144 | def train(model, optimizer, scheduler, data_type, data_loader, device, config, epoch, logger=None, writer=None): 145 | epoch_step = len(data_loader) 146 | total_step = config["epochs"] * epoch_step 147 | total_reg_loss = 0 148 | total_bp_loss = 0 149 | #total_cnt = 1e-6 150 | 151 | if config["reg_loss"] == "MAE": 152 | reg_crit = lambda pred, target: F.l1_loss(F.relu(pred), target) 153 | elif config["reg_loss"] == "MSE": 154 | reg_crit = lambda pred, target: F.mse_loss(F.relu(pred), target) 155 | elif config["reg_loss"] == "SMSE": 156 | reg_crit = lambda pred, target: F.smooth_l1_loss(F.relu(pred), target) 157 | elif config["reg_loss"] == "ABMAE": 158 | reg_crit = lambda pred, target: bp_compute_abmae(F.leaky_relu(pred), target)+0.8*F.l1_loss(F.relu(pred), target) 159 | else: 160 | raise NotImplementedError 161 | 162 | if config["bp_loss"] == "MAE": 163 | bp_crit = lambda pred, target, neg_slp: F.l1_loss(F.leaky_relu(pred, neg_slp), target) 164 | elif config["bp_loss"] == "MSE": 165 | bp_crit = lambda pred, target, neg_slp: F.mse_loss(F.leaky_relu(pred, neg_slp), target) 166 | elif config["bp_loss"] == "SMSE": 167 | bp_crit = lambda pred, target, neg_slp: F.smooth_l1_loss(F.leaky_relu(pred, neg_slp), target) 168 | elif config["bp_loss"] == "ABMAE": 169 | bp_crit = lambda pred, target, neg_slp: bp_compute_abmae(F.leaky_relu(pred, neg_slp), target)+0.8*F.l1_loss(F.leaky_relu(pred, neg_slp), target, reduce="none") 170 | else: 171 | raise NotImplementedError 172 | 173 | model.train() 174 | total_time=0 175 | for batch_id, batch in enumerate(data_loader): 176 | ids, graph_label, graph, graph_len= batch 177 | # print(batch) 178 | graph=graph.to(device) 179 | graph_label=graph_label.to(device) 180 | graph_len = graph_len.to(device) 181 | s=time.time() 182 | x,pred = model(graph, graph_len) 183 | pred=F.sigmoid(pred) 184 | 185 | adj = graph.adjacency_matrix() 186 | adj = adj.to(device) 187 | if train_config["pretrain_hop_num"]==0: 188 | pred=pred 189 | else: 190 | for count in range(train_config["pretrain_hop_num"]): 191 | pred = torch.matmul(adj, pred) 192 | #print(pred.size()) 193 | _pred=split_and_batchify_graph_feats(pred, graph_len)[0] 194 | sample = graph.ndata['sample'] 195 | _sample=split_and_batchify_graph_feats(sample, graph_len)[0] 196 | sample_=_sample.reshape(_sample.size(0),-1,1) 197 | #print(_pred.size()) 198 | #print(sample_.size()) 199 | _pred=torch.gather(input=_pred,dim=1,index=sample_) 200 | #print(_pred.size()) 201 | _pred=_pred.resize_as(_sample) 202 | #print(_pred.size()) 203 | 204 | reg_loss = compareloss(_pred,train_config["temperature"]) 205 | reg_loss.requires_grad_(True) 206 | # print(reg_loss.size()) 207 | 208 | if isinstance(config["bp_loss_slp"], (int, float)): 209 | neg_slp = float(config["bp_loss_slp"]) 210 | else: 211 | bp_loss_slp, l0, l1 = config["bp_loss_slp"].rsplit("$", 3) 212 | neg_slp = anneal_fn(bp_loss_slp, batch_id + epoch * epoch_step, T=total_step // 4, lambda0=float(l0), 213 | lambda1=float(l1)) 214 | bp_loss = reg_loss 215 | bp_loss.requires_grad_(True) 216 | 217 | 218 | # float 219 | reg_loss_item = reg_loss.item() 220 | bp_loss_item = bp_loss.item() 221 | total_reg_loss += reg_loss_item 222 | total_bp_loss += bp_loss_item 223 | 224 | if writer: 225 | writer.add_scalar("%s/REG-%s" % (data_type, config["reg_loss"]), reg_loss_item, 226 | epoch * epoch_step + batch_id) 227 | writer.add_scalar("%s/BP-%s" % (data_type, config["bp_loss"]), bp_loss_item, epoch * epoch_step + batch_id) 228 | 229 | if logger and (batch_id % config["print_every"] == 0 or batch_id == epoch_step - 1): 230 | logger.info( 231 | "epoch: {:0>3d}/{:0>3d}\tdata_type: {:<5s}\tbatch: {:0>5d}/{:0>5d}\treg loss: {:0>10.3f}\tbp loss: {:0>16.3f}".format( 232 | epoch, config["epochs"], data_type, batch_id, epoch_step, 233 | reg_loss_item, bp_loss_item)) 234 | print(bp_loss.grad) 235 | bp_loss.backward() 236 | if (config["update_every"] < 2 or batch_id % config["update_every"] == 0 or batch_id == epoch_step - 1): 237 | if config["max_grad_norm"] > 0: 238 | torch.nn.utils.clip_grad_norm_(model.parameters(), config["max_grad_norm"]) 239 | if scheduler is not None: 240 | scheduler.step(epoch * epoch_step + batch_id) 241 | optimizer.step() 242 | optimizer.zero_grad() 243 | e=time.time() 244 | total_time+=e-s 245 | #mean_reg_loss = total_reg_loss / total_cnt 246 | #mean_bp_loss = total_bp_loss / total_cnt 247 | mean_reg_loss = total_reg_loss 248 | mean_bp_loss = total_bp_loss 249 | if writer: 250 | writer.add_scalar("%s/REG-%s-epoch" % (data_type, config["reg_loss"]), mean_reg_loss, epoch) 251 | writer.add_scalar("%s/BP-%s-epoch" % (data_type, config["bp_loss"]), mean_bp_loss, epoch) 252 | if logger: 253 | logger.info("epoch: {:0>3d}/{:0>3d}\tdata_type: {:<5s}\treg loss: {:0>10.3f}\tbp loss: {:0>16.3f}".format( 254 | epoch, config["epochs"], data_type, mean_reg_loss, mean_bp_loss)) 255 | 256 | gc.collect() 257 | return mean_reg_loss, mean_bp_loss, total_time 258 | 259 | 260 | def evaluate(model, data_type, data_loader, device, config, epoch, logger=None, writer=None): 261 | epoch_step = len(data_loader) 262 | total_reg_loss = 0 263 | #total_cnt = 1e-6 264 | 265 | evaluate_results = {"data": {"id": list(), "counts": list(), "pred": list()}, 266 | "error": {"mae": INF, "mse": INF}, 267 | "time": {"avg": list(), "total": 0.0}} 268 | 269 | if config["reg_loss"] == "MAE": 270 | reg_crit = lambda pred, target: F.l1_loss(F.relu(pred), target, reduce="none") 271 | elif config["reg_loss"] == "MSE": 272 | reg_crit = lambda pred, target: F.mse_loss(F.relu(pred), target, reduce="none") 273 | elif config["reg_loss"] == "SMSE": 274 | reg_crit = lambda pred, target: F.smooth_l1_loss(F.relu(pred), target, reduce="none") 275 | elif config["reg_loss"] == "ABMAE": 276 | reg_crit = lambda pred, target: bp_compute_abmae(F.relu(pred), target)+0.8*F.l1_loss(F.relu(pred), target, reduce="none") 277 | else: 278 | raise NotImplementedError 279 | 280 | if config["bp_loss"] == "MAE": 281 | bp_crit = lambda pred, target, neg_slp: F.l1_loss(F.leaky_relu(pred, neg_slp), target, reduce="none") 282 | elif config["bp_loss"] == "MSE": 283 | bp_crit = lambda pred, target, neg_slp: F.mse_loss(F.leaky_relu(pred, neg_slp), target, reduce="none") 284 | elif config["bp_loss"] == "SMSE": 285 | bp_crit = lambda pred, target, neg_slp: F.smooth_l1_loss(F.leaky_relu(pred, neg_slp), target, reduce="none") 286 | elif config["bp_loss"] == "ABMAE": 287 | bp_crit = lambda pred, target, neg_slp: bp_compute_abmae(F.leaky_relu(pred, neg_slp), target)+0.8*F.l1_loss(F.leaky_relu(pred, neg_slp), target, reduce="none") 288 | else: 289 | raise NotImplementedError 290 | 291 | model.eval() 292 | total_time=0 293 | with torch.no_grad(): 294 | for batch_id, batch in enumerate(data_loader): 295 | ids, graph_label, graph, graph_len= batch 296 | #cnt = counts.shape[0] 297 | #total_cnt += cnt 298 | 299 | graph = graph.to(device) 300 | graph_label=graph_label.to(device) 301 | graph_len = graph_len.to(device) 302 | st = time.time() 303 | pred = model(graph, graph_len) 304 | adj = graph.adjacency_matrix() 305 | adj = adj.to(device) 306 | pred = torch.matmul(adj, pred) 307 | sample = graph.ndata['sample'] 308 | _sample = sample.reshape(-1, 1) 309 | pred = torch.gather(input=pred, dim=0, index=_sample) 310 | pred = pred.resize_as(sample) 311 | 312 | et=time.time() 313 | evaluate_results["time"]["total"] += (et - st) 314 | #avg_t = (et - st) / (cnt + 1e-8) 315 | #evaluate_results["time"]["avg"].extend([avg_t] * cnt) 316 | #evaluate_results["data"]["pred"].extend(pred.cpu().view(-1).tolist()) 317 | 318 | reg_loss = compareloss(pred, train_config["temperature"]) 319 | reg_loss_item = reg_loss.item() 320 | 321 | if writer: 322 | writer.add_scalar("%s/REG-%s" % (data_type, config["reg_loss"]), reg_loss_item, 323 | epoch * epoch_step + batch_id) 324 | '''writer.add_scalar("%s/BP-%s" % (data_type, config["bp_loss"]), bp_loss_item, 325 | epoch * epoch_step + batch_id)''' 326 | 327 | if logger and batch_id == epoch_step - 1: 328 | logger.info( 329 | "epoch: {:0>3d}/{:0>3d}\tdata_type: {:<5s}\tbatch: {:0>5d}/{:0>5d}\treg loss: {:0>10.3f}".format( 330 | epoch, config["epochs"], data_type, batch_id, epoch_step, 331 | reg_loss_item)) 332 | et=time.time() 333 | total_time+=et-st 334 | total_reg_loss+=reg_loss_item 335 | mean_reg_loss = total_reg_loss 336 | #mean_bp_loss = total_bp_loss / total_cnt 337 | if writer: 338 | writer.add_scalar("%s/REG-%s-epoch" % (data_type, config["reg_loss"]), mean_reg_loss, epoch) 339 | #writer.add_scalar("%s/BP-%s-epoch" % (data_type, config["bp_loss"]), mean_bp_loss, epoch) 340 | if logger: 341 | logger.info("epoch: {:0>3d}/{:0>3d}\tdata_type: {:<5s}\treg loss: {:0>10.3f}".format( 342 | epoch, config["epochs"], data_type, mean_reg_loss)) 343 | 344 | evaluate_results["error"]["loss"] = mean_reg_loss 345 | #evaluate_results["error"]["mse"] = evaluate_results["error"]["mse"] / total_cnt 346 | 347 | gc.collect() 348 | return mean_reg_loss,0,evaluate_results, total_time 349 | 350 | 351 | if __name__ == "__main__": 352 | for i in range(1, len(sys.argv), 2): 353 | arg = sys.argv[i] 354 | value = sys.argv[i + 1] 355 | 356 | if arg.startswith("--"): 357 | arg = arg[2:] 358 | if arg not in train_config: 359 | print("Warning: %s is not surported now." % (arg)) 360 | continue 361 | train_config[arg] = value 362 | try: 363 | value = eval(value) 364 | if isinstance(value, (int, float)): 365 | train_config[arg] = value 366 | except: 367 | pass 368 | 369 | torch.manual_seed(train_config["seed"]) 370 | np.random.seed(train_config["seed"]) 371 | 372 | ts = datetime.datetime.now().strftime('%Y_%m_%d_%H_%M_%S') 373 | model_name = "%s_%s_%s" % (train_config["model"], train_config["predict_net"], ts) 374 | save_model_dir = train_config["save_model_dir"] 375 | os.makedirs(save_model_dir, exist_ok=True) 376 | 377 | # save config 378 | with open(os.path.join(save_model_dir, "train_config.json"), "w") as f: 379 | json.dump(train_config, f) 380 | 381 | # set logger 382 | logger = logging.getLogger() 383 | logger.setLevel(logging.INFO) 384 | fmt = logging.Formatter('%(asctime)s: [ %(message)s ]', '%Y/%m/%d %H:%M:%S') 385 | console = logging.StreamHandler() 386 | console.setFormatter(fmt) 387 | logger.addHandler(console) 388 | logfile = logging.FileHandler(os.path.join(save_model_dir, "train_log.txt"), 'w') 389 | logfile.setFormatter(fmt) 390 | logger.addHandler(logfile) 391 | 392 | # set device 393 | device = torch.device("cuda:%d" % train_config["gpu_id"] if train_config["gpu_id"] != -1 else "cpu") 394 | if train_config["gpu_id"] != -1: 395 | torch.cuda.set_device(device) 396 | 397 | # reset the pattern parameters 398 | if train_config["share_emb"]: 399 | train_config["max_npv"], train_config["max_npvl"], train_config["max_npe"], train_config["max_npel"] = \ 400 | train_config["max_ngv"], train_config["max_ngvl"], train_config["max_nge"], train_config["max_ngel"] 401 | 402 | 403 | if train_config["model"] == "GCN": 404 | model = GCN(train_config) 405 | if train_config["model"] == "GIN": 406 | model = GIN(train_config) 407 | if train_config["model"] == "GAT": 408 | model = GAT(train_config) 409 | if train_config["model"] == "GraphSage": 410 | model = Graphsage(train_config) 411 | 412 | model = model.to(device) 413 | logger.info(model) 414 | logger.info("num of parameters: %d" % (sum(p.numel() for p in model.parameters() if p.requires_grad))) 415 | 416 | # load data 417 | os.makedirs(train_config["save_data_dir"], exist_ok=True) 418 | data_loaders = OrderedDict({"train": None, "dev": None}) 419 | if all([os.path.exists(os.path.join(train_config["save_data_dir"], 420 | "%s_%s_dataset.pt" % ( 421 | data_type, "dgl" if train_config["model"] in ["RGCN", "RGIN", 422 | "GAT","GCN","GraphSage","GIN"] else "edgeseq"))) 423 | for data_type in data_loaders]): 424 | 425 | logger.info("loading data from pt...") 426 | for data_type in data_loaders: 427 | if train_config["model"] in ["RGCN", "RGIN", "GAT","GCN","GraphSage","GIN"]: 428 | dataset = GraphAdjDataset(list()) 429 | dataset.load(os.path.join(train_config["save_data_dir"], "%s_dgl_dataset.pt" % (data_type))) 430 | 431 | sampler = Sampler(dataset, group_by=["graph"], batch_size=train_config["batch_size"], 432 | shuffle=data_type == "train", drop_last=False) 433 | data_loader = DataLoader(dataset, 434 | batch_sampler=sampler, 435 | collate_fn=GraphAdjDataset.batchify, 436 | pin_memory=data_type == "train") 437 | else: 438 | dataset = EdgeSeqDataset(list()) 439 | dataset.load(os.path.join(train_config["save_data_dir"], "%s_edgeseq_dataset.pt" % (data_type))) 440 | sampler = Sampler(dataset, group_by=["graph", "pattern"], batch_size=train_config["batch_size"], 441 | shuffle=data_type == "train", drop_last=False) 442 | data_loader = DataLoader(dataset, 443 | batch_sampler=sampler, 444 | collate_fn=EdgeSeqDataset.batchify, 445 | pin_memory=data_type == "train") 446 | data_loaders[data_type] = data_loader 447 | logger.info("data (data_type: {:<5s}, len: {}) generated".format(data_type, len(dataset.data))) 448 | logger.info( 449 | "data_loader (data_type: {:<5s}, len: {}, batch_size: {}) generated".format(data_type, len(data_loader), 450 | train_config["batch_size"])) 451 | else: 452 | data = pretrain_load_data(train_config["graph_dir"], train_config["graphslabel_dir"], num_workers=train_config["num_workers"]) 453 | logger.info("{}/{}/{} data loaded".format(len(data["train"]), len(data["dev"]), len(data["test"]))) 454 | for data_type, x in data.items(): 455 | if train_config["model"] in ["RGCN", "RGIN", "GAT","GCN","GraphSage","GIN"]: 456 | if os.path.exists(os.path.join(train_config["save_data_dir"], "%s_dgl_dataset.pt" % (data_type))): 457 | dataset = GraphAdjDataset(list()) 458 | dataset.load(os.path.join(train_config["save_data_dir"], "%s_dgl_dataset.pt" % (data_type))) 459 | else: 460 | if(data_type=="test"): 461 | test_time_start=time.time() 462 | elif(data_type=="train"): 463 | train_time_start = time.time() 464 | else: 465 | val_time_start = time.time() 466 | dataset = GraphAdjDataset(x) 467 | if(data_type=="test"): 468 | test_time_end=time.time() 469 | test_time=test_time_end-test_time_start 470 | logger.info( 471 | "preprocess test time: {:.3f}".format(test_time)) 472 | elif(data_type=="train"): 473 | train_time_end=time.time() 474 | train_time=train_time_end-train_time_start 475 | logger.info( 476 | "preprocess train time: {:.3f}".format(train_time)) 477 | else: 478 | val_time_end=time.time() 479 | val_time=val_time_end-val_time_start 480 | logger.info( 481 | "preprocess val time: {:.3f}".format(val_time)) 482 | dataset.save(os.path.join(train_config["save_data_dir"], "%s_dgl_dataset.pt" % (data_type))) 483 | sampler = Sampler(dataset, group_by=["graph"], batch_size=train_config["batch_size"], 484 | shuffle=data_type == "train", drop_last=False) 485 | data_loader = DataLoader(dataset, 486 | batch_sampler=sampler, 487 | collate_fn=GraphAdjDataset.batchify, 488 | pin_memory=data_type == "train") 489 | else: 490 | if os.path.exists(os.path.join(train_config["save_data_dir"], "%s_edgeseq_dataset.pt" % (data_type))): 491 | dataset = EdgeSeqDataset(list()) 492 | dataset.load(os.path.join(train_config["save_data_dir"], "%s_edgeseq_dataset.pt" % (data_type))) 493 | else: 494 | dataset = EdgeSeqDataset(x) 495 | dataset.save(os.path.join(train_config["save_data_dir"], "%s_edgeseq_dataset.pt" % (data_type))) 496 | sampler = Sampler(dataset, group_by=["graph", "pattern"], batch_size=train_config["batch_size"], 497 | shuffle=data_type == "train", drop_last=False) 498 | data_loader = DataLoader(dataset, 499 | batch_sampler=sampler, 500 | collate_fn=EdgeSeqDataset.batchify, 501 | pin_memory=data_type == "train") 502 | data_loaders[data_type] = data_loader 503 | logger.info("data (data_type: {:<5s}, len: {}) generated".format(data_type, len(dataset.data))) 504 | logger.info( 505 | "data_loader (data_type: {:<5s}, len: {}, batch_size: {}) generated".format(data_type, len(data_loader), 506 | train_config["batch_size"])) 507 | 508 | print('data_loaders', data_loaders.items()) 509 | 510 | # optimizer and losses 511 | writer = SummaryWriter(save_model_dir) 512 | optimizer = torch.optim.AdamW(model.parameters(), lr=train_config["lr"], weight_decay=train_config["weight_decay"], 513 | amsgrad=True) 514 | optimizer.zero_grad() 515 | scheduler = None 516 | best_reg_losses = {"train": INF, "dev": INF, "test": INF} 517 | best_reg_epochs = {"train": -1, "dev": -1, "test": -1} 518 | 519 | total_train_time=0 520 | total_dev_time=0 521 | total_test_time=0 522 | 523 | plt_x=list() 524 | plt_y=list() 525 | 526 | for epoch in range(train_config["epochs"]): 527 | for data_type, data_loader in data_loaders.items(): 528 | 529 | if data_type == "train": 530 | mean_reg_loss, mean_bp_loss, _time = train(model, optimizer, scheduler, data_type, data_loader, device, 531 | train_config, epoch, logger=logger, writer=writer) 532 | total_train_time+=_time 533 | torch.save(model.state_dict(), os.path.join(save_model_dir, 'epoch%d.pt' % (epoch))) 534 | else: 535 | mean_reg_loss, mean_bp_loss, evaluate_results, _time = evaluate(model, data_type, data_loader, device, 536 | train_config, epoch, logger=logger, 537 | writer=writer) 538 | total_dev_time+=_time 539 | with open(os.path.join(save_model_dir, '%s%d.json' % (data_type, epoch)), "w") as f: 540 | json.dump(evaluate_results, f) 541 | if mean_reg_loss <= best_reg_losses[data_type]: 542 | best_reg_losses[data_type] = mean_reg_loss 543 | best_reg_epochs[data_type] = epoch 544 | logger.info( 545 | "data_type: {:<5s}\tbest mean loss: {:.3f} (epoch: {:0>3d})".format(data_type, mean_reg_loss, 546 | epoch)) 547 | if data_type == "train": 548 | plt_x.append(epoch) 549 | plt_y.append(mean_reg_loss) 550 | 551 | plt.figure(1) 552 | plt.plot(plt_x,plt_y) 553 | plt.savefig('epoch_loss.png') 554 | for data_type in data_loaders.keys(): 555 | logger.info( 556 | "data_type: {:<5s}\tbest mean loss: {:.3f} (epoch: {:0>3d})".format(data_type, best_reg_losses[data_type], 557 | best_reg_epochs[data_type])) 558 | 559 | best_epoch = train_config["epochs"]-1 560 | model.load_state_dict(torch.load(os.path.join(save_model_dir, 'epoch%d.pt' % (best_epoch)))) 561 | torch.save(model.state_dict(), os.path.join(save_model_dir, "best.pt")) 562 | -------------------------------------------------------------------------------- /graphdownstream/test.ipynb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Starlien95/GraphPrompt/f27e8328a61a86c3d8f5d0fba4edb890968866be/graphdownstream/test.ipynb -------------------------------------------------------------------------------- /nodedownstream/ENZYMES2ONE_Graph.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.optim as optim 5 | import math 6 | import numpy as np 7 | import re 8 | import os 9 | import sys 10 | import json 11 | from torch.optim.lr_scheduler import LambdaLR 12 | from collections import OrderedDict 13 | from multiprocessing import Pool 14 | from tqdm import tqdm 15 | from sklearn.metrics import accuracy_score,f1_score,precision_score,recall_score 16 | import random 17 | from sklearn.metrics import precision_recall_fscore_support 18 | import functools 19 | import dgl 20 | 21 | 22 | def igraph_node_feature2dgl_node_feature(label,input): 23 | a=input 24 | b = a.split('[')[1] 25 | c = b.split(']')[0] 26 | d = c.split('\\n\', ') 27 | fix=d[len(d)-1] 28 | fix=fix.split('\\')[0] 29 | d[len(d)-1]=fix 30 | nodefeaturestring = [] 31 | for nodefeature in d: 32 | temp = nodefeature.split('\'')[1] 33 | temp = temp.split(',') 34 | nodefeaturestring.append(temp) 35 | 36 | numbers_float = [] 37 | for num in nodefeaturestring: 38 | temp = [] 39 | for data in num: 40 | temp.append(float(data)) 41 | numbers_float.append(temp) 42 | return label,numbers_float 43 | 44 | def FUCK_U_IGraphLoad(path,graph_attr_num): 45 | with open(path, "r") as f: 46 | data=f.readlines() 47 | count=0 48 | for line in data: 49 | gattr=line.split() 50 | if gattr[0]=="label": 51 | label=int(gattr[1]) 52 | count+=1 53 | if gattr[0]=="feature": 54 | feature=line.split("feature")[1] 55 | count+=1 56 | if count==graph_attr_num: 57 | return label, feature 58 | 59 | def FUCK_IGraphLoad(path,graph_attr_num): 60 | label,feature=FUCK_U_IGraphLoad(path,graph_attr_num) 61 | return igraph_node_feature2dgl_node_feature(label,feature) 62 | 63 | def ReSetNodeId(startid,edgelist): 64 | count=0 65 | for edge in edgelist: 66 | src,dst=edge 67 | src+=startid 68 | dst+=startid 69 | edgelist[count]=(src,dst) 70 | count+=1 71 | return edgelist 72 | 73 | 74 | def _read_graphs_from_dir(dirpath): 75 | import igraph as ig 76 | graph = ig.Graph() 77 | count=0 78 | for filename in os.listdir(dirpath): 79 | if not os.path.isdir(os.path.join(dirpath, filename)): 80 | names = os.path.splitext(os.path.basename(filename)) 81 | if names[1] != ".gml": 82 | continue 83 | try: 84 | if count==0: 85 | _graph = ig.read(os.path.join(dirpath, filename)) 86 | label,feature=FUCK_IGraphLoad(os.path.join(dirpath, filename),2) 87 | _graph.vs["label"] = [int(x) for x in _graph.vs["label"]] 88 | _graph.es["label"] = [int(x) for x in _graph.es["label"]] 89 | _graph.es["key"] = [int(x) for x in _graph.es["key"]] 90 | _graph["feature"]=feature 91 | graph=_graph 92 | count+=1 93 | else: 94 | _graph = ig.read(os.path.join(dirpath, filename)) 95 | label,feature=FUCK_IGraphLoad(os.path.join(dirpath, filename),2) 96 | _graph.vs["label"] = [int(x) for x in _graph.vs["label"]] 97 | _graph.es["label"] = [int(x) for x in _graph.es["label"]] 98 | _graph.es["key"] = [int(x) for x in _graph.es["key"]] 99 | _graph["feature"]=feature 100 | _graph_nodelabel=_graph.vs["label"] 101 | graph_nodelabel=graph.vs["label"] 102 | new_nodelabel=graph_nodelabel+_graph_nodelabel 103 | _graph_edgelabel=_graph.es["label"] 104 | graph_edgelabel=graph.es["label"] 105 | new_edgelabel=graph_edgelabel+_graph_edgelabel 106 | _graph_edgekey=_graph.es["key"] 107 | graph_edgekey=graph.es["key"] 108 | new_edgekey=graph_edgekey+_graph_edgekey 109 | 110 | graph_nodenum=graph.vcount() 111 | _graph_nodenum=_graph.vcount() 112 | graph.add_vertices(_graph_nodenum) 113 | _graphedge=_graph.get_edgelist() 114 | _graphedge=ReSetNodeId(graph_nodenum,_graphedge) 115 | graph.add_edges(_graphedge) 116 | graph.vs["label"]=new_nodelabel 117 | graph.es["label"]=new_edgelabel 118 | graph.es["key"]=new_edgekey 119 | graph["feature"]=graph["feature"]+_graph["feature"] 120 | 121 | except BaseException as e: 122 | print(e) 123 | break 124 | return graph 125 | 126 | def graph2dglgraph(graph): 127 | dglgraph = dgl.DGLGraph(multigraph=True) 128 | dglgraph.add_nodes(graph.vcount()) 129 | edges = graph.get_edgelist() 130 | dglgraph.add_edges([e[0] for e in edges], [e[1] for e in edges]) 131 | dglgraph.readonly(True) 132 | return dglgraph 133 | 134 | def dglpreprocess(x): 135 | graph = x 136 | graph_dglgraph = graph2dglgraph(graph) 137 | graph_dglgraph.ndata["indeg"] = torch.tensor(np.array(graph.indegree(), dtype=np.float32)) 138 | graph_dglgraph.ndata["label"] = torch.tensor(np.array(graph.vs["label"], dtype=np.int64)) 139 | graph_dglgraph.ndata["id"] = torch.tensor(np.arange(0, graph.vcount(), dtype=np.int64)) 140 | nodefeature=graph["feature"] 141 | graph_dglgraph.ndata["feature"]=torch.tensor(np.array(nodefeature, dtype=np.float32)) 142 | return graph_dglgraph 143 | 144 | def read_graphs_from_dir(dirpath): 145 | import igraph as ig 146 | ret=[] 147 | for filename in os.listdir(dirpath): 148 | if not os.path.isdir(os.path.join(dirpath, filename)): 149 | names = os.path.splitext(os.path.basename(filename)) 150 | if names[1] != ".gml": 151 | continue 152 | try: 153 | graph = ig.read(os.path.join(dirpath, filename)) 154 | label,feature=FUCK_IGraphLoad(os.path.join(dirpath, filename),2) 155 | graph.vs["label"] = [int(x) for x in graph.vs["label"]] 156 | graph.es["label"] = [int(x) for x in graph.es["label"]] 157 | graph.es["key"] = [int(x) for x in graph.es["key"]] 158 | graph["label"]=label 159 | graph["feature"]=feature 160 | ret.append(graph) 161 | except BaseException as e: 162 | print(e) 163 | break 164 | return ret 165 | 166 | 167 | 168 | if __name__ == "__main__": 169 | assert len(sys.argv) == 2 170 | nci1_data_path = sys.argv[1] 171 | save_path="../data/ENZYMES/test_allinone" 172 | graph=_read_graphs_from_dir(nci1_data_path) 173 | dglgraph=dglpreprocess(graph) 174 | dgl.data.utils.save_graphs(os.path.join(save_path,"graph"),dglgraph) 175 | g=dgl.load_graphs(os.path.join(save_path,"graph"))[0][0] 176 | print(g) 177 | print(g.number_of_nodes()) 178 | 179 | def Raw2OneGraph(raw_data,save_data): 180 | nci1_data_path = raw_data 181 | save_path=save_data 182 | graphs=read_graphs_from_dir(nci1_data_path) 183 | count=0 184 | for graph in graphs: 185 | print("process graph ",count) 186 | dglgraph=dglpreprocess(graph) 187 | if countlabelnum(dglgraph)!=1: 188 | dgl.data.utils.save_graphs(os.path.join(save_path,str(count)),dglgraph) 189 | count+=1 190 | return count 191 | 192 | def countlabelnum(graph): 193 | count=torch.zeros(3) 194 | for i in graph.ndata["label"]: 195 | count[i]=1 196 | return count.count_nonzero() -------------------------------------------------------------------------------- /nodedownstream/__pycache__/ENZYMES2ONE_Graph.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Starlien95/GraphPrompt/f27e8328a61a86c3d8f5d0fba4edb890968866be/nodedownstream/__pycache__/ENZYMES2ONE_Graph.cpython-36.pyc -------------------------------------------------------------------------------- /nodedownstream/__pycache__/basemodel.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Starlien95/GraphPrompt/f27e8328a61a86c3d8f5d0fba4edb890968866be/nodedownstream/__pycache__/basemodel.cpython-36.pyc -------------------------------------------------------------------------------- /nodedownstream/__pycache__/dataset.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Starlien95/GraphPrompt/f27e8328a61a86c3d8f5d0fba4edb890968866be/nodedownstream/__pycache__/dataset.cpython-36.pyc -------------------------------------------------------------------------------- /nodedownstream/__pycache__/embedding.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Starlien95/GraphPrompt/f27e8328a61a86c3d8f5d0fba4edb890968866be/nodedownstream/__pycache__/embedding.cpython-36.pyc -------------------------------------------------------------------------------- /nodedownstream/__pycache__/filternet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Starlien95/GraphPrompt/f27e8328a61a86c3d8f5d0fba4edb890968866be/nodedownstream/__pycache__/filternet.cpython-36.pyc -------------------------------------------------------------------------------- /nodedownstream/__pycache__/gin.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Starlien95/GraphPrompt/f27e8328a61a86c3d8f5d0fba4edb890968866be/nodedownstream/__pycache__/gin.cpython-36.pyc -------------------------------------------------------------------------------- /nodedownstream/__pycache__/node_finetuning_layer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Starlien95/GraphPrompt/f27e8328a61a86c3d8f5d0fba4edb890968866be/nodedownstream/__pycache__/node_finetuning_layer.cpython-36.pyc -------------------------------------------------------------------------------- /nodedownstream/__pycache__/node_prompt_layer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Starlien95/GraphPrompt/f27e8328a61a86c3d8f5d0fba4edb890968866be/nodedownstream/__pycache__/node_prompt_layer.cpython-36.pyc -------------------------------------------------------------------------------- /nodedownstream/__pycache__/predictnet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Starlien95/GraphPrompt/f27e8328a61a86c3d8f5d0fba4edb890968866be/nodedownstream/__pycache__/predictnet.cpython-36.pyc -------------------------------------------------------------------------------- /nodedownstream/__pycache__/split.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Starlien95/GraphPrompt/f27e8328a61a86c3d8f5d0fba4edb890968866be/nodedownstream/__pycache__/split.cpython-36.pyc -------------------------------------------------------------------------------- /nodedownstream/__pycache__/utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Starlien95/GraphPrompt/f27e8328a61a86c3d8f5d0fba4edb890968866be/nodedownstream/__pycache__/utils.cpython-36.pyc -------------------------------------------------------------------------------- /nodedownstream/basemodel.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | import copy 6 | import numpy as np 7 | from utils import int2onehot 8 | from utils import get_enc_len, split_and_batchify_graph_feats, batch_convert_len_to_mask 9 | from embedding import OrthogonalEmbedding, NormalEmbedding, EquivariantEmbedding 10 | from filternet import MaxGatedFilterNet 11 | from predictnet import MeanPredictNet, SumPredictNet, MaxPredictNet, \ 12 | MeanAttnPredictNet, SumAttnPredictNet, MaxAttnPredictNet, \ 13 | MeanMemAttnPredictNet, SumMemAttnPredictNet, MaxMemAttnPredictNet, \ 14 | DIAMNet 15 | 16 | class BaseModel(nn.Module): 17 | def __init__(self, config): 18 | super(BaseModel, self).__init__() 19 | 20 | self.act_func = config["activation_function"] 21 | self.init_emb = config["init_emb"] 22 | self.share_emb = config["share_emb"] 23 | self.share_arch = config["share_arch"] 24 | self.base = config["base"] 25 | self.max_ngv = config["max_ngv"] 26 | self.max_ngvl = config["max_ngvl"] 27 | self.max_nge = config["max_nge"] 28 | self.max_ngel = config["max_ngel"] 29 | self.max_npv = config["max_npv"] 30 | self.max_npvl = config["max_npvl"] 31 | self.max_npe = config["max_npe"] 32 | self.max_npel = config["max_npel"] 33 | 34 | self.emb_dim = config["emb_dim"] 35 | self.dropout = config["dropout"] 36 | self.dropatt = config["dropatt"] 37 | self.add_enc = config["predict_net_add_enc"] 38 | 39 | # create encoding layer 40 | # create filter layers 41 | # create embedding layers 42 | # create networks 43 | #self.p_net, self.g_net = None, None 44 | self.g_net = None 45 | 46 | # create predict layers 47 | self.predict_net = None 48 | 49 | def get_emb_dim(self): 50 | if self.init_emb == "None": 51 | return self.get_enc_dim() 52 | else: 53 | return self.emb_dim 54 | 55 | def get_enc(self, graph, graph_len): 56 | raise NotImplementedError 57 | 58 | def get_emb(self, graph, graph_len): 59 | raise NotImplementedError 60 | 61 | def get_filter_gate(self, graph, graph_len): 62 | raise NotImplementedError 63 | 64 | def create_filter(self, filter_type): 65 | if filter_type == "None": 66 | filter_net = None 67 | elif filter_type == "MaxGatedFilterNet": 68 | filter_net = MaxGatedFilterNet() 69 | else: 70 | raise NotImplementedError("Currently, %s is not supported!" % (filter_type)) 71 | return filter_net 72 | 73 | def create_enc(self, max_n, base): 74 | enc_len = get_enc_len(max_n-1, base) 75 | enc_dim = enc_len * base 76 | enc = nn.Embedding(max_n, enc_dim) 77 | enc.weight.data.copy_(torch.from_numpy(int2onehot(np.arange(0, max_n), enc_len, base))) 78 | enc.weight.requires_grad=False 79 | return enc 80 | 81 | def create_emb(self, input_dim, emb_dim, init_emb="Orthogonal"): 82 | if init_emb == "None": 83 | emb = None 84 | elif init_emb == "Orthogonal": 85 | emb = OrthogonalEmbedding(input_dim, emb_dim) 86 | elif init_emb == "Normal": 87 | emb = NormalEmbedding(input_dim, emb_dim) 88 | elif init_emb == "Equivariant": 89 | emb = EquivariantEmbedding(input_dim, emb_dim) 90 | else: 91 | raise NotImplementedError 92 | return emb 93 | 94 | def create_net(self, name, input_dim, **kw): 95 | raise NotImplementedError 96 | 97 | def create_predict_net(self, predict_type, pattern_dim, graph_dim, **kw): 98 | if predict_type == "None": 99 | predict_net = None 100 | elif predict_type == "MeanPredictNet": 101 | hidden_dim = kw.get("hidden_dim", 64) 102 | predict_net = MeanPredictNet(pattern_dim, graph_dim, hidden_dim, 103 | act_func=self.act_func, dropout=self.dropout) 104 | elif predict_type == "SumPredictNet": 105 | hidden_dim = kw.get("hidden_dim", 64) 106 | predict_net = SumPredictNet(pattern_dim, graph_dim, hidden_dim, 107 | act_func=self.act_func, dropout=self.dropout) 108 | elif predict_type == "MaxPredictNet": 109 | hidden_dim = kw.get("hidden_dim", 64) 110 | predict_net = MaxPredictNet(pattern_dim, graph_dim, hidden_dim, 111 | act_func=self.act_func, dropout=self.dropout) 112 | elif predict_type == "MeanAttnPredictNet": 113 | hidden_dim = kw.get("hidden_dim", 64) 114 | recurrent_steps = kw.get("recurrent_steps", 1) 115 | num_heads = kw.get("num_heads", 1) 116 | predict_net = MeanAttnPredictNet(pattern_dim, graph_dim, hidden_dim, 117 | act_func=self.act_func, 118 | num_heads=num_heads, recurrent_steps=recurrent_steps, 119 | dropout=self.dropout, dropatt=self.dropatt) 120 | elif predict_type == "SumAttnPredictNet": 121 | hidden_dim = kw.get("hidden_dim", 64) 122 | recurrent_steps = kw.get("recurrent_steps", 1) 123 | num_heads = kw.get("num_heads", 1) 124 | predict_net = SumAttnPredictNet(pattern_dim, graph_dim, hidden_dim, 125 | act_func=self.act_func, 126 | num_heads=num_heads, recurrent_steps=recurrent_steps, 127 | dropout=self.dropout, dropatt=self.dropatt) 128 | elif predict_type == "MaxAttnPredictNet": 129 | hidden_dim = kw.get("hidden_dim", 64) 130 | recurrent_steps = kw.get("recurrent_steps", 1) 131 | num_heads = kw.get("num_heads", 1) 132 | predict_net = MaxAttnPredictNet(pattern_dim, graph_dim, hidden_dim, 133 | act_func=self.act_func, 134 | num_heads=num_heads, recurrent_steps=recurrent_steps, 135 | dropout=self.dropout, dropatt=self.dropatt) 136 | elif predict_type == "MeanMemAttnPredictNet": 137 | hidden_dim = kw.get("hidden_dim", 64) 138 | recurrent_steps = kw.get("recurrent_steps", 1) 139 | num_heads = kw.get("num_heads", 1) 140 | mem_len = kw.get("mem_len", 4) 141 | predict_net = MeanMemAttnPredictNet(pattern_dim, graph_dim, hidden_dim, 142 | act_func=self.act_func, 143 | num_heads=num_heads, recurrent_steps=recurrent_steps, 144 | mem_len=mem_len, 145 | dropout=self.dropout, dropatt=self.dropatt) 146 | elif predict_type == "SumMemAttnPredictNet": 147 | hidden_dim = kw.get("hidden_dim", 64) 148 | recurrent_steps = kw.get("recurrent_steps", 1) 149 | num_heads = kw.get("num_heads", 1) 150 | mem_len = kw.get("mem_len", 4) 151 | predict_net = SumMemAttnPredictNet(pattern_dim, graph_dim, hidden_dim, 152 | act_func=self.act_func, 153 | num_heads=num_heads, recurrent_steps=recurrent_steps, 154 | mem_len=mem_len, 155 | dropout=self.dropout, dropatt=self.dropatt) 156 | elif predict_type == "MaxMemAttnPredictNet": 157 | hidden_dim = kw.get("hidden_dim", 64) 158 | recurrent_steps = kw.get("recurrent_steps", 1) 159 | num_heads = kw.get("num_heads", 1) 160 | mem_len = kw.get("mem_len", 4) 161 | predict_net = MaxMemAttnPredictNet(pattern_dim, graph_dim, hidden_dim, 162 | act_func=self.act_func, 163 | num_heads=num_heads, recurrent_steps=recurrent_steps, 164 | mem_len=mem_len, 165 | dropout=self.dropout, dropatt=self.dropatt) 166 | elif predict_type == "DIAMNet": 167 | hidden_dim = kw.get("hidden_dim", 64) 168 | recurrent_steps = kw.get("recurrent_steps", 1) 169 | num_heads = kw.get("num_heads", 1) 170 | mem_len = kw.get("mem_len", 4) 171 | mem_init = kw.get("mem_init", "mean") 172 | predict_net = DIAMNet(pattern_dim, graph_dim, hidden_dim, 173 | act_func=self.act_func, 174 | num_heads=num_heads, recurrent_steps=recurrent_steps, 175 | mem_len=mem_len, mem_init=mem_init, 176 | dropout=self.dropout, dropatt=self.dropatt) 177 | else: 178 | raise NotImplementedError("Currently, %s is not supported!" % (predict_type)) 179 | return predict_net 180 | 181 | def increase_input_size(self, config): 182 | assert config["base"] == self.base 183 | assert config["max_npv"] >= self.max_npv 184 | assert config["max_npvl"] >= self.max_npvl 185 | assert config["max_npe"] >= self.max_npe 186 | assert config["max_npel"] >= self.max_npel 187 | assert config["max_ngv"] >= self.max_ngv 188 | assert config["max_ngvl"] >= self.max_ngvl 189 | assert config["max_nge"] >= self.max_nge 190 | assert config["max_ngel"] >= self.max_ngel 191 | assert config["predict_net_add_enc"] or not self.add_enc 192 | assert config["predict_net_add_degree"] or not self.add_degree 193 | 194 | # create encoding layers 195 | # increase embedding layers 196 | # increase predict network 197 | # set new parameters 198 | 199 | def increase_net(self, config): 200 | raise NotImplementedError 201 | 202 | 203 | class EdgeSeqModel(BaseModel): 204 | def __init__(self, config): 205 | super(EdgeSeqModel, self).__init__(config) 206 | # create encoding layer 207 | self.g_v_enc, self.g_vl_enc, self.g_el_enc = \ 208 | [self.create_enc(max_n, self.base) for max_n in [self.max_ngv, self.max_ngvl, self.max_ngel]] 209 | self.g_u_enc, self.g_ul_enc = self.g_v_enc, self.g_vl_enc 210 | if self.share_emb: 211 | self.p_v_enc, self.p_vl_enc, self.p_el_enc = \ 212 | self.g_v_enc, self.g_vl_enc, self.g_el_enc 213 | else: 214 | self.p_v_enc, self.p_vl_enc, self.p_el_enc = \ 215 | [self.create_enc(max_n, self.base) for max_n in [self.max_npv, self.max_npvl, self.max_npel]] 216 | self.p_u_enc, self.p_ul_enc = self.p_v_enc, self.p_vl_enc 217 | 218 | # create filter layers 219 | self.ul_flt, self.el_flt, self.vl_flt = [self.create_filter(config["filter_net"]) for _ in range(3)] 220 | 221 | # create embedding layers 222 | self.g_u_emb, self.g_v_emb, self.g_ul_emb, self.g_el_emb, self.g_vl_emb = \ 223 | [self.create_emb(enc.embedding_dim, self.emb_dim, init_emb=self.init_emb) \ 224 | for enc in [self.g_u_enc, self.g_v_enc, self.g_ul_enc, self.g_el_enc, self.g_vl_enc]] 225 | if self.share_emb: 226 | self.p_u_emb, self.p_v_emb, self.p_ul_emb, self.p_el_emb, self.p_vl_emb = \ 227 | self.g_u_emb, self.g_v_emb, self.g_ul_emb, self.g_el_emb, self.g_vl_emb 228 | else: 229 | self.p_u_emb, self.p_v_emb, self.p_ul_emb, self.p_el_emb, self.p_vl_emb = \ 230 | [self.create_emb(enc.embedding_dim, self.emb_dim, init_emb=self.init_emb) \ 231 | for enc in [self.p_u_enc, self.p_v_enc, self.p_ul_enc, self.p_el_enc, self.p_vl_enc]] 232 | 233 | # create networks 234 | # create predict layers 235 | 236 | def get_enc_dim(self): 237 | #get_enc_len返回math.floor(math.log(x, base)+1.0) 238 | #base默认为2 239 | g_dim = self.base * (get_enc_len(self.max_ngv-1, self.base) * 2 + \ 240 | get_enc_len(self.max_ngvl-1, self.base) * 2 + \ 241 | get_enc_len(self.max_ngel-1, self.base)) 242 | if self.share_emb: 243 | return g_dim, g_dim 244 | else: 245 | p_dim = self.base * (get_enc_len(self.max_npv-1, self.base) * 2 + \ 246 | get_enc_len(self.max_npvl-1, self.base) * 2 + \ 247 | get_enc_len(self.max_npel-1, self.base)) 248 | return p_dim, g_dim 249 | 250 | def get_emb_dim(self): 251 | if self.init_emb == "None": 252 | return self.get_enc_dim() 253 | else: 254 | return self.emb_dim, self.emb_dim 255 | 256 | def get_enc(self, pattern, pattern_len, graph, graph_len): 257 | pattern_u, pattern_v, pattern_ul, pattern_el, pattern_vl = \ 258 | self.p_u_enc(pattern.u), self.p_v_enc(pattern.v), self.p_ul_enc(pattern.ul), self.p_el_enc(pattern.el), self.p_vl_enc(pattern.vl) 259 | graph_u, graph_v, graph_ul, graph_el, graph_vl = \ 260 | self.g_u_enc(graph.u), self.g_v_enc(graph.v), self.g_ul_enc(graph.ul), self.g_el_enc(graph.el), self.g_vl_enc(graph.vl) 261 | 262 | p_enc = torch.cat([ 263 | pattern_u, 264 | pattern_v, 265 | pattern_ul, 266 | pattern_el, 267 | pattern_vl], dim=2) 268 | g_enc = torch.cat([ 269 | graph_u, 270 | graph_v, 271 | graph_ul, 272 | graph_el, 273 | graph_vl], dim=2) 274 | return p_enc, g_enc 275 | 276 | def get_emb(self, pattern, pattern_len, graph, graph_len): 277 | bsz = pattern_len.size(0) 278 | pattern_u, pattern_v, pattern_ul, pattern_el, pattern_vl = \ 279 | self.p_u_enc(pattern.u), self.p_v_enc(pattern.v), self.p_ul_enc(pattern.ul), self.p_el_enc(pattern.el), self.p_vl_enc(pattern.vl) 280 | graph_u, graph_v, graph_ul, graph_el, graph_vl = \ 281 | self.g_u_enc(graph.u), self.g_v_enc(graph.v), self.g_ul_enc(graph.ul), self.g_el_enc(graph.el), self.g_vl_enc(graph.vl) 282 | 283 | if self.init_emb == "None": 284 | p_emb = torch.cat([pattern_u, pattern_v, pattern_ul, pattern_el, pattern_vl], dim=2) 285 | g_emb = torch.cat([graph_u, graph_v, graph_ul, graph_el, graph_vl], dim=2) 286 | else: 287 | p_emb = self.p_u_emb(pattern_u) + \ 288 | self.p_v_emb(pattern_v) + \ 289 | self.p_ul_emb(pattern_ul) + \ 290 | self.p_el_emb(pattern_el) + \ 291 | self.p_vl_emb(pattern_vl) 292 | g_emb = self.g_u_emb(graph_u) + \ 293 | self.g_v_emb(graph_v) + \ 294 | self.g_ul_emb(graph_ul) + \ 295 | self.g_el_emb(graph_el) + \ 296 | self.g_vl_emb(graph_vl) 297 | return p_emb, g_emb 298 | 299 | def get_filter_gate(self, pattern, pattern_len, graph, graph_len): 300 | gate = None 301 | if self.ul_flt is not None: 302 | if gate is not None: 303 | gate &= self.ul_flt(pattern.ul, graph.ul) 304 | else: 305 | gate = self.ul_flt(pattern.ul, graph.ul) 306 | if self.el_flt is not None: 307 | if gate is not None: 308 | gate &= self.el_flt(pattern.el, graph.el) 309 | else: 310 | gate = self.el_flt(pattern.el, graph.el) 311 | if self.vl_flt is not None: 312 | if gate is not None: 313 | gate &= self.vl_flt(pattern.vl, graph.vl) 314 | else: 315 | gate = self.vl_flt(pattern.vl, graph.vl) 316 | return gate 317 | 318 | def increase_input_size(self, config): 319 | super(EdgeSeqModel, self).increase_input_size(config) 320 | 321 | # create encoding layers 322 | new_g_v_enc, new_g_vl_enc, new_g_el_enc = \ 323 | [self.create_enc(max_n, self.base) for max_n in [config["max_ngv"], config["max_ngvl"], config["max_ngel"]]] 324 | if self.share_emb: 325 | new_p_v_enc, new_p_vl_enc, new_p_el_enc = \ 326 | new_g_v_enc, new_g_vl_enc, new_g_el_enc 327 | else: 328 | new_p_v_enc, new_p_vl_enc, new_p_el_enc = \ 329 | [self.create_enc(max_n, self.base) for max_n in [config["max_npv"], config["max_npvl"], config["max_npel"]]] 330 | del self.g_v_enc, self.g_vl_enc, self.g_el_enc, self.g_u_enc, self.g_ul_enc 331 | del self.p_v_enc, self.p_vl_enc, self.p_el_enc, self.p_u_enc, self.p_ul_enc 332 | self.g_v_enc, self.g_vl_enc, self.g_el_enc = new_g_v_enc, new_g_vl_enc, new_g_el_enc 333 | self.g_u_enc, self.g_ul_enc = self.g_v_enc, self.g_vl_enc 334 | self.p_v_enc, self.p_vl_enc, self.p_el_enc = new_p_v_enc, new_p_vl_enc, new_p_el_enc 335 | self.p_u_enc, self.p_ul_enc = self.p_v_enc, self.p_vl_enc 336 | 337 | # increase embedding layers 338 | self.g_u_emb.increase_input_size(self.g_u_enc.embedding_dim) 339 | self.g_v_emb.increase_input_size(self.g_v_enc.embedding_dim) 340 | self.g_ul_emb.increase_input_size(self.g_ul_enc.embedding_dim) 341 | self.g_vl_emb.increase_input_size(self.g_vl_enc.embedding_dim) 342 | self.g_el_emb.increase_input_size(self.g_el_enc.embedding_dim) 343 | if not self.share_emb: 344 | self.p_u_emb.increase_input_size(self.p_u_enc.embedding_dim) 345 | self.p_v_emb.increase_input_size(self.p_v_enc.embedding_dim) 346 | self.p_ul_emb.increase_input_size(self.p_ul_enc.embedding_dim) 347 | self.p_vl_emb.increase_input_size(self.p_vl_enc.embedding_dim) 348 | self.p_el_emb.increase_input_size(self.p_el_enc.embedding_dim) 349 | 350 | # increase predict network 351 | 352 | # set new parameters 353 | self.max_npv = config["max_npv"] 354 | self.max_npvl = config["max_npvl"] 355 | self.max_npe = config["max_npe"] 356 | self.max_npel = config["max_npel"] 357 | self.max_ngv = config["max_ngv"] 358 | self.max_ngvl = config["max_ngvl"] 359 | self.max_nge = config["max_nge"] 360 | self.max_ngel = config["max_ngel"] 361 | 362 | 363 | 364 | class GraphAdjModel(BaseModel): 365 | def __init__(self, config): 366 | super(GraphAdjModel, self).__init__(config) 367 | 368 | self.add_degree = config["predict_net_add_degree"] 369 | 370 | # create encoding layer 371 | self.g_v_enc, self.g_vl_enc = \ 372 | [self.create_enc(max_n, self.base) for max_n in [self.max_ngv, self.max_ngvl]] 373 | '''if self.share_emb: 374 | self.p_v_enc, self.p_vl_enc = \ 375 | self.g_v_enc, self.g_vl_enc 376 | else: 377 | self.p_v_enc, self.p_vl_enc = \ 378 | [self.create_enc(max_n, self.base) for max_n in [self.max_npv, self.max_npvl]]''' 379 | 380 | # create filter layers 381 | self.vl_flt = self.create_filter(config["filter_net"]) 382 | 383 | # create embedding layers 384 | self.g_vl_emb = self.create_emb(self.g_vl_enc.embedding_dim, self.emb_dim, init_emb=self.init_emb) 385 | '''if self.share_emb: 386 | self.p_vl_emb = self.g_vl_emb 387 | else: 388 | self.p_vl_emb = self.create_emb(self.p_vl_enc.embedding_dim, self.emb_dim, init_emb=self.init_emb)''' 389 | 390 | # create networks 391 | # create predict layers 392 | 393 | def get_enc_dim(self): 394 | g_dim = self.base * (get_enc_len(self.max_ngv-1, self.base) + get_enc_len(self.max_ngvl-1, self.base)) 395 | '''if self.share_emb: 396 | return g_dim, g_dim 397 | else: 398 | p_dim = self.base * (get_enc_len(self.max_npv-1, self.base) + get_enc_len(self.max_npvl-1, self.base)) 399 | return p_dim, g_dim''' 400 | return g_dim 401 | 402 | #def get_enc(self, pattern, pattern_len, graph, graph_len): 403 | def get_enc(self, graph, graph_len): 404 | graph_v, graph_vl = self.g_v_enc(graph.ndata["id"]), self.g_vl_enc(graph.ndata["label"]) 405 | #p_enc = torch.cat([pattern_v, pattern_vl], dim=1) 406 | g_enc = torch.cat([graph_v, graph_vl], dim=1) 407 | #return p_enc, g_enc 408 | return g_enc 409 | 410 | #def get_emb(self, pattern, pattern_len, graph, graph_len): 411 | def get_emb(self, graph, graph_len): 412 | 413 | graph_v, graph_vl = self.g_v_enc(graph.ndata["id"]), self.g_vl_enc(graph.ndata["label"]) 414 | 415 | if self.init_emb == "None": 416 | g_emb = graph_vl 417 | else: 418 | g_emb = self.g_vl_emb(graph_vl) 419 | return g_emb 420 | 421 | def get_filter_gate(self, graph, graph_len): 422 | 423 | gate = None 424 | if self.vl_flt is not None: 425 | gate = self.vl_flt(split_and_batchify_graph_feats(graph.ndata["label"].unsqueeze(-1), graph_len)[0]) 426 | 427 | if gate is not None: 428 | bsz = graph_len.size(0) 429 | max_g_len = graph_len.max() 430 | if bsz * max_g_len != graph.number_of_nodes(): 431 | graph_mask = batch_convert_len_to_mask(graph_len) # bsz x max_len 432 | gate = gate.masked_select(graph_mask.unsqueeze(-1)).view(-1, 1) 433 | else: 434 | gate = gate.view(-1, 1) 435 | return gate 436 | 437 | def increase_input_size(self, config): 438 | super(GraphAdjModel, self).increase_input_size(config) 439 | 440 | # create encoding layers 441 | new_g_v_enc, new_g_vl_enc = \ 442 | [self.create_enc(max_n, self.base) for max_n in [config["max_ngv"], config["max_ngvl"]]] 443 | del self.g_v_enc, self.g_vl_enc 444 | #del self.p_v_enc, self.p_vl_enc 445 | self.g_v_enc, self.g_vl_enc = new_g_v_enc, new_g_vl_enc 446 | self.g_vl_emb.increase_input_size(self.g_vl_enc.embedding_dim) 447 | 448 | self.max_ngv = config["max_ngv"] 449 | self.max_ngvl = config["max_ngvl"] 450 | self.max_nge = config["max_nge"] 451 | self.max_ngel = config["max_ngel"] 452 | -------------------------------------------------------------------------------- /nodedownstream/dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import dgl 4 | import os 5 | import math 6 | import pickle 7 | import json 8 | import copy 9 | import torch.utils.data as data 10 | import random 11 | from collections import defaultdict, Counter 12 | from tqdm import tqdm 13 | from utils import get_enc_len, int2onehot, \ 14 | batch_convert_tensor_to_tensor, batch_convert_array_to_array,label2onehot 15 | 16 | INF = float("inf") 17 | 18 | ############################################## 19 | ################ Sampler Part ################ 20 | ############################################## 21 | class Sampler(data.Sampler): 22 | _type_map = { 23 | int: np.int32, 24 | float: np.float32} 25 | 26 | def __init__(self, dataset, group_by, batch_size, shuffle, drop_last): 27 | super(Sampler, self).__init__(dataset) 28 | if isinstance(group_by, str): 29 | group_by = [group_by] 30 | for attr in group_by: 31 | setattr(self, attr, list()) 32 | self.data_size = len(dataset.data) 33 | print(self.data_size) 34 | for x in dataset.data: 35 | for attr in group_by: 36 | value = x[attr] 37 | if isinstance(value, dgl.DGLGraph): 38 | getattr(self, attr).append(value.number_of_nodes()) 39 | elif hasattr(value, "__len__"): 40 | getattr(self, attr).append(len(value)) 41 | else: 42 | getattr(self, attr).append(value) 43 | self.order = copy.copy(group_by) 44 | self.order.append("rand") 45 | self.batch_size = batch_size 46 | self.shuffle = shuffle 47 | self.drop_last = drop_last 48 | 49 | def make_array(self): 50 | self.rand = np.random.rand(self.data_size).astype(np.float32) 51 | if self.data_size == 0: 52 | types = [np.float32] * len(self.order) 53 | else: 54 | types = [type(getattr(self, attr)[0]) for attr in self.order] 55 | types = [Sampler._type_map.get(t, t) for t in types] 56 | dtype = list(zip(self.order, types)) 57 | array = np.array( 58 | list(zip(*[getattr(self, attr) for attr in self.order])), 59 | dtype=dtype) 60 | return array 61 | 62 | def __iter__(self): 63 | array = self.make_array() 64 | indices = np.argsort(array, axis=0, order=self.order) 65 | batches = [indices[i:i + self.batch_size] for i in range(0, len(indices), self.batch_size)] 66 | if self.shuffle: 67 | np.random.shuffle(batches) 68 | batch_idx = 0 69 | while batch_idx < len(batches)-1: 70 | yield batches[batch_idx] 71 | batch_idx += 1 72 | if len(batches) > 0 and (len(batches[batch_idx]) == self.batch_size or not self.drop_last): 73 | yield batches[batch_idx] 74 | 75 | def __len__(self): 76 | if self.drop_last: 77 | return math.floor(self.data_size/self.batch_size) 78 | else: 79 | return math.ceil(self.data_size/self.batch_size) 80 | 81 | 82 | ############################################## 83 | ############# EdgeSeq Data Part ############## 84 | ############################################## 85 | class EdgeSeq: 86 | def __init__(self, code): 87 | self.u = code[:,0] 88 | self.v = code[:,1] 89 | self.ul = code[:,2] 90 | self.el = code[:,3] 91 | self.vl = code[:,4] 92 | 93 | def __len__(self): 94 | if len(self.u.shape) == 2: # single code 95 | return self.u.shape[0] 96 | else: # batch code 97 | return self.u.shape[0] * self.u.shape[1] 98 | 99 | @staticmethod 100 | def batch(data): 101 | b = EdgeSeq(torch.empty((0,5), dtype=torch.long)) 102 | b.u = batch_convert_tensor_to_tensor([x.u for x in data]) 103 | b.v = batch_convert_tensor_to_tensor([x.v for x in data]) 104 | b.ul = batch_convert_tensor_to_tensor([x.ul for x in data]) 105 | b.el = batch_convert_tensor_to_tensor([x.el for x in data]) 106 | b.vl = batch_convert_tensor_to_tensor([x.vl for x in data]) 107 | return b 108 | 109 | def to(self, device): 110 | self.u = self.u.to(device) 111 | self.v = self.v.to(device) 112 | self.ul = self.ul.to(device) 113 | self.el = self.el.to(device) 114 | self.vl = self.vl.to(device) 115 | 116 | 117 | ############################################## 118 | ############# EdgeSeq Data Part ############## 119 | ############################################## 120 | class EdgeSeqDataset(data.Dataset): 121 | def __init__(self, data=None): 122 | super(EdgeSeqDataset, self).__init__() 123 | 124 | if data: 125 | self.data = EdgeSeqDataset.preprocess_batch(data, use_tqdm=True) 126 | else: 127 | self.data = list() 128 | self._to_tensor() 129 | 130 | def _to_tensor(self): 131 | for x in self.data: 132 | for k in ["pattern", "graph", "subisomorphisms"]: 133 | if isinstance(x[k], np.ndarray): 134 | x[k] = torch.from_numpy(x[k]) 135 | 136 | def __len__(self): 137 | return len(self.data) 138 | 139 | def __getitem__(self, idx): 140 | return self.data[idx] 141 | 142 | def save(self, filename): 143 | cache = defaultdict(list) 144 | for x in self.data: 145 | for k in list(x.keys()): 146 | if k.startswith("_"): 147 | cache[k].append(x.pop(k)) 148 | with open(filename, "wb") as f: 149 | torch.save(self.data, f, pickle_protocol=pickle.HIGHEST_PROTOCOL) 150 | if len(cache) > 0: 151 | keys = cache.keys() 152 | for i in range(len(self.data)): 153 | for k in keys: 154 | self.data[i][k] = cache[k][i] 155 | 156 | def load(self, filename): 157 | with open(filename, "rb") as f: 158 | data = torch.load(f) 159 | del self.data 160 | self.data = data 161 | 162 | return self 163 | 164 | @staticmethod 165 | def graph2edgeseq(graph): 166 | labels = graph.vs["label"] 167 | graph_code = list() 168 | 169 | for edge in graph.es: 170 | v, u = edge.tuple 171 | graph_code.append((v, u, labels[v], edge["label"], labels[u])) 172 | graph_code = np.array(graph_code, dtype=np.int64) 173 | graph_code.view( 174 | [("v", "int64"), ("u", "int64"), ("vl", "int64"), ("el", "int64"), ("ul", "int64")]).sort( 175 | axis=0, order=["v", "u", "el"]) 176 | return graph_code 177 | 178 | @staticmethod 179 | def preprocess(x): 180 | pattern_code = EdgeSeqDataset.graph2edgeseq(x["pattern"]) 181 | graph_code = EdgeSeqDataset.graph2edgeseq(x["graph"]) 182 | subisomorphisms = np.array(x["subisomorphisms"], dtype=np.int32).reshape(-1, x["pattern"].vcount()) 183 | 184 | x = { 185 | "id": x["id"], 186 | "pattern": pattern_code, 187 | "graph": graph_code, 188 | "counts": x["counts"], 189 | "subisomorphisms": subisomorphisms} 190 | return x 191 | 192 | @staticmethod 193 | def preprocess_batch(data, use_tqdm=False): 194 | d = list() 195 | if use_tqdm: 196 | data = tqdm(data) 197 | for x in data: 198 | d.append(EdgeSeqDataset.preprocess(x)) 199 | return d 200 | 201 | @staticmethod 202 | def batchify(batch): 203 | _id = [x["id"] for x in batch] 204 | pattern = EdgeSeq.batch([EdgeSeq(x["pattern"]) for x in batch]) 205 | pattern_len = torch.tensor([x["pattern"].shape[0] for x in batch], dtype=torch.int32).view(-1, 1) 206 | graph = EdgeSeq.batch([EdgeSeq(x["graph"]) for x in batch]) 207 | graph_len = torch.tensor([x["graph"].shape[0] for x in batch], dtype=torch.int32).view(-1, 1) 208 | counts = torch.tensor([x["counts"] for x in batch], dtype=torch.float32).view(-1, 1) 209 | return _id, pattern, pattern_len, graph, graph_len, counts 210 | 211 | 212 | ############################################## 213 | ######### GraphAdj Data Part ########### 214 | ############################################## 215 | class GraphAdjDataset_DGL_Input(data.Dataset): 216 | def __init__(self, data=None): 217 | super(GraphAdjDataset_DGL_Input, self).__init__() 218 | 219 | self.data = GraphAdjDataset_DGL_Input.preprocess_batch(data, use_tqdm=True) 220 | #self._to_tensor() 221 | 222 | def _to_tensor(self): 223 | for x in self.data: 224 | for k in ["graph"]: 225 | y = x[k] 226 | for k, v in y.ndata.items(): 227 | if isinstance(v, np.ndarray): 228 | y.ndata[k] = torch.from_numpy(v) 229 | for k, v in y.edata.items(): 230 | if isinstance(v, np.ndarray): 231 | y.edata[k] = torch.from_numpy(v) 232 | if isinstance(x["subisomorphisms"], np.ndarray): 233 | x["subisomorphisms"] = torch.from_numpy(x["subisomorphisms"]) 234 | 235 | def __len__(self): 236 | return len(self.data) 237 | 238 | def __getitem__(self, idx): 239 | return self.data[idx] 240 | 241 | def save(self, filename): 242 | cache = defaultdict(list) 243 | for x in self.data: 244 | for k in list(x.keys()): 245 | if k.startswith("_"): 246 | cache[k].append(x.pop(k)) 247 | with open(filename, "wb") as f: 248 | torch.save(self.data, f, pickle_protocol=pickle.HIGHEST_PROTOCOL) 249 | if len(cache) > 0: 250 | keys = cache.keys() 251 | for i in range(len(self.data)): 252 | for k in keys: 253 | self.data[i][k] = cache[k][i] 254 | 255 | def load(self, filename): 256 | with open(filename, "rb") as f: 257 | data = torch.load(f) 258 | del self.data 259 | self.data = data 260 | return self 261 | 262 | @staticmethod 263 | def comp_indeg_norm(graph): 264 | import igraph as ig 265 | if isinstance(graph, ig.Graph): 266 | # 10x faster 267 | in_deg = np.array(graph.indegree(), dtype=np.float32) 268 | elif isinstance(graph, dgl.DGLGraph): 269 | in_deg = graph.in_degrees(range(graph.number_of_nodes())).float().numpy() 270 | else: 271 | raise NotImplementedError 272 | norm = 1.0 / in_deg 273 | norm[np.isinf(norm)] = 0 274 | return norm 275 | 276 | @staticmethod 277 | def graph2dglgraph(graph): 278 | dglgraph = dgl.DGLGraph(multigraph=True) 279 | dglgraph.add_nodes(graph.vcount()) 280 | edges = graph.get_edgelist() 281 | dglgraph.add_edges([e[0] for e in edges], [e[1] for e in edges]) 282 | dglgraph.readonly(True) 283 | return dglgraph 284 | 285 | @staticmethod 286 | def find_no_connection_node(graph,node): 287 | numnode=graph.number_of_nodes() 288 | rand=list(range(numnode)) 289 | random.shuffle(rand) 290 | for i in range(numnode): 291 | if graph.has_edges_between(node,rand[i]): 292 | continue 293 | else: 294 | return i 295 | 296 | @staticmethod 297 | def findsample(graph): 298 | nodenum=graph.number_of_nodes() 299 | result=torch.ones(nodenum,3) 300 | adj=graph.adjacency_matrix() 301 | src=adj._indices()[1].tolist() 302 | dst=adj._indices()[0].tolist() 303 | for i in range(nodenum): 304 | result[i,0]=i 305 | # NCI1存在着某些节点是孤立的情况,这里将孤立节点的正样本设为其自身 306 | if i not in src: 307 | result[i,1]=i 308 | else: 309 | index_i=src.index(i) 310 | i_point_to=dst[index_i] 311 | result[i,1]=i_point_to 312 | result[i,2]=GraphAdjDataset.find_no_connection_node(graph,i) 313 | #------------------------------------------------------------------------------------------- 314 | return torch.tensor(result,dtype=int) 315 | 316 | @staticmethod 317 | def preprocess(x): 318 | graph = x["graph"] 319 | '''graph_dglgraph = GraphAdjDataset.graph2dglgraph(graph) 320 | graph_dglgraph.ndata["indeg"] = torch.tensor(np.array(graph.indegree(), dtype=np.float32)) 321 | graph_dglgraph.ndata["label"] = torch.tensor(np.array(graph.vs["label"], dtype=np.int64)) 322 | graph_dglgraph.ndata["id"] = torch.tensor(np.arange(0, graph.vcount(), dtype=np.int64)) 323 | graph_dglgraph.ndata["sample"] = GraphAdjDataset.findsample(graph_dglgraph)''' 324 | x = { 325 | "id": x["id"], 326 | "graph": graph, 327 | "label": x["label"]} 328 | return x 329 | 330 | 331 | @staticmethod 332 | def preprocess_batch(data, use_tqdm=False): 333 | d = list() 334 | if use_tqdm: 335 | data = tqdm(data) 336 | for x in data: 337 | d.append(GraphAdjDataset_DGL_Input.preprocess(x)) 338 | return d 339 | 340 | @staticmethod 341 | def batchify(batch): 342 | _id = [x["id"] for x in batch] 343 | graph_label = torch.tensor([x["label"] for x in batch], dtype=torch.float64).view(-1, 1) 344 | graph = dgl.batch([x["graph"] for x in batch]) 345 | graph_len = torch.tensor([x["graph"].number_of_nodes() for x in batch], dtype=torch.int32).view(-1, 1) 346 | return _id, graph_label, graph, graph_len 347 | 348 | 349 | class GraphAdjDataset(data.Dataset): 350 | def __init__(self, data=None): 351 | super(GraphAdjDataset, self).__init__() 352 | 353 | if data: 354 | self.data = GraphAdjDataset.preprocess_batch(data, use_tqdm=True) 355 | else: 356 | self.data = list() 357 | # self._to_tensor() 358 | 359 | def _to_tensor(self): 360 | for x in self.data: 361 | for k in ["graph"]: 362 | y = x[k] 363 | for k, v in y.ndata.items(): 364 | if isinstance(v, np.ndarray): 365 | y.ndata[k] = torch.from_numpy(v) 366 | for k, v in y.edata.items(): 367 | if isinstance(v, np.ndarray): 368 | y.edata[k] = torch.from_numpy(v) 369 | if isinstance(x["subisomorphisms"], np.ndarray): 370 | x["subisomorphisms"] = torch.from_numpy(x["subisomorphisms"]) 371 | 372 | def __len__(self): 373 | return len(self.data) 374 | 375 | def __getitem__(self, idx): 376 | return self.data[idx] 377 | 378 | def save(self, filename): 379 | cache = defaultdict(list) 380 | for x in self.data: 381 | for k in list(x.keys()): 382 | if k.startswith("_"): 383 | cache[k].append(x.pop(k)) 384 | with open(filename, "wb") as f: 385 | torch.save(self.data, f, pickle_protocol=pickle.HIGHEST_PROTOCOL) 386 | if len(cache) > 0: 387 | keys = cache.keys() 388 | for i in range(len(self.data)): 389 | for k in keys: 390 | self.data[i][k] = cache[k][i] 391 | 392 | def load(self, filename): 393 | with open(filename, "rb") as f: 394 | data = torch.load(f) 395 | del self.data 396 | self.data = data 397 | return self 398 | 399 | @staticmethod 400 | def comp_indeg_norm(graph): 401 | import igraph as ig 402 | if isinstance(graph, ig.Graph): 403 | # 10x faster 404 | in_deg = np.array(graph.indegree(), dtype=np.float32) 405 | elif isinstance(graph, dgl.DGLGraph): 406 | in_deg = graph.in_degrees(range(graph.number_of_nodes())).float().numpy() 407 | else: 408 | raise NotImplementedError 409 | norm = 1.0 / in_deg 410 | norm[np.isinf(norm)] = 0 411 | return norm 412 | 413 | @staticmethod 414 | def graph2dglgraph(graph): 415 | dglgraph = dgl.DGLGraph(multigraph=True) 416 | dglgraph.add_nodes(graph.vcount()) 417 | edges = graph.get_edgelist() 418 | dglgraph.add_edges([e[0] for e in edges], [e[1] for e in edges]) 419 | dglgraph.readonly(True) 420 | return dglgraph 421 | 422 | @staticmethod 423 | def find_no_connection_node(graph, node): 424 | numnode = graph.number_of_nodes() 425 | rand = list(range(numnode)) 426 | random.shuffle(rand) 427 | for i in range(numnode): 428 | if graph.has_edges_between(node, rand[i]): 429 | continue 430 | else: 431 | return i 432 | 433 | @staticmethod 434 | def findsample(graph): 435 | nodenum = graph.number_of_nodes() 436 | result = torch.ones(nodenum, 3) 437 | adj = graph.adjacency_matrix() 438 | '''src = adj._indices()[1].tolist() 439 | dst = adj._indices()[0].tolist()''' 440 | src = adj._indices()[0].tolist() 441 | dst = adj._indices()[1].tolist() 442 | 443 | # ----------------------------------------------------------------------------------------- 444 | for i in range(nodenum): 445 | result[i, 0] = i 446 | if i not in src: 447 | result[i, 1] = i 448 | else: 449 | index_i = src.index(i) 450 | i_point_to = dst[index_i] 451 | result[i, 1] = i_point_to 452 | result[i, 2] = GraphAdjDataset.find_no_connection_node(graph, i) 453 | # ------------------------------------------------------------------------------------------- 454 | return torch.tensor(result, dtype=int) 455 | 456 | 457 | @staticmethod 458 | def preprocess(x): 459 | graph = x 460 | graph_dglgraph = GraphAdjDataset.graph2dglgraph(graph) 461 | '''graph_dglgraph.ndata["indeg"] = np.array(graph.indegree(), dtype=np.float32) 462 | graph_dglgraph.ndata["label"] = np.array(graph.vs["label"], dtype=np.int64) 463 | graph_dglgraph.ndata["id"] = np.arange(0, graph.vcount(), dtype=np.int64)''' 464 | # graph_dglgraph.edata["label"] = np.array(graph.es["label"], dtype=np.int64) 465 | graph_dglgraph.ndata["indeg"] = torch.tensor(np.array(graph.indegree(), dtype=np.float32)) 466 | graph_dglgraph.ndata["label"] = torch.tensor(np.array(graph.vs["label"], dtype=np.int64)) 467 | graph_dglgraph.ndata["id"] = torch.tensor(np.arange(0, graph.vcount(), dtype=np.int64)) 468 | graph_dglgraph.ndata["sample"] = GraphAdjDataset.findsample(graph_dglgraph) 469 | nodefeature=graph["feature"] 470 | graph_dglgraph.ndata["feature"]=torch.tensor(np.array(nodefeature, dtype=np.float32)) 471 | x = { 472 | "id": "0", 473 | "graph": graph_dglgraph, 474 | "label": 0} 475 | return x 476 | 477 | @staticmethod 478 | def preprocess_batch(data, use_tqdm=False): 479 | d = list() 480 | if use_tqdm: 481 | data = tqdm(data) 482 | for x in data: 483 | d.append(GraphAdjDataset.preprocess(x)) 484 | return d 485 | 486 | @staticmethod 487 | def batchify(batch): 488 | _id = [x["id"] for x in batch] 489 | graph_label = torch.tensor([x["label"] for x in batch], dtype=torch.float64).view(-1, 1) 490 | graph = dgl.batch([x["graph"] for x in batch]) 491 | graph_len = torch.tensor([x["graph"].number_of_nodes() for x in batch], dtype=torch.int32).view(-1, 1) 492 | return _id, graph_label, graph, graph_len 493 | -------------------------------------------------------------------------------- /nodedownstream/datasetInfo.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.optim as optim 5 | import math 6 | import numpy as np 7 | import re 8 | import os 9 | import sys 10 | import json 11 | from torch.optim.lr_scheduler import LambdaLR 12 | from collections import OrderedDict 13 | from multiprocessing import Pool 14 | from tqdm import tqdm 15 | from sklearn.metrics import accuracy_score,f1_score,precision_score,recall_score 16 | import random 17 | from tqdm import trange 18 | from sklearn.metrics import precision_recall_fscore_support 19 | import functools 20 | import dgl 21 | 22 | #drop==True means drop nodes of class drop when split train,val, test;but can only drop the biggest class(ex 0,1,2 can only drop label 2) 23 | def few_shot_split_nodelevel(graph,tasknum,trainshot,valshot,labelnum,seed=0, drop=False): 24 | train=[] 25 | val=[] 26 | test=[] 27 | if drop: 28 | labelnum=labelnum-1 29 | nodenum=graph.number_of_nodes() 30 | random.seed(seed) 31 | for count in range(tasknum): 32 | index = random.sample(range(0, nodenum), nodenum) 33 | trainindex=[] 34 | valindex=[] 35 | testindex=[] 36 | traincount = torch.zeros(labelnum) 37 | valcount = torch.zeros(labelnum) 38 | for i in index: 39 | label=graph.ndata["label"][i] 40 | if drop: 41 | if label==labelnum: 42 | continue 43 | if traincount[label] 0 and (len(batches[batch_idx]) == self.batch_size or not self.drop_last): 67 | yield batches[batch_idx] 68 | 69 | def __len__(self): 70 | if self.drop_last: 71 | return math.floor(self.data_size / self.batch_size) 72 | else: 73 | return math.ceil(self.data_size / self.batch_size) 74 | 75 | 76 | ############################################## 77 | ############# EdgeSeq Data Part ############## 78 | ############################################## 79 | class EdgeSeq: 80 | def __init__(self, code): 81 | self.u = code[:, 0] 82 | self.v = code[:, 1] 83 | self.ul = code[:, 2] 84 | self.el = code[:, 3] 85 | self.vl = code[:, 4] 86 | 87 | def __len__(self): 88 | if len(self.u.shape) == 2: # single code 89 | return self.u.shape[0] 90 | else: # batch code 91 | return self.u.shape[0] * self.u.shape[1] 92 | 93 | @staticmethod 94 | def batch(data): 95 | b = EdgeSeq(torch.empty((0, 5), dtype=torch.long)) 96 | b.u = batch_convert_tensor_to_tensor([x.u for x in data]) 97 | b.v = batch_convert_tensor_to_tensor([x.v for x in data]) 98 | b.ul = batch_convert_tensor_to_tensor([x.ul for x in data]) 99 | b.el = batch_convert_tensor_to_tensor([x.el for x in data]) 100 | b.vl = batch_convert_tensor_to_tensor([x.vl for x in data]) 101 | return b 102 | 103 | def to(self, device): 104 | self.u = self.u.to(device) 105 | self.v = self.v.to(device) 106 | self.ul = self.ul.to(device) 107 | self.el = self.el.to(device) 108 | self.vl = self.vl.to(device) 109 | 110 | 111 | ############################################## 112 | ############# EdgeSeq Data Part ############## 113 | ############################################## 114 | class EdgeSeqDataset(data.Dataset): 115 | def __init__(self, data=None): 116 | super(EdgeSeqDataset, self).__init__() 117 | 118 | if data: 119 | self.data = EdgeSeqDataset.preprocess_batch(data, use_tqdm=True) 120 | else: 121 | self.data = list() 122 | self._to_tensor() 123 | 124 | def _to_tensor(self): 125 | for x in self.data: 126 | for k in ["pattern", "graph", "subisomorphisms"]: 127 | if isinstance(x[k], np.ndarray): 128 | x[k] = torch.from_numpy(x[k]) 129 | 130 | def __len__(self): 131 | return len(self.data) 132 | 133 | def __getitem__(self, idx): 134 | return self.data[idx] 135 | 136 | def save(self, filename): 137 | cache = defaultdict(list) 138 | for x in self.data: 139 | for k in list(x.keys()): 140 | if k.startswith("_"): 141 | cache[k].append(x.pop(k)) 142 | with open(filename, "wb") as f: 143 | torch.save(self.data, f, pickle_protocol=pickle.HIGHEST_PROTOCOL) 144 | if len(cache) > 0: 145 | keys = cache.keys() 146 | for i in range(len(self.data)): 147 | for k in keys: 148 | self.data[i][k] = cache[k][i] 149 | 150 | def load(self): 151 | self.data = FlickrDataset()[0] 152 | save_path="../data/Flickr/allinone/graph" 153 | if os.path.exists(save_path)==False: 154 | dgl.data.utils.save_graphs(os.path.join(save_path,"graph"),self.data) 155 | return self 156 | 157 | @staticmethod 158 | def graph2edgeseq(graph): 159 | labels = graph.vs["label"] 160 | graph_code = list() 161 | 162 | for edge in graph.es: 163 | v, u = edge.tuple 164 | graph_code.append((v, u, labels[v], edge["label"], labels[u])) 165 | graph_code = np.array(graph_code, dtype=np.int64) 166 | graph_code.view( 167 | [("v", "int64"), ("u", "int64"), ("vl", "int64"), ("el", "int64"), ("ul", "int64")]).sort( 168 | axis=0, order=["v", "u", "el"]) 169 | return graph_code 170 | 171 | @staticmethod 172 | def preprocess(x): 173 | pattern_code = EdgeSeqDataset.graph2edgeseq(x["pattern"]) 174 | graph_code = EdgeSeqDataset.graph2edgeseq(x["graph"]) 175 | subisomorphisms = np.array(x["subisomorphisms"], dtype=np.int32).reshape(-1, x["pattern"].vcount()) 176 | 177 | x = { 178 | "id": x["id"], 179 | "pattern": pattern_code, 180 | "graph": graph_code, 181 | "counts": x["counts"], 182 | "subisomorphisms": subisomorphisms} 183 | return x 184 | 185 | @staticmethod 186 | def preprocess_batch(data, use_tqdm=False): 187 | d = list() 188 | if use_tqdm: 189 | data = tqdm(data) 190 | for x in data: 191 | d.append(EdgeSeqDataset.preprocess(x)) 192 | return d 193 | 194 | @staticmethod 195 | def batchify(batch): 196 | _id = [x["id"] for x in batch] 197 | pattern = EdgeSeq.batch([EdgeSeq(x["pattern"]) for x in batch]) 198 | pattern_len = torch.tensor([x["pattern"].shape[0] for x in batch], dtype=torch.int32).view(-1, 1) 199 | graph = EdgeSeq.batch([EdgeSeq(x["graph"]) for x in batch]) 200 | graph_len = torch.tensor([x["graph"].shape[0] for x in batch], dtype=torch.int32).view(-1, 1) 201 | counts = torch.tensor([x["counts"] for x in batch], dtype=torch.float32).view(-1, 1) 202 | return _id, pattern, pattern_len, graph, graph_len, counts 203 | 204 | 205 | ############################################## 206 | ######### GraphAdj Data Part ########### 207 | ############################################## 208 | class GraphAdjDataset_DGL_Input(data.Dataset): 209 | def __init__(self, data=None): 210 | super(GraphAdjDataset_DGL_Input, self).__init__() 211 | 212 | self.data = GraphAdjDataset_DGL_Input.preprocess_batch(data, use_tqdm=True) 213 | # self._to_tensor() 214 | 215 | def _to_tensor(self): 216 | for x in self.data: 217 | for k in ["graph"]: 218 | y = x[k] 219 | for k, v in y.ndata.items(): 220 | if isinstance(v, np.ndarray): 221 | y.ndata[k] = torch.from_numpy(v) 222 | for k, v in y.edata.items(): 223 | if isinstance(v, np.ndarray): 224 | y.edata[k] = torch.from_numpy(v) 225 | if isinstance(x["subisomorphisms"], np.ndarray): 226 | x["subisomorphisms"] = torch.from_numpy(x["subisomorphisms"]) 227 | 228 | def __len__(self): 229 | return len(self.data) 230 | 231 | def __getitem__(self, idx): 232 | return self.data[idx] 233 | 234 | def save(self, filename): 235 | cache = defaultdict(list) 236 | for x in self.data: 237 | for k in list(x.keys()): 238 | if k.startswith("_"): 239 | cache[k].append(x.pop(k)) 240 | with open(filename, "wb") as f: 241 | torch.save(self.data, f, pickle_protocol=pickle.HIGHEST_PROTOCOL) 242 | if len(cache) > 0: 243 | keys = cache.keys() 244 | for i in range(len(self.data)): 245 | for k in keys: 246 | self.data[i][k] = cache[k][i] 247 | 248 | def load(self): 249 | self.data = FlickrDataset()[0] 250 | print(self.data) 251 | save_path="../data/Flickr/allinone/graph" 252 | if os.path.exists(save_path)==False: 253 | dgl.data.utils.save_graphs(os.path.join(save_path,"graph"),self.data) 254 | return self 255 | 256 | @staticmethod 257 | def comp_indeg_norm(graph): 258 | import igraph as ig 259 | if isinstance(graph, ig.Graph): 260 | # 10x faster 261 | in_deg = np.array(graph.indegree(), dtype=np.float32) 262 | elif isinstance(graph, dgl.DGLGraph): 263 | in_deg = graph.in_degrees(range(graph.number_of_nodes())).float().numpy() 264 | else: 265 | raise NotImplementedError 266 | norm = 1.0 / in_deg 267 | norm[np.isinf(norm)] = 0 268 | return norm 269 | 270 | @staticmethod 271 | def graph2dglgraph(graph): 272 | dglgraph = dgl.DGLGraph(multigraph=True) 273 | dglgraph.add_nodes(graph.vcount()) 274 | edges = graph.get_edgelist() 275 | dglgraph.add_edges([e[0] for e in edges], [e[1] for e in edges]) 276 | dglgraph.readonly(True) 277 | return dglgraph 278 | 279 | @staticmethod 280 | def find_no_connection_node(graph, node): 281 | numnode = graph.number_of_nodes() 282 | rand = list(range(numnode)) 283 | random.shuffle(rand) 284 | for i in range(numnode): 285 | if graph.has_edges_between(node, rand[i]): 286 | continue 287 | else: 288 | return i 289 | 290 | @staticmethod 291 | def findsample(graph): 292 | nodenum = graph.number_of_nodes() 293 | result = torch.ones(nodenum, 3) 294 | adj = graph.adjacency_matrix() 295 | src = adj._indices()[1].tolist() 296 | dst = adj._indices()[0].tolist() 297 | for i in range(nodenum): 298 | result[i, 0] = i 299 | if i not in src: 300 | result[i, 1] = i 301 | else: 302 | index_i = src.index(i) 303 | i_point_to = dst[index_i] 304 | result[i, 1] = i_point_to 305 | result[i, 2] = GraphAdjDataset.find_no_connection_node(graph, i) 306 | # ------------------------------------------------------------------------------------------- 307 | return torch.tensor(result, dtype=int) 308 | 309 | @staticmethod 310 | def preprocess(x): 311 | graph = x["graph"] 312 | '''graph_dglgraph = GraphAdjDataset.graph2dglgraph(graph) 313 | graph_dglgraph.ndata["indeg"] = torch.tensor(np.array(graph.indegree(), dtype=np.float32)) 314 | graph_dglgraph.ndata["label"] = torch.tensor(np.array(graph.vs["label"], dtype=np.int64)) 315 | graph_dglgraph.ndata["id"] = torch.tensor(np.arange(0, graph.vcount(), dtype=np.int64)) 316 | graph_dglgraph.ndata["sample"] = GraphAdjDataset.findsample(graph_dglgraph)''' 317 | x = { 318 | "id": x["id"], 319 | "graph": graph, 320 | "label": x["label"]} 321 | return x 322 | 323 | @staticmethod 324 | def preprocess_batch(data, use_tqdm=False): 325 | d = list() 326 | if use_tqdm: 327 | data = tqdm(data) 328 | for x in data: 329 | d.append(GraphAdjDataset_DGL_Input.preprocess(x)) 330 | return d 331 | 332 | @staticmethod 333 | def batchify(batch): 334 | _id = [x["id"] for x in batch] 335 | graph_label = torch.tensor([x["label"] for x in batch], dtype=torch.float64).view(-1, 1) 336 | graph = dgl.batch([x["graph"] for x in batch]) 337 | graph_len = torch.tensor([x["graph"].number_of_nodes() for x in batch], dtype=torch.int32).view(-1, 1) 338 | return _id, graph_label, graph, graph_len 339 | 340 | 341 | class GraphAdjDataset(data.Dataset): 342 | def __init__(self, data=None): 343 | super(GraphAdjDataset, self).__init__() 344 | 345 | if data: 346 | self.data = GraphAdjDataset.preprocess_batch(data, use_tqdm=True) 347 | else: 348 | self.data = list() 349 | # self._to_tensor() 350 | 351 | def _to_tensor(self): 352 | for x in self.data: 353 | for k in ["graph"]: 354 | y = x[k] 355 | for k, v in y.ndata.items(): 356 | if isinstance(v, np.ndarray): 357 | y.ndata[k] = torch.from_numpy(v) 358 | for k, v in y.edata.items(): 359 | if isinstance(v, np.ndarray): 360 | y.edata[k] = torch.from_numpy(v) 361 | if isinstance(x["subisomorphisms"], np.ndarray): 362 | x["subisomorphisms"] = torch.from_numpy(x["subisomorphisms"]) 363 | 364 | def __len__(self): 365 | # return len(self.data) 366 | return 1 367 | 368 | def __getitem__(self, idx): 369 | return self.data[idx] 370 | 371 | def save(self, filename): 372 | cache = defaultdict(list) 373 | for x in self.data: 374 | for k in list(x.keys()): 375 | if k.startswith("_"): 376 | cache[k].append(x.pop(k)) 377 | with open(filename, "wb") as f: 378 | torch.save(self.data, f, pickle_protocol=pickle.HIGHEST_PROTOCOL) 379 | if len(cache) > 0: 380 | keys = cache.keys() 381 | for i in range(len(self.data)): 382 | for k in keys: 383 | self.data[i][k] = cache[k][i] 384 | 385 | def load(self): 386 | self.data = FlickrDataset()[0] 387 | print(self.data) 388 | save_path="../data/Flickr/allinone/graph" 389 | if os.path.exists(save_path)==False: 390 | dgl.data.utils.save_graphs(os.path.join(save_path,"graph"),self.data) 391 | return self 392 | 393 | @staticmethod 394 | def comp_indeg_norm(graph): 395 | import igraph as ig 396 | if isinstance(graph, ig.Graph): 397 | # 10x faster 398 | in_deg = np.array(graph.indegree(), dtype=np.float32) 399 | elif isinstance(graph, dgl.DGLGraph): 400 | in_deg = graph.in_degrees(range(graph.number_of_nodes())).float().numpy() 401 | else: 402 | raise NotImplementedError 403 | norm = 1.0 / in_deg 404 | norm[np.isinf(norm)] = 0 405 | return norm 406 | 407 | @staticmethod 408 | def graph2dglgraph(graph): 409 | dglgraph = dgl.DGLGraph(multigraph=True) 410 | dglgraph.add_nodes(graph.vcount()) 411 | edges = graph.get_edgelist() 412 | dglgraph.add_edges([e[0] for e in edges], [e[1] for e in edges]) 413 | dglgraph.readonly(True) 414 | return dglgraph 415 | 416 | @staticmethod 417 | def find_no_connection_node(graph, node): 418 | numnode = graph.number_of_nodes() 419 | rand = list(range(numnode)) 420 | random.shuffle(rand) 421 | for i in range(numnode): 422 | if graph.has_edges_between(node, rand[i]): 423 | continue 424 | else: 425 | return i 426 | 427 | @staticmethod 428 | def findsample(graph): 429 | nodenum = graph.number_of_nodes() 430 | result = torch.ones(nodenum, 3) 431 | adj = graph.adjacency_matrix() 432 | '''src = adj._indices()[1].tolist() 433 | dst = adj._indices()[0].tolist()''' 434 | src = adj._indices()[0].tolist() 435 | dst = adj._indices()[1].tolist() 436 | 437 | for i in range(nodenum): 438 | result[i, 0] = i 439 | if i not in src: 440 | result[i, 1] = i 441 | else: 442 | index_i = src.index(i) 443 | i_point_to = dst[index_i] 444 | result[i, 1] = i_point_to 445 | result[i, 2] = GraphAdjDataset.find_no_connection_node(graph, i) 446 | # ------------------------------------------------------------------------------------------- 447 | return torch.tensor(result, dtype=int) 448 | 449 | @staticmethod 450 | def preprocess(input): 451 | x=input.to_homogeneous() 452 | x.ndata["feature"]=input.ndata["feat"] 453 | x = { 454 | "id": "0", 455 | "graph": x, 456 | "label": 0} 457 | return x 458 | 459 | @staticmethod 460 | def preprocess_batch(data, use_tqdm=False): 461 | d = list() 462 | if use_tqdm: 463 | data = tqdm(data) 464 | for x in data: 465 | d.append(GraphAdjDataset.preprocess(x)) 466 | return d 467 | 468 | @staticmethod 469 | def batchify(batch): 470 | _id = [x["id"] for x in batch] 471 | graph_label = torch.tensor([x["label"] for x in batch], dtype=torch.float64).view(-1, 1) 472 | graph = dgl.batch([x["graph"] for x in batch]) 473 | graph_len = torch.tensor([x["graph"].number_of_nodes() for x in batch], dtype=torch.int32).view(-1, 1) 474 | return _id, graph_label, graph, graph_len 475 | -------------------------------------------------------------------------------- /nodedownstream/embedding.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from utils import extend_dimensions 5 | 6 | 7 | class NormalEmbedding(nn.Module): 8 | def __init__(self, input_dim, emb_dim): 9 | super(NormalEmbedding, self).__init__() 10 | self.input_dim = input_dim 11 | self.emb_dim = emb_dim 12 | self.emb_layer = nn.Linear(input_dim, emb_dim, bias=False) 13 | 14 | # init 15 | nn.init.normal_(self.emb_layer.weight, 0.0, 1.0) 16 | 17 | def increase_input_size(self, new_input_dim): 18 | assert new_input_dim >= self.input_dim 19 | if new_input_dim != self.input_dim: 20 | new_emb_layer = extend_dimensions(self.emb_layer, new_input_dim=new_input_dim, upper=False) 21 | del self.emb_layer 22 | self.emb_layer = new_emb_layer 23 | self.input_dim = new_input_dim 24 | 25 | def forward(self, x): 26 | emb = self.emb_layer(x) 27 | return emb 28 | 29 | class OrthogonalEmbedding(nn.Module): 30 | def __init__(self, input_dim, emb_dim): 31 | super(OrthogonalEmbedding, self).__init__() 32 | self.input_dim = input_dim 33 | self.emb_dim = emb_dim 34 | self.emb_layer = nn.Linear(input_dim, emb_dim, bias=False) 35 | 36 | # init 37 | nn.init.orthogonal_(self.emb_layer.weight) 38 | 39 | def increase_input_size(self, new_input_dim): 40 | assert new_input_dim >= self.input_dim 41 | if new_input_dim != self.input_dim: 42 | new_emb_layer = extend_dimensions(self.emb_layer, new_input_dim=new_input_dim, upper=False) 43 | del self.emb_layer 44 | self.emb_layer = new_emb_layer 45 | self.input_dim = new_input_dim 46 | 47 | def forward(self, x): 48 | emb = self.emb_layer(x) 49 | return emb 50 | 51 | class EquivariantEmbedding(nn.Module): 52 | def __init__(self, input_dim, emb_dim): 53 | super(EquivariantEmbedding, self).__init__() 54 | self.input_dim = input_dim 55 | self.emb_dim = emb_dim 56 | self.emb_layer = nn.Linear(input_dim, emb_dim, bias=False) 57 | 58 | # init 59 | nn.init.normal_(self.emb_layer.weight[:,0], 0.0, 1.0) 60 | emb_column = self.emb_layer.weight[:,0] 61 | with torch.no_grad(): 62 | for i in range(1, self.input_dim): 63 | self.emb_layer.weight[:,i].data.copy_(torch.roll(emb_column, i, 0)) 64 | 65 | def increase_input_size(self, new_input_dim): 66 | assert new_input_dim >= self.input_dim 67 | if new_input_dim != self.input_dim: 68 | new_emb_layer = extend_dimensions(self.emb_layer, new_input_dim=new_input_dim, upper=False) 69 | del self.emb_layer 70 | self.emb_layer = new_emb_layer 71 | self.input_dim = new_input_dim 72 | 73 | def forward(self, x): 74 | emb = self.emb_layer(x) 75 | return emb -------------------------------------------------------------------------------- /nodedownstream/filternet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | # class MaxGatedFilterNet(nn.Module): 6 | # def __init__(self, pattern_dim, graph_dim): 7 | # super(MaxGatedFilterNet, self).__init__() 8 | # self.g_layer = nn.Linear(graph_dim, pattern_dim) 9 | # self.f_layer = nn.Linear(pattern_dim, 1) 10 | 11 | # # init 12 | # scale = (1/pattern_dim)**0.5 13 | # nn.init.normal_(self.g_layer.weight, 0.0, scale) 14 | # nn.init.zeros_(self.g_layer.bias) 15 | # nn.init.normal_(self.f_layer.weight, 0.0, scale) 16 | # nn.init.ones_(self.f_layer.bias) 17 | 18 | # def forward(self, p_x, g_x): 19 | # max_x = torch.max(p_x, dim=1, keepdim=True)[0].float() 20 | # g_x = self.g_layer(g_x.float()) 21 | # f = self.f_layer(g_x * max_x) 22 | # return F.sigmoid(f) 23 | 24 | class MaxGatedFilterNet(nn.Module): 25 | def __init__(self): 26 | super(MaxGatedFilterNet, self).__init__() 27 | 28 | def forward(self, p_x, g_x): 29 | max_x = torch.max(p_x, dim=1, keepdim=True)[0] 30 | if max_x.dim() == 2: 31 | return g_x <= max_x 32 | else: 33 | return (g_x <= max_x).all(keepdim=True, dim=2) 34 | 35 | 36 | -------------------------------------------------------------------------------- /nodedownstream/flikcrtaskchoose.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import numpy as np 4 | 5 | 6 | 7 | train_config = { 8 | "max_npv": 8, # max_number_pattern_vertices: 8, 16, 32 9 | "max_npe": 8, # max_number_pattern_edges: 8, 16, 32 10 | "max_npvl": 8, # max_number_pattern_vertex_labels: 8, 16, 32 11 | "max_npel": 8, # max_number_pattern_edge_labels: 8, 16, 32 12 | 13 | "max_ngv": 64, # max_number_graph_vertices: 64, 512,4096 14 | "max_nge": 256, # max_number_graph_edges: 256, 2048, 16384 15 | "max_ngvl": 16, # max_number_graph_vertex_labels: 16, 64, 256 16 | "max_ngel": 16, # max_number_graph_edge_labels: 16, 64, 256 17 | 18 | "base": 2, 19 | 20 | "gpu_id": -1, 21 | "num_workers": 12, 22 | 23 | "epochs": 200, 24 | "batch_size": 512, 25 | "update_every": 1, # actual batch_sizer = batch_size * update_every 26 | "print_every": 100, 27 | "init_emb": "Equivariant", # None, Orthogonal, Normal, Equivariant 28 | "share_emb": True, # sharing embedding requires the same vector length 29 | "share_arch": True, # sharing architectures 30 | "dropout": 0.2, 31 | "dropatt": 0.2, 32 | 33 | "reg_loss": "MSE", # MAE, MSEl 34 | "bp_loss": "MSE", # MAE, MSE 35 | "bp_loss_slp": "anneal_cosine$1.0$0.01", # 0, 0.01, logistic$1.0$0.01, linear$1.0$0.01, cosine$1.0$0.01, 36 | # cyclical_logistic$1.0$0.01, cyclical_linear$1.0$0.01, cyclical_cosine$1.0$0.01 37 | # anneal_logistic$1.0$0.01, anneal_linear$1.0$0.01, anneal_cosine$1.0$0.01 38 | "lr": 0.001, 39 | "weight_decay": 0.00001, 40 | "max_grad_norm": 8, 41 | 42 | "pretrain_model": "GCN", 43 | 44 | "emb_dim": 128, 45 | "activation_function": "leaky_relu", # sigmoid, softmax, tanh, relu, leaky_relu, prelu, gelu 46 | 47 | "filter_net": "MaxGatedFilterNet", # None, MaxGatedFilterNet 48 | "predict_net": "SumPredictNet", # MeanPredictNet, SumPredictNet, MaxPredictNet, 49 | "predict_net_add_enc": True, 50 | "predict_net_add_degree": True, 51 | 52 | # MeanAttnPredictNet, SumAttnPredictNet, MaxAttnPredictNet, 53 | # MeanMemAttnPredictNet, SumMemAttnPredictNet, MaxMemAttnPredictNet, 54 | # DIAMNet 55 | # "predict_net_add_enc": True, 56 | # "predict_net_add_degree": True, 57 | "txl_graph_num_layers": 3, 58 | "txl_pattern_num_layers": 3, 59 | "txl_d_model": 128, 60 | "txl_d_inner": 128, 61 | "txl_n_head": 4, 62 | "txl_d_head": 4, 63 | "txl_pre_lnorm": True, 64 | "txl_tgt_len": 64, 65 | "txl_ext_len": 0, # useless in current settings 66 | "txl_mem_len": 64, 67 | "txl_clamp_len": -1, # max positional embedding index 68 | "txl_attn_type": 0, # 0 for Dai et al, 1 for Shaw et al, 2 for Vaswani et al, 3 for Al Rfou et al. 69 | "txl_same_len": False, 70 | 71 | "gcn_num_bases": 8, 72 | "gcn_regularizer": "bdd", # basis, bdd 73 | "gcn_graph_num_layers": 3, 74 | "gcn_hidden_dim": 128, 75 | "gcn_ignore_norm": False, # ignorm=True -> RGCN-SUM 76 | 77 | "graph_dir": "../data/debug/graphs", 78 | "save_data_dir": "../data/debug", 79 | "save_model_dir": "../dumps/debug", 80 | "save_pretrain_model_dir": "../dumps/MUTAGPreTrain/GCN", 81 | "graphslabel_dir":"../data/debug/graphs", 82 | "downstream_graph_dir": "../data/debug/graphs", 83 | "downstream_save_data_dir": "../data/debug", 84 | "downstream_save_model_dir": "../dumps/debug", 85 | "downstream_graphslabel_dir":"../data/debug/graphs", 86 | "temperature": 0.01, 87 | "graph_finetuning_input_dim": 8, 88 | "graph_finetuning_output_dim": 2, 89 | "graph_label_num":2, 90 | "seed": 0, 91 | "update_pretrain": False, 92 | "dropout": 0.5, 93 | "gcn_output_dim": 8, 94 | 95 | "prompt": "SUM", 96 | "prompt_output_dim": 2, 97 | "scalar": 1e3, 98 | 99 | "dataset_seed": 0, 100 | "train_shotnum": 50, 101 | "val_shotnum": 50, 102 | "few_shot_tasknum": 10, 103 | 104 | "save_fewshot_dir": "../data/FlickrPreTrainNodeClassification/fewshot", 105 | "select_fewshot_dir": ".../data/FlickrPreTrainNodeClassification/select", 106 | "None": True, 107 | 108 | "downstream_dropout": 0, 109 | "node_feature_dim": 18, 110 | "train_label_num": 6, 111 | "val_label_num": 6, 112 | "test_label_num": 6, 113 | "nhop_neighbour": 1 114 | } 115 | 116 | 117 | fewshot_dir = os.path.join(train_config["save_fewshot_dir"], "%s_trainshot_%s_valshot_%s_tasks" % 118 | (train_config["train_shotnum"], train_config["val_shotnum"], 119 | train_config["few_shot_tasknum"])) 120 | print(os.path.exists(fewshot_dir)) 121 | print("Load Few Shot") 122 | trainset = np.load(os.path.join(fewshot_dir, "train_dgl_dataset.npy"), allow_pickle=True).tolist() 123 | valset = np.load(os.path.join(fewshot_dir, "val_dgl_dataset.npy"), allow_pickle=True).tolist() 124 | testset = np.load(os.path.join(fewshot_dir, "test_dgl_dataset.npy"), allow_pickle=True).tolist() 125 | save=[0,1,3,8] 126 | rettrain=[] 127 | retval=[] 128 | rettest=[] 129 | for i in save: 130 | rettrain.append(trainset[i]) 131 | retval.append(valset[1]) 132 | rettest.append(testset[1]) 133 | 134 | selectdir = os.path.join(train_config["save_fewshot_dir"], "%s_trainshot_%s_valshot_%s_tasks" % 135 | (train_config["train_shotnum"], train_config["val_shotnum"], 136 | train_config["few_shot_tasknum"])) 137 | 138 | if train_config["None"]: 139 | rettrain = np.array(rettrain) 140 | retval = np.array(retval) 141 | rettest = np.array(rettest) 142 | 143 | else: 144 | trainset = np.load(os.path.join(selectdir, "train_dgl_dataset.npy"), allow_pickle=True).tolist() 145 | valset = np.load(os.path.join(selectdir, "val_dgl_dataset.npy"), allow_pickle=True).tolist() 146 | testset = np.load(os.path.join(selectdir, "test_dgl_dataset.npy"), allow_pickle=True).tolist() 147 | rettrain=rettrain.append(trainset) 148 | retval=rettrain.append(valset) 149 | rettest=rettrain.append(testset) 150 | rettrain = np.array(rettrain) 151 | retval = np.array(retval) 152 | rettest = np.array(rettest) 153 | 154 | np.save(os.path.join(fewshot_dir, "train_dgl_dataset"), rettrain) 155 | np.save(os.path.join(fewshot_dir, "val_dgl_dataset"), retval) 156 | np.save(os.path.join(fewshot_dir, "test_dgl_dataset"), rettest) 157 | 158 | -------------------------------------------------------------------------------- /nodedownstream/gin.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import dgl 5 | import dgl.function as fn 6 | import copy 7 | from functools import partial 8 | from dgl.nn.pytorch.conv import RelGraphConv 9 | from basemodel import GraphAdjModel 10 | from utils import map_activation_str_to_layer, split_and_batchify_graph_feats,GetAdj 11 | 12 | 13 | class GIN(torch.nn.Module): 14 | def __init__(self, config): 15 | super(GIN, self).__init__() 16 | 17 | self.act=torch.nn.ReLU() 18 | self.g_net, self.bns, g_dim = self.create_net( 19 | name="graph", input_dim=config["node_feature_dim"], hidden_dim=config["gcn_hidden_dim"], 20 | num_layers=config["gcn_graph_num_layers"], num_bases=config["gcn_num_bases"], regularizer=config["gcn_regularizer"]) 21 | self.num_layers_num=config["gcn_graph_num_layers"] 22 | self.dropout=torch.nn.Dropout(p=config["dropout"]) 23 | 24 | 25 | def create_net(self, name, input_dim, **kw): 26 | num_layers = kw.get("num_layers", 1) 27 | hidden_dim = kw.get("hidden_dim", 64) 28 | num_rels = kw.get("num_rels", 1) 29 | num_bases = kw.get("num_bases", 8) 30 | regularizer = kw.get("regularizer", "basis") 31 | dropout = kw.get("dropout", 0.5) 32 | 33 | 34 | self.convs = torch.nn.ModuleList() 35 | self.bns = torch.nn.ModuleList() 36 | 37 | for i in range(num_layers): 38 | 39 | if i: 40 | nn = torch.nn.Sequential(torch.nn.Linear(hidden_dim, hidden_dim), self.act, torch.nn.Linear(hidden_dim, hidden_dim)) 41 | else: 42 | nn = torch.nn.Sequential(torch.nn.Linear(input_dim, hidden_dim), self.act, torch.nn.Linear(hidden_dim, hidden_dim)) 43 | conv = dgl.nn.pytorch.conv.GINConv(apply_func=nn,aggregator_type='sum') 44 | bn = torch.nn.BatchNorm1d(hidden_dim) 45 | 46 | self.convs.append(conv) 47 | self.bns.append(bn) 48 | 49 | return self.convs, self.bns, hidden_dim 50 | 51 | 52 | #def forward(self, pattern, pattern_len, graph, graph_len): 53 | def forward(self, graph, graph_len,graphtask=False): 54 | graph_output = graph.ndata["feature"] 55 | xs = [] 56 | for i in range(self.num_layers_num): 57 | graph_output = F.relu(self.convs[i](graph,graph_output)) 58 | graph_output = self.bns[i](graph_output) 59 | graph_output = self.dropout(graph_output) 60 | xs.append(graph_output) 61 | xpool= [] 62 | for x in xs: 63 | if graphtask: 64 | graph_embedding = split_and_batchify_graph_feats(x, graph_len)[0] 65 | else: 66 | graph_embedding=x 67 | graph_embedding = torch.sum(graph_embedding, dim=1) 68 | xpool.append(graph_embedding) 69 | x = torch.cat(xpool, -1) 70 | #x is graph level embedding; xs is node level embedding 71 | return x,torch.cat(xs, -1) 72 | -------------------------------------------------------------------------------- /nodedownstream/node_prompt_layer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import dgl 5 | import dgl.function as fn 6 | import copy 7 | from functools import partial 8 | from dgl.nn.pytorch.conv import RelGraphConv 9 | from basemodel import GraphAdjModel 10 | from utils import map_activation_str_to_layer, split_and_batchify_graph_feats,GetAdj 11 | 12 | #use prompt to finish step 1 13 | class graph_prompt_layer_mean(nn.Module): 14 | def __init__(self): 15 | super(graph_prompt_layer_mean, self).__init__() 16 | self.weight= torch.nn.Parameter(torch.Tensor(2, 2)) 17 | def forward(self, graph_embedding, graph_len): 18 | graph_embedding=split_and_batchify_graph_feats(graph_embedding, graph_len)[0] 19 | graph_prompt_result=graph_embedding.mean(dim=1) 20 | return graph_prompt_result 21 | 22 | class node_prompt_layer_linear_mean(nn.Module): 23 | def __init__(self,input_dim,output_dim): 24 | super(node_prompt_layer_linear_mean, self).__init__() 25 | self.linear=torch.nn.Linear(input_dim,output_dim) 26 | 27 | def forward(self, graph_embedding, graph_len): 28 | graph_embedding=self.linear(graph_embedding) 29 | return graph_embedding 30 | 31 | class node_prompt_layer_linear_sum(nn.Module): 32 | def __init__(self,input_dim,output_dim): 33 | super(node_prompt_layer_linear_sum, self).__init__() 34 | self.linear=torch.nn.Linear(input_dim,output_dim) 35 | 36 | def forward(self, graph_embedding, graph_len): 37 | graph_embedding=self.linear(graph_embedding) 38 | return graph_embedding 39 | 40 | 41 | 42 | #sum result is same as mean result 43 | class node_prompt_layer_sum(nn.Module): 44 | def __init__(self): 45 | super(node_prompt_layer_sum, self).__init__() 46 | self.weight= torch.nn.Parameter(torch.Tensor(2, 2)) 47 | def forward(self, graph_embedding, graph_len): 48 | return graph_embedding 49 | 50 | 51 | 52 | class graph_prompt_layer_weighted(nn.Module): 53 | def __init__(self,max_n_num): 54 | super(graph_prompt_layer_weighted, self).__init__() 55 | self.weight= torch.nn.Parameter(torch.Tensor(1,max_n_num)) 56 | self.max_n_num=max_n_num 57 | self.reset_parameters() 58 | def reset_parameters(self): 59 | torch.nn.init.xavier_uniform_(self.weight) 60 | def forward(self, graph_embedding, graph_len): 61 | graph_embedding=split_and_batchify_graph_feats(graph_embedding, graph_len)[0] 62 | weight = self.weight[0][0:graph_embedding.size(1)] 63 | temp1 = torch.ones(graph_embedding.size(0), graph_embedding.size(2), graph_embedding.size(1)).to(graph_embedding.device) 64 | temp1 = weight * temp1 65 | temp1 = temp1.permute(0, 2, 1) 66 | graph_embedding=graph_embedding*temp1 67 | graph_prompt_result=graph_embedding.sum(dim=1) 68 | return graph_prompt_result 69 | 70 | class node_prompt_layer_feature_weighted_mean(nn.Module): 71 | def __init__(self,input_dim): 72 | super(node_prompt_layer_feature_weighted_mean, self).__init__() 73 | self.weight= torch.nn.Parameter(torch.Tensor(1,input_dim)) 74 | self.max_n_num=input_dim 75 | self.reset_parameters() 76 | def reset_parameters(self): 77 | torch.nn.init.xavier_uniform_(self.weight) 78 | def forward(self, graph_embedding, graph_len): 79 | graph_embedding=graph_embedding*self.weight 80 | return graph_embedding 81 | 82 | class node_prompt_layer_feature_weighted_sum(nn.Module): 83 | def __init__(self,input_dim): 84 | super(node_prompt_layer_feature_weighted_sum, self).__init__() 85 | self.weight= torch.nn.Parameter(torch.Tensor(1,input_dim)) 86 | self.max_n_num=input_dim 87 | self.reset_parameters() 88 | def reset_parameters(self): 89 | torch.nn.init.xavier_uniform_(self.weight) 90 | def forward(self, graph_embedding, graph_len): 91 | graph_embedding=graph_embedding*self.weight 92 | return graph_embedding 93 | 94 | class graph_prompt_layer_weighted_matrix(nn.Module): 95 | def __init__(self,max_n_num,input_dim): 96 | super(graph_prompt_layer_weighted_matrix, self).__init__() 97 | self.weight= torch.nn.Parameter(torch.Tensor(input_dim,max_n_num)) 98 | self.max_n_num=max_n_num 99 | self.reset_parameters() 100 | def reset_parameters(self): 101 | torch.nn.init.xavier_uniform_(self.weight) 102 | def forward(self, graph_embedding, graph_len): 103 | graph_embedding=split_and_batchify_graph_feats(graph_embedding, graph_len)[0] 104 | weight = self.weight.permute(1, 0)[0:graph_embedding.size(1)] 105 | weight = weight.expand(graph_embedding.size(0), weight.size(0), weight.size(1)) 106 | graph_embedding = graph_embedding * weight 107 | graph_prompt_result=graph_embedding.sum(dim=1) 108 | return graph_prompt_result 109 | 110 | class graph_prompt_layer_weighted_linear(nn.Module): 111 | def __init__(self,max_n_num,input_dim,output_dim): 112 | super(graph_prompt_layer_weighted_linear, self).__init__() 113 | self.weight= torch.nn.Parameter(torch.Tensor(1,max_n_num)) 114 | self.linear=nn.Linear(input_dim,output_dim) 115 | self.max_n_num=max_n_num 116 | self.reset_parameters() 117 | def reset_parameters(self): 118 | torch.nn.init.xavier_uniform_(self.weight) 119 | def forward(self, graph_embedding, graph_len): 120 | graph_embedding=self.linear(graph_embedding) 121 | graph_embedding=split_and_batchify_graph_feats(graph_embedding, graph_len)[0] 122 | weight = self.weight[0][0:graph_embedding.size(1)] 123 | temp1 = torch.ones(graph_embedding.size(0), graph_embedding.size(2), graph_embedding.size(1)).to(graph_embedding.device) 124 | temp1 = weight * temp1 125 | temp1 = temp1.permute(0, 2, 1) 126 | graph_embedding=graph_embedding*temp1 127 | graph_prompt_result = graph_embedding.mean(dim=1) 128 | return graph_prompt_result 129 | 130 | class graph_prompt_layer_weighted_matrix_linear(nn.Module): 131 | def __init__(self,max_n_num,input_dim,output_dim): 132 | super(graph_prompt_layer_weighted_matrix_linear, self).__init__() 133 | self.weight= torch.nn.Parameter(torch.Tensor(output_dim,max_n_num)) 134 | self.linear=nn.Linear(input_dim,output_dim) 135 | self.max_n_num=max_n_num 136 | self.reset_parameters() 137 | def reset_parameters(self): 138 | torch.nn.init.xavier_uniform_(self.weight) 139 | def forward(self, graph_embedding, graph_len): 140 | graph_embedding=self.linear(graph_embedding) 141 | graph_embedding=split_and_batchify_graph_feats(graph_embedding, graph_len)[0] 142 | weight = self.weight.permute(1, 0)[0:graph_embedding.size(1)] 143 | weight = weight.expand(graph_embedding.size(0), weight.size(0), weight.size(1)) 144 | graph_embedding = graph_embedding * weight 145 | graph_prompt_result=graph_embedding.mean(dim=1) 146 | return graph_prompt_result 147 | -------------------------------------------------------------------------------- /nodedownstream/pre_train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import numpy as np 4 | import dgl 5 | import logging 6 | import datetime 7 | import math 8 | import sys 9 | import gc 10 | import json 11 | import time 12 | import torch.nn.functional as F 13 | import warnings 14 | from functools import partial 15 | from collections import OrderedDict 16 | from torch.utils.data import DataLoader 17 | import matplotlib.pyplot as plt 18 | import random 19 | 20 | 21 | try: 22 | from torch.utils.tensorboard import SummaryWriter 23 | except BaseException as e: 24 | from tensorboardX import SummaryWriter 25 | from dataset import Sampler, EdgeSeqDataset, GraphAdjDataset 26 | from utils import anneal_fn, get_enc_len, pretrain_load_data, \ 27 | get_linear_schedule_with_warmup,bp_compute_abmae,compareloss,split_and_batchify_graph_feats 28 | '''from cnn import CNN 29 | from rnn import RNN 30 | from txl import TXL 31 | from rgcn import RGCN 32 | from rgin import RGIN 33 | 34 | from gin import GIN''' 35 | from gin import GIN 36 | 37 | warnings.filterwarnings("ignore") 38 | INF = float("inf") 39 | 40 | train_config = { 41 | "max_npv": 8, # max_number_pattern_vertices: 8, 16, 32 42 | "max_npe": 8, # max_number_pattern_edges: 8, 16, 32 43 | "max_npvl": 8, # max_number_pattern_vertex_labels: 8, 16, 32 44 | "max_npel": 8, # max_number_pattern_edge_labels: 8, 16, 32 45 | 46 | "max_ngv": 64, # max_number_graph_vertices: 64, 512,4096 47 | "max_nge": 256, # max_number_graph_edges: 256, 2048, 16384 48 | "max_ngvl": 16, # max_number_graph_vertex_labels: 16, 64, 256 49 | "max_ngel": 16, # max_number_graph_edge_labels: 16, 64, 256 50 | 51 | "base": 2, 52 | 53 | "gpu_id": -1, 54 | "num_workers": 12, 55 | 56 | "epochs": 200, 57 | "batch_size": 512, 58 | "update_every": 1, # actual batch_sizer = batch_size * update_every 59 | "print_every": 100, 60 | "init_emb": "Equivariant", # None, Orthogonal, Normal, Equivariant 61 | "share_emb": True, # sharing embedding requires the same vector length 62 | "share_arch": True, # sharing architectures 63 | "dropout": 0.2, 64 | "dropatt": 0.2, 65 | 66 | "reg_loss": "MSE", # MAE, MSEl 67 | "bp_loss": "MSE", # MAE, MSE 68 | "bp_loss_slp": "anneal_cosine$1.0$0.01", # 0, 0.01, logistic$1.0$0.01, linear$1.0$0.01, cosine$1.0$0.01, 69 | # cyclical_logistic$1.0$0.01, cyclical_linear$1.0$0.01, cyclical_cosine$1.0$0.01 70 | # anneal_logistic$1.0$0.01, anneal_linear$1.0$0.01, anneal_cosine$1.0$0.01 71 | "lr": 0.001, 72 | "weight_decay": 0.00001, 73 | "max_grad_norm": 8, 74 | 75 | "model": "CNN", # CNN, RNN, TXL, RGCN, RGIN, RSIN 76 | 77 | "predict_net": "SumPredictNet", # MeanPredictNet, SumPredictNet, MaxPredictNet, 78 | # MeanAttnPredictNet, SumAttnPredictNet, MaxAttnPredictNet, 79 | # MeanMemAttnPredictNet, SumMemAttnPredictNet, MaxMemAttnPredictNet, 80 | # DIAMNet 81 | # "predict_net_add_enc": True, 82 | # "predict_net_add_degree": True, 83 | "predict_net_add_enc": True, 84 | "predict_net_add_degree": True, 85 | 86 | "predict_net_hidden_dim": 128, 87 | "predict_net_num_heads": 4, 88 | "predict_net_mem_len": 4, 89 | "predict_net_mem_init": "mean", 90 | # mean, sum, max, attn, circular_mean, circular_sum, circular_max, circular_attn, lstm 91 | "predict_net_recurrent_steps": 3, 92 | 93 | "emb_dim": 128, 94 | "activation_function": "leaky_relu", # sigmoid, softmax, tanh, relu, leaky_relu, prelu, gelu 95 | 96 | "filter_net": "MaxGatedFilterNet", # None, MaxGatedFilterNet 97 | # MeanAttnPredictNet, SumAttnPredictNet, MaxAttnPredictNet, 98 | # MeanMemAttnPredictNet, SumMemAttnPredictNet, MaxMemAttnPredictNet, 99 | # DIAMNet 100 | # "predict_net_add_enc": True, 101 | # "predict_net_add_degree": True, 102 | "txl_graph_num_layers": 3, 103 | "txl_pattern_num_layers": 3, 104 | "txl_d_model": 128, 105 | "txl_d_inner": 128, 106 | "txl_n_head": 4, 107 | "txl_d_head": 4, 108 | "txl_pre_lnorm": True, 109 | "txl_tgt_len": 64, 110 | "txl_ext_len": 0, # useless in current settings 111 | "txl_mem_len": 64, 112 | "txl_clamp_len": -1, # max positional embedding index 113 | "txl_attn_type": 0, # 0 for Dai et al, 1 for Shaw et al, 2 for Vaswani et al, 3 for Al Rfou et al. 114 | "txl_same_len": False, 115 | 116 | "gcn_num_bases": 8, 117 | "gcn_regularizer": "bdd", # basis, bdd 118 | "gcn_graph_num_layers": 3, 119 | "gcn_hidden_dim": 128, 120 | "gcn_ignore_norm": False, # ignorm=True -> RGCN-SUM 121 | 122 | "graph_dir": "../data/debug/graphs", 123 | "save_data_dir": "../data/debug", 124 | "save_model_dir": "../dumps/debug", 125 | "save_pretrain_model_dir": "../dumps/MUTAGPreTrain/GCN", 126 | "graphslabel_dir":"../data/debug/graphs", 127 | "downstream_graph_dir": "../data/debug/graphs", 128 | "downstream_save_data_dir": "../data/debug", 129 | "downstream_save_model_dir": "../dumps/debug", 130 | "downstream_graphslabel_dir":"../data/debug/graphs", 131 | "train_num_per_class": 3, 132 | "shot_num": 2, 133 | "temperature": 1, 134 | "graph_finetuning_input_dim": 8, 135 | "graph_finetuning_output_dim": 2, 136 | "graph_label_num": 2, 137 | "seed": 0, 138 | "model": "GIN", 139 | "dropout": 0.5, 140 | "node_feature_dim": 18 141 | } 142 | 143 | 144 | def train(model, optimizer, scheduler, data_type, data_loader, device, config, epoch, logger=None, writer=None): 145 | epoch_step = len(data_loader) 146 | total_step = config["epochs"] * epoch_step 147 | total_reg_loss = 0 148 | total_bp_loss = 0 149 | #total_cnt = 1e-6 150 | 151 | if config["reg_loss"] == "MAE": 152 | reg_crit = lambda pred, target: F.l1_loss(F.relu(pred), target) 153 | elif config["reg_loss"] == "MSE": 154 | reg_crit = lambda pred, target: F.mse_loss(F.relu(pred), target) 155 | elif config["reg_loss"] == "SMSE": 156 | reg_crit = lambda pred, target: F.smooth_l1_loss(F.relu(pred), target) 157 | elif config["reg_loss"] == "ABMAE": 158 | reg_crit = lambda pred, target: bp_compute_abmae(F.leaky_relu(pred), target)+0.8*F.l1_loss(F.relu(pred), target) 159 | else: 160 | raise NotImplementedError 161 | 162 | if config["bp_loss"] == "MAE": 163 | bp_crit = lambda pred, target, neg_slp: F.l1_loss(F.leaky_relu(pred, neg_slp), target) 164 | elif config["bp_loss"] == "MSE": 165 | bp_crit = lambda pred, target, neg_slp: F.mse_loss(F.leaky_relu(pred, neg_slp), target) 166 | elif config["bp_loss"] == "SMSE": 167 | bp_crit = lambda pred, target, neg_slp: F.smooth_l1_loss(F.leaky_relu(pred, neg_slp), target) 168 | elif config["bp_loss"] == "ABMAE": 169 | bp_crit = lambda pred, target, neg_slp: bp_compute_abmae(F.leaky_relu(pred, neg_slp), target)+0.8*F.l1_loss(F.leaky_relu(pred, neg_slp), target, reduce="none") 170 | else: 171 | raise NotImplementedError 172 | 173 | model.train() 174 | total_time=0 175 | for batch_id, batch in enumerate(data_loader): 176 | ids, graph_label, graph, graph_len= batch 177 | # print(batch) 178 | graph=graph.to(device) 179 | graph_label=graph_label.to(device) 180 | graph_len = graph_len.to(device) 181 | s=time.time() 182 | x,pred = model(graph, graph_len) 183 | pred=F.sigmoid(pred) 184 | 185 | adj = graph.adjacency_matrix() 186 | adj = adj.to(device) 187 | '''print('---------------------------------------------------') 188 | print('adj: ',adj.size()) 189 | print('pred: ',adj.size()) 190 | print('---------------------------------------------------')''' 191 | 192 | pred = torch.matmul(adj, pred) 193 | #print(pred.size()) 194 | _pred=split_and_batchify_graph_feats(pred, graph_len)[0] 195 | sample = graph.ndata['sample'] 196 | _sample=split_and_batchify_graph_feats(sample, graph_len)[0] 197 | sample_=_sample.reshape(_sample.size(0),-1,1) 198 | #print(_pred.size()) 199 | #print(sample_.size()) 200 | _pred=torch.gather(input=_pred,dim=1,index=sample_) 201 | #print(_pred.size()) 202 | _pred=_pred.resize_as(_sample) 203 | #print(_pred.size()) 204 | 205 | reg_loss = compareloss(_pred,train_config["temperature"]) 206 | reg_loss.requires_grad_(True) 207 | # print(reg_loss.size()) 208 | 209 | if isinstance(config["bp_loss_slp"], (int, float)): 210 | neg_slp = float(config["bp_loss_slp"]) 211 | else: 212 | bp_loss_slp, l0, l1 = config["bp_loss_slp"].rsplit("$", 3) 213 | neg_slp = anneal_fn(bp_loss_slp, batch_id + epoch * epoch_step, T=total_step // 4, lambda0=float(l0), 214 | lambda1=float(l1)) 215 | bp_loss = reg_loss 216 | bp_loss.requires_grad_(True) 217 | 218 | 219 | # float 220 | reg_loss_item = reg_loss.item() 221 | bp_loss_item = bp_loss.item() 222 | total_reg_loss += reg_loss_item 223 | total_bp_loss += bp_loss_item 224 | 225 | if writer: 226 | writer.add_scalar("%s/REG-%s" % (data_type, config["reg_loss"]), reg_loss_item, 227 | epoch * epoch_step + batch_id) 228 | writer.add_scalar("%s/BP-%s" % (data_type, config["bp_loss"]), bp_loss_item, epoch * epoch_step + batch_id) 229 | 230 | if logger and (batch_id % config["print_every"] == 0 or batch_id == epoch_step - 1): 231 | logger.info( 232 | "epoch: {:0>3d}/{:0>3d}\tdata_type: {:<5s}\tbatch: {:0>5d}/{:0>5d}\treg loss: {:0>10.3f}\tbp loss: {:0>16.3f}".format( 233 | epoch, config["epochs"], data_type, batch_id, epoch_step, 234 | reg_loss_item, bp_loss_item)) 235 | print(bp_loss.grad) 236 | bp_loss.backward() 237 | if (config["update_every"] < 2 or batch_id % config["update_every"] == 0 or batch_id == epoch_step - 1): 238 | if config["max_grad_norm"] > 0: 239 | torch.nn.utils.clip_grad_norm_(model.parameters(), config["max_grad_norm"]) 240 | if scheduler is not None: 241 | scheduler.step(epoch * epoch_step + batch_id) 242 | optimizer.step() 243 | optimizer.zero_grad() 244 | e=time.time() 245 | total_time+=e-s 246 | mean_reg_loss = total_reg_loss 247 | mean_bp_loss = total_bp_loss 248 | if writer: 249 | writer.add_scalar("%s/REG-%s-epoch" % (data_type, config["reg_loss"]), mean_reg_loss, epoch) 250 | writer.add_scalar("%s/BP-%s-epoch" % (data_type, config["bp_loss"]), mean_bp_loss, epoch) 251 | if logger: 252 | logger.info("epoch: {:0>3d}/{:0>3d}\tdata_type: {:<5s}\treg loss: {:0>10.3f}\tbp loss: {:0>16.3f}".format( 253 | epoch, config["epochs"], data_type, mean_reg_loss, mean_bp_loss)) 254 | 255 | gc.collect() 256 | return mean_reg_loss, mean_bp_loss, total_time 257 | 258 | 259 | def evaluate(model, data_type, data_loader, device, config, epoch, logger=None, writer=None): 260 | epoch_step = len(data_loader) 261 | total_reg_loss = 0 262 | #total_cnt = 1e-6 263 | 264 | evaluate_results = {"data": {"id": list(), "counts": list(), "pred": list()}, 265 | "error": {"mae": INF, "mse": INF}, 266 | "time": {"avg": list(), "total": 0.0}} 267 | 268 | if config["reg_loss"] == "MAE": 269 | reg_crit = lambda pred, target: F.l1_loss(F.relu(pred), target, reduce="none") 270 | elif config["reg_loss"] == "MSE": 271 | reg_crit = lambda pred, target: F.mse_loss(F.relu(pred), target, reduce="none") 272 | elif config["reg_loss"] == "SMSE": 273 | reg_crit = lambda pred, target: F.smooth_l1_loss(F.relu(pred), target, reduce="none") 274 | elif config["reg_loss"] == "ABMAE": 275 | reg_crit = lambda pred, target: bp_compute_abmae(F.relu(pred), target)+0.8*F.l1_loss(F.relu(pred), target, reduce="none") 276 | else: 277 | raise NotImplementedError 278 | 279 | if config["bp_loss"] == "MAE": 280 | bp_crit = lambda pred, target, neg_slp: F.l1_loss(F.leaky_relu(pred, neg_slp), target, reduce="none") 281 | elif config["bp_loss"] == "MSE": 282 | bp_crit = lambda pred, target, neg_slp: F.mse_loss(F.leaky_relu(pred, neg_slp), target, reduce="none") 283 | elif config["bp_loss"] == "SMSE": 284 | bp_crit = lambda pred, target, neg_slp: F.smooth_l1_loss(F.leaky_relu(pred, neg_slp), target, reduce="none") 285 | elif config["bp_loss"] == "ABMAE": 286 | bp_crit = lambda pred, target, neg_slp: bp_compute_abmae(F.leaky_relu(pred, neg_slp), target)+0.8*F.l1_loss(F.leaky_relu(pred, neg_slp), target, reduce="none") 287 | else: 288 | raise NotImplementedError 289 | 290 | model.eval() 291 | total_time=0 292 | with torch.no_grad(): 293 | for batch_id, batch in enumerate(data_loader): 294 | ids, graph_label, graph, graph_len= batch 295 | #cnt = counts.shape[0] 296 | #total_cnt += cnt 297 | 298 | graph = graph.to(device) 299 | graph_label=graph_label.to(device) 300 | graph_len = graph_len.to(device) 301 | st = time.time() 302 | pred = model(graph, graph_len) 303 | adj = graph.adjacency_matrix() 304 | adj = adj.to(device) 305 | pred = torch.matmul(adj, pred) 306 | sample = graph.ndata['sample'] 307 | _sample = sample.reshape(-1, 1) 308 | pred = torch.gather(input=pred, dim=0, index=_sample) 309 | pred = pred.resize_as(sample) 310 | 311 | et=time.time() 312 | evaluate_results["time"]["total"] += (et - st) 313 | 314 | reg_loss = compareloss(pred, train_config["temperature"]) 315 | reg_loss_item = reg_loss.item() 316 | 317 | if writer: 318 | writer.add_scalar("%s/REG-%s" % (data_type, config["reg_loss"]), reg_loss_item, 319 | epoch * epoch_step + batch_id) 320 | '''writer.add_scalar("%s/BP-%s" % (data_type, config["bp_loss"]), bp_loss_item, 321 | epoch * epoch_step + batch_id)''' 322 | 323 | if logger and batch_id == epoch_step - 1: 324 | logger.info( 325 | "epoch: {:0>3d}/{:0>3d}\tdata_type: {:<5s}\tbatch: {:0>5d}/{:0>5d}\treg loss: {:0>10.3f}".format( 326 | epoch, config["epochs"], data_type, batch_id, epoch_step, 327 | reg_loss_item)) 328 | et=time.time() 329 | total_time+=et-st 330 | total_reg_loss+=reg_loss_item 331 | mean_reg_loss = total_reg_loss 332 | #mean_bp_loss = total_bp_loss / total_cnt 333 | if writer: 334 | writer.add_scalar("%s/REG-%s-epoch" % (data_type, config["reg_loss"]), mean_reg_loss, epoch) 335 | #writer.add_scalar("%s/BP-%s-epoch" % (data_type, config["bp_loss"]), mean_bp_loss, epoch) 336 | if logger: 337 | logger.info("epoch: {:0>3d}/{:0>3d}\tdata_type: {:<5s}\treg loss: {:0>10.3f}".format( 338 | epoch, config["epochs"], data_type, mean_reg_loss)) 339 | 340 | evaluate_results["error"]["loss"] = mean_reg_loss 341 | 342 | gc.collect() 343 | return mean_reg_loss,0,evaluate_results, total_time 344 | 345 | 346 | if __name__ == "__main__": 347 | for i in range(1, len(sys.argv), 2): 348 | arg = sys.argv[i] 349 | value = sys.argv[i + 1] 350 | 351 | if arg.startswith("--"): 352 | arg = arg[2:] 353 | if arg not in train_config: 354 | print("Warning: %s is not surported now." % (arg)) 355 | continue 356 | train_config[arg] = value 357 | try: 358 | value = eval(value) 359 | if isinstance(value, (int, float)): 360 | train_config[arg] = value 361 | except: 362 | pass 363 | 364 | torch.manual_seed(train_config["seed"]) 365 | np.random.seed(train_config["seed"]) 366 | 367 | ts = datetime.datetime.now().strftime('%Y_%m_%d_%H_%M_%S') 368 | model_name = "%s_%s_%s" % (train_config["model"], train_config["predict_net"], ts) 369 | save_model_dir = train_config["save_model_dir"] 370 | os.makedirs(save_model_dir, exist_ok=True) 371 | 372 | # save config 373 | with open(os.path.join(save_model_dir, "train_config.json"), "w") as f: 374 | json.dump(train_config, f) 375 | 376 | # set logger 377 | logger = logging.getLogger() 378 | logger.setLevel(logging.INFO) 379 | fmt = logging.Formatter('%(asctime)s: [ %(message)s ]', '%Y/%m/%d %H:%M:%S') 380 | console = logging.StreamHandler() 381 | console.setFormatter(fmt) 382 | logger.addHandler(console) 383 | logfile = logging.FileHandler(os.path.join(save_model_dir, "train_log.txt"), 'w') 384 | logfile.setFormatter(fmt) 385 | logger.addHandler(logfile) 386 | 387 | # set device 388 | device = torch.device("cuda:%d" % train_config["gpu_id"] if train_config["gpu_id"] != -1 else "cpu") 389 | if train_config["gpu_id"] != -1: 390 | torch.cuda.set_device(device) 391 | 392 | # reset the pattern parameters 393 | if train_config["share_emb"]: 394 | train_config["max_npv"], train_config["max_npvl"], train_config["max_npe"], train_config["max_npel"] = \ 395 | train_config["max_ngv"], train_config["max_ngvl"], train_config["max_nge"], train_config["max_ngel"] 396 | 397 | 398 | if train_config["model"] == "GCN": 399 | model = GCN(train_config) 400 | if train_config["model"] == "GIN": 401 | model = GIN(train_config) 402 | if train_config["model"] == "GAT": 403 | model = GAT(train_config) 404 | if train_config["model"] == "GraphSage": 405 | model = Graphsage(train_config) 406 | 407 | model = model.to(device) 408 | logger.info(model) 409 | logger.info("num of parameters: %d" % (sum(p.numel() for p in model.parameters() if p.requires_grad))) 410 | 411 | # load data 412 | os.makedirs(train_config["save_data_dir"], exist_ok=True) 413 | data_loaders = OrderedDict({"train": None, "dev": None}) 414 | if all([os.path.exists(os.path.join(train_config["save_data_dir"], 415 | "%s_%s_dataset.pt" % ( 416 | data_type, "dgl" if train_config["model"] in ["RGCN", "RGIN", 417 | "GAT","GCN","GraphSage","GIN"] else "edgeseq"))) 418 | for data_type in data_loaders]): 419 | 420 | logger.info("loading data from pt...") 421 | for data_type in data_loaders: 422 | if train_config["model"] in ["RGCN", "RGIN", "GAT","GCN","GraphSage","GIN"]: 423 | dataset = GraphAdjDataset(list()) 424 | dataset.load(os.path.join(train_config["save_data_dir"], "%s_dgl_dataset.pt" % (data_type))) 425 | print(dataset) 426 | sampler = Sampler(dataset, group_by=["graph"], batch_size=train_config["batch_size"], 427 | shuffle=data_type == "train", drop_last=False) 428 | data_loader = DataLoader(dataset, 429 | batch_sampler=sampler, 430 | collate_fn=GraphAdjDataset.batchify, 431 | pin_memory=data_type == "train") 432 | else: 433 | dataset = EdgeSeqDataset(list()) 434 | dataset.load(os.path.join(train_config["save_data_dir"], "%s_edgeseq_dataset.pt" % (data_type))) 435 | sampler = Sampler(dataset, group_by=["graph", "pattern"], batch_size=train_config["batch_size"], 436 | shuffle=data_type == "train", drop_last=False) 437 | data_loader = DataLoader(dataset, 438 | batch_sampler=sampler, 439 | collate_fn=EdgeSeqDataset.batchify, 440 | pin_memory=data_type == "train") 441 | data_loaders[data_type] = data_loader 442 | logger.info("data (data_type: {:<5s}, len: {}) generated".format(data_type, len(dataset.data))) 443 | logger.info( 444 | "data_loader (data_type: {:<5s}, len: {}, batch_size: {}) generated".format(data_type, len(data_loader), 445 | train_config["batch_size"])) 446 | else: 447 | data = pretrain_load_data(train_config["graph_dir"], train_config["graphslabel_dir"], num_workers=train_config["num_workers"]) 448 | logger.info("{}/{}/{} data loaded".format(len(data["train"]), len(data["dev"]), len(data["test"]))) 449 | for data_type, x in data.items(): 450 | if train_config["model"] in ["RGCN", "RGIN", "GAT","GCN","GraphSage","GIN"]: 451 | if os.path.exists(os.path.join(train_config["save_data_dir"], "%s_dgl_dataset.pt" % (data_type))): 452 | dataset = GraphAdjDataset(list()) 453 | dataset.load(os.path.join(train_config["save_data_dir"], "%s_dgl_dataset.pt" % (data_type))) 454 | else: 455 | if(data_type=="test"): 456 | test_time_start=time.time() 457 | elif(data_type=="train"): 458 | train_time_start = time.time() 459 | else: 460 | val_time_start = time.time() 461 | dataset = GraphAdjDataset(x) 462 | if(data_type=="test"): 463 | test_time_end=time.time() 464 | test_time=test_time_end-test_time_start 465 | logger.info( 466 | "preprocess test time: {:.3f}".format(test_time)) 467 | elif(data_type=="train"): 468 | train_time_end=time.time() 469 | train_time=train_time_end-train_time_start 470 | logger.info( 471 | "preprocess train time: {:.3f}".format(train_time)) 472 | else: 473 | val_time_end=time.time() 474 | val_time=val_time_end-val_time_start 475 | logger.info( 476 | "preprocess val time: {:.3f}".format(val_time)) 477 | dataset.save(os.path.join(train_config["save_data_dir"], "%s_dgl_dataset.pt" % (data_type))) 478 | sampler = Sampler(dataset, group_by=["graph"], batch_size=train_config["batch_size"], 479 | shuffle=data_type == "train", drop_last=False) 480 | data_loader = DataLoader(dataset, 481 | batch_sampler=sampler, 482 | collate_fn=GraphAdjDataset.batchify, 483 | pin_memory=data_type == "train") 484 | else: 485 | if os.path.exists(os.path.join(train_config["save_data_dir"], "%s_edgeseq_dataset.pt" % (data_type))): 486 | dataset = EdgeSeqDataset(list()) 487 | dataset.load(os.path.join(train_config["save_data_dir"], "%s_edgeseq_dataset.pt" % (data_type))) 488 | else: 489 | dataset = EdgeSeqDataset(x) 490 | dataset.save(os.path.join(train_config["save_data_dir"], "%s_edgeseq_dataset.pt" % (data_type))) 491 | sampler = Sampler(dataset, group_by=["graph", "pattern"], batch_size=train_config["batch_size"], 492 | shuffle=data_type == "train", drop_last=False) 493 | data_loader = DataLoader(dataset, 494 | batch_sampler=sampler, 495 | collate_fn=EdgeSeqDataset.batchify, 496 | pin_memory=data_type == "train") 497 | data_loaders[data_type] = data_loader 498 | logger.info("data (data_type: {:<5s}, len: {}) generated".format(data_type, len(dataset.data))) 499 | logger.info( 500 | "data_loader (data_type: {:<5s}, len: {}, batch_size: {}) generated".format(data_type, len(data_loader), 501 | train_config["batch_size"])) 502 | 503 | print('data_loaders', data_loaders.items()) 504 | 505 | # optimizer and losses 506 | writer = SummaryWriter(save_model_dir) 507 | optimizer = torch.optim.AdamW(model.parameters(), lr=train_config["lr"], weight_decay=train_config["weight_decay"], 508 | amsgrad=True) 509 | optimizer.zero_grad() 510 | scheduler = None 511 | # scheduler = get_linear_schedule_with_warmup(optimizer, 512 | # len(data_loaders["train"]), train_config["epochs"]*len(data_loaders["train"]), min_percent=0.0001) 513 | best_reg_losses = {"train": INF, "dev": INF, "test": INF} 514 | best_reg_epochs = {"train": -1, "dev": -1, "test": -1} 515 | 516 | total_train_time=0 517 | total_dev_time=0 518 | total_test_time=0 519 | 520 | plt_x=list() 521 | plt_y=list() 522 | for epoch in range(train_config["epochs"]): 523 | for data_type, data_loader in data_loaders.items(): 524 | if data_type == "train": 525 | mean_reg_loss, mean_bp_loss, _time = train(model, optimizer, scheduler, data_type, data_loader, device, 526 | train_config, epoch, logger=logger, writer=writer) 527 | total_train_time+=_time 528 | torch.save(model.state_dict(), os.path.join(save_model_dir, 'epoch%d.pt' % (epoch))) 529 | else: 530 | mean_reg_loss, mean_bp_loss, evaluate_results, _time = evaluate(model, data_type, data_loader, device, 531 | train_config, epoch, logger=logger, 532 | writer=writer) 533 | total_dev_time+=_time 534 | with open(os.path.join(save_model_dir, '%s%d.json' % (data_type, epoch)), "w") as f: 535 | json.dump(evaluate_results, f) 536 | if mean_reg_loss <= best_reg_losses[data_type]: 537 | best_reg_losses[data_type] = mean_reg_loss 538 | best_reg_epochs[data_type] = epoch 539 | logger.info( 540 | "data_type: {:<5s}\tbest mean loss: {:.3f} (epoch: {:0>3d})".format(data_type, mean_reg_loss, 541 | epoch)) 542 | if data_type == "train": 543 | plt_x.append(epoch) 544 | plt_y.append(mean_reg_loss) 545 | 546 | plt.figure(1) 547 | plt.plot(plt_x,plt_y) 548 | plt.savefig('epoch_loss.png') 549 | for data_type in data_loaders.keys(): 550 | logger.info( 551 | "data_type: {:<5s}\tbest mean loss: {:.3f} (epoch: {:0>3d})".format(data_type, best_reg_losses[data_type], 552 | best_reg_epochs[data_type])) 553 | 554 | best_epoch = train_config["epochs"]-1 555 | model.load_state_dict(torch.load(os.path.join(save_model_dir, 'epoch%d.pt' % (best_epoch)))) 556 | torch.save(model.state_dict(), os.path.join(save_model_dir, "best.pt")) 557 | -------------------------------------------------------------------------------- /nodedownstream/split.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.optim as optim 5 | import math 6 | import numpy as np 7 | import re 8 | import os 9 | import sys 10 | import json 11 | from torch.optim.lr_scheduler import LambdaLR 12 | from collections import OrderedDict 13 | from multiprocessing import Pool 14 | from tqdm import tqdm 15 | from sklearn.metrics import accuracy_score,f1_score,precision_score,recall_score 16 | import random 17 | from tqdm import trange 18 | from sklearn.metrics import precision_recall_fscore_support 19 | import functools 20 | import dgl 21 | 22 | #drop==True means drop nodes of class drop when split train,val, test;but can only drop the biggest class(ex 0,1,2 can only drop label 2) 23 | def few_shot_split_nodelevel(graph,tasknum,trainshot,valshot,labelnum,seed=0, drop=False): 24 | train=[] 25 | val=[] 26 | test=[] 27 | if drop: 28 | labelnum=labelnum-1 29 | nodenum=graph.number_of_nodes() 30 | random.seed(seed) 31 | for count in range(tasknum): 32 | index = random.sample(range(0, nodenum), nodenum) 33 | trainindex=[] 34 | valindex=[] 35 | testindex=[] 36 | traincount = torch.zeros(labelnum) 37 | valcount = torch.zeros(labelnum) 38 | for i in index: 39 | label=graph.ndata["label"][i] 40 | if drop: 41 | if label==labelnum: 42 | continue 43 | if traincount[label]