├── .gitignore ├── figs └── graph_meta_learning.png ├── requirements.txt ├── data_process ├── node_process.py └── link_process.py ├── G-Meta ├── learner.py ├── train.py ├── meta.py └── subgraph_data_processing.py ├── README.md └── test.ipynb /.gitignore: -------------------------------------------------------------------------------- 1 | .ipynb_checkpoints 2 | .DS_Store 3 | 4 | __pycache__/ 5 | 6 | data/ -------------------------------------------------------------------------------- /figs/graph_meta_learning.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mims-harvard/G-Meta/HEAD/figs/graph_meta_learning.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch == 1.5.0 2 | dgl == 0.4.3post2 3 | numpy 4 | networkx 5 | scipy 6 | tqdm 7 | scikit-learn 8 | pandas -------------------------------------------------------------------------------- /data_process/node_process.py: -------------------------------------------------------------------------------- 1 | import os 2 | import networkx as nx 3 | import numpy as np 4 | import pandas as pd 5 | import torch 6 | import pickle 7 | import dgl 8 | from tqdm import tqdm 9 | import json 10 | 11 | # this is an example of disjoint label multiple graphs. 12 | 13 | path = 'PATH' 14 | 15 | # assume you have a list of DGL graphs stored in the variable dgl_Gs 16 | dgl_Gs = [G1, G2, ...] 17 | # assume you have an array of features where [feat_1, feat_2, ...] and each feat_i corresponding to the graph i. 18 | feature_map = [feat1, feat2, ...] 19 | # assume you have an array of labels where [label_1, label_2, ...] and each label_i corresponding to the graph i. 20 | label_map = [label1, label2, ...] 21 | # number of unique labels, e.g. 30 22 | num_of_labels = 30 23 | # number of labels for each label set, ideally << num_of_labels so that each task can from different permutation of labels 24 | num_label_set = 5 25 | 26 | info = {} 27 | 28 | for idx, G in enumerate(dgl_Gs): 29 | # G is a dgl graph 30 | for j in range(len(label_map[idx])): 31 | info[str(idx) + '_' + str(j)] = label_map[idx][j] 32 | 33 | df = pd.DataFrame.from_dict(info, orient='index').reset_index().rename(columns={"index": "name", 0: "label"}) 34 | 35 | labels = np.unique(list(range(num_of_labels))) 36 | 37 | test_labels = np.random.choice(labels, num_label_set, False) 38 | labels_left = [i for i in labels if i not in test_labels] 39 | val_labels = np.random.choice(labels_left, num_label_set, False) 40 | train_labels = [i for i in labels_left if i not in val_labels] 41 | 42 | df[df.label.isin(train_labels)].reset_index(drop = True).to_csv(path + '/train.csv') 43 | df[df.label.isin(val_labels)].reset_index(drop = True).to_csv(path + '/val.csv') 44 | df[df.label.isin(test_labels)].reset_index(drop = True).to_csv(path + '/test.csv') 45 | 46 | with open(path + '/graph_dgl.pkl', 'wb') as f: 47 | pickle.dump(dgl_Gs, f) 48 | 49 | with open(path + '/label.pkl', 'wb') as f: 50 | pickle.dump(info, f) 51 | 52 | np.save(path + '/features.npy', np.array(feature_map)) 53 | 54 | 55 | # for shared labels, multiple graph setting, similarly, assume you have process the following variables: 56 | 57 | # assume you have a list of DGL graphs stored in the variable dgl_Gs 58 | dgl_Gs = [G1, G2, ...] 59 | # assume you have an array of features where [feat_1, feat_2, ...] and each feat_i corresponding to the graph i. 60 | feature_map = [feat1, feat2, ...] 61 | # assume you have an array of labels where [label_1, label_2, ...] and each label_i corresponding to the graph i. 62 | label_map = [label1, label2, ...] 63 | # number of unique labels, e.g. 5 64 | num_of_labels = 5 65 | 66 | info = {} 67 | for idx, G in enumerate(dgl_Gs): 68 | for i in tqdm(list(G.nodes)): 69 | info[str(idx) + '_' + str(i)] = labels_set[idx][i] 70 | 71 | np.save(path + '/features.npy', np.array(feature_map)) 72 | 73 | with open(path + '/graph_dgl.pkl', 'wb') as f: 74 | pickle.dump(dgl_Gs, f) 75 | 76 | with open(path + '/label.pkl', 'wb') as f: 77 | pickle.dump(info, f) 78 | 79 | df = pd.DataFrame.from_dict(info, orient='index').reset_index().rename(columns={"index": "name", 0: "label"}) 80 | 81 | # for example, specify the graph idx to be used for val, test set, other graphs are put in the meta-train 82 | folds = [[0, 23], [1, 22], [2, 21], [3, 20], [4, 19]] 83 | 84 | for fold_n, i in enumerate(folds): 85 | temp_path = path + '/fold' + str(fold_n+1) 86 | train_graphs = list(range(len(dgl_Gs))) 87 | train_graphs.remove(i[0]) 88 | train_graphs.remove(i[1]) 89 | val_graph = i[0] 90 | test_graph = i[1] 91 | 92 | val_df = df[df.name.str.contains(str(val_graph)+'_')] 93 | test_df = df[df.name.str.contains(str(test_graph)+'_')] 94 | 95 | train_df = df[~df.index.isin(val_df.index)] 96 | train_df = train_df[~train_df.index.isin(test_df.index)] 97 | train_df.reset_index(drop = True).to_csv(temp_path + '/train.csv') 98 | val_df.reset_index(drop = True).to_csv(temp_path + '/val.csv') 99 | test_df.reset_index(drop = True).to_csv(temp_path + '/test.csv') -------------------------------------------------------------------------------- /data_process/link_process.py: -------------------------------------------------------------------------------- 1 | import dgl 2 | from tqdm import tqdm 3 | import networkx as nx 4 | from itertools import combinations 5 | import numpy as np 6 | import random 7 | import pickle 8 | 9 | path = 'PATH' 10 | adjs = np.load(path + '/graphs_adj.npy', allow_pickle = True) 11 | # this .npy file is an array of 2D-array. [A1, A2, ..., An] where Ai is the adjacency matrix of graph i. 12 | 13 | training_edges_fraction = 0.3 14 | pos_test_edges = [] 15 | pos_val_edges = [] 16 | pos_train_edges = [] 17 | neg_test_edges = [] 18 | neg_train_edges = [] 19 | neg_val_edges = [] 20 | 21 | info = {} 22 | info_spt = {} 23 | info_qry = {} 24 | total_subgraph = {} 25 | center_nodes = {} 26 | 27 | G_all_graphs = [] 28 | 29 | for idx_ in tqdm(range(len(adjs))): 30 | G = nx.from_numpy_array(adjs[idx_]) 31 | 32 | adj_upp = np.multiply(adjs[idx_], np.triu(np.ones(adjs[idx_].shape))) 33 | x1, x2 = np.where(adj_upp == 1) 34 | edges = list(zip(x1, x2)) 35 | 36 | # training edges 37 | sampled = np.random.choice(list(range(len(edges))), int(len(edges)*training_edges_fraction), replace = False) 38 | 39 | pos_train_edges.append([str(idx_) + '_' + str(i[0]) + '_' + str(i[1]) for i in np.array(edges)[sampled]]) 40 | 41 | pos_test = [i for i in list(range(len(edges))) if i not in sampled] 42 | 43 | pos_test_edges.append([str(idx_) + '_' + str(i[0]) + '_' + str(i[1]) for i in np.array(edges)[pos_test]]) 44 | 45 | G_sample = dgl.DGLGraph() 46 | G_sample.add_nodes(len(G.nodes)) 47 | G_sample.add_edges(np.array(edges).T[0], np.array(edges).T[1]) 48 | num_pos = np.sum(adjs[idx_])/2 49 | 50 | sampled_frac = int(5*(sum(sum(adjs[idx_]))/len(G.nodes))) 51 | 52 | comb = [] 53 | for i in list(range(len(G.nodes))): 54 | l = list(range(len(G.nodes))) 55 | l.remove(i) 56 | comb = comb + (list(zip([i] * sampled_frac, random.choices(l, k = sampled_frac)))) 57 | 58 | random.shuffle(comb) 59 | comb_flipped = [(k,v) for v,k in comb] 60 | l = list(set(comb_flipped) & set(comb)) 61 | 62 | neg_edges_sampled = [i for i in comb if i not in l] 63 | 64 | neg_edges = list(set(neg_edges_sampled) - set(edges) - set([(k,v) for (v,k) in edges])) 65 | 66 | np.random.seed(10) 67 | idx_neg = np.random.choice(list(range(len(neg_edges))), len(edges), replace = False) 68 | neg_edges = np.array(neg_edges)[idx_neg] 69 | 70 | idx_neg_train = np.random.choice(list(range(len(neg_edges))), len(sampled), replace = False) 71 | 72 | neg_train_edges.append([str(idx_) + '_' + str(i[0]) + '_' + str(i[1]) for i in np.array(neg_edges)[idx_neg_train]]) 73 | 74 | neg_test = [i for i in list(range(len(neg_edges))) if i not in idx_neg_train] 75 | neg_test_edges.append([str(idx_) + '_' + str(i[0]) + '_' + str(i[1]) for i in np.array(neg_edges)[neg_test]]) 76 | 77 | train_edges_pos = np.array(edges)[sampled] 78 | test_edges_pos = np.array(edges)[pos_test] 79 | 80 | train_edges_neg = np.array(neg_edges)[idx_neg_train] 81 | test_edges_neg = np.array(neg_edges)[neg_test] 82 | 83 | for i in np.array(neg_edges): 84 | # negative injection, following SEAL 85 | G_sample.add_edge(i[0],i[1]) 86 | 87 | G_all_graphs.append(G_sample) 88 | 89 | for i in np.array(train_edges_pos): 90 | node1 = i[0] 91 | node2 = i[1] 92 | 93 | info[str(idx_) + '_' + str(node1) + '_' + str(node2)] = 1 94 | info_spt[str(idx_) + '_' + str(node1) + '_' + str(node2)] = 1 95 | 96 | for i in np.array(test_edges_pos): 97 | node1 = i[0] 98 | node2 = i[1] 99 | 100 | info[str(idx_) + '_' + str(node1) + '_' + str(node2)] = 1 101 | info_qry[str(idx_) + '_' + str(node1) + '_' + str(node2)] = 1 102 | 103 | for i in np.array(train_edges_neg): 104 | node1 = i[0] 105 | node2 = i[1] 106 | 107 | info[str(idx_) + '_' + str(node1) + '_' + str(node2)] = 0 108 | info_spt[str(idx_) + '_' + str(node1) + '_' + str(node2)] = 0 109 | 110 | for i in np.array(test_edges_neg): 111 | node1 = i[0] 112 | node2 = i[1] 113 | 114 | info[str(idx_) + '_' + str(node1) + '_' + str(node2)] = 0 115 | info_qry[str(idx_) + '_' + str(node1) + '_' + str(node2)] = 0 116 | 117 | with open(path + '/graph_dgl.pkl', 'wb') as f: 118 | pickle.dump(G_all_graphs, f) 119 | 120 | with open(path + '/label.pkl', 'wb') as f: 121 | pickle.dump(info, f) 122 | 123 | # split on graphs 124 | num_test_graphs = int(0.1 * len(G_all_graphs)) 125 | 126 | l = list(range(len(G_all_graphs))) 127 | test_graphs_idx = np.random.choice(l, num_test_graphs, replace = False).tolist() 128 | 129 | l = [i for i in l if i not in test_graphs_idx] 130 | val_graphs_idx = np.random.choice(l, num_test_graphs, replace = False).tolist() 131 | 132 | fold = [test_graphs_idx, val_graphs_idx] 133 | 134 | df_spt = pd.DataFrame.from_dict(info_spt, orient='index').reset_index().rename(columns={"index": "name", 0: "label"}) 135 | df_qry = pd.DataFrame.from_dict(info_qry, orient='index').reset_index().rename(columns={"index": "name", 0: "label"}) 136 | df = pd.DataFrame.from_dict(info, orient='index').reset_index().rename(columns={"index": "name", 0: "label"}) 137 | 138 | i = fold 139 | 140 | temp_path = path 141 | train_graphs = list(range(len(G_all_graphs))) 142 | 143 | train_graphs = [j for j in train_graphs if j not in i[0] + i[1]] 144 | val_graph = i[1] 145 | test_graph = i[0] 146 | 147 | train_spt = pd.DataFrame() 148 | val_spt = pd.DataFrame() 149 | test_spt = pd.DataFrame() 150 | 151 | train_qry = pd.DataFrame() 152 | val_qry = pd.DataFrame() 153 | test_qry = pd.DataFrame() 154 | 155 | train = pd.DataFrame() 156 | val = pd.DataFrame() 157 | test = pd.DataFrame() 158 | 159 | for graph_id in range(len(val_graph)): 160 | 161 | val_df = df_spt[df_spt.name.str.contains('^' + str(val_graph[graph_id])+'_')] 162 | test_df = df_spt[df_spt.name.str.contains('^' + str(test_graph[graph_id])+'_')] 163 | 164 | val_spt = val_spt.append(val_df) 165 | test_spt = test_spt.append(test_df) 166 | 167 | val_df = df_qry[df_qry.name.str.contains('^' + str(val_graph[graph_id])+'_')] 168 | test_df = df_qry[df_qry.name.str.contains('^' + str(test_graph[graph_id])+'_')] 169 | 170 | val_qry = val_qry.append(val_df) 171 | test_qry = test_qry.append(test_df) 172 | 173 | val_df = df[df.name.str.contains('^' + str(val_graph[graph_id])+'_')] 174 | test_df = df[df.name.str.contains('^' + str(test_graph[graph_id])+'_')] 175 | 176 | val = val.append(val_df) 177 | test = test.append(test_df) 178 | 179 | val_spt.reset_index(drop = True).to_csv(temp_path + '/val_spt.csv') 180 | test_spt.reset_index(drop = True).to_csv(temp_path + '/test_spt.csv') 181 | 182 | val_qry.reset_index(drop = True).to_csv(temp_path + '/val_qry.csv') 183 | test_qry.reset_index(drop = True).to_csv(temp_path + '/test_qry.csv') 184 | 185 | val.reset_index(drop = True).to_csv(temp_path + '/val.csv') 186 | test.reset_index(drop = True).to_csv(temp_path + '/test.csv') 187 | 188 | train_df = df_spt[~df_spt.index.isin(val_spt.index)] 189 | train_df = train_df[~train_df.index.isin(test_spt.index)] 190 | train_df.reset_index(drop = True).to_csv(temp_path + '/train_spt.csv') 191 | 192 | train_df = df_qry[~df_qry.index.isin(val_qry.index)] 193 | train_df = train_df[~train_df.index.isin(test_qry.index)] 194 | train_df.reset_index(drop = True).to_csv(temp_path + '/train_qry.csv') 195 | 196 | train_df = df[~df.index.isin(val.index)] 197 | train_df = train_df[~train_df.index.isin(test.index)] 198 | train_df.reset_index(drop = True).to_csv(temp_path + '/train.csv') 199 | -------------------------------------------------------------------------------- /G-Meta/learner.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import dgl.function as fn 4 | import torch.nn as nn 5 | from torch.nn import init 6 | import dgl 7 | 8 | # Sends a message of node feature h. 9 | msg = fn.copy_src(src='h', out='m') 10 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 11 | 12 | # copied and editted from DGL Source 13 | class GraphConv(nn.Module): 14 | def __init__(self, 15 | in_feats, 16 | out_feats, 17 | activation=None): 18 | super(GraphConv, self).__init__() 19 | self._in_feats = in_feats 20 | self._out_feats = out_feats 21 | self._norm = True 22 | self._activation = activation 23 | 24 | 25 | def forward(self, graph, feat, weight, bias): 26 | 27 | graph = graph.local_var() 28 | if self._norm: 29 | norm = torch.pow(graph.in_degrees().float().clamp(min=1), -0.5) 30 | shp = norm.shape + (1,) * (feat.dim() - 1) 31 | norm = torch.reshape(norm, shp).to(feat.device) 32 | feat = feat * norm 33 | 34 | if self._in_feats > self._out_feats: 35 | # mult W first to reduce the feature size for aggregation. 36 | feat = torch.matmul(feat, weight) 37 | graph.ndata['h'] = feat 38 | graph.update_all(fn.copy_src(src='h', out='m'), 39 | fn.sum(msg='m', out='h')) 40 | rst = graph.ndata['h'] 41 | else: 42 | # aggregate first then mult W 43 | graph.ndata['h'] = feat 44 | graph.update_all(fn.copy_src(src='h', out='m'), 45 | fn.sum(msg='m', out='h')) 46 | rst = graph.ndata['h'] 47 | rst = torch.matmul(rst, weight) 48 | 49 | rst = rst * norm 50 | 51 | rst = rst + bias 52 | 53 | if self._activation is not None: 54 | rst = self._activation(rst) 55 | 56 | return rst 57 | 58 | def extra_repr(self): 59 | """Set the extra representation of the module, 60 | which will come into effect when printing the model. 61 | """ 62 | summary = 'in={_in_feats}, out={_out_feats}' 63 | summary += ', normalization={_norm}' 64 | if '_activation' in self.__dict__: 65 | summary += ', activation={_activation}' 66 | return summary.format(**self.__dict__) 67 | 68 | 69 | class Classifier(nn.Module): 70 | def __init__(self, config): 71 | super(Classifier, self).__init__() 72 | 73 | self.vars = nn.ParameterList() 74 | self.graph_conv = [] 75 | self.config = config 76 | self.LinkPred_mode = False 77 | 78 | if self.config[-1][0] == 'LinkPred': 79 | self.LinkPred_mode = True 80 | 81 | for i, (name, param) in enumerate(self.config): 82 | 83 | if name is 'Linear': 84 | if self.LinkPred_mode: 85 | w = nn.Parameter(torch.ones(param[1], param[0] * 2)) 86 | else: 87 | w = nn.Parameter(torch.ones(param[1], param[0])) 88 | init.kaiming_normal_(w) 89 | self.vars.append(w) 90 | self.vars.append(nn.Parameter(torch.zeros(param[1]))) 91 | if name is 'GraphConv': 92 | # param: in_dim, hidden_dim 93 | w = nn.Parameter(torch.Tensor(param[0], param[1])) 94 | init.xavier_uniform_(w) 95 | self.vars.append(w) 96 | self.vars.append(nn.Parameter(torch.zeros(param[1]))) 97 | self.graph_conv.append(GraphConv(param[0], param[1], activation = F.relu)) 98 | if name is 'Attention': 99 | # param[0] hidden size 100 | # param[1] attention_head_size 101 | # param[2] hidden_dim for classifier 102 | # param[3] n_ways 103 | # param[4] number of graphlets 104 | if self.LinkPred_mode: 105 | w_q = nn.Parameter(torch.ones(param[1], param[0] * 2)) 106 | else: 107 | w_q = nn.Parameter(torch.ones(param[1], param[0])) 108 | w_k = nn.Parameter(torch.ones(param[1], param[0])) 109 | w_v = nn.Parameter(torch.ones(param[1], param[4])) 110 | 111 | if self.LinkPred_mode: 112 | w_l = nn.Parameter(torch.ones(param[3], param[2] * 2 + param[1])) 113 | else: 114 | w_l = nn.Parameter(torch.ones(param[3], param[2] + param[1])) 115 | 116 | init.kaiming_normal_(w_q) 117 | init.kaiming_normal_(w_k) 118 | init.kaiming_normal_(w_v) 119 | init.kaiming_normal_(w_l) 120 | 121 | self.vars.append(w_q) 122 | self.vars.append(w_k) 123 | self.vars.append(w_v) 124 | self.vars.append(w_l) 125 | 126 | #bias for attentions 127 | self.vars.append(nn.Parameter(torch.zeros(param[1]))) 128 | self.vars.append(nn.Parameter(torch.zeros(param[1]))) 129 | self.vars.append(nn.Parameter(torch.zeros(param[1]))) 130 | #bias for classifier 131 | self.vars.append(nn.Parameter(torch.zeros(param[3]))) 132 | 133 | 134 | def forward(self, g, to_fetch, features, vars = None): 135 | # For undirected graphs, in_degree is the same as 136 | # out_degree. 137 | 138 | if vars is None: 139 | vars = self.vars 140 | 141 | idx = 0 142 | idx_gcn = 0 143 | 144 | h = features.float() 145 | h = h.to(device) 146 | 147 | for name, param in self.config: 148 | if name is 'GraphConv': 149 | w, b = vars[idx], vars[idx + 1] 150 | conv = self.graph_conv[idx_gcn] 151 | 152 | h = conv(g, h, w, b) 153 | 154 | g.ndata['h'] = h 155 | 156 | idx += 2 157 | idx_gcn += 1 158 | 159 | if idx_gcn == len(self.graph_conv): 160 | #h = dgl.mean_nodes(g, 'h') 161 | num_nodes_ = g.batch_num_nodes 162 | temp = [0] + num_nodes_ 163 | offset = torch.cumsum(torch.LongTensor(temp), dim = 0)[:-1].to(device) 164 | 165 | if self.LinkPred_mode: 166 | h1 = h[to_fetch[:,0] + offset] 167 | h2 = h[to_fetch[:,1] + offset] 168 | h = torch.cat((h1, h2), 1) 169 | else: 170 | h = h[to_fetch + offset] 171 | 172 | if name is 'Linear': 173 | w, b = vars[idx], vars[idx + 1] 174 | h = F.linear(h, w, b) 175 | idx += 2 176 | 177 | if name is 'Attention': 178 | w_q, w_k, w_v, w_l = vars[idx], vars[idx + 1], vars[idx + 2], vars[idx + 3] 179 | b_q, b_k, b_v, b_l = vars[idx + 4], vars[idx + 5], vars[idx + 6], vars[idx + 7] 180 | 181 | Q = F.linear(h, w_q, b_q) 182 | K = F.linear(h_graphlets, w_k, b_k) 183 | 184 | attention_scores = torch.matmul(Q, K.T) 185 | attention_probs = nn.Softmax(dim=-1)(attention_scores) 186 | context = F.linear(attention_probs, w_v, b_v) 187 | 188 | # classify layer, first concatenate the context vector 189 | # with the hidden dim of center nodes 190 | h = torch.cat((context, h), 1) 191 | h = F.linear(h, w_l, b_l) 192 | idx += 8 193 | 194 | return h, h 195 | 196 | def zero_grad(self, vars=None): 197 | 198 | with torch.no_grad(): 199 | if vars is None: 200 | for p in self.vars: 201 | if p.grad is not None: 202 | p.grad.zero_() 203 | else: 204 | for p in vars: 205 | if p.grad is not None: 206 | p.grad.zero_() 207 | 208 | def parameters(self): 209 | return self.vars -------------------------------------------------------------------------------- /G-Meta/train.py: -------------------------------------------------------------------------------- 1 | import torch, os 2 | import numpy as np 3 | from subgraph_data_processing import Subgraphs 4 | import scipy.stats 5 | from torch.utils.data import DataLoader 6 | from torch.optim import lr_scheduler 7 | import random, sys, pickle 8 | import argparse 9 | 10 | import networkx as nx 11 | import numpy as np 12 | from scipy.special import comb 13 | from itertools import combinations 14 | import networkx.algorithms.isomorphism as iso 15 | from tqdm import tqdm 16 | import dgl 17 | 18 | from meta import Meta 19 | import time 20 | import copy 21 | import psutil 22 | from memory_profiler import memory_usage 23 | 24 | os.environ['KMP_DUPLICATE_LIB_OK']='True' 25 | 26 | def collate(samples): 27 | graphs_spt, labels_spt, graph_qry, labels_qry, center_spt, center_qry, nodeidx_spt, nodeidx_qry, support_graph_idx, query_graph_idx = map(list, zip(*samples)) 28 | 29 | return graphs_spt, labels_spt, graph_qry, labels_qry, center_spt, center_qry, nodeidx_spt, nodeidx_qry, support_graph_idx, query_graph_idx 30 | 31 | def main(): 32 | mem_usage = memory_usage(-1, interval=.5, timeout=1) 33 | torch.manual_seed(222) 34 | torch.cuda.manual_seed_all(222) 35 | np.random.seed(222) 36 | 37 | print(args) 38 | 39 | root = args.data_dir 40 | 41 | feat = np.load(root + 'features.npy', allow_pickle = True) 42 | 43 | with open(root + '/graph_dgl.pkl', 'rb') as f: 44 | dgl_graph = pickle.load(f) 45 | 46 | if args.task_setup == 'Disjoint': 47 | with open(root + 'label.pkl', 'rb') as f: 48 | info = pickle.load(f) 49 | elif args.task_setup == 'Shared': 50 | if args.task_mode == 'True': 51 | root = root + '/task' + str(args.task_n) + '/' 52 | with open(root + 'label.pkl', 'rb') as f: 53 | info = pickle.load(f) 54 | 55 | total_class = len(np.unique(np.array(list(info.values())))) 56 | print('There are {} classes '.format(total_class)) 57 | 58 | if args.task_setup == 'Disjoint': 59 | labels_num = args.n_way 60 | elif args.task_setup == 'Shared': 61 | labels_num = total_class 62 | 63 | if len(feat.shape) == 2: 64 | # single graph, to make it compatible to multiple graph retrieval. 65 | feat = [feat] 66 | 67 | config = [('GraphConv', [feat[0].shape[1], args.hidden_dim])] 68 | 69 | if args.h > 1: 70 | config = config + [('GraphConv', [args.hidden_dim, args.hidden_dim])] * (args.h - 1) 71 | 72 | config = config + [('Linear', [args.hidden_dim, labels_num])] 73 | 74 | if args.link_pred_mode == 'True': 75 | config.append(('LinkPred', [True])) 76 | 77 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 78 | 79 | maml = Meta(args, config).to(device) 80 | 81 | tmp = filter(lambda x: x.requires_grad, maml.parameters()) 82 | num = sum(map(lambda x: np.prod(x.shape), tmp)) 83 | print(maml) 84 | print('Total trainable tensors:', num) 85 | 86 | max_acc = 0 87 | model_max = copy.deepcopy(maml) 88 | 89 | db_train = Subgraphs(root, 'train', info, n_way=args.n_way, k_shot=args.k_spt, k_query=args.k_qry, batchsz=args.batchsz, args = args, adjs = dgl_graph, h = args.h) 90 | db_val = Subgraphs(root, 'val', info, n_way=args.n_way, k_shot=args.k_spt,k_query=args.k_qry, batchsz=100, args = args, adjs = dgl_graph, h = args.h) 91 | db_test = Subgraphs(root, 'test', info, n_way=args.n_way, k_shot=args.k_spt,k_query=args.k_qry, batchsz=100, args = args, adjs = dgl_graph, h = args.h) 92 | print('------ Start Training ------') 93 | s_start = time.time() 94 | max_memory = 0 95 | for epoch in range(args.epoch): 96 | db = DataLoader(db_train, args.task_num, shuffle=True, num_workers=args.num_workers, pin_memory=True, collate_fn = collate) 97 | s_f = time.time() 98 | for step, (x_spt, y_spt, x_qry, y_qry, c_spt, c_qry, n_spt, n_qry, g_spt, g_qry) in enumerate(db): 99 | nodes_len = 0 100 | if step >= 1: 101 | data_loading_time = time.time() - s_r 102 | else: 103 | data_loading_time = time.time() - s_f 104 | s = time.time() 105 | # x_spt: a list of #task_num tasks, where each task is a mini-batch of k-shot * n_way subgraphs 106 | # y_spt: a list of #task_num lists of labels. Each list is of length k-shot * n_way int. 107 | nodes_len += sum([sum([len(j) for j in i]) for i in n_spt]) 108 | accs = maml(x_spt, y_spt, x_qry, y_qry, c_spt, c_qry, n_spt, n_qry, g_spt, g_qry, feat) 109 | max_memory = max(max_memory, float(psutil.virtual_memory().used/(1024**3))) 110 | if step % args.train_result_report_steps == 0: 111 | print('Epoch:', epoch + 1, ' Step:', step, ' training acc:', str(accs[-1])[:5], ' time elapsed:', str(time.time() - s)[:5], ' data loading takes:', str(data_loading_time)[:5], ' Memory usage:', str(float(psutil.virtual_memory().used/(1024**3)))[:5]) 112 | s_r = time.time() 113 | 114 | # validation per epoch 115 | db_v = DataLoader(db_val, 1, shuffle=True, num_workers=args.num_workers, pin_memory=True, collate_fn = collate) 116 | accs_all_test = [] 117 | 118 | for x_spt, y_spt, x_qry, y_qry, c_spt, c_qry, n_spt, n_qry, g_spt, g_qry in db_v: 119 | 120 | accs = maml.finetunning(x_spt, y_spt, x_qry, y_qry, c_spt, c_qry, n_spt, n_qry, g_spt, g_qry, feat) 121 | accs_all_test.append(accs) 122 | 123 | accs = np.array(accs_all_test).mean(axis=0).astype(np.float16) 124 | print('Epoch:', epoch + 1, ' Val acc:', str(accs[-1])[:5]) 125 | if accs[-1] > max_acc: 126 | max_acc = accs[-1] 127 | model_max = copy.deepcopy(maml) 128 | 129 | db_t = DataLoader(db_test, 1, shuffle=True, num_workers=args.num_workers, pin_memory=True, collate_fn = collate) 130 | accs_all_test = [] 131 | 132 | for x_spt, y_spt, x_qry, y_qry, c_spt, c_qry, n_spt, n_qry, g_spt, g_qry in db_t: 133 | accs = maml.finetunning(x_spt, y_spt, x_qry, y_qry, c_spt, c_qry, n_spt, n_qry, g_spt, g_qry, feat) 134 | accs_all_test.append(accs) 135 | 136 | accs = np.array(accs_all_test).mean(axis=0).astype(np.float16) 137 | print('Test acc:', str(accs[1])[:5]) 138 | 139 | for x_spt, y_spt, x_qry, y_qry, c_spt, c_qry, n_spt, n_qry, g_spt, g_qry in db_t: 140 | accs = model_max.finetunning(x_spt, y_spt, x_qry, y_qry, c_spt, c_qry, n_spt, n_qry, g_spt, g_qry, feat) 141 | accs_all_test.append(accs) 142 | 143 | #torch.save(model_max.state_dict(), './model.pt') 144 | 145 | accs = np.array(accs_all_test).mean(axis=0).astype(np.float16) 146 | print('Early Stopped Test acc:', str(accs[-1])[:5]) 147 | print('Total Time:', str(time.time() - s_start)[:5]) 148 | print('Max Momory:', str(max_memory)[:5]) 149 | 150 | if __name__ == '__main__': 151 | 152 | argparser = argparse.ArgumentParser() 153 | argparser.add_argument('--epoch', type=int, help='epoch number', default=10) 154 | argparser.add_argument('--n_way', type=int, help='n way', default=3) 155 | argparser.add_argument('--k_spt', type=int, help='k shot for support set', default=3) 156 | argparser.add_argument('--k_qry', type=int, help='k shot for query set', default=24) 157 | argparser.add_argument('--task_num', type=int, help='meta batch size, namely task num', default=8) 158 | argparser.add_argument('--meta_lr', type=float, help='meta-level outer learning rate', default=1e-3) 159 | argparser.add_argument('--update_lr', type=float, help='task-level inner update learning rate', default=1e-3) 160 | argparser.add_argument('--update_step', type=int, help='task-level inner update steps', default=5) 161 | argparser.add_argument('--update_step_test', type=int, help='update steps for finetunning', default=10) 162 | argparser.add_argument('--input_dim', type=int, help='input feature dim', default=1) 163 | argparser.add_argument('--hidden_dim', type=int, help='hidden dim', default=64) 164 | argparser.add_argument('--attention_size', type=int, help='dim of attention_size', default=32) 165 | argparser.add_argument("--data_dir", default=None, type=str, required=True, help="The input data dir.") 166 | argparser.add_argument("--no_finetune", default=True, type=str, required=False, help="no finetune mode.") 167 | argparser.add_argument("--task_setup", default='Disjoint', type=str, required=True, help="Select from Disjoint or Shared Setup. For Disjoint-Label, single/multiple graphs are both considered.") 168 | argparser.add_argument("--method", default='G-Meta', type=str, required=False, help="Use G-Meta") 169 | argparser.add_argument('--task_n', type=int, help='task number', default=1) 170 | argparser.add_argument("--task_mode", default='False', type=str, required=False, help="For Evaluating on Tasks") 171 | argparser.add_argument("--val_result_report_steps", default=100, type=int, required=False, help="validation report") 172 | argparser.add_argument("--train_result_report_steps", default=30, type=int, required=False, help="training report") 173 | argparser.add_argument("--num_workers", default=0, type=int, required=False, help="num of workers") 174 | argparser.add_argument("--batchsz", default=1000, type=int, required=False, help="batch size") 175 | argparser.add_argument("--link_pred_mode", default='False', type=str, required=False, help="For Link Prediction") 176 | argparser.add_argument("--h", default=2, type=int, required=False, help="neighborhood size") 177 | argparser.add_argument('--sample_nodes', type=int, help='sample nodes if above this number of nodes', default=1000) 178 | 179 | args = argparser.parse_args() 180 | 181 | main() 182 | -------------------------------------------------------------------------------- /G-Meta/meta.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch import optim 4 | from torch.nn import functional as F 5 | from torch.utils.data import TensorDataset, DataLoader 6 | from torch import optim 7 | import numpy as np 8 | 9 | from learner import Classifier 10 | from copy import deepcopy 11 | 12 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 13 | 14 | def euclidean_dist(x, y): 15 | # x: N x D 16 | # y: M x D 17 | n = x.size(0) 18 | m = y.size(0) 19 | d = x.size(1) 20 | if d != y.size(1): 21 | raise Exception 22 | 23 | x = x.unsqueeze(1).expand(n, m, d) 24 | y = y.unsqueeze(0).expand(n, m, d) 25 | 26 | return torch.pow(x - y, 2).sum(2) 27 | 28 | def proto_loss_spt(logits, y_t, n_support): 29 | target_cpu = y_t.to('cpu') 30 | input_cpu = logits.to('cpu') 31 | 32 | def supp_idxs(c): 33 | return target_cpu.eq(c).nonzero()[:n_support].squeeze(1) 34 | 35 | classes = torch.unique(target_cpu) 36 | n_classes = len(classes) 37 | n_query = n_support 38 | 39 | support_idxs = list(map(supp_idxs, classes)) 40 | 41 | prototypes = torch.stack([input_cpu[idx_list].mean(0) for idx_list in support_idxs]) 42 | query_idxs = torch.stack(list(map(lambda c: target_cpu.eq(c).nonzero()[:n_support], classes))).view(-1) 43 | query_samples = input_cpu[query_idxs] 44 | dists = euclidean_dist(query_samples, prototypes) 45 | log_p_y = F.log_softmax(-dists, dim=1).view(n_classes, n_query, -1) 46 | 47 | target_inds = torch.arange(0, n_classes) 48 | target_inds = target_inds.view(n_classes, 1, 1) 49 | target_inds = target_inds.expand(n_classes, n_query, 1).long() 50 | 51 | loss_val = -log_p_y.gather(2, target_inds).squeeze().view(-1).mean() 52 | _, y_hat = log_p_y.max(2) 53 | acc_val = y_hat.eq(target_inds.squeeze()).float().mean() 54 | return loss_val, acc_val, prototypes 55 | 56 | def proto_loss_qry(logits, y_t, prototypes): 57 | target_cpu = y_t.to('cpu') 58 | input_cpu = logits.to('cpu') 59 | 60 | classes = torch.unique(target_cpu) 61 | n_classes = len(classes) 62 | 63 | n_query = int(logits.shape[0]/n_classes) 64 | 65 | query_idxs = torch.stack(list(map(lambda c: target_cpu.eq(c).nonzero(), classes))).view(-1) 66 | query_samples = input_cpu[query_idxs] 67 | 68 | dists = euclidean_dist(query_samples, prototypes) 69 | 70 | log_p_y = F.log_softmax(-dists, dim=1).view(n_classes, n_query, -1) 71 | 72 | target_inds = torch.arange(0, n_classes) 73 | target_inds = target_inds.view(n_classes, 1, 1) 74 | target_inds = target_inds.expand(n_classes, n_query, 1).long() 75 | 76 | loss_val = -log_p_y.gather(2, target_inds).squeeze().view(-1).mean() 77 | _, y_hat = log_p_y.max(2) 78 | acc_val = y_hat.eq(target_inds.squeeze()).float().mean() 79 | return loss_val, acc_val 80 | 81 | 82 | class Meta(nn.Module): 83 | def __init__(self, args, config): 84 | super(Meta, self).__init__() 85 | self.update_lr = args.update_lr 86 | self.meta_lr = args.meta_lr 87 | self.n_way = args.n_way 88 | self.k_spt = args.k_spt 89 | self.k_qry = args.k_qry 90 | self.task_num = args.task_num 91 | self.update_step = args.update_step 92 | self.update_step_test = args.update_step_test 93 | 94 | self.net = Classifier(config) 95 | self.net = self.net.to(device) 96 | 97 | self.meta_optim = optim.Adam(self.net.parameters(), lr=self.meta_lr) 98 | 99 | self.method = args.method 100 | 101 | def forward_ProtoMAML(self, x_spt, y_spt, x_qry, y_qry, c_spt, c_qry, n_spt, n_qry, g_spt, g_qry,feat): 102 | """ 103 | b: number of tasks 104 | setsz: the size for each task 105 | 106 | :param x_spt: [b], where each unit is a mini-batch of subgraphs, i.e. x_spt[0] is a DGL batch of # setsz subgraphs 107 | :param y_spt: [b, setsz] 108 | :param x_qry: [b], where each unit is a mini-batch of subgraphs, i.e. x_spt[0] is a DGL batch of # setsz subgraphs 109 | :param y_qry: [b, querysz] 110 | :return: 111 | """ 112 | task_num = len(x_spt) 113 | querysz = len(y_qry[0]) 114 | losses_s = [0 for _ in range(self.update_step)] 115 | losses_q = [0 for _ in range(self.update_step + 1)] # losses_q[i] is the loss on step i 116 | corrects = [0 for _ in range(self.update_step + 1)] 117 | 118 | for i in range(task_num): 119 | feat_spt = torch.Tensor(np.vstack(([feat[g_spt[i][j]][np.array(x)] for j, x in enumerate(n_spt[i])]))).to(device) 120 | feat_qry = torch.Tensor(np.vstack(([feat[g_qry[i][j]][np.array(x)] for j, x in enumerate(n_qry[i])]))).to(device) 121 | # 1. run the i-th task and compute loss for k=0 122 | logits, _ = self.net(x_spt[i].to(device), c_spt[i].to(device), feat_spt, vars=None) 123 | loss, _, prototypes = proto_loss_spt(logits, y_spt[i], self.k_spt) 124 | losses_s[0] += loss 125 | grad = torch.autograd.grad(loss, self.net.parameters()) 126 | fast_weights = list(map(lambda p: p[1] - self.update_lr * p[0], zip(grad, self.net.parameters()))) 127 | 128 | # this is the loss and accuracy before first update 129 | with torch.no_grad(): 130 | # [setsz, nway] 131 | logits_q, _ = self.net(x_qry[i].to(device), c_qry[i].to(device), feat_qry, self.net.parameters()) 132 | loss_q, acc_q = proto_loss_qry(logits_q, y_qry[i], prototypes) 133 | losses_q[0] += loss_q 134 | corrects[0] = corrects[0] + acc_q 135 | 136 | # this is the loss and accuracy after the first update 137 | with torch.no_grad(): 138 | logits_q, _ = self.net(x_qry[i].to(device), c_qry[i].to(device), feat_qry, fast_weights) 139 | loss_q, acc_q = proto_loss_qry(logits_q, y_qry[i], prototypes) 140 | losses_q[1] += loss_q 141 | corrects[1] = corrects[1] + acc_q 142 | 143 | for k in range(1, self.update_step): 144 | # 1. run the i-th task and compute loss for k=1~K-1 145 | logits, _ = self.net(x_spt[i].to(device), c_spt[i].to(device), feat_spt, fast_weights) 146 | loss, _, prototypes = proto_loss_spt(logits, y_spt[i], self.k_spt) 147 | losses_s[k] += loss 148 | # 2. compute grad on theta_pi 149 | grad = torch.autograd.grad(loss, fast_weights, retain_graph=True) 150 | # 3. theta_pi = theta_pi - train_lr * grad 151 | fast_weights = list(map(lambda p: p[1] - self.update_lr * p[0], zip(grad, fast_weights))) 152 | logits_q, _ = self.net(x_qry[i].to(device), c_qry[i].to(device), feat_qry, fast_weights) 153 | # loss_q will be overwritten and just keep the loss_q on last update step. 154 | loss_q, acc_q = proto_loss_qry(logits_q, y_qry[i], prototypes) 155 | losses_q[k + 1] += loss_q 156 | 157 | corrects[k + 1] = corrects[k + 1] + acc_q 158 | 159 | # end of all tasks 160 | # sum over all losses on query set across all tasks 161 | loss_q = losses_q[-1] / task_num 162 | 163 | if torch.isnan(loss_q): 164 | pass 165 | else: 166 | # optimize theta parameters 167 | self.meta_optim.zero_grad() 168 | loss_q.backward() 169 | self.meta_optim.step() 170 | 171 | accs = np.array(corrects) / (task_num) 172 | 173 | return accs 174 | 175 | def finetunning_ProtoMAML(self, x_spt, y_spt, x_qry, y_qry, c_spt, c_qry, n_spt, n_qry, g_spt, g_qry, feat): 176 | querysz = len(y_qry[0]) 177 | 178 | corrects = [0 for _ in range(self.update_step_test + 1)] 179 | 180 | # finetunning on the copied model instead of self.net 181 | net = deepcopy(self.net) 182 | x_spt = x_spt[0] 183 | y_spt = y_spt[0] 184 | x_qry = x_qry[0] 185 | y_qry = y_qry[0] 186 | c_spt = c_spt[0] 187 | c_qry = c_qry[0] 188 | n_spt = n_spt[0] 189 | n_qry = n_qry[0] 190 | g_spt = g_spt[0] 191 | g_qry = g_qry[0] 192 | 193 | feat_spt = torch.Tensor(np.vstack(([feat[g_spt[j]][np.array(x)] for j, x in enumerate(n_spt)]))).to(device) 194 | feat_qry = torch.Tensor(np.vstack(([feat[g_qry[j]][np.array(x)] for j, x in enumerate(n_qry)]))).to(device) 195 | 196 | 197 | # 1. run the i-th task and compute loss for k=0 198 | logits, _ = net(x_spt.to(device), c_spt.to(device), feat_spt) 199 | loss, _, prototypes = proto_loss_spt(logits, y_spt, self.k_spt) 200 | grad = torch.autograd.grad(loss, net.parameters()) 201 | fast_weights = list(map(lambda p: p[1] - self.update_lr * p[0], zip(grad, net.parameters()))) 202 | 203 | # this is the loss and accuracy before first update 204 | with torch.no_grad(): 205 | # [setsz, nway] 206 | logits_q, _ = net(x_qry.to(device), c_qry.to(device), feat_qry, net.parameters()) 207 | loss_q, acc_q = proto_loss_qry(logits_q, y_qry, prototypes) 208 | corrects[0] = corrects[0] + acc_q 209 | # this is the loss and accuracy after the first update 210 | with torch.no_grad(): 211 | # [setsz, nway] 212 | logits_q, _ = net(x_qry.to(device), c_qry.to(device), feat_qry, fast_weights) 213 | loss_q, acc_q = proto_loss_qry(logits_q, y_qry, prototypes) 214 | corrects[1] = corrects[1] + acc_q 215 | 216 | 217 | for k in range(1, self.update_step_test): 218 | # 1. run the i-th task and compute loss for k=1~K-1 219 | logits, _ = net(x_spt.to(device), c_spt.to(device), feat_spt, fast_weights) 220 | loss, _, prototypes = proto_loss_spt(logits, y_spt, self.k_spt) 221 | # 2. compute grad on theta_pi 222 | grad = torch.autograd.grad(loss, fast_weights, retain_graph=True) 223 | # 3. theta_pi = theta_pi - train_lr * grad 224 | fast_weights = list(map(lambda p: p[1] - self.update_lr * p[0], zip(grad, fast_weights))) 225 | 226 | logits_q, _ = net(x_qry.to(device), c_qry.to(device), feat_qry, fast_weights) 227 | # loss_q will be overwritten and just keep the loss_q on last update step. 228 | loss_q, acc_q = proto_loss_qry(logits_q, y_qry, prototypes) 229 | corrects[k + 1] = corrects[k + 1] + acc_q 230 | 231 | del net 232 | accs = np.array(corrects) 233 | 234 | return accs 235 | 236 | def forward(self, x_spt, y_spt, x_qry, y_qry, c_spt, c_qry, n_spt, n_qry, g_spt, g_qry,feat): 237 | if self.method == 'G-Meta': 238 | accs = self.forward_ProtoMAML(x_spt, y_spt, x_qry, y_qry, c_spt, c_qry, n_spt, n_qry, g_spt, g_qry, feat) 239 | return accs 240 | 241 | def finetunning(self, x_spt, y_spt, x_qry, y_qry, c_spt, c_qry, n_spt, n_qry, g_spt, g_qry,feat): 242 | if self.method == 'G-Meta': 243 | accs = self.finetunning_ProtoMAML(x_spt, y_spt, x_qry, y_qry, c_spt, c_qry, n_spt, n_qry, g_spt, g_qry, feat) 244 | return accs 245 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # G-Meta: Graph Meta Learning via Local Subgraphs 2 | 3 | #### Authors: [Kexin Huang](https://www.kexinhuang.com), [Marinka Zitnik](https://zitniklab.hms.harvard.edu) 4 | 5 | #### [Project Website](https://zitniklab.hms.harvard.edu/projects/G-Meta) 6 | 7 | Prevailing methods for graphs require abundant label and edge information for learning. When data for a new task are scarce, meta learning can learn from prior experiences and form much-needed inductive biases for fast adaption to new tasks. 8 | 9 | Here, we introduce G-Meta, a novel meta-learning algorithm for graphs. 10 | G-Meta uses local subgraphs to transfer subgraph-specific information and learn transferable knowledge faster via meta gradients. G-Meta learns how to quickly adapt to a new task using only a handful of nodes or edges in the new task and does so by learning from data points in other graphs or related, albeit disjoint label sets. G-Meta is theoretically justified as we show that the evidence for a prediction can be found in the local subgraph surrounding the target node or edge. 11 | 12 | Experiments on seven datasets and nine baseline methods show that G-Meta outperforms existing methods by up to 16.3%. Unlike previous methods, G-Meta successfully learns in challenging, few-shot learning settings that require generalization to completely new graphs and never-before-seen labels. Finally, G-Meta scales to large graphs, which we demonstrate on a new Tree-of-Life dataset comprising of 1,840 graphs, a two-orders of magnitude increase in the number of graphs used in prior work. 13 | 14 | ![Graph Meta Learning Problems](figs/graph_meta_learning.png) 15 | 16 | 17 | ## Environment Installation 18 | 19 | ```bash 20 | python -m pip install --user virtualenv 21 | python -m venv gmeta_env 22 | source activate gmeta_env 23 | pip install -r requirements.txt 24 | ``` 25 | 26 | ## Run 27 | ```bash 28 | cd G-Meta 29 | # Single graph disjoint label, node classification (e.g. arxiv-ogbn) 30 | python train.py --data_dir DATA_PATH --task_setup Disjoint 31 | # Multiple graph shared label, node classification (e.g. Tissue-PPI) 32 | python train.py --data_dir DATA_PATH --task_setup Shared 33 | # Multiple graph disjoint label, node classification (e.g. Fold-PPI) 34 | python train.py --data_dir DATA_PATH --task_setup Disjoint 35 | # Multiple graph shared label, link prediction (e.g. FirstMM-DB, Tree-of-Life) 36 | python train.py --data_dir DATA_PATH --task_setup Shared --link_pred_mode True 37 | ``` 38 | 39 | It also supports various parameters input: 40 | 41 | ```bash 42 | python train.py --data_dir # str: data path 43 | --task_setup # 'Disjoint' or 'Shared': task setup, disjoint label or shared label 44 | --link_pred_mode # 'True' or 'False': link prediction or node classification 45 | --batchsz # int: number of tasks in total 46 | --epoch # int: epoch size 47 | --h # 1 or 2 or 3: use h-hops neighbor as the subgraph. 48 | --hidden_dim # int: hidden dim size of GNN 49 | --input_dim # int: input dim size of GNN 50 | --k_qry # int: number of query shots for each task 51 | --k_spt # int: number of support shots for each task 52 | --n_way # int: number of ways (size of the label set) 53 | --meta_lr # float: outer loop learning rate 54 | --update_lr # float: inner loop learning rate 55 | --update_step # int: inner loop update steps during training 56 | --update_step_test # int: inner loop update steps during finetuning 57 | --task_num # int: number of tasks for each meta-set 58 | --sample_nodes # int: when subgraph size is above this threshold, it samples this number of nodes from the subgraph 59 | --task_mode # 'True' or 'False': this is specifically for Tissue-PPI, where there are 10 tasks to evaluate. 60 | --num_worker # int: number of workers to process the dataloader. default 0. 61 | --train_result_report_steps # int: number to print the training accuracy. 62 | ``` 63 | 64 | To apply it to the five datasets reported in the paper, using the following code as example after you download the processed datasets from the section below. 65 | 66 | **arxiv-ogbn**: 67 |
68 | CLICK HERE FOR THE CODE! 69 | 70 | ``` 71 | python G-Meta/train.py --data_dir PATH/G-Meta_Data/arxiv/ \ 72 | --epoch 10 \ 73 | --task_setup Disjoint \ 74 | --k_spt 3 \ 75 | --k_qry 24 \ 76 | --n_way 3 \ 77 | --update_step 10 \ 78 | --update_lr 0.01 \ 79 | --num_workers 0 \ 80 | --train_result_report_steps 200 \ 81 | --hidden_dim 256 \ 82 | --update_step_test 20 \ 83 | --task_num 32 \ 84 | --batchsz 10000 85 | ``` 86 |
87 | 88 | **Tissue-PPI**: 89 |
90 | CLICK HERE FOR THE CODE! 91 | 92 | ``` 93 | python G-Meta/train.py --data_dir PATH/G-Meta_Data/tissue_PPI/ \ 94 | --epoch 15 \ 95 | --task_setup Shared \ 96 | --task_mode True \ 97 | --task_n 4 \ 98 | --k_qry 10 \ 99 | --k_spt 3 \ 100 | --update_lr 0.01 \ 101 | --update_step 10 \ 102 | --meta_lr 5e-3 \ 103 | --num_workers 0 \ 104 | --train_result_report_steps 200 \ 105 | --hidden_dim 128 \ 106 | --task_num 4 \ 107 | --batchsz 1000 108 | ``` 109 |
110 | 111 | **Fold-PPI**: 112 |
113 | CLICK HERE FOR THE CODE! 114 | 115 | ``` 116 | python G-Meta/train.py --data_dir PATH/G-Meta_Data/fold_PPI/ \ 117 | --epoch 5 \ 118 | --task_setup Disjoint \ 119 | --k_qry 24 \ 120 | --k_spt 3 \ 121 | --n_way 3 \ 122 | --update_lr 0.005 \ 123 | --meta_lr 1e-3 \ 124 | --num_workers 0 \ 125 | --train_result_report_steps 100 \ 126 | --hidden_dim 128 \ 127 | --update_step_test 20 \ 128 | --task_num 16 \ 129 | --batchsz 4000 130 | ``` 131 |
132 | 133 | **FirstMM-DB**: 134 |
135 | CLICK HERE FOR THE CODE! 136 | 137 | ``` 138 | python G-Meta/train.py --data_dir PATH/G-Meta_Data/FirstMM_DB/ \ 139 | --epoch 15 \ 140 | --task_setup Shared \ 141 | --k_qry 32 \ 142 | --k_spt 16 \ 143 | --n_way 2 \ 144 | --update_lr 0.01 \ 145 | --update_step 10 \ 146 | --meta_lr 5e-4 \ 147 | --num_workers 0 \ 148 | --train_result_report_steps 200 \ 149 | --hidden_dim 128 \ 150 | --update_step_test 20 \ 151 | --task_num 8 \ 152 | --batchsz 1500 \ 153 | --link_pred_mod True 154 | ``` 155 |
156 | 157 | **Tree-of-Life**: 158 |
159 | CLICK HERE FOR THE CODE! 160 | 161 | ``` 162 | python train.py --data_dir PATH/G-Meta_Data/tree-of-life/ \ 163 | --epoch 15 \ 164 | --task_setup Shared \ 165 | --k_qry 16 \ 166 | --k_spt 16 \ 167 | --n_way 2 \ 168 | --update_lr 0.005 \ 169 | --update_step 10 \ 170 | --meta_lr 0.0005 \ 171 | --num_workers 0 \ 172 | --train_result_report_steps 200 \ 173 | --hidden_dim 256 \ 174 | --update_step_test 20 \ 175 | --task_num 8 \ 176 | --batchsz 5000 \ 177 | --link_pred_mod True 178 | ``` 179 |
180 | 181 | Also, check out the [Jupyter notebook example](test.ipynb). 182 | 183 | 184 | ## Data Processing 185 | 186 | We provide the processed data files for five real-world datasets in this [Drive folder](https://drive.google.com/file/d/1TC06A02wmIQteKzqGSbl_i3VIQzsHVop/view?usp=drivesdk) and this [Microsoft OneDrive folder](https://hu-my.sharepoint.com/:u:/g/personal/kexinhuang_hsph_harvard_edu/EbSj1CehKDtKniKqtICWsScBESs9ldWWcTttGdADnFc6Wg?e=gJhl7c). 187 | 188 | 1\) To create your own dataset, create the following files and organize them as follows: 189 | 190 | - `graph_dgl.pkl`: A list of DGL graph objects. For single graph G, use [G]. 191 | - `features.npy`: An array of arrays [feat_1, feat_2, ...] where feat_i is the feature matrix of graph i. 192 | 193 | 2.1) Then, for **node classification**, include the following files: 194 | - `train.csv`, `val.csv`, and `test.csv`: Each file has two columns, the first one is 'X_Y' (node Y from graph X) and its label 'Z'. Each file corresponds to the meta-train, meta-val, meta-test set. 195 | - `label.pkl`: A dictionary of labels where {'X_Y': Z} means the node Y in graph X has label Z. 196 | 197 | 2.2) Or, for **link prediction**, note that the support set contains only edges in the highly incomplete graph (e.g., 30% of links) whereas the query set edges are in the rest of the graph (e.g., 70% of links). In the neural message passing, the GNN should ONLY exchange neural messages on the support set graph. Otherwise, the query set performance is biased. Because of that, we split the meta-train/val/test files into separate support and query files. For link prediction, create the following files: 198 | - `train_spt.csv`, `val_spt.csv`, and `test_spt.csv`: Two columns, first one is 'A_B_C' (node B and C from graph A) and the second one is the label. This is for the node pairs in the support set, i.e. positive links should be in the underlying GNN graph. 199 | - `train_qry.csv`, `val_qry.csv`, and `test_qry.csv`:Two columns, first one is 'A_B_C' (node B and C from graph A) and the second one is the label. This is for the node pairs in the query set, i.e. positive links should NOT be in the underlying GNN graph. 200 | - `train.csv`, `val.csv`, and `test.csv`: Merge the above two csv files. 201 | - `label.pkl`: A dictionary of labels where {'A_B_C': D} means the node B and node C in graph A has link status D. D can be 0 or 1 means no link or has link. 202 | 203 | We also provide a sample data processing scripts in `data_process` folder. See `node_process.py` and `link_process.py`. 204 | 205 | ## Cite Us 206 | 207 | ``` 208 | @article{g-meta, 209 | title={Graph Meta Learning via Local Subgraphs}, 210 | author={Huang, Kexin and Zitnik, Marinka}, 211 | journal={NeurIPS}, 212 | year={2020} 213 | } 214 | ``` 215 | 216 | ## Contact 217 | 218 | Open an issue or send an email to kexinhuang@hsph.harvard.edu if you have any question. 219 | 220 | -------------------------------------------------------------------------------- /G-Meta/subgraph_data_processing.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torch.utils.data import Dataset 4 | import numpy as np 5 | import collections 6 | import csv 7 | import random 8 | import pickle 9 | from torch.utils.data import DataLoader 10 | import dgl 11 | import networkx as nx 12 | import itertools 13 | 14 | class Subgraphs(Dataset): 15 | def __init__(self, root, mode, subgraph2label, n_way, k_shot, k_query, batchsz, args, adjs, h): 16 | self.batchsz = batchsz # batch of set, not batch of subgraphs 17 | self.n_way = n_way # n-way 18 | self.k_shot = k_shot # k-shot support set 19 | self.k_query = k_query # for query set 20 | self.setsz = self.n_way * self.k_shot # num of samples per support set 21 | self.querysz = self.n_way * self.k_query # number of samples per set for evaluation 22 | self.h = h # number of h hops 23 | self.sample_nodes = args.sample_nodes 24 | print('shuffle DB :%s, b:%d, %d-way, %d-shot, %d-query, %d-hops' % ( 25 | mode, batchsz, n_way, k_shot, k_query, h)) 26 | 27 | # load subgraph list if preprocessed 28 | self.subgraph2label = subgraph2label 29 | 30 | if args.link_pred_mode == 'True': 31 | self.link_pred_mode = True 32 | else: 33 | self.link_pred_mode = False 34 | 35 | if self.link_pred_mode: 36 | dictLabels_spt, dictGraphs_spt, dictGraphsLabels_spt = self.loadCSV(os.path.join(root, mode + '_spt.csv')) 37 | dictLabels_qry, dictGraphs_qry, dictGraphsLabels_qry = self.loadCSV(os.path.join(root, mode + '_qry.csv')) 38 | dictLabels, dictGraphs, dictGraphsLabels = self.loadCSV(os.path.join(root, mode + '.csv')) # csv path 39 | else: 40 | dictLabels, dictGraphs, dictGraphsLabels = self.loadCSV(os.path.join(root, mode + '.csv')) # csv path 41 | 42 | self.task_setup = args.task_setup 43 | 44 | self.G = [] 45 | 46 | for i in adjs: 47 | self.G.append(i) 48 | 49 | self.subgraphs = {} 50 | 51 | if self.task_setup == 'Disjoint': 52 | self.data = [] 53 | 54 | for i, (k, v) in enumerate(dictLabels.items()): 55 | self.data.append(v) # [[subgraph1, subgraph2, ...], [subgraph111, ...]] 56 | self.cls_num = len(self.data) 57 | 58 | self.create_batch_disjoint(self.batchsz) 59 | elif self.task_setup == 'Shared': 60 | 61 | if self.link_pred_mode: 62 | 63 | self.data_graph_spt = [] 64 | 65 | for i, (k, v) in enumerate(dictGraphs_spt.items()): 66 | self.data_graph_spt.append(v) 67 | self.graph_num_spt = len(self.data_graph_spt) 68 | 69 | self.data_label_spt = [[] for i in range(self.graph_num_spt)] 70 | 71 | relative_idx_map_spt = dict(zip(list(dictGraphs_spt.keys()), range(len(list(dictGraphs_spt.keys()))))) 72 | 73 | for i, (k, v) in enumerate(dictGraphsLabels_spt.items()): 74 | for m, n in v.items(): 75 | self.data_label_spt[relative_idx_map_spt[k]].append(n) 76 | 77 | self.cls_num_spt = len(self.data_label_spt[0]) 78 | 79 | self.data_graph_qry = [] 80 | 81 | for i, (k, v) in enumerate(dictGraphs_qry.items()): 82 | self.data_graph_qry.append(v) 83 | self.graph_num_qry = len(self.data_graph_qry) 84 | 85 | self.data_label_qry = [[] for i in range(self.graph_num_qry)] 86 | 87 | relative_idx_map_qry = dict(zip(list(dictGraphs_qry.keys()), range(len(list(dictGraphs_qry.keys()))))) 88 | 89 | for i, (k, v) in enumerate(dictGraphsLabels_qry.items()): 90 | for m, n in v.items(): 91 | self.data_label_qry[relative_idx_map_qry[k]].append(n) 92 | 93 | self.cls_num_qry = len(self.data_label_qry[0]) 94 | 95 | self.create_batch_LinkPred(self.batchsz) 96 | 97 | else: 98 | self.data_graph = [] 99 | 100 | for i, (k, v) in enumerate(dictGraphs.items()): 101 | self.data_graph.append(v) 102 | self.graph_num = len(self.data_graph) 103 | 104 | self.data_label = [[] for i in range(self.graph_num)] 105 | 106 | relative_idx_map = dict(zip(list(dictGraphs.keys()), range(len(list(dictGraphs.keys()))))) 107 | 108 | for i, (k, v) in enumerate(dictGraphsLabels.items()): 109 | #self.data_label[k] = [] 110 | for m, n in v.items(): 111 | 112 | self.data_label[relative_idx_map[k]].append(n) # [(graph 1)[(label1)[subgraph1, subgraph2, ...], (label2)[subgraph111, ...]], graph2: [[subgraph1, subgraph2, ...], [subgraph111, ...]] ] 113 | self.cls_num = len(self.data_label[0]) 114 | self.graph_num = len(self.data_graph) 115 | 116 | self.create_batch_shared(self.batchsz) 117 | 118 | 119 | def loadCSV(self, csvf): 120 | dictGraphsLabels = {} 121 | dictLabels = {} 122 | dictGraphs = {} 123 | 124 | with open(csvf) as csvfile: 125 | csvreader = csv.reader(csvfile, delimiter=',') 126 | next(csvreader, None) # skip (filename, label) 127 | for i, row in enumerate(csvreader): 128 | filename = row[1] 129 | g_idx = int(filename.split('_')[0]) 130 | label = row[2] 131 | # append filename to current label 132 | 133 | if g_idx in dictGraphs.keys(): 134 | dictGraphs[g_idx].append(filename) 135 | else: 136 | dictGraphs[g_idx] = [filename] 137 | dictGraphsLabels[g_idx] = {} 138 | 139 | if label in dictGraphsLabels[g_idx].keys(): 140 | dictGraphsLabels[g_idx][label].append(filename) 141 | else: 142 | dictGraphsLabels[g_idx][label] = [filename] 143 | 144 | if label in dictLabels.keys(): 145 | dictLabels[label].append(filename) 146 | else: 147 | dictLabels[label] = [filename] 148 | return dictLabels, dictGraphs, dictGraphsLabels 149 | 150 | def create_batch_disjoint(self, batchsz): 151 | """ 152 | create the entire set of batches of tasks for disjoint label setting, indepedent of # of graphs. 153 | """ 154 | self.support_x_batch = [] # support set batch 155 | self.query_x_batch = [] # query set batch 156 | for b in range(batchsz): # for each batch 157 | # 1.select n_way classes randomly 158 | #print(self.cls_num) 159 | #print(self.n_way) 160 | selected_cls = np.random.choice(self.cls_num, self.n_way, False) # no duplicate 161 | np.random.shuffle(selected_cls) 162 | support_x = [] 163 | query_x = [] 164 | for cls in selected_cls: 165 | 166 | # 2. select k_shot + k_query for each class 167 | selected_subgraphs_idx = np.random.choice(len(self.data[cls]), self.k_shot + self.k_query, False) 168 | 169 | np.random.shuffle(selected_subgraphs_idx) 170 | indexDtrain = np.array(selected_subgraphs_idx[:self.k_shot]) # idx for Dtrain 171 | indexDtest = np.array(selected_subgraphs_idx[self.k_shot:]) # idx for Dtest 172 | support_x.append( 173 | np.array(self.data[cls])[indexDtrain].tolist()) # get all subgraphs filename for current Dtrain 174 | query_x.append(np.array(self.data[cls])[indexDtest].tolist()) 175 | 176 | # shuffle the correponding relation between support set and query set 177 | random.shuffle(support_x) 178 | random.shuffle(query_x) 179 | 180 | # support_x: [setsz (k_shot+k_query * n_way)] numbers of subgraphs 181 | self.support_x_batch.append(support_x) # append set to current sets 182 | self.query_x_batch.append(query_x) # append sets to current sets 183 | 184 | def create_batch_shared(self, batchsz): 185 | """ 186 | create the entire set of batches of tasks for shared label setting, indepedent of # of graphs. 187 | """ 188 | k_shot = self.k_shot 189 | k_query = self.k_query 190 | 191 | self.support_x_batch = [] # support set batch 192 | self.query_x_batch = [] # query set batch 193 | for b in range(batchsz): # one loop generates one task 194 | # 1.select n_way classes randomly 195 | #print(self.cls_num) 196 | #print(self.n_way) 197 | 198 | selected_graph = np.random.choice(self.graph_num, 1, False)[0] # select one graph 199 | data = self.data_label[selected_graph] 200 | 201 | selected_cls = np.array(list(range(len(data)))) # for multiple graph setting, we select cls_num * k_shot nodes 202 | np.random.shuffle(selected_cls) 203 | 204 | support_x = [] 205 | query_x = [] 206 | 207 | for cls in selected_cls: 208 | 209 | # 2. select k_shot + k_query for each class 210 | try: 211 | selected_subgraphs_idx = np.random.choice(len(data[cls]), k_shot + k_query, False) 212 | np.random.shuffle(selected_subgraphs_idx) 213 | indexDtrain = np.array(selected_subgraphs_idx[:k_shot]) # idx for Dtrain 214 | indexDtest = np.array(selected_subgraphs_idx[k_shot:]) # idx for Dtest 215 | support_x.append( 216 | np.array(data[cls])[indexDtrain].tolist()) # get all subgraphs filename for current Dtrain 217 | query_x.append(np.array(data[cls])[indexDtest].tolist()) 218 | except: 219 | # this was not used in practice 220 | if len(data[cls]) >= k_shot: 221 | selected_subgraphs_idx = np.array(range(len(data[cls]))) 222 | np.random.shuffle(selected_subgraphs_idx) 223 | indexDtrain = np.array(selected_subgraphs_idx[:k_shot]) # idx for Dtrain 224 | indexDtest = np.array(selected_subgraphs_idx[k_shot:]) # idx for Dtest 225 | support_x.append( 226 | np.array(data[cls])[indexDtrain].tolist()) # get all subgraphs filename for current Dtrain 227 | 228 | num_more = k_shot + k_query - len(data[cls]) 229 | count = 0 230 | 231 | query_tmp = np.array(data[cls])[indexDtest].tolist() 232 | 233 | while count <= num_more: 234 | sub_cls = np.random.choice(selected_cls, 1)[0] 235 | idx = np.random.choice(len(data[sub_cls]), 1)[0] 236 | query_tmp = query_tmp + [np.array(data[sub_cls])[idx]] 237 | count += 1 238 | query_x.append(query_tmp) 239 | else: 240 | print('each class in a graph must have larger than k_shot entities in the current model') 241 | 242 | random.shuffle(support_x) 243 | random.shuffle(query_x) 244 | 245 | # support_x: [setsz (k_shot+k_query * 1)] numbers of subgraphs 246 | self.support_x_batch.append(support_x) # append set to current sets 247 | self.query_x_batch.append(query_x) # append sets to current sets 248 | 249 | def create_batch_LinkPred(self, batchsz): 250 | """ 251 | create the entire set of batches of tasks for shared label linked prediction setting, indepedent of # of graphs. 252 | """ 253 | k_shot = self.k_shot 254 | k_query = self.k_query 255 | 256 | self.support_x_batch = [] # support set batch 257 | self.query_x_batch = [] # query set batch 258 | 259 | for b in range(batchsz): # one loop generates one task 260 | 261 | selected_graph = np.random.choice(self.graph_num_spt, 1, False)[0] # select one graph 262 | data_spt = self.data_label_spt[selected_graph] 263 | 264 | selected_cls_spt = np.array(list(range(len(data_spt)))) # for multiple graph setting, we select cls_num * k_shot nodes 265 | np.random.shuffle(selected_cls_spt) 266 | 267 | data_qry = self.data_label_qry[selected_graph] 268 | 269 | selected_cls_qry = np.array(list(range(len(data_qry)))) # for multiple graph setting, we select cls_num * k_shot nodes 270 | np.random.shuffle(selected_cls_qry) 271 | 272 | support_x = [] 273 | query_x = [] 274 | 275 | for cls in selected_cls_spt: 276 | 277 | selected_subgraphs_idx = np.random.choice(len(data_spt[cls]), k_shot, False) 278 | np.random.shuffle(selected_subgraphs_idx) 279 | support_x.append( 280 | np.array(data_spt[cls])[selected_subgraphs_idx].tolist()) # get all subgraphs filename for current Dtrain 281 | 282 | for cls in selected_cls_qry: 283 | 284 | selected_subgraphs_idx = np.random.choice(len(data_qry[cls]), k_query, False) 285 | np.random.shuffle(selected_subgraphs_idx) 286 | query_x.append(np.array(data_qry[cls])[selected_subgraphs_idx].tolist()) 287 | 288 | random.shuffle(support_x) 289 | random.shuffle(query_x) 290 | 291 | self.support_x_batch.append(support_x) # append set to current sets 292 | self.query_x_batch.append(query_x) # append sets to current sets 293 | 294 | # helper to generate subgraphs on the fly. 295 | def generate_subgraph(self, G, i, item): 296 | if item in self.subgraphs: 297 | return self.subgraphs[item] 298 | else: 299 | # instead of calculating shortest distance, we find the following ways to get subgraphs are quicker 300 | if self.h == 2: 301 | f_hop = [n.item() for n in G.in_edges(i)[0]] 302 | n_l = [[n.item() for n in G.in_edges(i)[0]] for i in f_hop] 303 | h_hops_neighbor = torch.tensor(list(set(list(itertools.chain(*n_l)) + f_hop + [i]))).numpy() 304 | elif self.h == 1: 305 | f_hop = [n.item() for n in G.in_edges(i)[0]] 306 | h_hops_neighbor = torch.tensor(list(set(f_hop + [i]))).numpy() 307 | elif self.h == 3: 308 | f_hop = [n.item() for n in G.in_edges(i)[0]] 309 | n_2 = [[n.item() for n in G.in_edges(i)[0]] for i in f_hop] 310 | n_3 = [[n.item() for n in G.in_edges(i)[0]] for i in list(itertools.chain(*n_2))] 311 | h_hops_neighbor = torch.tensor(list(set(list(itertools.chain(*n_2)) + list(itertools.chain(*n_3)) + f_hop + [i]))).numpy() 312 | if h_hops_neighbor.reshape(-1,).shape[0] > self.sample_nodes: 313 | h_hops_neighbor = np.random.choice(h_hops_neighbor, self.sample_nodes, replace = False) 314 | h_hops_neighbor = np.unique(np.append(h_hops_neighbor, [i])) 315 | 316 | sub = G.subgraph(h_hops_neighbor) 317 | h_c = list(sub.parent_nid.numpy()) 318 | dict_ = dict(zip(h_c, list(range(len(h_c))))) 319 | self.subgraphs[item] = (sub, dict_[i], h_c) 320 | 321 | return sub, dict_[i], h_c 322 | 323 | def generate_subgraph_link_pred(self, G, i, j, item): 324 | if item in self.subgraphs: 325 | return self.subgraphs[item] 326 | else: 327 | f_hop = [n.item() for n in G.in_edges(i)[0]] 328 | n_l = [[n.item() for n in G.in_edges(i)[0]] for i in f_hop] 329 | h_hops_neighbor1 = torch.tensor(list(set([item for sublist in n_l for item in sublist] + f_hop + [i]))).numpy() 330 | 331 | f_hop = [n.item() for n in G.in_edges(j)[0]] 332 | n_l = [[n.item() for n in G.in_edges(j)[0]] for i in f_hop] 333 | h_hops_neighbor2 = torch.tensor(list(set([item for sublist in n_l for item in sublist] + f_hop + [j]))).numpy() 334 | 335 | h_hops_neighbor = np.union1d(h_hops_neighbor1, h_hops_neighbor2) 336 | 337 | if h_hops_neighbor.reshape(-1,).shape[0] > self.sample_nodes: 338 | h_hops_neighbor = np.random.choice(h_hops_neighbor, self.sample_nodes, replace = False) 339 | h_hops_neighbor = np.unique(np.append(h_hops_neighbor, [i, j])) 340 | 341 | sub = G.subgraph(h_hops_neighbor) 342 | h_c = list(sub.parent_nid.numpy()) 343 | dict_ = dict(zip(h_c, list(range(len(h_c))))) 344 | self.subgraphs[item] = (sub, [dict_[i], dict_[j]], h_c) 345 | 346 | return sub, [dict_[i], dict_[j]], h_c 347 | 348 | def __getitem__(self, index): 349 | """ 350 | get one task. support_x_batch[index], query_x_batch[index] 351 | 352 | """ 353 | #print(self.support_x_batch[index]) 354 | if self.link_pred_mode: 355 | info = [self.generate_subgraph_link_pred(self.G[int(item.split('_')[0])], int(item.split('_')[1]), int(item.split('_')[2]), item) 356 | for sublist in self.support_x_batch[index] for item in sublist] 357 | else: 358 | info = [self.generate_subgraph(self.G[int(item.split('_')[0])], int(item.split('_')[1]), item) 359 | for sublist in self.support_x_batch[index] for item in sublist] 360 | 361 | support_graph_idx = [int(item.split('_')[0]) # obtain a list of DGL subgraphs 362 | for sublist in self.support_x_batch[index] for item in sublist] 363 | 364 | support_x = [i for i, j, k in info] 365 | support_y = np.array([self.subgraph2label[item] 366 | for sublist in self.support_x_batch[index] for item in sublist]).astype(np.int32) 367 | 368 | support_center = np.array([j for i, j, k in info]).astype(np.int32) 369 | support_node_idx = [k for i, j, k in info] 370 | 371 | 372 | if self.link_pred_mode: 373 | info = [self.generate_subgraph_link_pred(self.G[int(item.split('_')[0])], int(item.split('_')[1]), int(item.split('_')[2]), item) 374 | for sublist in self.query_x_batch[index] for item in sublist] 375 | else: 376 | info = [self.generate_subgraph(self.G[int(item.split('_')[0])], int(item.split('_')[1]), item) 377 | for sublist in self.query_x_batch[index] for item in sublist] 378 | 379 | query_graph_idx = [int(item.split('_')[0]) # obtain a list of DGL subgraphs 380 | for sublist in self.query_x_batch[index] for item in sublist] 381 | 382 | query_x = [i for i, j, k in info] 383 | query_y = np.array([self.subgraph2label[item] 384 | for sublist in self.query_x_batch[index] for item in sublist]).astype(np.int32) 385 | 386 | query_center = np.array([j for i, j, k in info]).astype(np.int32) 387 | query_node_idx = [k for i, j, k in info] 388 | 389 | if self.task_setup == 'Disjoint': 390 | unique = np.unique(support_y) 391 | random.shuffle(unique) 392 | # relative means the label ranges from 0 to n-way 393 | support_y_relative = np.zeros(self.setsz) 394 | query_y_relative = np.zeros(self.querysz) 395 | for idx, l in enumerate(unique): 396 | support_y_relative[support_y == l] = idx 397 | query_y_relative[query_y == l] = idx 398 | # this is a set of subgraphs for one task. 399 | batched_graph_spt = dgl.batch(support_x) 400 | batched_graph_qry = dgl.batch(query_x) 401 | 402 | return batched_graph_spt, torch.LongTensor(support_y_relative), batched_graph_qry, torch.LongTensor(query_y_relative), torch.LongTensor(support_center), torch.LongTensor(query_center), support_node_idx, query_node_idx, support_graph_idx, query_graph_idx 403 | elif self.task_setup == 'Shared': 404 | 405 | batched_graph_spt = dgl.batch(support_x) 406 | batched_graph_qry = dgl.batch(query_x) 407 | 408 | return batched_graph_spt, torch.LongTensor(support_y), batched_graph_qry, torch.LongTensor(query_y), torch.LongTensor(support_center), torch.LongTensor(query_center), support_node_idx, query_node_idx, support_graph_idx, query_graph_idx 409 | 410 | def __len__(self): 411 | # as we have built up to batchsz of sets, you can sample some small batch size of sets. 412 | return self.batchsz 413 | 414 | def collate(samples): 415 | # The input `samples` is a list of pairs 416 | # (graph, label). 417 | graphs_spt, labels_spt, graph_qry, labels_qry, center_spt, center_qry, nodeidx_spt, nodeidx_qry, support_graph_idx, query_graph_idx = map(list, zip(*samples)) 418 | 419 | return graphs_spt, labels_spt, graph_qry, labels_qry, center_spt, center_qry, nodeidx_spt, nodeidx_qry, support_graph_idx, query_graph_idx 420 | -------------------------------------------------------------------------------- /test.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 10, 6 | "metadata": {}, 7 | "outputs": [ 8 | { 9 | "name": "stdout", 10 | "output_type": "stream", 11 | "text": [ 12 | "Using backend: pytorch\n", 13 | "Namespace(attention_size=32, batchsz=10000, data_dir='/n/scratch3/users/k/kh278/G-Meta_Data/arxiv/', epoch=10, h=2, hidden_dim=256, input_dim=1, k_qry=24, k_spt=3, link_pred_mode='False', meta_lr=0.001, method='G-Meta', n_way=3, no_finetune=True, num_workers=0, sample_nodes=1000, task_mode='False', task_n=1, task_num=32, task_setup='Disjoint', train_result_report_steps=200, update_lr=0.01, update_step=10, update_step_test=20, val_result_report_steps=100)\n", 14 | "There are 40 classes \n", 15 | "Meta(\n", 16 | " (net): Classifier(\n", 17 | " (vars): ParameterList(\n", 18 | " (0): Parameter containing: [torch.cuda.FloatTensor of size 128x256 (GPU 0)]\n", 19 | " (1): Parameter containing: [torch.cuda.FloatTensor of size 256 (GPU 0)]\n", 20 | " (2): Parameter containing: [torch.cuda.FloatTensor of size 256x256 (GPU 0)]\n", 21 | " (3): Parameter containing: [torch.cuda.FloatTensor of size 256 (GPU 0)]\n", 22 | " (4): Parameter containing: [torch.cuda.FloatTensor of size 3x256 (GPU 0)]\n", 23 | " (5): Parameter containing: [torch.cuda.FloatTensor of size 3 (GPU 0)]\n", 24 | " )\n", 25 | " )\n", 26 | ")\n", 27 | "Total trainable tensors: 99587\n", 28 | "shuffle DB :train, b:10000, 3-way, 3-shot, 24-query, 2-hops\n", 29 | "shuffle DB :val, b:100, 3-way, 3-shot, 24-query, 2-hops\n", 30 | "shuffle DB :test, b:100, 3-way, 3-shot, 24-query, 2-hops\n", 31 | "------ Start Training ------\n", 32 | "Epoch: 1 Step: 0 training acc: 0.338 time elapsed: 6.295 data loading takes: 4.125 Memory usage: 28.71\n", 33 | "Epoch: 1 Step: 200 training acc: 0.462 time elapsed: 4.871 data loading takes: 0.591 Memory usage: 31.05\n", 34 | "Epoch: 1 Val acc: 0.456\n", 35 | "Epoch: 2 Step: 0 training acc: 0.440 time elapsed: 4.613 data loading takes: 0.377 Memory usage: 31.74\n", 36 | "Epoch: 2 Step: 200 training acc: 0.480 time elapsed: 4.770 data loading takes: 0.483 Memory usage: 39.74\n", 37 | "Epoch: 2 Val acc: 0.445\n", 38 | "Epoch: 3 Step: 0 training acc: 0.521 time elapsed: 4.731 data loading takes: 0.442 Memory usage: 40.24\n", 39 | "Epoch: 3 Step: 200 training acc: 0.497 time elapsed: 4.776 data loading takes: 0.541 Memory usage: 40.33\n", 40 | "Epoch: 3 Val acc: 0.445\n", 41 | "Epoch: 4 Step: 0 training acc: 0.509 time elapsed: 4.752 data loading takes: 0.476 Memory usage: 39.76\n", 42 | "Epoch: 4 Step: 200 training acc: 0.505 time elapsed: 4.812 data loading takes: 0.471 Memory usage: 39.87\n", 43 | "Epoch: 4 Val acc: 0.443\n", 44 | "Epoch: 5 Step: 0 training acc: 0.503 time elapsed: 4.892 data loading takes: 0.473 Memory usage: 40.82\n", 45 | "Epoch: 5 Step: 200 training acc: 0.477 time elapsed: 4.825 data loading takes: 0.487 Memory usage: 40.92\n", 46 | "Epoch: 5 Val acc: 0.441\n", 47 | "Epoch: 6 Step: 0 training acc: 0.509 time elapsed: 4.829 data loading takes: 0.412 Memory usage: 40.33\n", 48 | "Epoch: 6 Step: 200 training acc: 0.519 time elapsed: 4.861 data loading takes: 0.453 Memory usage: 40.39\n", 49 | "Epoch: 6 Val acc: 0.444\n", 50 | "Epoch: 7 Step: 0 training acc: 0.553 time elapsed: 4.757 data loading takes: 0.399 Memory usage: 41.46\n", 51 | "Epoch: 7 Step: 200 training acc: 0.503 time elapsed: 4.825 data loading takes: 0.519 Memory usage: 41.73\n", 52 | "Epoch: 7 Val acc: 0.439\n", 53 | "Epoch: 8 Step: 0 training acc: 0.523 time elapsed: 4.859 data loading takes: 0.490 Memory usage: 41.53\n", 54 | "Epoch: 8 Step: 200 training acc: 0.516 time elapsed: 4.675 data loading takes: 0.554 Memory usage: 41.43\n", 55 | "Epoch: 8 Val acc: 0.441\n", 56 | "Epoch: 9 Step: 0 training acc: 0.519 time elapsed: 4.795 data loading takes: 0.444 Memory usage: 40.89\n", 57 | "Epoch: 9 Step: 200 training acc: 0.507 time elapsed: 4.875 data loading takes: 3.407 Memory usage: 41.21\n", 58 | "Epoch: 9 Val acc: 0.443\n", 59 | "Epoch: 10 Step: 0 training acc: 0.503 time elapsed: 4.898 data loading takes: 0.423 Memory usage: 41.88\n", 60 | "Epoch: 10 Step: 200 training acc: 0.560 time elapsed: 4.960 data loading takes: 0.492 Memory usage: 42.23\n", 61 | "Epoch: 10 Val acc: 0.44\n", 62 | "Test acc: 0.421\n", 63 | "Early Stopped Test acc: 0.436\n", 64 | "Total Time: 17206\n", 65 | "Max Momory: 42.52\n" 66 | ] 67 | } 68 | ], 69 | "source": [ 70 | "!python G-Meta/train.py --data_dir DATA_PATH/G-Meta_Data/arxiv/ \\\n", 71 | " --epoch 10 \\\n", 72 | " --task_setup Disjoint \\\n", 73 | " --k_spt 3 \\\n", 74 | " --k_qry 24 \\\n", 75 | " --n_way 3 \\\n", 76 | " --update_step 10 \\\n", 77 | " --update_lr 0.01 \\\n", 78 | " --num_workers 0 \\\n", 79 | " --train_result_report_steps 200 \\\n", 80 | " --hidden_dim 256 \\\n", 81 | " --update_step_test 20 \\\n", 82 | " --task_num 32 \\\n", 83 | " --batchsz 10000 " 84 | ] 85 | }, 86 | { 87 | "cell_type": "code", 88 | "execution_count": 1, 89 | "metadata": { 90 | "scrolled": false 91 | }, 92 | "outputs": [ 93 | { 94 | "name": "stdout", 95 | "output_type": "stream", 96 | "text": [ 97 | "Using backend: pytorch\n", 98 | "Namespace(attention_size=32, batchsz=1000, data_dir='/n/scratch3/users/k/kh278/G-Meta_Data/tissue_PPI/', epoch=15, h=2, hidden_dim=128, input_dim=1, k_qry=10, k_spt=3, link_pred_mode='False', meta_lr=0.005, method='G-Meta', n_way=3, no_finetune=True, num_workers=0, sample_nodes=1000, task_mode='True', task_n=4, task_num=4, task_setup='Shared', train_result_report_steps=200, update_lr=0.01, update_step=10, update_step_test=10, val_result_report_steps=100)\n", 99 | "There are 2 classes \n", 100 | "Meta(\n", 101 | " (net): Classifier(\n", 102 | " (vars): ParameterList(\n", 103 | " (0): Parameter containing: [torch.cuda.FloatTensor of size 50x128 (GPU 0)]\n", 104 | " (1): Parameter containing: [torch.cuda.FloatTensor of size 128 (GPU 0)]\n", 105 | " (2): Parameter containing: [torch.cuda.FloatTensor of size 128x128 (GPU 0)]\n", 106 | " (3): Parameter containing: [torch.cuda.FloatTensor of size 128 (GPU 0)]\n", 107 | " (4): Parameter containing: [torch.cuda.FloatTensor of size 2x128 (GPU 0)]\n", 108 | " (5): Parameter containing: [torch.cuda.FloatTensor of size 2 (GPU 0)]\n", 109 | " )\n", 110 | " )\n", 111 | ")\n", 112 | "Total trainable tensors: 23298\n", 113 | "shuffle DB :train, b:1000, 3-way, 3-shot, 10-query, 2-hops\n", 114 | "shuffle DB :val, b:100, 3-way, 3-shot, 10-query, 2-hops\n", 115 | "shuffle DB :test, b:100, 3-way, 3-shot, 10-query, 2-hops\n", 116 | "------ Start Training ------\n", 117 | "Epoch: 1 Step: 0 training acc: 0.575 time elapsed: 1.461 data loading takes: 1.735 Memory usage: 11.21\n", 118 | "Epoch: 1 Step: 200 training acc: 0.575 time elapsed: 0.578 data loading takes: 1.836 Memory usage: 31.95\n", 119 | "Epoch: 1 Val acc: 0.563\n", 120 | "Epoch: 2 Step: 0 training acc: 0.587 time elapsed: 0.532 data loading takes: 0.465 Memory usage: 37.52\n", 121 | "Epoch: 2 Step: 200 training acc: 0.612 time elapsed: 0.560 data loading takes: 0.549 Memory usage: 37.47\n", 122 | "Epoch: 2 Val acc: 0.608\n", 123 | "Epoch: 3 Step: 0 training acc: 0.65 time elapsed: 0.559 data loading takes: 0.484 Memory usage: 37.46\n", 124 | "Epoch: 3 Step: 200 training acc: 0.65 time elapsed: 0.551 data loading takes: 0.536 Memory usage: 37.52\n", 125 | "Epoch: 3 Val acc: 0.623\n", 126 | "Epoch: 4 Step: 0 training acc: 0.662 time elapsed: 0.540 data loading takes: 0.449 Memory usage: 37.59\n", 127 | "Epoch: 4 Step: 200 training acc: 0.724 time elapsed: 0.593 data loading takes: 0.649 Memory usage: 37.61\n", 128 | "Epoch: 4 Val acc: 0.644\n", 129 | "Epoch: 5 Step: 0 training acc: 0.787 time elapsed: 0.548 data loading takes: 0.479 Memory usage: 37.65\n", 130 | "Epoch: 5 Step: 200 training acc: 0.725 time elapsed: 0.555 data loading takes: 0.556 Memory usage: 37.71\n", 131 | "Epoch: 5 Val acc: 0.645\n", 132 | "Epoch: 6 Step: 0 training acc: 0.612 time elapsed: 0.537 data loading takes: 0.443 Memory usage: 37.56\n", 133 | "Epoch: 6 Step: 200 training acc: 0.674 time elapsed: 0.560 data loading takes: 0.562 Memory usage: 37.70\n", 134 | "Epoch: 6 Val acc: 0.666\n", 135 | "Epoch: 7 Step: 0 training acc: 0.575 time elapsed: 0.574 data loading takes: 0.579 Memory usage: 37.63\n", 136 | "Epoch: 7 Step: 200 training acc: 0.775 time elapsed: 0.551 data loading takes: 0.523 Memory usage: 37.67\n", 137 | "Epoch: 7 Val acc: 0.672\n", 138 | "Epoch: 8 Step: 0 training acc: 0.812 time elapsed: 0.531 data loading takes: 0.423 Memory usage: 37.71\n", 139 | "Epoch: 8 Step: 200 training acc: 0.737 time elapsed: 0.520 data loading takes: 0.460 Memory usage: 37.77\n", 140 | "Epoch: 8 Val acc: 0.681\n", 141 | "Epoch: 9 Step: 0 training acc: 0.825 time elapsed: 0.528 data loading takes: 0.437 Memory usage: 37.70\n", 142 | "Epoch: 9 Step: 200 training acc: 0.825 time elapsed: 0.595 data loading takes: 0.622 Memory usage: 37.74\n", 143 | "Epoch: 9 Val acc: 0.689\n", 144 | "Epoch: 10 Step: 0 training acc: 0.812 time elapsed: 0.572 data loading takes: 0.543 Memory usage: 37.72\n", 145 | "Epoch: 10 Step: 200 training acc: 0.762 time elapsed: 0.551 data loading takes: 0.555 Memory usage: 37.76\n", 146 | "Epoch: 10 Val acc: 0.688\n", 147 | "Epoch: 11 Step: 0 training acc: 0.825 time elapsed: 0.554 data loading takes: 0.524 Memory usage: 37.78\n", 148 | "Epoch: 11 Step: 200 training acc: 0.75 time elapsed: 0.608 data loading takes: 0.654 Memory usage: 37.76\n", 149 | "Epoch: 11 Val acc: 0.716\n", 150 | "Epoch: 12 Step: 0 training acc: 0.787 time elapsed: 0.524 data loading takes: 0.445 Memory usage: 37.90\n", 151 | "Epoch: 12 Step: 200 training acc: 0.787 time elapsed: 0.604 data loading takes: 0.683 Memory usage: 37.66\n", 152 | "Epoch: 12 Val acc: 0.703\n", 153 | "Epoch: 13 Step: 0 training acc: 0.762 time elapsed: 0.551 data loading takes: 0.478 Memory usage: 37.59\n", 154 | "Epoch: 13 Step: 200 training acc: 0.85 time elapsed: 0.538 data loading takes: 0.458 Memory usage: 37.62\n", 155 | "Epoch: 13 Val acc: 0.723\n", 156 | "Epoch: 14 Step: 0 training acc: 0.85 time elapsed: 0.540 data loading takes: 0.459 Memory usage: 37.66\n", 157 | "Epoch: 14 Step: 200 training acc: 0.762 time elapsed: 0.568 data loading takes: 0.571 Memory usage: 37.65\n", 158 | "Epoch: 14 Val acc: 0.723\n", 159 | "Epoch: 15 Step: 0 training acc: 0.8 time elapsed: 0.499 data loading takes: 0.321 Memory usage: 37.57\n", 160 | "Epoch: 15 Step: 200 training acc: 0.762 time elapsed: 0.536 data loading takes: 0.481 Memory usage: 37.69\n", 161 | "Epoch: 15 Val acc: 0.730\n", 162 | "Test acc: 0.78\n", 163 | "Early Stopped Test acc: 0.774\n", 164 | "Total Time: 4852.\n", 165 | "Max Momory: 37.90\n" 166 | ] 167 | } 168 | ], 169 | "source": [ 170 | "!python G-Meta/train.py --data_dir DATA_PATH/G-Meta_Data/tissue_PPI/ \\\n", 171 | " --epoch 15 \\\n", 172 | " --task_setup Shared \\\n", 173 | " --task_mode True \\\n", 174 | " --task_n 4 \\\n", 175 | " --k_qry 10 \\\n", 176 | " --k_spt 3 \\\n", 177 | " --update_lr 0.01 \\\n", 178 | " --update_step 10 \\\n", 179 | " --meta_lr 5e-3 \\\n", 180 | " --num_workers 0 \\\n", 181 | " --train_result_report_steps 200 \\\n", 182 | " --hidden_dim 128 \\\n", 183 | " --task_num 4 \\\n", 184 | " --batchsz 1000" 185 | ] 186 | }, 187 | { 188 | "cell_type": "code", 189 | "execution_count": 3, 190 | "metadata": {}, 191 | "outputs": [ 192 | { 193 | "name": "stdout", 194 | "output_type": "stream", 195 | "text": [ 196 | "Using backend: pytorch\n", 197 | "Namespace(attention_size=32, batchsz=4000, data_dir='/n/scratch3/users/k/kh278/G-Meta_Data/fold_PPI/', epoch=5, h=2, hidden_dim=128, input_dim=1, k_qry=24, k_spt=3, link_pred_mode='False', meta_lr=0.001, method='G-Meta', n_way=3, no_finetune=True, num_workers=0, sample_nodes=1000, task_mode='False', task_n=1, task_num=16, task_setup='Disjoint', train_result_report_steps=100, update_lr=0.005, update_step=5, update_step_test=20, val_result_report_steps=100)\n", 198 | "There are 30 classes \n", 199 | "Meta(\n", 200 | " (net): Classifier(\n", 201 | " (vars): ParameterList(\n", 202 | " (0): Parameter containing: [torch.cuda.FloatTensor of size 512x128 (GPU 0)]\n", 203 | " (1): Parameter containing: [torch.cuda.FloatTensor of size 128 (GPU 0)]\n", 204 | " (2): Parameter containing: [torch.cuda.FloatTensor of size 128x128 (GPU 0)]\n", 205 | " (3): Parameter containing: [torch.cuda.FloatTensor of size 128 (GPU 0)]\n", 206 | " (4): Parameter containing: [torch.cuda.FloatTensor of size 3x128 (GPU 0)]\n", 207 | " (5): Parameter containing: [torch.cuda.FloatTensor of size 3 (GPU 0)]\n", 208 | " )\n", 209 | " )\n", 210 | ")\n", 211 | "Total trainable tensors: 82563\n", 212 | "shuffle DB :train, b:4000, 3-way, 3-shot, 24-query, 2-hops\n", 213 | "shuffle DB :val, b:100, 3-way, 3-shot, 24-query, 2-hops\n", 214 | "shuffle DB :test, b:100, 3-way, 3-shot, 24-query, 2-hops\n", 215 | "------ Start Training ------\n", 216 | "Epoch: 1 Step: 0 training acc: 0.427 time elapsed: 3.852 data loading takes: 6.680 Memory usage: 27.73\n", 217 | "Epoch: 1 Step: 100 training acc: 0.451 time elapsed: 3.616 data loading takes: 1.766 Memory usage: 41.06\n", 218 | "Epoch: 1 Step: 200 training acc: 0.447 time elapsed: 4.523 data loading takes: 1.990 Memory usage: 41.58\n", 219 | "Epoch: 1 Val acc: 0.478\n", 220 | "Epoch: 2 Step: 0 training acc: 0.471 time elapsed: 3.576 data loading takes: 1.379 Memory usage: 43.33\n", 221 | "Epoch: 2 Step: 100 training acc: 0.571 time elapsed: 3.528 data loading takes: 1.434 Memory usage: 42.90\n", 222 | "Epoch: 2 Step: 200 training acc: 0.589 time elapsed: 3.381 data loading takes: 1.399 Memory usage: 43.79\n", 223 | "Epoch: 2 Val acc: 0.494\n", 224 | "Epoch: 3 Step: 0 training acc: 0.566 time elapsed: 3.575 data loading takes: 1.249 Memory usage: 44.07\n", 225 | "Epoch: 3 Step: 100 training acc: 0.662 time elapsed: 3.753 data loading takes: 1.538 Memory usage: 43.48\n", 226 | "Epoch: 3 Step: 200 training acc: 0.663 time elapsed: 3.829 data loading takes: 1.579 Memory usage: 44.14\n", 227 | "Epoch: 3 Val acc: 0.522\n", 228 | "Epoch: 4 Step: 0 training acc: 0.594 time elapsed: 3.804 data loading takes: 1.329 Memory usage: 43.42\n", 229 | "Epoch: 4 Step: 100 training acc: 0.598 time elapsed: 3.962 data loading takes: 1.630 Memory usage: 43.99\n", 230 | "Epoch: 4 Step: 200 training acc: 0.743 time elapsed: 3.637 data loading takes: 1.465 Memory usage: 43.46\n", 231 | "Epoch: 4 Val acc: 0.513\n", 232 | "Epoch: 5 Step: 0 training acc: 0.705 time elapsed: 3.728 data loading takes: 1.319 Memory usage: 43.72\n", 233 | "Epoch: 5 Step: 100 training acc: 0.812 time elapsed: 3.654 data loading takes: 1.521 Memory usage: 44.23\n", 234 | "Epoch: 5 Step: 200 training acc: 0.724 time elapsed: 3.919 data loading takes: 1.637 Memory usage: 44.02\n", 235 | "Epoch: 5 Val acc: 0.543\n", 236 | "Test acc: 0.578\n", 237 | "Early Stopped Test acc: 0.656\n", 238 | "Total Time: 7150.\n", 239 | "Max Momory: 44.39\n" 240 | ] 241 | } 242 | ], 243 | "source": [ 244 | "!python G-Meta/train.py --data_dir DATA_PATH/G-Meta_Data/fold_PPI/ \\\n", 245 | " --epoch 5 \\\n", 246 | " --task_setup Disjoint \\\n", 247 | " --k_qry 24 \\\n", 248 | " --k_spt 3 \\\n", 249 | " --n_way 3 \\\n", 250 | " --update_lr 0.005 \\\n", 251 | " --meta_lr 1e-3 \\\n", 252 | " --num_workers 0 \\\n", 253 | " --train_result_report_steps 100 \\\n", 254 | " --hidden_dim 128 \\\n", 255 | " --update_step_test 20 \\\n", 256 | " --task_num 16 \\\n", 257 | " --batchsz 4000" 258 | ] 259 | }, 260 | { 261 | "cell_type": "code", 262 | "execution_count": 2, 263 | "metadata": {}, 264 | "outputs": [ 265 | { 266 | "name": "stdout", 267 | "output_type": "stream", 268 | "text": [ 269 | "Using backend: pytorch\n", 270 | "Namespace(attention_size=32, batchsz=1500, data_dir='/n/scratch3/users/k/kh278/G-Meta_Data/FirstMM_DB/', epoch=15, h=2, hidden_dim=128, input_dim=1, k_qry=32, k_spt=16, link_pred_mode='True', meta_lr=0.0005, method='G-Meta', n_way=2, no_finetune=True, num_workers=0, sample_nodes=1000, task_mode='False', task_n=1, task_num=8, task_setup='Shared', train_result_report_steps=200, update_lr=0.01, update_step=10, update_step_test=20, val_result_report_steps=100)\n", 271 | "There are 2 classes \n", 272 | "Meta(\n", 273 | " (net): Classifier(\n", 274 | " (vars): ParameterList(\n", 275 | " (0): Parameter containing: [torch.cuda.FloatTensor of size 5x128 (GPU 0)]\n", 276 | " (1): Parameter containing: [torch.cuda.FloatTensor of size 128 (GPU 0)]\n", 277 | " (2): Parameter containing: [torch.cuda.FloatTensor of size 128x128 (GPU 0)]\n", 278 | " (3): Parameter containing: [torch.cuda.FloatTensor of size 128 (GPU 0)]\n", 279 | " (4): Parameter containing: [torch.cuda.FloatTensor of size 2x256 (GPU 0)]\n", 280 | " (5): Parameter containing: [torch.cuda.FloatTensor of size 2 (GPU 0)]\n", 281 | " )\n", 282 | " )\n", 283 | ")\n", 284 | "Total trainable tensors: 17794\n", 285 | "shuffle DB :train, b:1500, 2-way, 16-shot, 32-query, 2-hops\n", 286 | "shuffle DB :val, b:100, 2-way, 16-shot, 32-query, 2-hops\n", 287 | "shuffle DB :test, b:100, 2-way, 16-shot, 32-query, 2-hops\n", 288 | "------ Start Training ------\n", 289 | "Epoch: 1 Step: 0 training acc: 0.492 time elapsed: 1.279 data loading takes: 1.136 Memory usage: 13.26\n", 290 | "Epoch: 1 Val acc: 0.691\n", 291 | "Epoch: 2 Step: 0 training acc: 0.695 time elapsed: 0.635 data loading takes: 0.080 Memory usage: 14.51\n", 292 | "Epoch: 2 Val acc: 0.735\n", 293 | "Epoch: 3 Step: 0 training acc: 0.693 time elapsed: 0.636 data loading takes: 0.075 Memory usage: 14.53\n", 294 | "Epoch: 3 Val acc: 0.762\n", 295 | "Epoch: 4 Step: 0 training acc: 0.705 time elapsed: 0.642 data loading takes: 0.074 Memory usage: 14.51\n", 296 | "Epoch: 4 Val acc: 0.769\n", 297 | "Epoch: 5 Step: 0 training acc: 0.728 time elapsed: 0.630 data loading takes: 0.077 Memory usage: 14.52\n", 298 | "Epoch: 5 Val acc: 0.778\n", 299 | "Epoch: 6 Step: 0 training acc: 0.75 time elapsed: 0.633 data loading takes: 0.077 Memory usage: 14.52\n", 300 | "Epoch: 6 Val acc: 0.780\n", 301 | "Epoch: 7 Step: 0 training acc: 0.734 time elapsed: 0.640 data loading takes: 0.080 Memory usage: 14.58\n", 302 | "Epoch: 7 Val acc: 0.785\n", 303 | "Epoch: 8 Step: 0 training acc: 0.744 time elapsed: 0.639 data loading takes: 0.072 Memory usage: 14.52\n", 304 | "Epoch: 8 Val acc: 0.786\n", 305 | "Epoch: 9 Step: 0 training acc: 0.732 time elapsed: 1.201 data loading takes: 0.149 Memory usage: 14.52\n", 306 | "Epoch: 9 Val acc: 0.785\n", 307 | "Epoch: 10 Step: 0 training acc: 0.787 time elapsed: 0.633 data loading takes: 0.083 Memory usage: 14.54\n", 308 | "Epoch: 10 Val acc: 0.785\n", 309 | "Epoch: 11 Step: 0 training acc: 0.751 time elapsed: 0.633 data loading takes: 0.072 Memory usage: 14.59\n", 310 | "Epoch: 11 Val acc: 0.789\n", 311 | "Epoch: 12 Step: 0 training acc: 0.748 time elapsed: 0.635 data loading takes: 0.074 Memory usage: 14.58\n", 312 | "Epoch: 12 Val acc: 0.791\n", 313 | "Epoch: 13 Step: 0 training acc: 0.753 time elapsed: 0.635 data loading takes: 0.076 Memory usage: 14.55\n", 314 | "Epoch: 13 Val acc: 0.791\n", 315 | "Epoch: 14 Step: 0 training acc: 0.789 time elapsed: 0.660 data loading takes: 0.082 Memory usage: 14.63\n", 316 | "Epoch: 14 Val acc: 0.793\n", 317 | "Epoch: 15 Step: 0 training acc: 0.738 time elapsed: 0.628 data loading takes: 0.080 Memory usage: 14.59\n", 318 | "Epoch: 15 Val acc: 0.799\n", 319 | "Test acc: 0.769\n", 320 | "Early Stopped Test acc: 0.756\n", 321 | "Total Time: 2536.\n", 322 | "Max Momory: 14.86\n" 323 | ] 324 | } 325 | ], 326 | "source": [ 327 | "!python G-Meta/train.py --data_dir DATA_PATH/G-Meta_Data/FirstMM_DB/ \\\n", 328 | " --epoch 15 \\\n", 329 | " --task_setup Shared \\\n", 330 | " --k_qry 32 \\\n", 331 | " --k_spt 16 \\\n", 332 | " --n_way 2 \\\n", 333 | " --update_lr 0.01 \\\n", 334 | " --update_step 10 \\\n", 335 | " --meta_lr 5e-4 \\\n", 336 | " --num_workers 0 \\\n", 337 | " --train_result_report_steps 200 \\\n", 338 | " --hidden_dim 128 \\\n", 339 | " --update_step_test 20 \\\n", 340 | " --task_num 8 \\\n", 341 | " --batchsz 1500 \\\n", 342 | " --link_pred_mod True" 343 | ] 344 | }, 345 | { 346 | "cell_type": "code", 347 | "execution_count": 4, 348 | "metadata": {}, 349 | "outputs": [ 350 | { 351 | "name": "stdout", 352 | "output_type": "stream", 353 | "text": [ 354 | "Using backend: pytorch\n", 355 | "Namespace(attention_size=32, batchsz=5000, data_dir='/n/scratch3/users/k/kh278/G-Meta_Data/tree-of-life/', epoch=15, h=2, hidden_dim=256, input_dim=1, k_qry=16, k_spt=16, link_pred_mode='True', meta_lr=0.0005, method='G-Meta', n_way=2, no_finetune=True, num_workers=0, sample_nodes=1000, task_mode='False', task_n=1, task_num=8, task_setup='Shared', train_result_report_steps=200, update_lr=0.005, update_step=10, update_step_test=20, val_result_report_steps=100)\n", 356 | "There are 2 classes \n", 357 | "Meta(\n", 358 | " (net): Classifier(\n", 359 | " (vars): ParameterList(\n", 360 | " (0): Parameter containing: [torch.cuda.FloatTensor of size 1x256 (GPU 0)]\n", 361 | " (1): Parameter containing: [torch.cuda.FloatTensor of size 256 (GPU 0)]\n", 362 | " (2): Parameter containing: [torch.cuda.FloatTensor of size 256x256 (GPU 0)]\n", 363 | " (3): Parameter containing: [torch.cuda.FloatTensor of size 256 (GPU 0)]\n", 364 | " (4): Parameter containing: [torch.cuda.FloatTensor of size 2x512 (GPU 0)]\n", 365 | " (5): Parameter containing: [torch.cuda.FloatTensor of size 2 (GPU 0)]\n", 366 | " )\n", 367 | " )\n", 368 | ")\n", 369 | "Total trainable tensors: 67330\n", 370 | "shuffle DB :train, b:5000, 2-way, 16-shot, 16-query, 2-hops\n", 371 | "shuffle DB :val, b:100, 2-way, 16-shot, 16-query, 2-hops\n", 372 | "shuffle DB :test, b:100, 2-way, 16-shot, 16-query, 2-hops\n", 373 | "------ Start Training ------\n", 374 | "Epoch: 1 Step: 0 training acc: 0.628 time elapsed: 1.257 data loading takes: 3.446 Memory usage: 28.58\n", 375 | "Epoch: 1 Step: 200 training acc: 0.621 time elapsed: 0.734 data loading takes: 26.83 Memory usage: 41.77\n", 376 | "Epoch: 1 Step: 400 training acc: 0.738 time elapsed: 0.701 data loading takes: 3.756 Memory usage: 54.42\n", 377 | "Epoch: 1 Step: 600 training acc: 0.695 time elapsed: 0.648 data loading takes: 1.814 Memory usage: 65.41\n", 378 | "Epoch: 1 Val acc: 0.694\n", 379 | "Epoch: 2 Step: 0 training acc: 0.667 time elapsed: 0.672 data loading takes: 0.160 Memory usage: 67.54\n", 380 | "Epoch: 2 Step: 200 training acc: 0.660 time elapsed: 0.665 data loading takes: 0.151 Memory usage: 68.14\n", 381 | "Epoch: 2 Step: 400 training acc: 0.691 time elapsed: 0.673 data loading takes: 0.144 Memory usage: 67.23\n", 382 | "Epoch: 2 Step: 600 training acc: 0.714 time elapsed: 0.678 data loading takes: 0.172 Memory usage: 67.63\n", 383 | "Epoch: 2 Val acc: 0.702\n", 384 | "Epoch: 3 Step: 0 training acc: 0.695 time elapsed: 0.666 data loading takes: 0.166 Memory usage: 67.67\n", 385 | "Epoch: 3 Step: 200 training acc: 0.714 time elapsed: 0.718 data loading takes: 0.297 Memory usage: 68.02\n", 386 | "Epoch: 3 Step: 400 training acc: 0.710 time elapsed: 0.743 data loading takes: 0.393 Memory usage: 68.19\n", 387 | "Epoch: 3 Step: 600 training acc: 0.699 time elapsed: 0.726 data loading takes: 0.286 Memory usage: 68.07\n", 388 | "Epoch: 3 Val acc: 0.709\n", 389 | "Epoch: 4 Step: 0 training acc: 0.671 time elapsed: 0.656 data loading takes: 0.151 Memory usage: 68.10\n", 390 | "Epoch: 4 Step: 200 training acc: 0.722 time elapsed: 0.698 data loading takes: 0.172 Memory usage: 67.22\n", 391 | "Epoch: 4 Step: 400 training acc: 0.718 time elapsed: 0.680 data loading takes: 0.232 Memory usage: 67.23\n", 392 | "Epoch: 4 Step: 600 training acc: 0.726 time elapsed: 0.682 data loading takes: 0.201 Memory usage: 67.62\n", 393 | "Epoch: 4 Val acc: 0.723\n", 394 | "Epoch: 5 Step: 0 training acc: 0.679 time elapsed: 0.670 data loading takes: 0.190 Memory usage: 67.73\n", 395 | "Epoch: 5 Step: 200 training acc: 0.703 time elapsed: 0.696 data loading takes: 0.143 Memory usage: 68.12\n", 396 | "Epoch: 5 Step: 400 training acc: 0.679 time elapsed: 0.704 data loading takes: 0.279 Memory usage: 68.20\n", 397 | "Epoch: 5 Step: 600 training acc: 0.660 time elapsed: 0.704 data loading takes: 0.232 Memory usage: 68.31\n", 398 | "Epoch: 5 Val acc: 0.722\n", 399 | "Epoch: 6 Step: 0 training acc: 0.730 time elapsed: 0.722 data loading takes: 0.373 Memory usage: 67.17\n", 400 | "Epoch: 6 Step: 200 training acc: 0.726 time elapsed: 0.700 data loading takes: 0.216 Memory usage: 67.19\n", 401 | "Epoch: 6 Step: 400 training acc: 0.714 time elapsed: 0.740 data loading takes: 0.332 Memory usage: 67.80\n", 402 | "Epoch: 6 Step: 600 training acc: 0.707 time elapsed: 0.692 data loading takes: 0.199 Memory usage: 68.09\n", 403 | "Epoch: 6 Val acc: 0.730\n", 404 | "Epoch: 7 Step: 0 training acc: 0.703 time elapsed: 0.707 data loading takes: 0.271 Memory usage: 68.06\n", 405 | "Epoch: 7 Step: 200 training acc: 0.703 time elapsed: 0.691 data loading takes: 0.227 Memory usage: 68.29\n", 406 | "Epoch: 7 Step: 400 training acc: 0.738 time elapsed: 0.682 data loading takes: 0.178 Memory usage: 68.44\n", 407 | "Epoch: 7 Step: 600 training acc: 0.667 time elapsed: 0.696 data loading takes: 0.192 Memory usage: 67.66\n", 408 | "Epoch: 7 Val acc: 0.712\n", 409 | "Epoch: 8 Step: 0 training acc: 0.707 time elapsed: 0.673 data loading takes: 0.227 Memory usage: 67.77\n", 410 | "Epoch: 8 Step: 200 training acc: 0.765 time elapsed: 0.665 data loading takes: 0.216 Memory usage: 55.23\n", 411 | "Epoch: 8 Step: 400 training acc: 0.738 time elapsed: 0.743 data loading takes: 0.315 Memory usage: 57.62\n", 412 | "Epoch: 8 Step: 600 training acc: 0.734 time elapsed: 0.759 data loading takes: 0.259 Memory usage: 57.61\n", 413 | "Epoch: 8 Val acc: 0.734\n", 414 | "Epoch: 9 Step: 0 training acc: 0.710 time elapsed: 0.740 data loading takes: 0.258 Memory usage: 57.70\n", 415 | "Epoch: 9 Step: 200 training acc: 0.699 time elapsed: 0.701 data loading takes: 0.218 Memory usage: 57.63\n", 416 | "Epoch: 9 Step: 400 training acc: 0.734 time elapsed: 0.706 data loading takes: 0.172 Memory usage: 57.51\n", 417 | "Epoch: 9 Step: 600 training acc: 0.738 time elapsed: 0.701 data loading takes: 0.196 Memory usage: 57.50\n", 418 | "Epoch: 9 Val acc: 0.742\n", 419 | "Epoch: 10 Step: 0 training acc: 0.707 time elapsed: 0.697 data loading takes: 0.194 Memory usage: 57.48\n", 420 | "Epoch: 10 Step: 200 training acc: 0.722 time elapsed: 0.713 data loading takes: 0.213 Memory usage: 57.55\n", 421 | "Epoch: 10 Step: 400 training acc: 0.761 time elapsed: 0.688 data loading takes: 0.111 Memory usage: 57.53\n", 422 | "Epoch: 10 Step: 600 training acc: 0.765 time elapsed: 0.689 data loading takes: 0.160 Memory usage: 57.51\n", 423 | "Epoch: 10 Val acc: 0.752\n", 424 | "Epoch: 11 Step: 0 training acc: 0.671 time elapsed: 0.708 data loading takes: 0.240 Memory usage: 57.48\n", 425 | "Epoch: 11 Step: 200 training acc: 0.734 time elapsed: 0.726 data loading takes: 0.251 Memory usage: 57.54\n", 426 | "Epoch: 11 Step: 400 training acc: 0.726 time elapsed: 0.725 data loading takes: 0.280 Memory usage: 57.47\n", 427 | "Epoch: 11 Step: 600 training acc: 0.726 time elapsed: 0.757 data loading takes: 0.285 Memory usage: 57.47\n", 428 | "Epoch: 11 Val acc: 0.748\n", 429 | "Epoch: 12 Step: 0 training acc: 0.691 time elapsed: 0.702 data loading takes: 0.229 Memory usage: 57.46\n", 430 | "Epoch: 12 Step: 200 training acc: 0.765 time elapsed: 0.687 data loading takes: 0.184 Memory usage: 57.47\n", 431 | "Epoch: 12 Step: 400 training acc: 0.710 time elapsed: 0.711 data loading takes: 0.180 Memory usage: 57.56\n", 432 | "Epoch: 12 Step: 600 training acc: 0.722 time elapsed: 0.716 data loading takes: 0.271 Memory usage: 57.52\n", 433 | "Epoch: 12 Val acc: 0.721\n", 434 | "Epoch: 13 Step: 0 training acc: 0.777 time elapsed: 0.681 data loading takes: 0.166 Memory usage: 57.58\n", 435 | "Epoch: 13 Step: 200 training acc: 0.664 time elapsed: 0.765 data loading takes: 0.432 Memory usage: 57.44\n", 436 | "Epoch: 13 Step: 400 training acc: 0.753 time elapsed: 0.730 data loading takes: 0.226 Memory usage: 57.48\n", 437 | "Epoch: 13 Step: 600 training acc: 0.652 time elapsed: 0.713 data loading takes: 0.218 Memory usage: 57.46\n", 438 | "Epoch: 13 Val acc: 0.729\n", 439 | "Epoch: 14 Step: 0 training acc: 0.746 time elapsed: 0.661 data loading takes: 0.121 Memory usage: 57.49\n", 440 | "Epoch: 14 Step: 200 training acc: 0.664 time elapsed: 0.721 data loading takes: 0.315 Memory usage: 57.47\n", 441 | "Epoch: 14 Step: 400 training acc: 0.687 time elapsed: 0.706 data loading takes: 0.227 Memory usage: 57.48\n", 442 | "Epoch: 14 Step: 600 training acc: 0.730 time elapsed: 0.712 data loading takes: 0.239 Memory usage: 57.43\n", 443 | "Epoch: 14 Val acc: 0.724\n", 444 | "Epoch: 15 Step: 0 training acc: 0.679 time elapsed: 0.714 data loading takes: 0.282 Memory usage: 57.50\n", 445 | "Epoch: 15 Step: 200 training acc: 0.75 time elapsed: 0.695 data loading takes: 0.167 Memory usage: 57.58\n", 446 | "Epoch: 15 Step: 400 training acc: 0.75 time elapsed: 0.736 data loading takes: 0.296 Memory usage: 57.46\n", 447 | "Epoch: 15 Step: 600 training acc: 0.644 time elapsed: 0.725 data loading takes: 0.230 Memory usage: 57.53\n" 448 | ] 449 | }, 450 | { 451 | "name": "stdout", 452 | "output_type": "stream", 453 | "text": [ 454 | "Epoch: 15 Val acc: 0.721\n", 455 | "Test acc: 0.694\n", 456 | "Early Stopped Test acc: 0.723\n", 457 | "Total Time: 11569\n", 458 | "Max Momory: 68.59\n" 459 | ] 460 | } 461 | ], 462 | "source": [ 463 | "!python G-Meta/train.py --data_dir DATA_PATH/G-Meta_Data/tree-of-life/ \\\n", 464 | " --epoch 15 \\\n", 465 | " --task_setup Shared \\\n", 466 | " --k_qry 16 \\\n", 467 | " --k_spt 16 \\\n", 468 | " --n_way 2 \\\n", 469 | " --update_lr 0.005 \\\n", 470 | " --update_step 10 \\\n", 471 | " --meta_lr 0.0005 \\\n", 472 | " --num_workers 0 \\\n", 473 | " --train_result_report_steps 200 \\\n", 474 | " --hidden_dim 256 \\\n", 475 | " --update_step_test 20 \\\n", 476 | " --task_num 8 \\\n", 477 | " --batchsz 5000 \\\n", 478 | " --link_pred_mod True" 479 | ] 480 | }, 481 | { 482 | "cell_type": "code", 483 | "execution_count": null, 484 | "metadata": {}, 485 | "outputs": [], 486 | "source": [] 487 | } 488 | ], 489 | "metadata": { 490 | "kernelspec": { 491 | "display_name": "Python 3", 492 | "language": "python", 493 | "name": "python3" 494 | }, 495 | "language_info": { 496 | "codemirror_mode": { 497 | "name": "ipython", 498 | "version": 3 499 | }, 500 | "file_extension": ".py", 501 | "mimetype": "text/x-python", 502 | "name": "python", 503 | "nbconvert_exporter": "python", 504 | "pygments_lexer": "ipython3", 505 | "version": "3.7.4" 506 | } 507 | }, 508 | "nbformat": 4, 509 | "nbformat_minor": 4 510 | } 511 | --------------------------------------------------------------------------------