├── .gitignore
├── figs
└── graph_meta_learning.png
├── requirements.txt
├── data_process
├── node_process.py
└── link_process.py
├── G-Meta
├── learner.py
├── train.py
├── meta.py
└── subgraph_data_processing.py
├── README.md
└── test.ipynb
/.gitignore:
--------------------------------------------------------------------------------
1 | .ipynb_checkpoints
2 | .DS_Store
3 |
4 | __pycache__/
5 |
6 | data/
--------------------------------------------------------------------------------
/figs/graph_meta_learning.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mims-harvard/G-Meta/HEAD/figs/graph_meta_learning.png
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch == 1.5.0
2 | dgl == 0.4.3post2
3 | numpy
4 | networkx
5 | scipy
6 | tqdm
7 | scikit-learn
8 | pandas
--------------------------------------------------------------------------------
/data_process/node_process.py:
--------------------------------------------------------------------------------
1 | import os
2 | import networkx as nx
3 | import numpy as np
4 | import pandas as pd
5 | import torch
6 | import pickle
7 | import dgl
8 | from tqdm import tqdm
9 | import json
10 |
11 | # this is an example of disjoint label multiple graphs.
12 |
13 | path = 'PATH'
14 |
15 | # assume you have a list of DGL graphs stored in the variable dgl_Gs
16 | dgl_Gs = [G1, G2, ...]
17 | # assume you have an array of features where [feat_1, feat_2, ...] and each feat_i corresponding to the graph i.
18 | feature_map = [feat1, feat2, ...]
19 | # assume you have an array of labels where [label_1, label_2, ...] and each label_i corresponding to the graph i.
20 | label_map = [label1, label2, ...]
21 | # number of unique labels, e.g. 30
22 | num_of_labels = 30
23 | # number of labels for each label set, ideally << num_of_labels so that each task can from different permutation of labels
24 | num_label_set = 5
25 |
26 | info = {}
27 |
28 | for idx, G in enumerate(dgl_Gs):
29 | # G is a dgl graph
30 | for j in range(len(label_map[idx])):
31 | info[str(idx) + '_' + str(j)] = label_map[idx][j]
32 |
33 | df = pd.DataFrame.from_dict(info, orient='index').reset_index().rename(columns={"index": "name", 0: "label"})
34 |
35 | labels = np.unique(list(range(num_of_labels)))
36 |
37 | test_labels = np.random.choice(labels, num_label_set, False)
38 | labels_left = [i for i in labels if i not in test_labels]
39 | val_labels = np.random.choice(labels_left, num_label_set, False)
40 | train_labels = [i for i in labels_left if i not in val_labels]
41 |
42 | df[df.label.isin(train_labels)].reset_index(drop = True).to_csv(path + '/train.csv')
43 | df[df.label.isin(val_labels)].reset_index(drop = True).to_csv(path + '/val.csv')
44 | df[df.label.isin(test_labels)].reset_index(drop = True).to_csv(path + '/test.csv')
45 |
46 | with open(path + '/graph_dgl.pkl', 'wb') as f:
47 | pickle.dump(dgl_Gs, f)
48 |
49 | with open(path + '/label.pkl', 'wb') as f:
50 | pickle.dump(info, f)
51 |
52 | np.save(path + '/features.npy', np.array(feature_map))
53 |
54 |
55 | # for shared labels, multiple graph setting, similarly, assume you have process the following variables:
56 |
57 | # assume you have a list of DGL graphs stored in the variable dgl_Gs
58 | dgl_Gs = [G1, G2, ...]
59 | # assume you have an array of features where [feat_1, feat_2, ...] and each feat_i corresponding to the graph i.
60 | feature_map = [feat1, feat2, ...]
61 | # assume you have an array of labels where [label_1, label_2, ...] and each label_i corresponding to the graph i.
62 | label_map = [label1, label2, ...]
63 | # number of unique labels, e.g. 5
64 | num_of_labels = 5
65 |
66 | info = {}
67 | for idx, G in enumerate(dgl_Gs):
68 | for i in tqdm(list(G.nodes)):
69 | info[str(idx) + '_' + str(i)] = labels_set[idx][i]
70 |
71 | np.save(path + '/features.npy', np.array(feature_map))
72 |
73 | with open(path + '/graph_dgl.pkl', 'wb') as f:
74 | pickle.dump(dgl_Gs, f)
75 |
76 | with open(path + '/label.pkl', 'wb') as f:
77 | pickle.dump(info, f)
78 |
79 | df = pd.DataFrame.from_dict(info, orient='index').reset_index().rename(columns={"index": "name", 0: "label"})
80 |
81 | # for example, specify the graph idx to be used for val, test set, other graphs are put in the meta-train
82 | folds = [[0, 23], [1, 22], [2, 21], [3, 20], [4, 19]]
83 |
84 | for fold_n, i in enumerate(folds):
85 | temp_path = path + '/fold' + str(fold_n+1)
86 | train_graphs = list(range(len(dgl_Gs)))
87 | train_graphs.remove(i[0])
88 | train_graphs.remove(i[1])
89 | val_graph = i[0]
90 | test_graph = i[1]
91 |
92 | val_df = df[df.name.str.contains(str(val_graph)+'_')]
93 | test_df = df[df.name.str.contains(str(test_graph)+'_')]
94 |
95 | train_df = df[~df.index.isin(val_df.index)]
96 | train_df = train_df[~train_df.index.isin(test_df.index)]
97 | train_df.reset_index(drop = True).to_csv(temp_path + '/train.csv')
98 | val_df.reset_index(drop = True).to_csv(temp_path + '/val.csv')
99 | test_df.reset_index(drop = True).to_csv(temp_path + '/test.csv')
--------------------------------------------------------------------------------
/data_process/link_process.py:
--------------------------------------------------------------------------------
1 | import dgl
2 | from tqdm import tqdm
3 | import networkx as nx
4 | from itertools import combinations
5 | import numpy as np
6 | import random
7 | import pickle
8 |
9 | path = 'PATH'
10 | adjs = np.load(path + '/graphs_adj.npy', allow_pickle = True)
11 | # this .npy file is an array of 2D-array. [A1, A2, ..., An] where Ai is the adjacency matrix of graph i.
12 |
13 | training_edges_fraction = 0.3
14 | pos_test_edges = []
15 | pos_val_edges = []
16 | pos_train_edges = []
17 | neg_test_edges = []
18 | neg_train_edges = []
19 | neg_val_edges = []
20 |
21 | info = {}
22 | info_spt = {}
23 | info_qry = {}
24 | total_subgraph = {}
25 | center_nodes = {}
26 |
27 | G_all_graphs = []
28 |
29 | for idx_ in tqdm(range(len(adjs))):
30 | G = nx.from_numpy_array(adjs[idx_])
31 |
32 | adj_upp = np.multiply(adjs[idx_], np.triu(np.ones(adjs[idx_].shape)))
33 | x1, x2 = np.where(adj_upp == 1)
34 | edges = list(zip(x1, x2))
35 |
36 | # training edges
37 | sampled = np.random.choice(list(range(len(edges))), int(len(edges)*training_edges_fraction), replace = False)
38 |
39 | pos_train_edges.append([str(idx_) + '_' + str(i[0]) + '_' + str(i[1]) for i in np.array(edges)[sampled]])
40 |
41 | pos_test = [i for i in list(range(len(edges))) if i not in sampled]
42 |
43 | pos_test_edges.append([str(idx_) + '_' + str(i[0]) + '_' + str(i[1]) for i in np.array(edges)[pos_test]])
44 |
45 | G_sample = dgl.DGLGraph()
46 | G_sample.add_nodes(len(G.nodes))
47 | G_sample.add_edges(np.array(edges).T[0], np.array(edges).T[1])
48 | num_pos = np.sum(adjs[idx_])/2
49 |
50 | sampled_frac = int(5*(sum(sum(adjs[idx_]))/len(G.nodes)))
51 |
52 | comb = []
53 | for i in list(range(len(G.nodes))):
54 | l = list(range(len(G.nodes)))
55 | l.remove(i)
56 | comb = comb + (list(zip([i] * sampled_frac, random.choices(l, k = sampled_frac))))
57 |
58 | random.shuffle(comb)
59 | comb_flipped = [(k,v) for v,k in comb]
60 | l = list(set(comb_flipped) & set(comb))
61 |
62 | neg_edges_sampled = [i for i in comb if i not in l]
63 |
64 | neg_edges = list(set(neg_edges_sampled) - set(edges) - set([(k,v) for (v,k) in edges]))
65 |
66 | np.random.seed(10)
67 | idx_neg = np.random.choice(list(range(len(neg_edges))), len(edges), replace = False)
68 | neg_edges = np.array(neg_edges)[idx_neg]
69 |
70 | idx_neg_train = np.random.choice(list(range(len(neg_edges))), len(sampled), replace = False)
71 |
72 | neg_train_edges.append([str(idx_) + '_' + str(i[0]) + '_' + str(i[1]) for i in np.array(neg_edges)[idx_neg_train]])
73 |
74 | neg_test = [i for i in list(range(len(neg_edges))) if i not in idx_neg_train]
75 | neg_test_edges.append([str(idx_) + '_' + str(i[0]) + '_' + str(i[1]) for i in np.array(neg_edges)[neg_test]])
76 |
77 | train_edges_pos = np.array(edges)[sampled]
78 | test_edges_pos = np.array(edges)[pos_test]
79 |
80 | train_edges_neg = np.array(neg_edges)[idx_neg_train]
81 | test_edges_neg = np.array(neg_edges)[neg_test]
82 |
83 | for i in np.array(neg_edges):
84 | # negative injection, following SEAL
85 | G_sample.add_edge(i[0],i[1])
86 |
87 | G_all_graphs.append(G_sample)
88 |
89 | for i in np.array(train_edges_pos):
90 | node1 = i[0]
91 | node2 = i[1]
92 |
93 | info[str(idx_) + '_' + str(node1) + '_' + str(node2)] = 1
94 | info_spt[str(idx_) + '_' + str(node1) + '_' + str(node2)] = 1
95 |
96 | for i in np.array(test_edges_pos):
97 | node1 = i[0]
98 | node2 = i[1]
99 |
100 | info[str(idx_) + '_' + str(node1) + '_' + str(node2)] = 1
101 | info_qry[str(idx_) + '_' + str(node1) + '_' + str(node2)] = 1
102 |
103 | for i in np.array(train_edges_neg):
104 | node1 = i[0]
105 | node2 = i[1]
106 |
107 | info[str(idx_) + '_' + str(node1) + '_' + str(node2)] = 0
108 | info_spt[str(idx_) + '_' + str(node1) + '_' + str(node2)] = 0
109 |
110 | for i in np.array(test_edges_neg):
111 | node1 = i[0]
112 | node2 = i[1]
113 |
114 | info[str(idx_) + '_' + str(node1) + '_' + str(node2)] = 0
115 | info_qry[str(idx_) + '_' + str(node1) + '_' + str(node2)] = 0
116 |
117 | with open(path + '/graph_dgl.pkl', 'wb') as f:
118 | pickle.dump(G_all_graphs, f)
119 |
120 | with open(path + '/label.pkl', 'wb') as f:
121 | pickle.dump(info, f)
122 |
123 | # split on graphs
124 | num_test_graphs = int(0.1 * len(G_all_graphs))
125 |
126 | l = list(range(len(G_all_graphs)))
127 | test_graphs_idx = np.random.choice(l, num_test_graphs, replace = False).tolist()
128 |
129 | l = [i for i in l if i not in test_graphs_idx]
130 | val_graphs_idx = np.random.choice(l, num_test_graphs, replace = False).tolist()
131 |
132 | fold = [test_graphs_idx, val_graphs_idx]
133 |
134 | df_spt = pd.DataFrame.from_dict(info_spt, orient='index').reset_index().rename(columns={"index": "name", 0: "label"})
135 | df_qry = pd.DataFrame.from_dict(info_qry, orient='index').reset_index().rename(columns={"index": "name", 0: "label"})
136 | df = pd.DataFrame.from_dict(info, orient='index').reset_index().rename(columns={"index": "name", 0: "label"})
137 |
138 | i = fold
139 |
140 | temp_path = path
141 | train_graphs = list(range(len(G_all_graphs)))
142 |
143 | train_graphs = [j for j in train_graphs if j not in i[0] + i[1]]
144 | val_graph = i[1]
145 | test_graph = i[0]
146 |
147 | train_spt = pd.DataFrame()
148 | val_spt = pd.DataFrame()
149 | test_spt = pd.DataFrame()
150 |
151 | train_qry = pd.DataFrame()
152 | val_qry = pd.DataFrame()
153 | test_qry = pd.DataFrame()
154 |
155 | train = pd.DataFrame()
156 | val = pd.DataFrame()
157 | test = pd.DataFrame()
158 |
159 | for graph_id in range(len(val_graph)):
160 |
161 | val_df = df_spt[df_spt.name.str.contains('^' + str(val_graph[graph_id])+'_')]
162 | test_df = df_spt[df_spt.name.str.contains('^' + str(test_graph[graph_id])+'_')]
163 |
164 | val_spt = val_spt.append(val_df)
165 | test_spt = test_spt.append(test_df)
166 |
167 | val_df = df_qry[df_qry.name.str.contains('^' + str(val_graph[graph_id])+'_')]
168 | test_df = df_qry[df_qry.name.str.contains('^' + str(test_graph[graph_id])+'_')]
169 |
170 | val_qry = val_qry.append(val_df)
171 | test_qry = test_qry.append(test_df)
172 |
173 | val_df = df[df.name.str.contains('^' + str(val_graph[graph_id])+'_')]
174 | test_df = df[df.name.str.contains('^' + str(test_graph[graph_id])+'_')]
175 |
176 | val = val.append(val_df)
177 | test = test.append(test_df)
178 |
179 | val_spt.reset_index(drop = True).to_csv(temp_path + '/val_spt.csv')
180 | test_spt.reset_index(drop = True).to_csv(temp_path + '/test_spt.csv')
181 |
182 | val_qry.reset_index(drop = True).to_csv(temp_path + '/val_qry.csv')
183 | test_qry.reset_index(drop = True).to_csv(temp_path + '/test_qry.csv')
184 |
185 | val.reset_index(drop = True).to_csv(temp_path + '/val.csv')
186 | test.reset_index(drop = True).to_csv(temp_path + '/test.csv')
187 |
188 | train_df = df_spt[~df_spt.index.isin(val_spt.index)]
189 | train_df = train_df[~train_df.index.isin(test_spt.index)]
190 | train_df.reset_index(drop = True).to_csv(temp_path + '/train_spt.csv')
191 |
192 | train_df = df_qry[~df_qry.index.isin(val_qry.index)]
193 | train_df = train_df[~train_df.index.isin(test_qry.index)]
194 | train_df.reset_index(drop = True).to_csv(temp_path + '/train_qry.csv')
195 |
196 | train_df = df[~df.index.isin(val.index)]
197 | train_df = train_df[~train_df.index.isin(test.index)]
198 | train_df.reset_index(drop = True).to_csv(temp_path + '/train.csv')
199 |
--------------------------------------------------------------------------------
/G-Meta/learner.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import dgl.function as fn
4 | import torch.nn as nn
5 | from torch.nn import init
6 | import dgl
7 |
8 | # Sends a message of node feature h.
9 | msg = fn.copy_src(src='h', out='m')
10 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
11 |
12 | # copied and editted from DGL Source
13 | class GraphConv(nn.Module):
14 | def __init__(self,
15 | in_feats,
16 | out_feats,
17 | activation=None):
18 | super(GraphConv, self).__init__()
19 | self._in_feats = in_feats
20 | self._out_feats = out_feats
21 | self._norm = True
22 | self._activation = activation
23 |
24 |
25 | def forward(self, graph, feat, weight, bias):
26 |
27 | graph = graph.local_var()
28 | if self._norm:
29 | norm = torch.pow(graph.in_degrees().float().clamp(min=1), -0.5)
30 | shp = norm.shape + (1,) * (feat.dim() - 1)
31 | norm = torch.reshape(norm, shp).to(feat.device)
32 | feat = feat * norm
33 |
34 | if self._in_feats > self._out_feats:
35 | # mult W first to reduce the feature size for aggregation.
36 | feat = torch.matmul(feat, weight)
37 | graph.ndata['h'] = feat
38 | graph.update_all(fn.copy_src(src='h', out='m'),
39 | fn.sum(msg='m', out='h'))
40 | rst = graph.ndata['h']
41 | else:
42 | # aggregate first then mult W
43 | graph.ndata['h'] = feat
44 | graph.update_all(fn.copy_src(src='h', out='m'),
45 | fn.sum(msg='m', out='h'))
46 | rst = graph.ndata['h']
47 | rst = torch.matmul(rst, weight)
48 |
49 | rst = rst * norm
50 |
51 | rst = rst + bias
52 |
53 | if self._activation is not None:
54 | rst = self._activation(rst)
55 |
56 | return rst
57 |
58 | def extra_repr(self):
59 | """Set the extra representation of the module,
60 | which will come into effect when printing the model.
61 | """
62 | summary = 'in={_in_feats}, out={_out_feats}'
63 | summary += ', normalization={_norm}'
64 | if '_activation' in self.__dict__:
65 | summary += ', activation={_activation}'
66 | return summary.format(**self.__dict__)
67 |
68 |
69 | class Classifier(nn.Module):
70 | def __init__(self, config):
71 | super(Classifier, self).__init__()
72 |
73 | self.vars = nn.ParameterList()
74 | self.graph_conv = []
75 | self.config = config
76 | self.LinkPred_mode = False
77 |
78 | if self.config[-1][0] == 'LinkPred':
79 | self.LinkPred_mode = True
80 |
81 | for i, (name, param) in enumerate(self.config):
82 |
83 | if name is 'Linear':
84 | if self.LinkPred_mode:
85 | w = nn.Parameter(torch.ones(param[1], param[0] * 2))
86 | else:
87 | w = nn.Parameter(torch.ones(param[1], param[0]))
88 | init.kaiming_normal_(w)
89 | self.vars.append(w)
90 | self.vars.append(nn.Parameter(torch.zeros(param[1])))
91 | if name is 'GraphConv':
92 | # param: in_dim, hidden_dim
93 | w = nn.Parameter(torch.Tensor(param[0], param[1]))
94 | init.xavier_uniform_(w)
95 | self.vars.append(w)
96 | self.vars.append(nn.Parameter(torch.zeros(param[1])))
97 | self.graph_conv.append(GraphConv(param[0], param[1], activation = F.relu))
98 | if name is 'Attention':
99 | # param[0] hidden size
100 | # param[1] attention_head_size
101 | # param[2] hidden_dim for classifier
102 | # param[3] n_ways
103 | # param[4] number of graphlets
104 | if self.LinkPred_mode:
105 | w_q = nn.Parameter(torch.ones(param[1], param[0] * 2))
106 | else:
107 | w_q = nn.Parameter(torch.ones(param[1], param[0]))
108 | w_k = nn.Parameter(torch.ones(param[1], param[0]))
109 | w_v = nn.Parameter(torch.ones(param[1], param[4]))
110 |
111 | if self.LinkPred_mode:
112 | w_l = nn.Parameter(torch.ones(param[3], param[2] * 2 + param[1]))
113 | else:
114 | w_l = nn.Parameter(torch.ones(param[3], param[2] + param[1]))
115 |
116 | init.kaiming_normal_(w_q)
117 | init.kaiming_normal_(w_k)
118 | init.kaiming_normal_(w_v)
119 | init.kaiming_normal_(w_l)
120 |
121 | self.vars.append(w_q)
122 | self.vars.append(w_k)
123 | self.vars.append(w_v)
124 | self.vars.append(w_l)
125 |
126 | #bias for attentions
127 | self.vars.append(nn.Parameter(torch.zeros(param[1])))
128 | self.vars.append(nn.Parameter(torch.zeros(param[1])))
129 | self.vars.append(nn.Parameter(torch.zeros(param[1])))
130 | #bias for classifier
131 | self.vars.append(nn.Parameter(torch.zeros(param[3])))
132 |
133 |
134 | def forward(self, g, to_fetch, features, vars = None):
135 | # For undirected graphs, in_degree is the same as
136 | # out_degree.
137 |
138 | if vars is None:
139 | vars = self.vars
140 |
141 | idx = 0
142 | idx_gcn = 0
143 |
144 | h = features.float()
145 | h = h.to(device)
146 |
147 | for name, param in self.config:
148 | if name is 'GraphConv':
149 | w, b = vars[idx], vars[idx + 1]
150 | conv = self.graph_conv[idx_gcn]
151 |
152 | h = conv(g, h, w, b)
153 |
154 | g.ndata['h'] = h
155 |
156 | idx += 2
157 | idx_gcn += 1
158 |
159 | if idx_gcn == len(self.graph_conv):
160 | #h = dgl.mean_nodes(g, 'h')
161 | num_nodes_ = g.batch_num_nodes
162 | temp = [0] + num_nodes_
163 | offset = torch.cumsum(torch.LongTensor(temp), dim = 0)[:-1].to(device)
164 |
165 | if self.LinkPred_mode:
166 | h1 = h[to_fetch[:,0] + offset]
167 | h2 = h[to_fetch[:,1] + offset]
168 | h = torch.cat((h1, h2), 1)
169 | else:
170 | h = h[to_fetch + offset]
171 |
172 | if name is 'Linear':
173 | w, b = vars[idx], vars[idx + 1]
174 | h = F.linear(h, w, b)
175 | idx += 2
176 |
177 | if name is 'Attention':
178 | w_q, w_k, w_v, w_l = vars[idx], vars[idx + 1], vars[idx + 2], vars[idx + 3]
179 | b_q, b_k, b_v, b_l = vars[idx + 4], vars[idx + 5], vars[idx + 6], vars[idx + 7]
180 |
181 | Q = F.linear(h, w_q, b_q)
182 | K = F.linear(h_graphlets, w_k, b_k)
183 |
184 | attention_scores = torch.matmul(Q, K.T)
185 | attention_probs = nn.Softmax(dim=-1)(attention_scores)
186 | context = F.linear(attention_probs, w_v, b_v)
187 |
188 | # classify layer, first concatenate the context vector
189 | # with the hidden dim of center nodes
190 | h = torch.cat((context, h), 1)
191 | h = F.linear(h, w_l, b_l)
192 | idx += 8
193 |
194 | return h, h
195 |
196 | def zero_grad(self, vars=None):
197 |
198 | with torch.no_grad():
199 | if vars is None:
200 | for p in self.vars:
201 | if p.grad is not None:
202 | p.grad.zero_()
203 | else:
204 | for p in vars:
205 | if p.grad is not None:
206 | p.grad.zero_()
207 |
208 | def parameters(self):
209 | return self.vars
--------------------------------------------------------------------------------
/G-Meta/train.py:
--------------------------------------------------------------------------------
1 | import torch, os
2 | import numpy as np
3 | from subgraph_data_processing import Subgraphs
4 | import scipy.stats
5 | from torch.utils.data import DataLoader
6 | from torch.optim import lr_scheduler
7 | import random, sys, pickle
8 | import argparse
9 |
10 | import networkx as nx
11 | import numpy as np
12 | from scipy.special import comb
13 | from itertools import combinations
14 | import networkx.algorithms.isomorphism as iso
15 | from tqdm import tqdm
16 | import dgl
17 |
18 | from meta import Meta
19 | import time
20 | import copy
21 | import psutil
22 | from memory_profiler import memory_usage
23 |
24 | os.environ['KMP_DUPLICATE_LIB_OK']='True'
25 |
26 | def collate(samples):
27 | graphs_spt, labels_spt, graph_qry, labels_qry, center_spt, center_qry, nodeidx_spt, nodeidx_qry, support_graph_idx, query_graph_idx = map(list, zip(*samples))
28 |
29 | return graphs_spt, labels_spt, graph_qry, labels_qry, center_spt, center_qry, nodeidx_spt, nodeidx_qry, support_graph_idx, query_graph_idx
30 |
31 | def main():
32 | mem_usage = memory_usage(-1, interval=.5, timeout=1)
33 | torch.manual_seed(222)
34 | torch.cuda.manual_seed_all(222)
35 | np.random.seed(222)
36 |
37 | print(args)
38 |
39 | root = args.data_dir
40 |
41 | feat = np.load(root + 'features.npy', allow_pickle = True)
42 |
43 | with open(root + '/graph_dgl.pkl', 'rb') as f:
44 | dgl_graph = pickle.load(f)
45 |
46 | if args.task_setup == 'Disjoint':
47 | with open(root + 'label.pkl', 'rb') as f:
48 | info = pickle.load(f)
49 | elif args.task_setup == 'Shared':
50 | if args.task_mode == 'True':
51 | root = root + '/task' + str(args.task_n) + '/'
52 | with open(root + 'label.pkl', 'rb') as f:
53 | info = pickle.load(f)
54 |
55 | total_class = len(np.unique(np.array(list(info.values()))))
56 | print('There are {} classes '.format(total_class))
57 |
58 | if args.task_setup == 'Disjoint':
59 | labels_num = args.n_way
60 | elif args.task_setup == 'Shared':
61 | labels_num = total_class
62 |
63 | if len(feat.shape) == 2:
64 | # single graph, to make it compatible to multiple graph retrieval.
65 | feat = [feat]
66 |
67 | config = [('GraphConv', [feat[0].shape[1], args.hidden_dim])]
68 |
69 | if args.h > 1:
70 | config = config + [('GraphConv', [args.hidden_dim, args.hidden_dim])] * (args.h - 1)
71 |
72 | config = config + [('Linear', [args.hidden_dim, labels_num])]
73 |
74 | if args.link_pred_mode == 'True':
75 | config.append(('LinkPred', [True]))
76 |
77 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
78 |
79 | maml = Meta(args, config).to(device)
80 |
81 | tmp = filter(lambda x: x.requires_grad, maml.parameters())
82 | num = sum(map(lambda x: np.prod(x.shape), tmp))
83 | print(maml)
84 | print('Total trainable tensors:', num)
85 |
86 | max_acc = 0
87 | model_max = copy.deepcopy(maml)
88 |
89 | db_train = Subgraphs(root, 'train', info, n_way=args.n_way, k_shot=args.k_spt, k_query=args.k_qry, batchsz=args.batchsz, args = args, adjs = dgl_graph, h = args.h)
90 | db_val = Subgraphs(root, 'val', info, n_way=args.n_way, k_shot=args.k_spt,k_query=args.k_qry, batchsz=100, args = args, adjs = dgl_graph, h = args.h)
91 | db_test = Subgraphs(root, 'test', info, n_way=args.n_way, k_shot=args.k_spt,k_query=args.k_qry, batchsz=100, args = args, adjs = dgl_graph, h = args.h)
92 | print('------ Start Training ------')
93 | s_start = time.time()
94 | max_memory = 0
95 | for epoch in range(args.epoch):
96 | db = DataLoader(db_train, args.task_num, shuffle=True, num_workers=args.num_workers, pin_memory=True, collate_fn = collate)
97 | s_f = time.time()
98 | for step, (x_spt, y_spt, x_qry, y_qry, c_spt, c_qry, n_spt, n_qry, g_spt, g_qry) in enumerate(db):
99 | nodes_len = 0
100 | if step >= 1:
101 | data_loading_time = time.time() - s_r
102 | else:
103 | data_loading_time = time.time() - s_f
104 | s = time.time()
105 | # x_spt: a list of #task_num tasks, where each task is a mini-batch of k-shot * n_way subgraphs
106 | # y_spt: a list of #task_num lists of labels. Each list is of length k-shot * n_way int.
107 | nodes_len += sum([sum([len(j) for j in i]) for i in n_spt])
108 | accs = maml(x_spt, y_spt, x_qry, y_qry, c_spt, c_qry, n_spt, n_qry, g_spt, g_qry, feat)
109 | max_memory = max(max_memory, float(psutil.virtual_memory().used/(1024**3)))
110 | if step % args.train_result_report_steps == 0:
111 | print('Epoch:', epoch + 1, ' Step:', step, ' training acc:', str(accs[-1])[:5], ' time elapsed:', str(time.time() - s)[:5], ' data loading takes:', str(data_loading_time)[:5], ' Memory usage:', str(float(psutil.virtual_memory().used/(1024**3)))[:5])
112 | s_r = time.time()
113 |
114 | # validation per epoch
115 | db_v = DataLoader(db_val, 1, shuffle=True, num_workers=args.num_workers, pin_memory=True, collate_fn = collate)
116 | accs_all_test = []
117 |
118 | for x_spt, y_spt, x_qry, y_qry, c_spt, c_qry, n_spt, n_qry, g_spt, g_qry in db_v:
119 |
120 | accs = maml.finetunning(x_spt, y_spt, x_qry, y_qry, c_spt, c_qry, n_spt, n_qry, g_spt, g_qry, feat)
121 | accs_all_test.append(accs)
122 |
123 | accs = np.array(accs_all_test).mean(axis=0).astype(np.float16)
124 | print('Epoch:', epoch + 1, ' Val acc:', str(accs[-1])[:5])
125 | if accs[-1] > max_acc:
126 | max_acc = accs[-1]
127 | model_max = copy.deepcopy(maml)
128 |
129 | db_t = DataLoader(db_test, 1, shuffle=True, num_workers=args.num_workers, pin_memory=True, collate_fn = collate)
130 | accs_all_test = []
131 |
132 | for x_spt, y_spt, x_qry, y_qry, c_spt, c_qry, n_spt, n_qry, g_spt, g_qry in db_t:
133 | accs = maml.finetunning(x_spt, y_spt, x_qry, y_qry, c_spt, c_qry, n_spt, n_qry, g_spt, g_qry, feat)
134 | accs_all_test.append(accs)
135 |
136 | accs = np.array(accs_all_test).mean(axis=0).astype(np.float16)
137 | print('Test acc:', str(accs[1])[:5])
138 |
139 | for x_spt, y_spt, x_qry, y_qry, c_spt, c_qry, n_spt, n_qry, g_spt, g_qry in db_t:
140 | accs = model_max.finetunning(x_spt, y_spt, x_qry, y_qry, c_spt, c_qry, n_spt, n_qry, g_spt, g_qry, feat)
141 | accs_all_test.append(accs)
142 |
143 | #torch.save(model_max.state_dict(), './model.pt')
144 |
145 | accs = np.array(accs_all_test).mean(axis=0).astype(np.float16)
146 | print('Early Stopped Test acc:', str(accs[-1])[:5])
147 | print('Total Time:', str(time.time() - s_start)[:5])
148 | print('Max Momory:', str(max_memory)[:5])
149 |
150 | if __name__ == '__main__':
151 |
152 | argparser = argparse.ArgumentParser()
153 | argparser.add_argument('--epoch', type=int, help='epoch number', default=10)
154 | argparser.add_argument('--n_way', type=int, help='n way', default=3)
155 | argparser.add_argument('--k_spt', type=int, help='k shot for support set', default=3)
156 | argparser.add_argument('--k_qry', type=int, help='k shot for query set', default=24)
157 | argparser.add_argument('--task_num', type=int, help='meta batch size, namely task num', default=8)
158 | argparser.add_argument('--meta_lr', type=float, help='meta-level outer learning rate', default=1e-3)
159 | argparser.add_argument('--update_lr', type=float, help='task-level inner update learning rate', default=1e-3)
160 | argparser.add_argument('--update_step', type=int, help='task-level inner update steps', default=5)
161 | argparser.add_argument('--update_step_test', type=int, help='update steps for finetunning', default=10)
162 | argparser.add_argument('--input_dim', type=int, help='input feature dim', default=1)
163 | argparser.add_argument('--hidden_dim', type=int, help='hidden dim', default=64)
164 | argparser.add_argument('--attention_size', type=int, help='dim of attention_size', default=32)
165 | argparser.add_argument("--data_dir", default=None, type=str, required=True, help="The input data dir.")
166 | argparser.add_argument("--no_finetune", default=True, type=str, required=False, help="no finetune mode.")
167 | argparser.add_argument("--task_setup", default='Disjoint', type=str, required=True, help="Select from Disjoint or Shared Setup. For Disjoint-Label, single/multiple graphs are both considered.")
168 | argparser.add_argument("--method", default='G-Meta', type=str, required=False, help="Use G-Meta")
169 | argparser.add_argument('--task_n', type=int, help='task number', default=1)
170 | argparser.add_argument("--task_mode", default='False', type=str, required=False, help="For Evaluating on Tasks")
171 | argparser.add_argument("--val_result_report_steps", default=100, type=int, required=False, help="validation report")
172 | argparser.add_argument("--train_result_report_steps", default=30, type=int, required=False, help="training report")
173 | argparser.add_argument("--num_workers", default=0, type=int, required=False, help="num of workers")
174 | argparser.add_argument("--batchsz", default=1000, type=int, required=False, help="batch size")
175 | argparser.add_argument("--link_pred_mode", default='False', type=str, required=False, help="For Link Prediction")
176 | argparser.add_argument("--h", default=2, type=int, required=False, help="neighborhood size")
177 | argparser.add_argument('--sample_nodes', type=int, help='sample nodes if above this number of nodes', default=1000)
178 |
179 | args = argparser.parse_args()
180 |
181 | main()
182 |
--------------------------------------------------------------------------------
/G-Meta/meta.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch import optim
4 | from torch.nn import functional as F
5 | from torch.utils.data import TensorDataset, DataLoader
6 | from torch import optim
7 | import numpy as np
8 |
9 | from learner import Classifier
10 | from copy import deepcopy
11 |
12 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
13 |
14 | def euclidean_dist(x, y):
15 | # x: N x D
16 | # y: M x D
17 | n = x.size(0)
18 | m = y.size(0)
19 | d = x.size(1)
20 | if d != y.size(1):
21 | raise Exception
22 |
23 | x = x.unsqueeze(1).expand(n, m, d)
24 | y = y.unsqueeze(0).expand(n, m, d)
25 |
26 | return torch.pow(x - y, 2).sum(2)
27 |
28 | def proto_loss_spt(logits, y_t, n_support):
29 | target_cpu = y_t.to('cpu')
30 | input_cpu = logits.to('cpu')
31 |
32 | def supp_idxs(c):
33 | return target_cpu.eq(c).nonzero()[:n_support].squeeze(1)
34 |
35 | classes = torch.unique(target_cpu)
36 | n_classes = len(classes)
37 | n_query = n_support
38 |
39 | support_idxs = list(map(supp_idxs, classes))
40 |
41 | prototypes = torch.stack([input_cpu[idx_list].mean(0) for idx_list in support_idxs])
42 | query_idxs = torch.stack(list(map(lambda c: target_cpu.eq(c).nonzero()[:n_support], classes))).view(-1)
43 | query_samples = input_cpu[query_idxs]
44 | dists = euclidean_dist(query_samples, prototypes)
45 | log_p_y = F.log_softmax(-dists, dim=1).view(n_classes, n_query, -1)
46 |
47 | target_inds = torch.arange(0, n_classes)
48 | target_inds = target_inds.view(n_classes, 1, 1)
49 | target_inds = target_inds.expand(n_classes, n_query, 1).long()
50 |
51 | loss_val = -log_p_y.gather(2, target_inds).squeeze().view(-1).mean()
52 | _, y_hat = log_p_y.max(2)
53 | acc_val = y_hat.eq(target_inds.squeeze()).float().mean()
54 | return loss_val, acc_val, prototypes
55 |
56 | def proto_loss_qry(logits, y_t, prototypes):
57 | target_cpu = y_t.to('cpu')
58 | input_cpu = logits.to('cpu')
59 |
60 | classes = torch.unique(target_cpu)
61 | n_classes = len(classes)
62 |
63 | n_query = int(logits.shape[0]/n_classes)
64 |
65 | query_idxs = torch.stack(list(map(lambda c: target_cpu.eq(c).nonzero(), classes))).view(-1)
66 | query_samples = input_cpu[query_idxs]
67 |
68 | dists = euclidean_dist(query_samples, prototypes)
69 |
70 | log_p_y = F.log_softmax(-dists, dim=1).view(n_classes, n_query, -1)
71 |
72 | target_inds = torch.arange(0, n_classes)
73 | target_inds = target_inds.view(n_classes, 1, 1)
74 | target_inds = target_inds.expand(n_classes, n_query, 1).long()
75 |
76 | loss_val = -log_p_y.gather(2, target_inds).squeeze().view(-1).mean()
77 | _, y_hat = log_p_y.max(2)
78 | acc_val = y_hat.eq(target_inds.squeeze()).float().mean()
79 | return loss_val, acc_val
80 |
81 |
82 | class Meta(nn.Module):
83 | def __init__(self, args, config):
84 | super(Meta, self).__init__()
85 | self.update_lr = args.update_lr
86 | self.meta_lr = args.meta_lr
87 | self.n_way = args.n_way
88 | self.k_spt = args.k_spt
89 | self.k_qry = args.k_qry
90 | self.task_num = args.task_num
91 | self.update_step = args.update_step
92 | self.update_step_test = args.update_step_test
93 |
94 | self.net = Classifier(config)
95 | self.net = self.net.to(device)
96 |
97 | self.meta_optim = optim.Adam(self.net.parameters(), lr=self.meta_lr)
98 |
99 | self.method = args.method
100 |
101 | def forward_ProtoMAML(self, x_spt, y_spt, x_qry, y_qry, c_spt, c_qry, n_spt, n_qry, g_spt, g_qry,feat):
102 | """
103 | b: number of tasks
104 | setsz: the size for each task
105 |
106 | :param x_spt: [b], where each unit is a mini-batch of subgraphs, i.e. x_spt[0] is a DGL batch of # setsz subgraphs
107 | :param y_spt: [b, setsz]
108 | :param x_qry: [b], where each unit is a mini-batch of subgraphs, i.e. x_spt[0] is a DGL batch of # setsz subgraphs
109 | :param y_qry: [b, querysz]
110 | :return:
111 | """
112 | task_num = len(x_spt)
113 | querysz = len(y_qry[0])
114 | losses_s = [0 for _ in range(self.update_step)]
115 | losses_q = [0 for _ in range(self.update_step + 1)] # losses_q[i] is the loss on step i
116 | corrects = [0 for _ in range(self.update_step + 1)]
117 |
118 | for i in range(task_num):
119 | feat_spt = torch.Tensor(np.vstack(([feat[g_spt[i][j]][np.array(x)] for j, x in enumerate(n_spt[i])]))).to(device)
120 | feat_qry = torch.Tensor(np.vstack(([feat[g_qry[i][j]][np.array(x)] for j, x in enumerate(n_qry[i])]))).to(device)
121 | # 1. run the i-th task and compute loss for k=0
122 | logits, _ = self.net(x_spt[i].to(device), c_spt[i].to(device), feat_spt, vars=None)
123 | loss, _, prototypes = proto_loss_spt(logits, y_spt[i], self.k_spt)
124 | losses_s[0] += loss
125 | grad = torch.autograd.grad(loss, self.net.parameters())
126 | fast_weights = list(map(lambda p: p[1] - self.update_lr * p[0], zip(grad, self.net.parameters())))
127 |
128 | # this is the loss and accuracy before first update
129 | with torch.no_grad():
130 | # [setsz, nway]
131 | logits_q, _ = self.net(x_qry[i].to(device), c_qry[i].to(device), feat_qry, self.net.parameters())
132 | loss_q, acc_q = proto_loss_qry(logits_q, y_qry[i], prototypes)
133 | losses_q[0] += loss_q
134 | corrects[0] = corrects[0] + acc_q
135 |
136 | # this is the loss and accuracy after the first update
137 | with torch.no_grad():
138 | logits_q, _ = self.net(x_qry[i].to(device), c_qry[i].to(device), feat_qry, fast_weights)
139 | loss_q, acc_q = proto_loss_qry(logits_q, y_qry[i], prototypes)
140 | losses_q[1] += loss_q
141 | corrects[1] = corrects[1] + acc_q
142 |
143 | for k in range(1, self.update_step):
144 | # 1. run the i-th task and compute loss for k=1~K-1
145 | logits, _ = self.net(x_spt[i].to(device), c_spt[i].to(device), feat_spt, fast_weights)
146 | loss, _, prototypes = proto_loss_spt(logits, y_spt[i], self.k_spt)
147 | losses_s[k] += loss
148 | # 2. compute grad on theta_pi
149 | grad = torch.autograd.grad(loss, fast_weights, retain_graph=True)
150 | # 3. theta_pi = theta_pi - train_lr * grad
151 | fast_weights = list(map(lambda p: p[1] - self.update_lr * p[0], zip(grad, fast_weights)))
152 | logits_q, _ = self.net(x_qry[i].to(device), c_qry[i].to(device), feat_qry, fast_weights)
153 | # loss_q will be overwritten and just keep the loss_q on last update step.
154 | loss_q, acc_q = proto_loss_qry(logits_q, y_qry[i], prototypes)
155 | losses_q[k + 1] += loss_q
156 |
157 | corrects[k + 1] = corrects[k + 1] + acc_q
158 |
159 | # end of all tasks
160 | # sum over all losses on query set across all tasks
161 | loss_q = losses_q[-1] / task_num
162 |
163 | if torch.isnan(loss_q):
164 | pass
165 | else:
166 | # optimize theta parameters
167 | self.meta_optim.zero_grad()
168 | loss_q.backward()
169 | self.meta_optim.step()
170 |
171 | accs = np.array(corrects) / (task_num)
172 |
173 | return accs
174 |
175 | def finetunning_ProtoMAML(self, x_spt, y_spt, x_qry, y_qry, c_spt, c_qry, n_spt, n_qry, g_spt, g_qry, feat):
176 | querysz = len(y_qry[0])
177 |
178 | corrects = [0 for _ in range(self.update_step_test + 1)]
179 |
180 | # finetunning on the copied model instead of self.net
181 | net = deepcopy(self.net)
182 | x_spt = x_spt[0]
183 | y_spt = y_spt[0]
184 | x_qry = x_qry[0]
185 | y_qry = y_qry[0]
186 | c_spt = c_spt[0]
187 | c_qry = c_qry[0]
188 | n_spt = n_spt[0]
189 | n_qry = n_qry[0]
190 | g_spt = g_spt[0]
191 | g_qry = g_qry[0]
192 |
193 | feat_spt = torch.Tensor(np.vstack(([feat[g_spt[j]][np.array(x)] for j, x in enumerate(n_spt)]))).to(device)
194 | feat_qry = torch.Tensor(np.vstack(([feat[g_qry[j]][np.array(x)] for j, x in enumerate(n_qry)]))).to(device)
195 |
196 |
197 | # 1. run the i-th task and compute loss for k=0
198 | logits, _ = net(x_spt.to(device), c_spt.to(device), feat_spt)
199 | loss, _, prototypes = proto_loss_spt(logits, y_spt, self.k_spt)
200 | grad = torch.autograd.grad(loss, net.parameters())
201 | fast_weights = list(map(lambda p: p[1] - self.update_lr * p[0], zip(grad, net.parameters())))
202 |
203 | # this is the loss and accuracy before first update
204 | with torch.no_grad():
205 | # [setsz, nway]
206 | logits_q, _ = net(x_qry.to(device), c_qry.to(device), feat_qry, net.parameters())
207 | loss_q, acc_q = proto_loss_qry(logits_q, y_qry, prototypes)
208 | corrects[0] = corrects[0] + acc_q
209 | # this is the loss and accuracy after the first update
210 | with torch.no_grad():
211 | # [setsz, nway]
212 | logits_q, _ = net(x_qry.to(device), c_qry.to(device), feat_qry, fast_weights)
213 | loss_q, acc_q = proto_loss_qry(logits_q, y_qry, prototypes)
214 | corrects[1] = corrects[1] + acc_q
215 |
216 |
217 | for k in range(1, self.update_step_test):
218 | # 1. run the i-th task and compute loss for k=1~K-1
219 | logits, _ = net(x_spt.to(device), c_spt.to(device), feat_spt, fast_weights)
220 | loss, _, prototypes = proto_loss_spt(logits, y_spt, self.k_spt)
221 | # 2. compute grad on theta_pi
222 | grad = torch.autograd.grad(loss, fast_weights, retain_graph=True)
223 | # 3. theta_pi = theta_pi - train_lr * grad
224 | fast_weights = list(map(lambda p: p[1] - self.update_lr * p[0], zip(grad, fast_weights)))
225 |
226 | logits_q, _ = net(x_qry.to(device), c_qry.to(device), feat_qry, fast_weights)
227 | # loss_q will be overwritten and just keep the loss_q on last update step.
228 | loss_q, acc_q = proto_loss_qry(logits_q, y_qry, prototypes)
229 | corrects[k + 1] = corrects[k + 1] + acc_q
230 |
231 | del net
232 | accs = np.array(corrects)
233 |
234 | return accs
235 |
236 | def forward(self, x_spt, y_spt, x_qry, y_qry, c_spt, c_qry, n_spt, n_qry, g_spt, g_qry,feat):
237 | if self.method == 'G-Meta':
238 | accs = self.forward_ProtoMAML(x_spt, y_spt, x_qry, y_qry, c_spt, c_qry, n_spt, n_qry, g_spt, g_qry, feat)
239 | return accs
240 |
241 | def finetunning(self, x_spt, y_spt, x_qry, y_qry, c_spt, c_qry, n_spt, n_qry, g_spt, g_qry,feat):
242 | if self.method == 'G-Meta':
243 | accs = self.finetunning_ProtoMAML(x_spt, y_spt, x_qry, y_qry, c_spt, c_qry, n_spt, n_qry, g_spt, g_qry, feat)
244 | return accs
245 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # G-Meta: Graph Meta Learning via Local Subgraphs
2 |
3 | #### Authors: [Kexin Huang](https://www.kexinhuang.com), [Marinka Zitnik](https://zitniklab.hms.harvard.edu)
4 |
5 | #### [Project Website](https://zitniklab.hms.harvard.edu/projects/G-Meta)
6 |
7 | Prevailing methods for graphs require abundant label and edge information for learning. When data for a new task are scarce, meta learning can learn from prior experiences and form much-needed inductive biases for fast adaption to new tasks.
8 |
9 | Here, we introduce G-Meta, a novel meta-learning algorithm for graphs.
10 | G-Meta uses local subgraphs to transfer subgraph-specific information and learn transferable knowledge faster via meta gradients. G-Meta learns how to quickly adapt to a new task using only a handful of nodes or edges in the new task and does so by learning from data points in other graphs or related, albeit disjoint label sets. G-Meta is theoretically justified as we show that the evidence for a prediction can be found in the local subgraph surrounding the target node or edge.
11 |
12 | Experiments on seven datasets and nine baseline methods show that G-Meta outperforms existing methods by up to 16.3%. Unlike previous methods, G-Meta successfully learns in challenging, few-shot learning settings that require generalization to completely new graphs and never-before-seen labels. Finally, G-Meta scales to large graphs, which we demonstrate on a new Tree-of-Life dataset comprising of 1,840 graphs, a two-orders of magnitude increase in the number of graphs used in prior work.
13 |
14 | 
15 |
16 |
17 | ## Environment Installation
18 |
19 | ```bash
20 | python -m pip install --user virtualenv
21 | python -m venv gmeta_env
22 | source activate gmeta_env
23 | pip install -r requirements.txt
24 | ```
25 |
26 | ## Run
27 | ```bash
28 | cd G-Meta
29 | # Single graph disjoint label, node classification (e.g. arxiv-ogbn)
30 | python train.py --data_dir DATA_PATH --task_setup Disjoint
31 | # Multiple graph shared label, node classification (e.g. Tissue-PPI)
32 | python train.py --data_dir DATA_PATH --task_setup Shared
33 | # Multiple graph disjoint label, node classification (e.g. Fold-PPI)
34 | python train.py --data_dir DATA_PATH --task_setup Disjoint
35 | # Multiple graph shared label, link prediction (e.g. FirstMM-DB, Tree-of-Life)
36 | python train.py --data_dir DATA_PATH --task_setup Shared --link_pred_mode True
37 | ```
38 |
39 | It also supports various parameters input:
40 |
41 | ```bash
42 | python train.py --data_dir # str: data path
43 | --task_setup # 'Disjoint' or 'Shared': task setup, disjoint label or shared label
44 | --link_pred_mode # 'True' or 'False': link prediction or node classification
45 | --batchsz # int: number of tasks in total
46 | --epoch # int: epoch size
47 | --h # 1 or 2 or 3: use h-hops neighbor as the subgraph.
48 | --hidden_dim # int: hidden dim size of GNN
49 | --input_dim # int: input dim size of GNN
50 | --k_qry # int: number of query shots for each task
51 | --k_spt # int: number of support shots for each task
52 | --n_way # int: number of ways (size of the label set)
53 | --meta_lr # float: outer loop learning rate
54 | --update_lr # float: inner loop learning rate
55 | --update_step # int: inner loop update steps during training
56 | --update_step_test # int: inner loop update steps during finetuning
57 | --task_num # int: number of tasks for each meta-set
58 | --sample_nodes # int: when subgraph size is above this threshold, it samples this number of nodes from the subgraph
59 | --task_mode # 'True' or 'False': this is specifically for Tissue-PPI, where there are 10 tasks to evaluate.
60 | --num_worker # int: number of workers to process the dataloader. default 0.
61 | --train_result_report_steps # int: number to print the training accuracy.
62 | ```
63 |
64 | To apply it to the five datasets reported in the paper, using the following code as example after you download the processed datasets from the section below.
65 |
66 | **arxiv-ogbn**:
67 |
68 | CLICK HERE FOR THE CODE!
69 |
70 | ```
71 | python G-Meta/train.py --data_dir PATH/G-Meta_Data/arxiv/ \
72 | --epoch 10 \
73 | --task_setup Disjoint \
74 | --k_spt 3 \
75 | --k_qry 24 \
76 | --n_way 3 \
77 | --update_step 10 \
78 | --update_lr 0.01 \
79 | --num_workers 0 \
80 | --train_result_report_steps 200 \
81 | --hidden_dim 256 \
82 | --update_step_test 20 \
83 | --task_num 32 \
84 | --batchsz 10000
85 | ```
86 |
87 |
88 | **Tissue-PPI**:
89 |
90 | CLICK HERE FOR THE CODE!
91 |
92 | ```
93 | python G-Meta/train.py --data_dir PATH/G-Meta_Data/tissue_PPI/ \
94 | --epoch 15 \
95 | --task_setup Shared \
96 | --task_mode True \
97 | --task_n 4 \
98 | --k_qry 10 \
99 | --k_spt 3 \
100 | --update_lr 0.01 \
101 | --update_step 10 \
102 | --meta_lr 5e-3 \
103 | --num_workers 0 \
104 | --train_result_report_steps 200 \
105 | --hidden_dim 128 \
106 | --task_num 4 \
107 | --batchsz 1000
108 | ```
109 |
110 |
111 | **Fold-PPI**:
112 |
113 | CLICK HERE FOR THE CODE!
114 |
115 | ```
116 | python G-Meta/train.py --data_dir PATH/G-Meta_Data/fold_PPI/ \
117 | --epoch 5 \
118 | --task_setup Disjoint \
119 | --k_qry 24 \
120 | --k_spt 3 \
121 | --n_way 3 \
122 | --update_lr 0.005 \
123 | --meta_lr 1e-3 \
124 | --num_workers 0 \
125 | --train_result_report_steps 100 \
126 | --hidden_dim 128 \
127 | --update_step_test 20 \
128 | --task_num 16 \
129 | --batchsz 4000
130 | ```
131 |
132 |
133 | **FirstMM-DB**:
134 |
135 | CLICK HERE FOR THE CODE!
136 |
137 | ```
138 | python G-Meta/train.py --data_dir PATH/G-Meta_Data/FirstMM_DB/ \
139 | --epoch 15 \
140 | --task_setup Shared \
141 | --k_qry 32 \
142 | --k_spt 16 \
143 | --n_way 2 \
144 | --update_lr 0.01 \
145 | --update_step 10 \
146 | --meta_lr 5e-4 \
147 | --num_workers 0 \
148 | --train_result_report_steps 200 \
149 | --hidden_dim 128 \
150 | --update_step_test 20 \
151 | --task_num 8 \
152 | --batchsz 1500 \
153 | --link_pred_mod True
154 | ```
155 |
156 |
157 | **Tree-of-Life**:
158 |
159 | CLICK HERE FOR THE CODE!
160 |
161 | ```
162 | python train.py --data_dir PATH/G-Meta_Data/tree-of-life/ \
163 | --epoch 15 \
164 | --task_setup Shared \
165 | --k_qry 16 \
166 | --k_spt 16 \
167 | --n_way 2 \
168 | --update_lr 0.005 \
169 | --update_step 10 \
170 | --meta_lr 0.0005 \
171 | --num_workers 0 \
172 | --train_result_report_steps 200 \
173 | --hidden_dim 256 \
174 | --update_step_test 20 \
175 | --task_num 8 \
176 | --batchsz 5000 \
177 | --link_pred_mod True
178 | ```
179 |
180 |
181 | Also, check out the [Jupyter notebook example](test.ipynb).
182 |
183 |
184 | ## Data Processing
185 |
186 | We provide the processed data files for five real-world datasets in this [Drive folder](https://drive.google.com/file/d/1TC06A02wmIQteKzqGSbl_i3VIQzsHVop/view?usp=drivesdk) and this [Microsoft OneDrive folder](https://hu-my.sharepoint.com/:u:/g/personal/kexinhuang_hsph_harvard_edu/EbSj1CehKDtKniKqtICWsScBESs9ldWWcTttGdADnFc6Wg?e=gJhl7c).
187 |
188 | 1\) To create your own dataset, create the following files and organize them as follows:
189 |
190 | - `graph_dgl.pkl`: A list of DGL graph objects. For single graph G, use [G].
191 | - `features.npy`: An array of arrays [feat_1, feat_2, ...] where feat_i is the feature matrix of graph i.
192 |
193 | 2.1) Then, for **node classification**, include the following files:
194 | - `train.csv`, `val.csv`, and `test.csv`: Each file has two columns, the first one is 'X_Y' (node Y from graph X) and its label 'Z'. Each file corresponds to the meta-train, meta-val, meta-test set.
195 | - `label.pkl`: A dictionary of labels where {'X_Y': Z} means the node Y in graph X has label Z.
196 |
197 | 2.2) Or, for **link prediction**, note that the support set contains only edges in the highly incomplete graph (e.g., 30% of links) whereas the query set edges are in the rest of the graph (e.g., 70% of links). In the neural message passing, the GNN should ONLY exchange neural messages on the support set graph. Otherwise, the query set performance is biased. Because of that, we split the meta-train/val/test files into separate support and query files. For link prediction, create the following files:
198 | - `train_spt.csv`, `val_spt.csv`, and `test_spt.csv`: Two columns, first one is 'A_B_C' (node B and C from graph A) and the second one is the label. This is for the node pairs in the support set, i.e. positive links should be in the underlying GNN graph.
199 | - `train_qry.csv`, `val_qry.csv`, and `test_qry.csv`:Two columns, first one is 'A_B_C' (node B and C from graph A) and the second one is the label. This is for the node pairs in the query set, i.e. positive links should NOT be in the underlying GNN graph.
200 | - `train.csv`, `val.csv`, and `test.csv`: Merge the above two csv files.
201 | - `label.pkl`: A dictionary of labels where {'A_B_C': D} means the node B and node C in graph A has link status D. D can be 0 or 1 means no link or has link.
202 |
203 | We also provide a sample data processing scripts in `data_process` folder. See `node_process.py` and `link_process.py`.
204 |
205 | ## Cite Us
206 |
207 | ```
208 | @article{g-meta,
209 | title={Graph Meta Learning via Local Subgraphs},
210 | author={Huang, Kexin and Zitnik, Marinka},
211 | journal={NeurIPS},
212 | year={2020}
213 | }
214 | ```
215 |
216 | ## Contact
217 |
218 | Open an issue or send an email to kexinhuang@hsph.harvard.edu if you have any question.
219 |
220 |
--------------------------------------------------------------------------------
/G-Meta/subgraph_data_processing.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | from torch.utils.data import Dataset
4 | import numpy as np
5 | import collections
6 | import csv
7 | import random
8 | import pickle
9 | from torch.utils.data import DataLoader
10 | import dgl
11 | import networkx as nx
12 | import itertools
13 |
14 | class Subgraphs(Dataset):
15 | def __init__(self, root, mode, subgraph2label, n_way, k_shot, k_query, batchsz, args, adjs, h):
16 | self.batchsz = batchsz # batch of set, not batch of subgraphs
17 | self.n_way = n_way # n-way
18 | self.k_shot = k_shot # k-shot support set
19 | self.k_query = k_query # for query set
20 | self.setsz = self.n_way * self.k_shot # num of samples per support set
21 | self.querysz = self.n_way * self.k_query # number of samples per set for evaluation
22 | self.h = h # number of h hops
23 | self.sample_nodes = args.sample_nodes
24 | print('shuffle DB :%s, b:%d, %d-way, %d-shot, %d-query, %d-hops' % (
25 | mode, batchsz, n_way, k_shot, k_query, h))
26 |
27 | # load subgraph list if preprocessed
28 | self.subgraph2label = subgraph2label
29 |
30 | if args.link_pred_mode == 'True':
31 | self.link_pred_mode = True
32 | else:
33 | self.link_pred_mode = False
34 |
35 | if self.link_pred_mode:
36 | dictLabels_spt, dictGraphs_spt, dictGraphsLabels_spt = self.loadCSV(os.path.join(root, mode + '_spt.csv'))
37 | dictLabels_qry, dictGraphs_qry, dictGraphsLabels_qry = self.loadCSV(os.path.join(root, mode + '_qry.csv'))
38 | dictLabels, dictGraphs, dictGraphsLabels = self.loadCSV(os.path.join(root, mode + '.csv')) # csv path
39 | else:
40 | dictLabels, dictGraphs, dictGraphsLabels = self.loadCSV(os.path.join(root, mode + '.csv')) # csv path
41 |
42 | self.task_setup = args.task_setup
43 |
44 | self.G = []
45 |
46 | for i in adjs:
47 | self.G.append(i)
48 |
49 | self.subgraphs = {}
50 |
51 | if self.task_setup == 'Disjoint':
52 | self.data = []
53 |
54 | for i, (k, v) in enumerate(dictLabels.items()):
55 | self.data.append(v) # [[subgraph1, subgraph2, ...], [subgraph111, ...]]
56 | self.cls_num = len(self.data)
57 |
58 | self.create_batch_disjoint(self.batchsz)
59 | elif self.task_setup == 'Shared':
60 |
61 | if self.link_pred_mode:
62 |
63 | self.data_graph_spt = []
64 |
65 | for i, (k, v) in enumerate(dictGraphs_spt.items()):
66 | self.data_graph_spt.append(v)
67 | self.graph_num_spt = len(self.data_graph_spt)
68 |
69 | self.data_label_spt = [[] for i in range(self.graph_num_spt)]
70 |
71 | relative_idx_map_spt = dict(zip(list(dictGraphs_spt.keys()), range(len(list(dictGraphs_spt.keys())))))
72 |
73 | for i, (k, v) in enumerate(dictGraphsLabels_spt.items()):
74 | for m, n in v.items():
75 | self.data_label_spt[relative_idx_map_spt[k]].append(n)
76 |
77 | self.cls_num_spt = len(self.data_label_spt[0])
78 |
79 | self.data_graph_qry = []
80 |
81 | for i, (k, v) in enumerate(dictGraphs_qry.items()):
82 | self.data_graph_qry.append(v)
83 | self.graph_num_qry = len(self.data_graph_qry)
84 |
85 | self.data_label_qry = [[] for i in range(self.graph_num_qry)]
86 |
87 | relative_idx_map_qry = dict(zip(list(dictGraphs_qry.keys()), range(len(list(dictGraphs_qry.keys())))))
88 |
89 | for i, (k, v) in enumerate(dictGraphsLabels_qry.items()):
90 | for m, n in v.items():
91 | self.data_label_qry[relative_idx_map_qry[k]].append(n)
92 |
93 | self.cls_num_qry = len(self.data_label_qry[0])
94 |
95 | self.create_batch_LinkPred(self.batchsz)
96 |
97 | else:
98 | self.data_graph = []
99 |
100 | for i, (k, v) in enumerate(dictGraphs.items()):
101 | self.data_graph.append(v)
102 | self.graph_num = len(self.data_graph)
103 |
104 | self.data_label = [[] for i in range(self.graph_num)]
105 |
106 | relative_idx_map = dict(zip(list(dictGraphs.keys()), range(len(list(dictGraphs.keys())))))
107 |
108 | for i, (k, v) in enumerate(dictGraphsLabels.items()):
109 | #self.data_label[k] = []
110 | for m, n in v.items():
111 |
112 | self.data_label[relative_idx_map[k]].append(n) # [(graph 1)[(label1)[subgraph1, subgraph2, ...], (label2)[subgraph111, ...]], graph2: [[subgraph1, subgraph2, ...], [subgraph111, ...]] ]
113 | self.cls_num = len(self.data_label[0])
114 | self.graph_num = len(self.data_graph)
115 |
116 | self.create_batch_shared(self.batchsz)
117 |
118 |
119 | def loadCSV(self, csvf):
120 | dictGraphsLabels = {}
121 | dictLabels = {}
122 | dictGraphs = {}
123 |
124 | with open(csvf) as csvfile:
125 | csvreader = csv.reader(csvfile, delimiter=',')
126 | next(csvreader, None) # skip (filename, label)
127 | for i, row in enumerate(csvreader):
128 | filename = row[1]
129 | g_idx = int(filename.split('_')[0])
130 | label = row[2]
131 | # append filename to current label
132 |
133 | if g_idx in dictGraphs.keys():
134 | dictGraphs[g_idx].append(filename)
135 | else:
136 | dictGraphs[g_idx] = [filename]
137 | dictGraphsLabels[g_idx] = {}
138 |
139 | if label in dictGraphsLabels[g_idx].keys():
140 | dictGraphsLabels[g_idx][label].append(filename)
141 | else:
142 | dictGraphsLabels[g_idx][label] = [filename]
143 |
144 | if label in dictLabels.keys():
145 | dictLabels[label].append(filename)
146 | else:
147 | dictLabels[label] = [filename]
148 | return dictLabels, dictGraphs, dictGraphsLabels
149 |
150 | def create_batch_disjoint(self, batchsz):
151 | """
152 | create the entire set of batches of tasks for disjoint label setting, indepedent of # of graphs.
153 | """
154 | self.support_x_batch = [] # support set batch
155 | self.query_x_batch = [] # query set batch
156 | for b in range(batchsz): # for each batch
157 | # 1.select n_way classes randomly
158 | #print(self.cls_num)
159 | #print(self.n_way)
160 | selected_cls = np.random.choice(self.cls_num, self.n_way, False) # no duplicate
161 | np.random.shuffle(selected_cls)
162 | support_x = []
163 | query_x = []
164 | for cls in selected_cls:
165 |
166 | # 2. select k_shot + k_query for each class
167 | selected_subgraphs_idx = np.random.choice(len(self.data[cls]), self.k_shot + self.k_query, False)
168 |
169 | np.random.shuffle(selected_subgraphs_idx)
170 | indexDtrain = np.array(selected_subgraphs_idx[:self.k_shot]) # idx for Dtrain
171 | indexDtest = np.array(selected_subgraphs_idx[self.k_shot:]) # idx for Dtest
172 | support_x.append(
173 | np.array(self.data[cls])[indexDtrain].tolist()) # get all subgraphs filename for current Dtrain
174 | query_x.append(np.array(self.data[cls])[indexDtest].tolist())
175 |
176 | # shuffle the correponding relation between support set and query set
177 | random.shuffle(support_x)
178 | random.shuffle(query_x)
179 |
180 | # support_x: [setsz (k_shot+k_query * n_way)] numbers of subgraphs
181 | self.support_x_batch.append(support_x) # append set to current sets
182 | self.query_x_batch.append(query_x) # append sets to current sets
183 |
184 | def create_batch_shared(self, batchsz):
185 | """
186 | create the entire set of batches of tasks for shared label setting, indepedent of # of graphs.
187 | """
188 | k_shot = self.k_shot
189 | k_query = self.k_query
190 |
191 | self.support_x_batch = [] # support set batch
192 | self.query_x_batch = [] # query set batch
193 | for b in range(batchsz): # one loop generates one task
194 | # 1.select n_way classes randomly
195 | #print(self.cls_num)
196 | #print(self.n_way)
197 |
198 | selected_graph = np.random.choice(self.graph_num, 1, False)[0] # select one graph
199 | data = self.data_label[selected_graph]
200 |
201 | selected_cls = np.array(list(range(len(data)))) # for multiple graph setting, we select cls_num * k_shot nodes
202 | np.random.shuffle(selected_cls)
203 |
204 | support_x = []
205 | query_x = []
206 |
207 | for cls in selected_cls:
208 |
209 | # 2. select k_shot + k_query for each class
210 | try:
211 | selected_subgraphs_idx = np.random.choice(len(data[cls]), k_shot + k_query, False)
212 | np.random.shuffle(selected_subgraphs_idx)
213 | indexDtrain = np.array(selected_subgraphs_idx[:k_shot]) # idx for Dtrain
214 | indexDtest = np.array(selected_subgraphs_idx[k_shot:]) # idx for Dtest
215 | support_x.append(
216 | np.array(data[cls])[indexDtrain].tolist()) # get all subgraphs filename for current Dtrain
217 | query_x.append(np.array(data[cls])[indexDtest].tolist())
218 | except:
219 | # this was not used in practice
220 | if len(data[cls]) >= k_shot:
221 | selected_subgraphs_idx = np.array(range(len(data[cls])))
222 | np.random.shuffle(selected_subgraphs_idx)
223 | indexDtrain = np.array(selected_subgraphs_idx[:k_shot]) # idx for Dtrain
224 | indexDtest = np.array(selected_subgraphs_idx[k_shot:]) # idx for Dtest
225 | support_x.append(
226 | np.array(data[cls])[indexDtrain].tolist()) # get all subgraphs filename for current Dtrain
227 |
228 | num_more = k_shot + k_query - len(data[cls])
229 | count = 0
230 |
231 | query_tmp = np.array(data[cls])[indexDtest].tolist()
232 |
233 | while count <= num_more:
234 | sub_cls = np.random.choice(selected_cls, 1)[0]
235 | idx = np.random.choice(len(data[sub_cls]), 1)[0]
236 | query_tmp = query_tmp + [np.array(data[sub_cls])[idx]]
237 | count += 1
238 | query_x.append(query_tmp)
239 | else:
240 | print('each class in a graph must have larger than k_shot entities in the current model')
241 |
242 | random.shuffle(support_x)
243 | random.shuffle(query_x)
244 |
245 | # support_x: [setsz (k_shot+k_query * 1)] numbers of subgraphs
246 | self.support_x_batch.append(support_x) # append set to current sets
247 | self.query_x_batch.append(query_x) # append sets to current sets
248 |
249 | def create_batch_LinkPred(self, batchsz):
250 | """
251 | create the entire set of batches of tasks for shared label linked prediction setting, indepedent of # of graphs.
252 | """
253 | k_shot = self.k_shot
254 | k_query = self.k_query
255 |
256 | self.support_x_batch = [] # support set batch
257 | self.query_x_batch = [] # query set batch
258 |
259 | for b in range(batchsz): # one loop generates one task
260 |
261 | selected_graph = np.random.choice(self.graph_num_spt, 1, False)[0] # select one graph
262 | data_spt = self.data_label_spt[selected_graph]
263 |
264 | selected_cls_spt = np.array(list(range(len(data_spt)))) # for multiple graph setting, we select cls_num * k_shot nodes
265 | np.random.shuffle(selected_cls_spt)
266 |
267 | data_qry = self.data_label_qry[selected_graph]
268 |
269 | selected_cls_qry = np.array(list(range(len(data_qry)))) # for multiple graph setting, we select cls_num * k_shot nodes
270 | np.random.shuffle(selected_cls_qry)
271 |
272 | support_x = []
273 | query_x = []
274 |
275 | for cls in selected_cls_spt:
276 |
277 | selected_subgraphs_idx = np.random.choice(len(data_spt[cls]), k_shot, False)
278 | np.random.shuffle(selected_subgraphs_idx)
279 | support_x.append(
280 | np.array(data_spt[cls])[selected_subgraphs_idx].tolist()) # get all subgraphs filename for current Dtrain
281 |
282 | for cls in selected_cls_qry:
283 |
284 | selected_subgraphs_idx = np.random.choice(len(data_qry[cls]), k_query, False)
285 | np.random.shuffle(selected_subgraphs_idx)
286 | query_x.append(np.array(data_qry[cls])[selected_subgraphs_idx].tolist())
287 |
288 | random.shuffle(support_x)
289 | random.shuffle(query_x)
290 |
291 | self.support_x_batch.append(support_x) # append set to current sets
292 | self.query_x_batch.append(query_x) # append sets to current sets
293 |
294 | # helper to generate subgraphs on the fly.
295 | def generate_subgraph(self, G, i, item):
296 | if item in self.subgraphs:
297 | return self.subgraphs[item]
298 | else:
299 | # instead of calculating shortest distance, we find the following ways to get subgraphs are quicker
300 | if self.h == 2:
301 | f_hop = [n.item() for n in G.in_edges(i)[0]]
302 | n_l = [[n.item() for n in G.in_edges(i)[0]] for i in f_hop]
303 | h_hops_neighbor = torch.tensor(list(set(list(itertools.chain(*n_l)) + f_hop + [i]))).numpy()
304 | elif self.h == 1:
305 | f_hop = [n.item() for n in G.in_edges(i)[0]]
306 | h_hops_neighbor = torch.tensor(list(set(f_hop + [i]))).numpy()
307 | elif self.h == 3:
308 | f_hop = [n.item() for n in G.in_edges(i)[0]]
309 | n_2 = [[n.item() for n in G.in_edges(i)[0]] for i in f_hop]
310 | n_3 = [[n.item() for n in G.in_edges(i)[0]] for i in list(itertools.chain(*n_2))]
311 | h_hops_neighbor = torch.tensor(list(set(list(itertools.chain(*n_2)) + list(itertools.chain(*n_3)) + f_hop + [i]))).numpy()
312 | if h_hops_neighbor.reshape(-1,).shape[0] > self.sample_nodes:
313 | h_hops_neighbor = np.random.choice(h_hops_neighbor, self.sample_nodes, replace = False)
314 | h_hops_neighbor = np.unique(np.append(h_hops_neighbor, [i]))
315 |
316 | sub = G.subgraph(h_hops_neighbor)
317 | h_c = list(sub.parent_nid.numpy())
318 | dict_ = dict(zip(h_c, list(range(len(h_c)))))
319 | self.subgraphs[item] = (sub, dict_[i], h_c)
320 |
321 | return sub, dict_[i], h_c
322 |
323 | def generate_subgraph_link_pred(self, G, i, j, item):
324 | if item in self.subgraphs:
325 | return self.subgraphs[item]
326 | else:
327 | f_hop = [n.item() for n in G.in_edges(i)[0]]
328 | n_l = [[n.item() for n in G.in_edges(i)[0]] for i in f_hop]
329 | h_hops_neighbor1 = torch.tensor(list(set([item for sublist in n_l for item in sublist] + f_hop + [i]))).numpy()
330 |
331 | f_hop = [n.item() for n in G.in_edges(j)[0]]
332 | n_l = [[n.item() for n in G.in_edges(j)[0]] for i in f_hop]
333 | h_hops_neighbor2 = torch.tensor(list(set([item for sublist in n_l for item in sublist] + f_hop + [j]))).numpy()
334 |
335 | h_hops_neighbor = np.union1d(h_hops_neighbor1, h_hops_neighbor2)
336 |
337 | if h_hops_neighbor.reshape(-1,).shape[0] > self.sample_nodes:
338 | h_hops_neighbor = np.random.choice(h_hops_neighbor, self.sample_nodes, replace = False)
339 | h_hops_neighbor = np.unique(np.append(h_hops_neighbor, [i, j]))
340 |
341 | sub = G.subgraph(h_hops_neighbor)
342 | h_c = list(sub.parent_nid.numpy())
343 | dict_ = dict(zip(h_c, list(range(len(h_c)))))
344 | self.subgraphs[item] = (sub, [dict_[i], dict_[j]], h_c)
345 |
346 | return sub, [dict_[i], dict_[j]], h_c
347 |
348 | def __getitem__(self, index):
349 | """
350 | get one task. support_x_batch[index], query_x_batch[index]
351 |
352 | """
353 | #print(self.support_x_batch[index])
354 | if self.link_pred_mode:
355 | info = [self.generate_subgraph_link_pred(self.G[int(item.split('_')[0])], int(item.split('_')[1]), int(item.split('_')[2]), item)
356 | for sublist in self.support_x_batch[index] for item in sublist]
357 | else:
358 | info = [self.generate_subgraph(self.G[int(item.split('_')[0])], int(item.split('_')[1]), item)
359 | for sublist in self.support_x_batch[index] for item in sublist]
360 |
361 | support_graph_idx = [int(item.split('_')[0]) # obtain a list of DGL subgraphs
362 | for sublist in self.support_x_batch[index] for item in sublist]
363 |
364 | support_x = [i for i, j, k in info]
365 | support_y = np.array([self.subgraph2label[item]
366 | for sublist in self.support_x_batch[index] for item in sublist]).astype(np.int32)
367 |
368 | support_center = np.array([j for i, j, k in info]).astype(np.int32)
369 | support_node_idx = [k for i, j, k in info]
370 |
371 |
372 | if self.link_pred_mode:
373 | info = [self.generate_subgraph_link_pred(self.G[int(item.split('_')[0])], int(item.split('_')[1]), int(item.split('_')[2]), item)
374 | for sublist in self.query_x_batch[index] for item in sublist]
375 | else:
376 | info = [self.generate_subgraph(self.G[int(item.split('_')[0])], int(item.split('_')[1]), item)
377 | for sublist in self.query_x_batch[index] for item in sublist]
378 |
379 | query_graph_idx = [int(item.split('_')[0]) # obtain a list of DGL subgraphs
380 | for sublist in self.query_x_batch[index] for item in sublist]
381 |
382 | query_x = [i for i, j, k in info]
383 | query_y = np.array([self.subgraph2label[item]
384 | for sublist in self.query_x_batch[index] for item in sublist]).astype(np.int32)
385 |
386 | query_center = np.array([j for i, j, k in info]).astype(np.int32)
387 | query_node_idx = [k for i, j, k in info]
388 |
389 | if self.task_setup == 'Disjoint':
390 | unique = np.unique(support_y)
391 | random.shuffle(unique)
392 | # relative means the label ranges from 0 to n-way
393 | support_y_relative = np.zeros(self.setsz)
394 | query_y_relative = np.zeros(self.querysz)
395 | for idx, l in enumerate(unique):
396 | support_y_relative[support_y == l] = idx
397 | query_y_relative[query_y == l] = idx
398 | # this is a set of subgraphs for one task.
399 | batched_graph_spt = dgl.batch(support_x)
400 | batched_graph_qry = dgl.batch(query_x)
401 |
402 | return batched_graph_spt, torch.LongTensor(support_y_relative), batched_graph_qry, torch.LongTensor(query_y_relative), torch.LongTensor(support_center), torch.LongTensor(query_center), support_node_idx, query_node_idx, support_graph_idx, query_graph_idx
403 | elif self.task_setup == 'Shared':
404 |
405 | batched_graph_spt = dgl.batch(support_x)
406 | batched_graph_qry = dgl.batch(query_x)
407 |
408 | return batched_graph_spt, torch.LongTensor(support_y), batched_graph_qry, torch.LongTensor(query_y), torch.LongTensor(support_center), torch.LongTensor(query_center), support_node_idx, query_node_idx, support_graph_idx, query_graph_idx
409 |
410 | def __len__(self):
411 | # as we have built up to batchsz of sets, you can sample some small batch size of sets.
412 | return self.batchsz
413 |
414 | def collate(samples):
415 | # The input `samples` is a list of pairs
416 | # (graph, label).
417 | graphs_spt, labels_spt, graph_qry, labels_qry, center_spt, center_qry, nodeidx_spt, nodeidx_qry, support_graph_idx, query_graph_idx = map(list, zip(*samples))
418 |
419 | return graphs_spt, labels_spt, graph_qry, labels_qry, center_spt, center_qry, nodeidx_spt, nodeidx_qry, support_graph_idx, query_graph_idx
420 |
--------------------------------------------------------------------------------
/test.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 10,
6 | "metadata": {},
7 | "outputs": [
8 | {
9 | "name": "stdout",
10 | "output_type": "stream",
11 | "text": [
12 | "Using backend: pytorch\n",
13 | "Namespace(attention_size=32, batchsz=10000, data_dir='/n/scratch3/users/k/kh278/G-Meta_Data/arxiv/', epoch=10, h=2, hidden_dim=256, input_dim=1, k_qry=24, k_spt=3, link_pred_mode='False', meta_lr=0.001, method='G-Meta', n_way=3, no_finetune=True, num_workers=0, sample_nodes=1000, task_mode='False', task_n=1, task_num=32, task_setup='Disjoint', train_result_report_steps=200, update_lr=0.01, update_step=10, update_step_test=20, val_result_report_steps=100)\n",
14 | "There are 40 classes \n",
15 | "Meta(\n",
16 | " (net): Classifier(\n",
17 | " (vars): ParameterList(\n",
18 | " (0): Parameter containing: [torch.cuda.FloatTensor of size 128x256 (GPU 0)]\n",
19 | " (1): Parameter containing: [torch.cuda.FloatTensor of size 256 (GPU 0)]\n",
20 | " (2): Parameter containing: [torch.cuda.FloatTensor of size 256x256 (GPU 0)]\n",
21 | " (3): Parameter containing: [torch.cuda.FloatTensor of size 256 (GPU 0)]\n",
22 | " (4): Parameter containing: [torch.cuda.FloatTensor of size 3x256 (GPU 0)]\n",
23 | " (5): Parameter containing: [torch.cuda.FloatTensor of size 3 (GPU 0)]\n",
24 | " )\n",
25 | " )\n",
26 | ")\n",
27 | "Total trainable tensors: 99587\n",
28 | "shuffle DB :train, b:10000, 3-way, 3-shot, 24-query, 2-hops\n",
29 | "shuffle DB :val, b:100, 3-way, 3-shot, 24-query, 2-hops\n",
30 | "shuffle DB :test, b:100, 3-way, 3-shot, 24-query, 2-hops\n",
31 | "------ Start Training ------\n",
32 | "Epoch: 1 Step: 0 training acc: 0.338 time elapsed: 6.295 data loading takes: 4.125 Memory usage: 28.71\n",
33 | "Epoch: 1 Step: 200 training acc: 0.462 time elapsed: 4.871 data loading takes: 0.591 Memory usage: 31.05\n",
34 | "Epoch: 1 Val acc: 0.456\n",
35 | "Epoch: 2 Step: 0 training acc: 0.440 time elapsed: 4.613 data loading takes: 0.377 Memory usage: 31.74\n",
36 | "Epoch: 2 Step: 200 training acc: 0.480 time elapsed: 4.770 data loading takes: 0.483 Memory usage: 39.74\n",
37 | "Epoch: 2 Val acc: 0.445\n",
38 | "Epoch: 3 Step: 0 training acc: 0.521 time elapsed: 4.731 data loading takes: 0.442 Memory usage: 40.24\n",
39 | "Epoch: 3 Step: 200 training acc: 0.497 time elapsed: 4.776 data loading takes: 0.541 Memory usage: 40.33\n",
40 | "Epoch: 3 Val acc: 0.445\n",
41 | "Epoch: 4 Step: 0 training acc: 0.509 time elapsed: 4.752 data loading takes: 0.476 Memory usage: 39.76\n",
42 | "Epoch: 4 Step: 200 training acc: 0.505 time elapsed: 4.812 data loading takes: 0.471 Memory usage: 39.87\n",
43 | "Epoch: 4 Val acc: 0.443\n",
44 | "Epoch: 5 Step: 0 training acc: 0.503 time elapsed: 4.892 data loading takes: 0.473 Memory usage: 40.82\n",
45 | "Epoch: 5 Step: 200 training acc: 0.477 time elapsed: 4.825 data loading takes: 0.487 Memory usage: 40.92\n",
46 | "Epoch: 5 Val acc: 0.441\n",
47 | "Epoch: 6 Step: 0 training acc: 0.509 time elapsed: 4.829 data loading takes: 0.412 Memory usage: 40.33\n",
48 | "Epoch: 6 Step: 200 training acc: 0.519 time elapsed: 4.861 data loading takes: 0.453 Memory usage: 40.39\n",
49 | "Epoch: 6 Val acc: 0.444\n",
50 | "Epoch: 7 Step: 0 training acc: 0.553 time elapsed: 4.757 data loading takes: 0.399 Memory usage: 41.46\n",
51 | "Epoch: 7 Step: 200 training acc: 0.503 time elapsed: 4.825 data loading takes: 0.519 Memory usage: 41.73\n",
52 | "Epoch: 7 Val acc: 0.439\n",
53 | "Epoch: 8 Step: 0 training acc: 0.523 time elapsed: 4.859 data loading takes: 0.490 Memory usage: 41.53\n",
54 | "Epoch: 8 Step: 200 training acc: 0.516 time elapsed: 4.675 data loading takes: 0.554 Memory usage: 41.43\n",
55 | "Epoch: 8 Val acc: 0.441\n",
56 | "Epoch: 9 Step: 0 training acc: 0.519 time elapsed: 4.795 data loading takes: 0.444 Memory usage: 40.89\n",
57 | "Epoch: 9 Step: 200 training acc: 0.507 time elapsed: 4.875 data loading takes: 3.407 Memory usage: 41.21\n",
58 | "Epoch: 9 Val acc: 0.443\n",
59 | "Epoch: 10 Step: 0 training acc: 0.503 time elapsed: 4.898 data loading takes: 0.423 Memory usage: 41.88\n",
60 | "Epoch: 10 Step: 200 training acc: 0.560 time elapsed: 4.960 data loading takes: 0.492 Memory usage: 42.23\n",
61 | "Epoch: 10 Val acc: 0.44\n",
62 | "Test acc: 0.421\n",
63 | "Early Stopped Test acc: 0.436\n",
64 | "Total Time: 17206\n",
65 | "Max Momory: 42.52\n"
66 | ]
67 | }
68 | ],
69 | "source": [
70 | "!python G-Meta/train.py --data_dir DATA_PATH/G-Meta_Data/arxiv/ \\\n",
71 | " --epoch 10 \\\n",
72 | " --task_setup Disjoint \\\n",
73 | " --k_spt 3 \\\n",
74 | " --k_qry 24 \\\n",
75 | " --n_way 3 \\\n",
76 | " --update_step 10 \\\n",
77 | " --update_lr 0.01 \\\n",
78 | " --num_workers 0 \\\n",
79 | " --train_result_report_steps 200 \\\n",
80 | " --hidden_dim 256 \\\n",
81 | " --update_step_test 20 \\\n",
82 | " --task_num 32 \\\n",
83 | " --batchsz 10000 "
84 | ]
85 | },
86 | {
87 | "cell_type": "code",
88 | "execution_count": 1,
89 | "metadata": {
90 | "scrolled": false
91 | },
92 | "outputs": [
93 | {
94 | "name": "stdout",
95 | "output_type": "stream",
96 | "text": [
97 | "Using backend: pytorch\n",
98 | "Namespace(attention_size=32, batchsz=1000, data_dir='/n/scratch3/users/k/kh278/G-Meta_Data/tissue_PPI/', epoch=15, h=2, hidden_dim=128, input_dim=1, k_qry=10, k_spt=3, link_pred_mode='False', meta_lr=0.005, method='G-Meta', n_way=3, no_finetune=True, num_workers=0, sample_nodes=1000, task_mode='True', task_n=4, task_num=4, task_setup='Shared', train_result_report_steps=200, update_lr=0.01, update_step=10, update_step_test=10, val_result_report_steps=100)\n",
99 | "There are 2 classes \n",
100 | "Meta(\n",
101 | " (net): Classifier(\n",
102 | " (vars): ParameterList(\n",
103 | " (0): Parameter containing: [torch.cuda.FloatTensor of size 50x128 (GPU 0)]\n",
104 | " (1): Parameter containing: [torch.cuda.FloatTensor of size 128 (GPU 0)]\n",
105 | " (2): Parameter containing: [torch.cuda.FloatTensor of size 128x128 (GPU 0)]\n",
106 | " (3): Parameter containing: [torch.cuda.FloatTensor of size 128 (GPU 0)]\n",
107 | " (4): Parameter containing: [torch.cuda.FloatTensor of size 2x128 (GPU 0)]\n",
108 | " (5): Parameter containing: [torch.cuda.FloatTensor of size 2 (GPU 0)]\n",
109 | " )\n",
110 | " )\n",
111 | ")\n",
112 | "Total trainable tensors: 23298\n",
113 | "shuffle DB :train, b:1000, 3-way, 3-shot, 10-query, 2-hops\n",
114 | "shuffle DB :val, b:100, 3-way, 3-shot, 10-query, 2-hops\n",
115 | "shuffle DB :test, b:100, 3-way, 3-shot, 10-query, 2-hops\n",
116 | "------ Start Training ------\n",
117 | "Epoch: 1 Step: 0 training acc: 0.575 time elapsed: 1.461 data loading takes: 1.735 Memory usage: 11.21\n",
118 | "Epoch: 1 Step: 200 training acc: 0.575 time elapsed: 0.578 data loading takes: 1.836 Memory usage: 31.95\n",
119 | "Epoch: 1 Val acc: 0.563\n",
120 | "Epoch: 2 Step: 0 training acc: 0.587 time elapsed: 0.532 data loading takes: 0.465 Memory usage: 37.52\n",
121 | "Epoch: 2 Step: 200 training acc: 0.612 time elapsed: 0.560 data loading takes: 0.549 Memory usage: 37.47\n",
122 | "Epoch: 2 Val acc: 0.608\n",
123 | "Epoch: 3 Step: 0 training acc: 0.65 time elapsed: 0.559 data loading takes: 0.484 Memory usage: 37.46\n",
124 | "Epoch: 3 Step: 200 training acc: 0.65 time elapsed: 0.551 data loading takes: 0.536 Memory usage: 37.52\n",
125 | "Epoch: 3 Val acc: 0.623\n",
126 | "Epoch: 4 Step: 0 training acc: 0.662 time elapsed: 0.540 data loading takes: 0.449 Memory usage: 37.59\n",
127 | "Epoch: 4 Step: 200 training acc: 0.724 time elapsed: 0.593 data loading takes: 0.649 Memory usage: 37.61\n",
128 | "Epoch: 4 Val acc: 0.644\n",
129 | "Epoch: 5 Step: 0 training acc: 0.787 time elapsed: 0.548 data loading takes: 0.479 Memory usage: 37.65\n",
130 | "Epoch: 5 Step: 200 training acc: 0.725 time elapsed: 0.555 data loading takes: 0.556 Memory usage: 37.71\n",
131 | "Epoch: 5 Val acc: 0.645\n",
132 | "Epoch: 6 Step: 0 training acc: 0.612 time elapsed: 0.537 data loading takes: 0.443 Memory usage: 37.56\n",
133 | "Epoch: 6 Step: 200 training acc: 0.674 time elapsed: 0.560 data loading takes: 0.562 Memory usage: 37.70\n",
134 | "Epoch: 6 Val acc: 0.666\n",
135 | "Epoch: 7 Step: 0 training acc: 0.575 time elapsed: 0.574 data loading takes: 0.579 Memory usage: 37.63\n",
136 | "Epoch: 7 Step: 200 training acc: 0.775 time elapsed: 0.551 data loading takes: 0.523 Memory usage: 37.67\n",
137 | "Epoch: 7 Val acc: 0.672\n",
138 | "Epoch: 8 Step: 0 training acc: 0.812 time elapsed: 0.531 data loading takes: 0.423 Memory usage: 37.71\n",
139 | "Epoch: 8 Step: 200 training acc: 0.737 time elapsed: 0.520 data loading takes: 0.460 Memory usage: 37.77\n",
140 | "Epoch: 8 Val acc: 0.681\n",
141 | "Epoch: 9 Step: 0 training acc: 0.825 time elapsed: 0.528 data loading takes: 0.437 Memory usage: 37.70\n",
142 | "Epoch: 9 Step: 200 training acc: 0.825 time elapsed: 0.595 data loading takes: 0.622 Memory usage: 37.74\n",
143 | "Epoch: 9 Val acc: 0.689\n",
144 | "Epoch: 10 Step: 0 training acc: 0.812 time elapsed: 0.572 data loading takes: 0.543 Memory usage: 37.72\n",
145 | "Epoch: 10 Step: 200 training acc: 0.762 time elapsed: 0.551 data loading takes: 0.555 Memory usage: 37.76\n",
146 | "Epoch: 10 Val acc: 0.688\n",
147 | "Epoch: 11 Step: 0 training acc: 0.825 time elapsed: 0.554 data loading takes: 0.524 Memory usage: 37.78\n",
148 | "Epoch: 11 Step: 200 training acc: 0.75 time elapsed: 0.608 data loading takes: 0.654 Memory usage: 37.76\n",
149 | "Epoch: 11 Val acc: 0.716\n",
150 | "Epoch: 12 Step: 0 training acc: 0.787 time elapsed: 0.524 data loading takes: 0.445 Memory usage: 37.90\n",
151 | "Epoch: 12 Step: 200 training acc: 0.787 time elapsed: 0.604 data loading takes: 0.683 Memory usage: 37.66\n",
152 | "Epoch: 12 Val acc: 0.703\n",
153 | "Epoch: 13 Step: 0 training acc: 0.762 time elapsed: 0.551 data loading takes: 0.478 Memory usage: 37.59\n",
154 | "Epoch: 13 Step: 200 training acc: 0.85 time elapsed: 0.538 data loading takes: 0.458 Memory usage: 37.62\n",
155 | "Epoch: 13 Val acc: 0.723\n",
156 | "Epoch: 14 Step: 0 training acc: 0.85 time elapsed: 0.540 data loading takes: 0.459 Memory usage: 37.66\n",
157 | "Epoch: 14 Step: 200 training acc: 0.762 time elapsed: 0.568 data loading takes: 0.571 Memory usage: 37.65\n",
158 | "Epoch: 14 Val acc: 0.723\n",
159 | "Epoch: 15 Step: 0 training acc: 0.8 time elapsed: 0.499 data loading takes: 0.321 Memory usage: 37.57\n",
160 | "Epoch: 15 Step: 200 training acc: 0.762 time elapsed: 0.536 data loading takes: 0.481 Memory usage: 37.69\n",
161 | "Epoch: 15 Val acc: 0.730\n",
162 | "Test acc: 0.78\n",
163 | "Early Stopped Test acc: 0.774\n",
164 | "Total Time: 4852.\n",
165 | "Max Momory: 37.90\n"
166 | ]
167 | }
168 | ],
169 | "source": [
170 | "!python G-Meta/train.py --data_dir DATA_PATH/G-Meta_Data/tissue_PPI/ \\\n",
171 | " --epoch 15 \\\n",
172 | " --task_setup Shared \\\n",
173 | " --task_mode True \\\n",
174 | " --task_n 4 \\\n",
175 | " --k_qry 10 \\\n",
176 | " --k_spt 3 \\\n",
177 | " --update_lr 0.01 \\\n",
178 | " --update_step 10 \\\n",
179 | " --meta_lr 5e-3 \\\n",
180 | " --num_workers 0 \\\n",
181 | " --train_result_report_steps 200 \\\n",
182 | " --hidden_dim 128 \\\n",
183 | " --task_num 4 \\\n",
184 | " --batchsz 1000"
185 | ]
186 | },
187 | {
188 | "cell_type": "code",
189 | "execution_count": 3,
190 | "metadata": {},
191 | "outputs": [
192 | {
193 | "name": "stdout",
194 | "output_type": "stream",
195 | "text": [
196 | "Using backend: pytorch\n",
197 | "Namespace(attention_size=32, batchsz=4000, data_dir='/n/scratch3/users/k/kh278/G-Meta_Data/fold_PPI/', epoch=5, h=2, hidden_dim=128, input_dim=1, k_qry=24, k_spt=3, link_pred_mode='False', meta_lr=0.001, method='G-Meta', n_way=3, no_finetune=True, num_workers=0, sample_nodes=1000, task_mode='False', task_n=1, task_num=16, task_setup='Disjoint', train_result_report_steps=100, update_lr=0.005, update_step=5, update_step_test=20, val_result_report_steps=100)\n",
198 | "There are 30 classes \n",
199 | "Meta(\n",
200 | " (net): Classifier(\n",
201 | " (vars): ParameterList(\n",
202 | " (0): Parameter containing: [torch.cuda.FloatTensor of size 512x128 (GPU 0)]\n",
203 | " (1): Parameter containing: [torch.cuda.FloatTensor of size 128 (GPU 0)]\n",
204 | " (2): Parameter containing: [torch.cuda.FloatTensor of size 128x128 (GPU 0)]\n",
205 | " (3): Parameter containing: [torch.cuda.FloatTensor of size 128 (GPU 0)]\n",
206 | " (4): Parameter containing: [torch.cuda.FloatTensor of size 3x128 (GPU 0)]\n",
207 | " (5): Parameter containing: [torch.cuda.FloatTensor of size 3 (GPU 0)]\n",
208 | " )\n",
209 | " )\n",
210 | ")\n",
211 | "Total trainable tensors: 82563\n",
212 | "shuffle DB :train, b:4000, 3-way, 3-shot, 24-query, 2-hops\n",
213 | "shuffle DB :val, b:100, 3-way, 3-shot, 24-query, 2-hops\n",
214 | "shuffle DB :test, b:100, 3-way, 3-shot, 24-query, 2-hops\n",
215 | "------ Start Training ------\n",
216 | "Epoch: 1 Step: 0 training acc: 0.427 time elapsed: 3.852 data loading takes: 6.680 Memory usage: 27.73\n",
217 | "Epoch: 1 Step: 100 training acc: 0.451 time elapsed: 3.616 data loading takes: 1.766 Memory usage: 41.06\n",
218 | "Epoch: 1 Step: 200 training acc: 0.447 time elapsed: 4.523 data loading takes: 1.990 Memory usage: 41.58\n",
219 | "Epoch: 1 Val acc: 0.478\n",
220 | "Epoch: 2 Step: 0 training acc: 0.471 time elapsed: 3.576 data loading takes: 1.379 Memory usage: 43.33\n",
221 | "Epoch: 2 Step: 100 training acc: 0.571 time elapsed: 3.528 data loading takes: 1.434 Memory usage: 42.90\n",
222 | "Epoch: 2 Step: 200 training acc: 0.589 time elapsed: 3.381 data loading takes: 1.399 Memory usage: 43.79\n",
223 | "Epoch: 2 Val acc: 0.494\n",
224 | "Epoch: 3 Step: 0 training acc: 0.566 time elapsed: 3.575 data loading takes: 1.249 Memory usage: 44.07\n",
225 | "Epoch: 3 Step: 100 training acc: 0.662 time elapsed: 3.753 data loading takes: 1.538 Memory usage: 43.48\n",
226 | "Epoch: 3 Step: 200 training acc: 0.663 time elapsed: 3.829 data loading takes: 1.579 Memory usage: 44.14\n",
227 | "Epoch: 3 Val acc: 0.522\n",
228 | "Epoch: 4 Step: 0 training acc: 0.594 time elapsed: 3.804 data loading takes: 1.329 Memory usage: 43.42\n",
229 | "Epoch: 4 Step: 100 training acc: 0.598 time elapsed: 3.962 data loading takes: 1.630 Memory usage: 43.99\n",
230 | "Epoch: 4 Step: 200 training acc: 0.743 time elapsed: 3.637 data loading takes: 1.465 Memory usage: 43.46\n",
231 | "Epoch: 4 Val acc: 0.513\n",
232 | "Epoch: 5 Step: 0 training acc: 0.705 time elapsed: 3.728 data loading takes: 1.319 Memory usage: 43.72\n",
233 | "Epoch: 5 Step: 100 training acc: 0.812 time elapsed: 3.654 data loading takes: 1.521 Memory usage: 44.23\n",
234 | "Epoch: 5 Step: 200 training acc: 0.724 time elapsed: 3.919 data loading takes: 1.637 Memory usage: 44.02\n",
235 | "Epoch: 5 Val acc: 0.543\n",
236 | "Test acc: 0.578\n",
237 | "Early Stopped Test acc: 0.656\n",
238 | "Total Time: 7150.\n",
239 | "Max Momory: 44.39\n"
240 | ]
241 | }
242 | ],
243 | "source": [
244 | "!python G-Meta/train.py --data_dir DATA_PATH/G-Meta_Data/fold_PPI/ \\\n",
245 | " --epoch 5 \\\n",
246 | " --task_setup Disjoint \\\n",
247 | " --k_qry 24 \\\n",
248 | " --k_spt 3 \\\n",
249 | " --n_way 3 \\\n",
250 | " --update_lr 0.005 \\\n",
251 | " --meta_lr 1e-3 \\\n",
252 | " --num_workers 0 \\\n",
253 | " --train_result_report_steps 100 \\\n",
254 | " --hidden_dim 128 \\\n",
255 | " --update_step_test 20 \\\n",
256 | " --task_num 16 \\\n",
257 | " --batchsz 4000"
258 | ]
259 | },
260 | {
261 | "cell_type": "code",
262 | "execution_count": 2,
263 | "metadata": {},
264 | "outputs": [
265 | {
266 | "name": "stdout",
267 | "output_type": "stream",
268 | "text": [
269 | "Using backend: pytorch\n",
270 | "Namespace(attention_size=32, batchsz=1500, data_dir='/n/scratch3/users/k/kh278/G-Meta_Data/FirstMM_DB/', epoch=15, h=2, hidden_dim=128, input_dim=1, k_qry=32, k_spt=16, link_pred_mode='True', meta_lr=0.0005, method='G-Meta', n_way=2, no_finetune=True, num_workers=0, sample_nodes=1000, task_mode='False', task_n=1, task_num=8, task_setup='Shared', train_result_report_steps=200, update_lr=0.01, update_step=10, update_step_test=20, val_result_report_steps=100)\n",
271 | "There are 2 classes \n",
272 | "Meta(\n",
273 | " (net): Classifier(\n",
274 | " (vars): ParameterList(\n",
275 | " (0): Parameter containing: [torch.cuda.FloatTensor of size 5x128 (GPU 0)]\n",
276 | " (1): Parameter containing: [torch.cuda.FloatTensor of size 128 (GPU 0)]\n",
277 | " (2): Parameter containing: [torch.cuda.FloatTensor of size 128x128 (GPU 0)]\n",
278 | " (3): Parameter containing: [torch.cuda.FloatTensor of size 128 (GPU 0)]\n",
279 | " (4): Parameter containing: [torch.cuda.FloatTensor of size 2x256 (GPU 0)]\n",
280 | " (5): Parameter containing: [torch.cuda.FloatTensor of size 2 (GPU 0)]\n",
281 | " )\n",
282 | " )\n",
283 | ")\n",
284 | "Total trainable tensors: 17794\n",
285 | "shuffle DB :train, b:1500, 2-way, 16-shot, 32-query, 2-hops\n",
286 | "shuffle DB :val, b:100, 2-way, 16-shot, 32-query, 2-hops\n",
287 | "shuffle DB :test, b:100, 2-way, 16-shot, 32-query, 2-hops\n",
288 | "------ Start Training ------\n",
289 | "Epoch: 1 Step: 0 training acc: 0.492 time elapsed: 1.279 data loading takes: 1.136 Memory usage: 13.26\n",
290 | "Epoch: 1 Val acc: 0.691\n",
291 | "Epoch: 2 Step: 0 training acc: 0.695 time elapsed: 0.635 data loading takes: 0.080 Memory usage: 14.51\n",
292 | "Epoch: 2 Val acc: 0.735\n",
293 | "Epoch: 3 Step: 0 training acc: 0.693 time elapsed: 0.636 data loading takes: 0.075 Memory usage: 14.53\n",
294 | "Epoch: 3 Val acc: 0.762\n",
295 | "Epoch: 4 Step: 0 training acc: 0.705 time elapsed: 0.642 data loading takes: 0.074 Memory usage: 14.51\n",
296 | "Epoch: 4 Val acc: 0.769\n",
297 | "Epoch: 5 Step: 0 training acc: 0.728 time elapsed: 0.630 data loading takes: 0.077 Memory usage: 14.52\n",
298 | "Epoch: 5 Val acc: 0.778\n",
299 | "Epoch: 6 Step: 0 training acc: 0.75 time elapsed: 0.633 data loading takes: 0.077 Memory usage: 14.52\n",
300 | "Epoch: 6 Val acc: 0.780\n",
301 | "Epoch: 7 Step: 0 training acc: 0.734 time elapsed: 0.640 data loading takes: 0.080 Memory usage: 14.58\n",
302 | "Epoch: 7 Val acc: 0.785\n",
303 | "Epoch: 8 Step: 0 training acc: 0.744 time elapsed: 0.639 data loading takes: 0.072 Memory usage: 14.52\n",
304 | "Epoch: 8 Val acc: 0.786\n",
305 | "Epoch: 9 Step: 0 training acc: 0.732 time elapsed: 1.201 data loading takes: 0.149 Memory usage: 14.52\n",
306 | "Epoch: 9 Val acc: 0.785\n",
307 | "Epoch: 10 Step: 0 training acc: 0.787 time elapsed: 0.633 data loading takes: 0.083 Memory usage: 14.54\n",
308 | "Epoch: 10 Val acc: 0.785\n",
309 | "Epoch: 11 Step: 0 training acc: 0.751 time elapsed: 0.633 data loading takes: 0.072 Memory usage: 14.59\n",
310 | "Epoch: 11 Val acc: 0.789\n",
311 | "Epoch: 12 Step: 0 training acc: 0.748 time elapsed: 0.635 data loading takes: 0.074 Memory usage: 14.58\n",
312 | "Epoch: 12 Val acc: 0.791\n",
313 | "Epoch: 13 Step: 0 training acc: 0.753 time elapsed: 0.635 data loading takes: 0.076 Memory usage: 14.55\n",
314 | "Epoch: 13 Val acc: 0.791\n",
315 | "Epoch: 14 Step: 0 training acc: 0.789 time elapsed: 0.660 data loading takes: 0.082 Memory usage: 14.63\n",
316 | "Epoch: 14 Val acc: 0.793\n",
317 | "Epoch: 15 Step: 0 training acc: 0.738 time elapsed: 0.628 data loading takes: 0.080 Memory usage: 14.59\n",
318 | "Epoch: 15 Val acc: 0.799\n",
319 | "Test acc: 0.769\n",
320 | "Early Stopped Test acc: 0.756\n",
321 | "Total Time: 2536.\n",
322 | "Max Momory: 14.86\n"
323 | ]
324 | }
325 | ],
326 | "source": [
327 | "!python G-Meta/train.py --data_dir DATA_PATH/G-Meta_Data/FirstMM_DB/ \\\n",
328 | " --epoch 15 \\\n",
329 | " --task_setup Shared \\\n",
330 | " --k_qry 32 \\\n",
331 | " --k_spt 16 \\\n",
332 | " --n_way 2 \\\n",
333 | " --update_lr 0.01 \\\n",
334 | " --update_step 10 \\\n",
335 | " --meta_lr 5e-4 \\\n",
336 | " --num_workers 0 \\\n",
337 | " --train_result_report_steps 200 \\\n",
338 | " --hidden_dim 128 \\\n",
339 | " --update_step_test 20 \\\n",
340 | " --task_num 8 \\\n",
341 | " --batchsz 1500 \\\n",
342 | " --link_pred_mod True"
343 | ]
344 | },
345 | {
346 | "cell_type": "code",
347 | "execution_count": 4,
348 | "metadata": {},
349 | "outputs": [
350 | {
351 | "name": "stdout",
352 | "output_type": "stream",
353 | "text": [
354 | "Using backend: pytorch\n",
355 | "Namespace(attention_size=32, batchsz=5000, data_dir='/n/scratch3/users/k/kh278/G-Meta_Data/tree-of-life/', epoch=15, h=2, hidden_dim=256, input_dim=1, k_qry=16, k_spt=16, link_pred_mode='True', meta_lr=0.0005, method='G-Meta', n_way=2, no_finetune=True, num_workers=0, sample_nodes=1000, task_mode='False', task_n=1, task_num=8, task_setup='Shared', train_result_report_steps=200, update_lr=0.005, update_step=10, update_step_test=20, val_result_report_steps=100)\n",
356 | "There are 2 classes \n",
357 | "Meta(\n",
358 | " (net): Classifier(\n",
359 | " (vars): ParameterList(\n",
360 | " (0): Parameter containing: [torch.cuda.FloatTensor of size 1x256 (GPU 0)]\n",
361 | " (1): Parameter containing: [torch.cuda.FloatTensor of size 256 (GPU 0)]\n",
362 | " (2): Parameter containing: [torch.cuda.FloatTensor of size 256x256 (GPU 0)]\n",
363 | " (3): Parameter containing: [torch.cuda.FloatTensor of size 256 (GPU 0)]\n",
364 | " (4): Parameter containing: [torch.cuda.FloatTensor of size 2x512 (GPU 0)]\n",
365 | " (5): Parameter containing: [torch.cuda.FloatTensor of size 2 (GPU 0)]\n",
366 | " )\n",
367 | " )\n",
368 | ")\n",
369 | "Total trainable tensors: 67330\n",
370 | "shuffle DB :train, b:5000, 2-way, 16-shot, 16-query, 2-hops\n",
371 | "shuffle DB :val, b:100, 2-way, 16-shot, 16-query, 2-hops\n",
372 | "shuffle DB :test, b:100, 2-way, 16-shot, 16-query, 2-hops\n",
373 | "------ Start Training ------\n",
374 | "Epoch: 1 Step: 0 training acc: 0.628 time elapsed: 1.257 data loading takes: 3.446 Memory usage: 28.58\n",
375 | "Epoch: 1 Step: 200 training acc: 0.621 time elapsed: 0.734 data loading takes: 26.83 Memory usage: 41.77\n",
376 | "Epoch: 1 Step: 400 training acc: 0.738 time elapsed: 0.701 data loading takes: 3.756 Memory usage: 54.42\n",
377 | "Epoch: 1 Step: 600 training acc: 0.695 time elapsed: 0.648 data loading takes: 1.814 Memory usage: 65.41\n",
378 | "Epoch: 1 Val acc: 0.694\n",
379 | "Epoch: 2 Step: 0 training acc: 0.667 time elapsed: 0.672 data loading takes: 0.160 Memory usage: 67.54\n",
380 | "Epoch: 2 Step: 200 training acc: 0.660 time elapsed: 0.665 data loading takes: 0.151 Memory usage: 68.14\n",
381 | "Epoch: 2 Step: 400 training acc: 0.691 time elapsed: 0.673 data loading takes: 0.144 Memory usage: 67.23\n",
382 | "Epoch: 2 Step: 600 training acc: 0.714 time elapsed: 0.678 data loading takes: 0.172 Memory usage: 67.63\n",
383 | "Epoch: 2 Val acc: 0.702\n",
384 | "Epoch: 3 Step: 0 training acc: 0.695 time elapsed: 0.666 data loading takes: 0.166 Memory usage: 67.67\n",
385 | "Epoch: 3 Step: 200 training acc: 0.714 time elapsed: 0.718 data loading takes: 0.297 Memory usage: 68.02\n",
386 | "Epoch: 3 Step: 400 training acc: 0.710 time elapsed: 0.743 data loading takes: 0.393 Memory usage: 68.19\n",
387 | "Epoch: 3 Step: 600 training acc: 0.699 time elapsed: 0.726 data loading takes: 0.286 Memory usage: 68.07\n",
388 | "Epoch: 3 Val acc: 0.709\n",
389 | "Epoch: 4 Step: 0 training acc: 0.671 time elapsed: 0.656 data loading takes: 0.151 Memory usage: 68.10\n",
390 | "Epoch: 4 Step: 200 training acc: 0.722 time elapsed: 0.698 data loading takes: 0.172 Memory usage: 67.22\n",
391 | "Epoch: 4 Step: 400 training acc: 0.718 time elapsed: 0.680 data loading takes: 0.232 Memory usage: 67.23\n",
392 | "Epoch: 4 Step: 600 training acc: 0.726 time elapsed: 0.682 data loading takes: 0.201 Memory usage: 67.62\n",
393 | "Epoch: 4 Val acc: 0.723\n",
394 | "Epoch: 5 Step: 0 training acc: 0.679 time elapsed: 0.670 data loading takes: 0.190 Memory usage: 67.73\n",
395 | "Epoch: 5 Step: 200 training acc: 0.703 time elapsed: 0.696 data loading takes: 0.143 Memory usage: 68.12\n",
396 | "Epoch: 5 Step: 400 training acc: 0.679 time elapsed: 0.704 data loading takes: 0.279 Memory usage: 68.20\n",
397 | "Epoch: 5 Step: 600 training acc: 0.660 time elapsed: 0.704 data loading takes: 0.232 Memory usage: 68.31\n",
398 | "Epoch: 5 Val acc: 0.722\n",
399 | "Epoch: 6 Step: 0 training acc: 0.730 time elapsed: 0.722 data loading takes: 0.373 Memory usage: 67.17\n",
400 | "Epoch: 6 Step: 200 training acc: 0.726 time elapsed: 0.700 data loading takes: 0.216 Memory usage: 67.19\n",
401 | "Epoch: 6 Step: 400 training acc: 0.714 time elapsed: 0.740 data loading takes: 0.332 Memory usage: 67.80\n",
402 | "Epoch: 6 Step: 600 training acc: 0.707 time elapsed: 0.692 data loading takes: 0.199 Memory usage: 68.09\n",
403 | "Epoch: 6 Val acc: 0.730\n",
404 | "Epoch: 7 Step: 0 training acc: 0.703 time elapsed: 0.707 data loading takes: 0.271 Memory usage: 68.06\n",
405 | "Epoch: 7 Step: 200 training acc: 0.703 time elapsed: 0.691 data loading takes: 0.227 Memory usage: 68.29\n",
406 | "Epoch: 7 Step: 400 training acc: 0.738 time elapsed: 0.682 data loading takes: 0.178 Memory usage: 68.44\n",
407 | "Epoch: 7 Step: 600 training acc: 0.667 time elapsed: 0.696 data loading takes: 0.192 Memory usage: 67.66\n",
408 | "Epoch: 7 Val acc: 0.712\n",
409 | "Epoch: 8 Step: 0 training acc: 0.707 time elapsed: 0.673 data loading takes: 0.227 Memory usage: 67.77\n",
410 | "Epoch: 8 Step: 200 training acc: 0.765 time elapsed: 0.665 data loading takes: 0.216 Memory usage: 55.23\n",
411 | "Epoch: 8 Step: 400 training acc: 0.738 time elapsed: 0.743 data loading takes: 0.315 Memory usage: 57.62\n",
412 | "Epoch: 8 Step: 600 training acc: 0.734 time elapsed: 0.759 data loading takes: 0.259 Memory usage: 57.61\n",
413 | "Epoch: 8 Val acc: 0.734\n",
414 | "Epoch: 9 Step: 0 training acc: 0.710 time elapsed: 0.740 data loading takes: 0.258 Memory usage: 57.70\n",
415 | "Epoch: 9 Step: 200 training acc: 0.699 time elapsed: 0.701 data loading takes: 0.218 Memory usage: 57.63\n",
416 | "Epoch: 9 Step: 400 training acc: 0.734 time elapsed: 0.706 data loading takes: 0.172 Memory usage: 57.51\n",
417 | "Epoch: 9 Step: 600 training acc: 0.738 time elapsed: 0.701 data loading takes: 0.196 Memory usage: 57.50\n",
418 | "Epoch: 9 Val acc: 0.742\n",
419 | "Epoch: 10 Step: 0 training acc: 0.707 time elapsed: 0.697 data loading takes: 0.194 Memory usage: 57.48\n",
420 | "Epoch: 10 Step: 200 training acc: 0.722 time elapsed: 0.713 data loading takes: 0.213 Memory usage: 57.55\n",
421 | "Epoch: 10 Step: 400 training acc: 0.761 time elapsed: 0.688 data loading takes: 0.111 Memory usage: 57.53\n",
422 | "Epoch: 10 Step: 600 training acc: 0.765 time elapsed: 0.689 data loading takes: 0.160 Memory usage: 57.51\n",
423 | "Epoch: 10 Val acc: 0.752\n",
424 | "Epoch: 11 Step: 0 training acc: 0.671 time elapsed: 0.708 data loading takes: 0.240 Memory usage: 57.48\n",
425 | "Epoch: 11 Step: 200 training acc: 0.734 time elapsed: 0.726 data loading takes: 0.251 Memory usage: 57.54\n",
426 | "Epoch: 11 Step: 400 training acc: 0.726 time elapsed: 0.725 data loading takes: 0.280 Memory usage: 57.47\n",
427 | "Epoch: 11 Step: 600 training acc: 0.726 time elapsed: 0.757 data loading takes: 0.285 Memory usage: 57.47\n",
428 | "Epoch: 11 Val acc: 0.748\n",
429 | "Epoch: 12 Step: 0 training acc: 0.691 time elapsed: 0.702 data loading takes: 0.229 Memory usage: 57.46\n",
430 | "Epoch: 12 Step: 200 training acc: 0.765 time elapsed: 0.687 data loading takes: 0.184 Memory usage: 57.47\n",
431 | "Epoch: 12 Step: 400 training acc: 0.710 time elapsed: 0.711 data loading takes: 0.180 Memory usage: 57.56\n",
432 | "Epoch: 12 Step: 600 training acc: 0.722 time elapsed: 0.716 data loading takes: 0.271 Memory usage: 57.52\n",
433 | "Epoch: 12 Val acc: 0.721\n",
434 | "Epoch: 13 Step: 0 training acc: 0.777 time elapsed: 0.681 data loading takes: 0.166 Memory usage: 57.58\n",
435 | "Epoch: 13 Step: 200 training acc: 0.664 time elapsed: 0.765 data loading takes: 0.432 Memory usage: 57.44\n",
436 | "Epoch: 13 Step: 400 training acc: 0.753 time elapsed: 0.730 data loading takes: 0.226 Memory usage: 57.48\n",
437 | "Epoch: 13 Step: 600 training acc: 0.652 time elapsed: 0.713 data loading takes: 0.218 Memory usage: 57.46\n",
438 | "Epoch: 13 Val acc: 0.729\n",
439 | "Epoch: 14 Step: 0 training acc: 0.746 time elapsed: 0.661 data loading takes: 0.121 Memory usage: 57.49\n",
440 | "Epoch: 14 Step: 200 training acc: 0.664 time elapsed: 0.721 data loading takes: 0.315 Memory usage: 57.47\n",
441 | "Epoch: 14 Step: 400 training acc: 0.687 time elapsed: 0.706 data loading takes: 0.227 Memory usage: 57.48\n",
442 | "Epoch: 14 Step: 600 training acc: 0.730 time elapsed: 0.712 data loading takes: 0.239 Memory usage: 57.43\n",
443 | "Epoch: 14 Val acc: 0.724\n",
444 | "Epoch: 15 Step: 0 training acc: 0.679 time elapsed: 0.714 data loading takes: 0.282 Memory usage: 57.50\n",
445 | "Epoch: 15 Step: 200 training acc: 0.75 time elapsed: 0.695 data loading takes: 0.167 Memory usage: 57.58\n",
446 | "Epoch: 15 Step: 400 training acc: 0.75 time elapsed: 0.736 data loading takes: 0.296 Memory usage: 57.46\n",
447 | "Epoch: 15 Step: 600 training acc: 0.644 time elapsed: 0.725 data loading takes: 0.230 Memory usage: 57.53\n"
448 | ]
449 | },
450 | {
451 | "name": "stdout",
452 | "output_type": "stream",
453 | "text": [
454 | "Epoch: 15 Val acc: 0.721\n",
455 | "Test acc: 0.694\n",
456 | "Early Stopped Test acc: 0.723\n",
457 | "Total Time: 11569\n",
458 | "Max Momory: 68.59\n"
459 | ]
460 | }
461 | ],
462 | "source": [
463 | "!python G-Meta/train.py --data_dir DATA_PATH/G-Meta_Data/tree-of-life/ \\\n",
464 | " --epoch 15 \\\n",
465 | " --task_setup Shared \\\n",
466 | " --k_qry 16 \\\n",
467 | " --k_spt 16 \\\n",
468 | " --n_way 2 \\\n",
469 | " --update_lr 0.005 \\\n",
470 | " --update_step 10 \\\n",
471 | " --meta_lr 0.0005 \\\n",
472 | " --num_workers 0 \\\n",
473 | " --train_result_report_steps 200 \\\n",
474 | " --hidden_dim 256 \\\n",
475 | " --update_step_test 20 \\\n",
476 | " --task_num 8 \\\n",
477 | " --batchsz 5000 \\\n",
478 | " --link_pred_mod True"
479 | ]
480 | },
481 | {
482 | "cell_type": "code",
483 | "execution_count": null,
484 | "metadata": {},
485 | "outputs": [],
486 | "source": []
487 | }
488 | ],
489 | "metadata": {
490 | "kernelspec": {
491 | "display_name": "Python 3",
492 | "language": "python",
493 | "name": "python3"
494 | },
495 | "language_info": {
496 | "codemirror_mode": {
497 | "name": "ipython",
498 | "version": 3
499 | },
500 | "file_extension": ".py",
501 | "mimetype": "text/x-python",
502 | "name": "python",
503 | "nbconvert_exporter": "python",
504 | "pygments_lexer": "ipython3",
505 | "version": "3.7.4"
506 | }
507 | },
508 | "nbformat": 4,
509 | "nbformat_minor": 4
510 | }
511 |
--------------------------------------------------------------------------------