├── GPT_GNN ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-37.pyc │ ├── conv.cpython-37.pyc │ ├── data.cpython-37.pyc │ ├── model.cpython-37.pyc │ └── utils.cpython-37.pyc ├── conv.py ├── data.py ├── model.py └── utils.py ├── LICENSE ├── README.md ├── example_OAG ├── GPT_GNN │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ ├── conv.cpython-37.pyc │ │ ├── data.cpython-37.pyc │ │ ├── model.cpython-37.pyc │ │ └── utils.cpython-37.pyc │ ├── conv.py │ ├── data.py │ ├── model.py │ └── utils.py ├── finetune_OAG_AD.py ├── finetune_OAG_PF.py ├── finetune_OAG_PV.py ├── preprocess_OAG.py └── pretrain_OAG.py ├── example_reddit ├── .ipynb_checkpoints │ └── pretrain_reddit-checkpoint.py ├── GPT_GNN │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ ├── conv.cpython-37.pyc │ │ ├── data.cpython-37.pyc │ │ ├── model.cpython-37.pyc │ │ └── utils.cpython-37.pyc │ ├── conv.py │ ├── data.py │ ├── model.py │ └── utils.py ├── finetune_reddit.py ├── preprocess_reddit.py └── pretrain_reddit.py ├── images ├── gpt-intro.png └── pretrain_OAG.gif └── requirements.txt /GPT_GNN/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /GPT_GNN/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/acbull/GPT-GNN/f26e13c69ddc8a3f2580cb16d0b9a1c73d89f4bc/GPT_GNN/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /GPT_GNN/__pycache__/conv.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/acbull/GPT-GNN/f26e13c69ddc8a3f2580cb16d0b9a1c73d89f4bc/GPT_GNN/__pycache__/conv.cpython-37.pyc -------------------------------------------------------------------------------- /GPT_GNN/__pycache__/data.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/acbull/GPT-GNN/f26e13c69ddc8a3f2580cb16d0b9a1c73d89f4bc/GPT_GNN/__pycache__/data.cpython-37.pyc -------------------------------------------------------------------------------- /GPT_GNN/__pycache__/model.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/acbull/GPT-GNN/f26e13c69ddc8a3f2580cb16d0b9a1c73d89f4bc/GPT_GNN/__pycache__/model.cpython-37.pyc -------------------------------------------------------------------------------- /GPT_GNN/__pycache__/utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/acbull/GPT-GNN/f26e13c69ddc8a3f2580cb16d0b9a1c73d89f4bc/GPT_GNN/__pycache__/utils.cpython-37.pyc -------------------------------------------------------------------------------- /GPT_GNN/conv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | from torch_geometric.nn import GCNConv, GATConv 6 | from torch_geometric.nn.conv import MessagePassing 7 | from torch_geometric.nn.inits import glorot, uniform 8 | from torch_geometric.utils import softmax 9 | import math 10 | 11 | class HGTConv(MessagePassing): 12 | def __init__(self, in_dim, out_dim, num_types, num_relations, n_heads, dropout = 0.2, use_norm = True, use_RTE = True, **kwargs): 13 | super(HGTConv, self).__init__(node_dim=0, aggr='add', **kwargs) 14 | 15 | self.in_dim = in_dim 16 | self.out_dim = out_dim 17 | self.num_types = num_types 18 | self.num_relations = num_relations 19 | self.total_rel = num_types * num_relations * num_types 20 | self.n_heads = n_heads 21 | self.d_k = out_dim // n_heads 22 | self.sqrt_dk = math.sqrt(self.d_k) 23 | self.use_norm = use_norm 24 | self.use_RTE = use_RTE: 25 | self.att = None 26 | 27 | 28 | self.k_linears = nn.ModuleList() 29 | self.q_linears = nn.ModuleList() 30 | self.v_linears = nn.ModuleList() 31 | self.a_linears = nn.ModuleList() 32 | self.norms = nn.ModuleList() 33 | 34 | for t in range(num_types): 35 | self.k_linears.append(nn.Linear(in_dim, out_dim)) 36 | self.q_linears.append(nn.Linear(in_dim, out_dim)) 37 | self.v_linears.append(nn.Linear(in_dim, out_dim)) 38 | self.a_linears.append(nn.Linear(out_dim, out_dim)) 39 | if use_norm: 40 | self.norms.append(nn.LayerNorm(out_dim)) 41 | ''' 42 | TODO: make relation_pri smaller, as not all pair exist in meta relation list. 43 | ''' 44 | self.relation_pri = nn.Parameter(torch.ones(num_relations, self.n_heads)) 45 | self.relation_att = nn.Parameter(torch.Tensor(num_relations, n_heads, self.d_k, self.d_k)) 46 | self.relation_msg = nn.Parameter(torch.Tensor(num_relations, n_heads, self.d_k, self.d_k)) 47 | self.skip = nn.Parameter(torch.ones(num_types)) 48 | self.drop = nn.Dropout(dropout) 49 | 50 | if self.use_RTE: 51 | self.emb = RelTemporalEncoding(in_dim) 52 | 53 | glorot(self.relation_att) 54 | glorot(self.relation_msg) 55 | 56 | def forward(self, node_inp, node_type, edge_index, edge_type, edge_time): 57 | return self.propagate(edge_index, node_inp=node_inp, node_type=node_type, \ 58 | edge_type=edge_type, edge_time = edge_time) 59 | 60 | def message(self, edge_index_i, node_inp_i, node_inp_j, node_type_i, node_type_j, edge_type, edge_time): 61 | ''' 62 | j: source, i: target; 63 | ''' 64 | data_size = edge_index_i.size(0) 65 | ''' 66 | Create Attention and Message tensor beforehand. 67 | ''' 68 | res_att = torch.zeros(data_size, self.n_heads).to(node_inp_i.device) 69 | res_msg = torch.zeros(data_size, self.n_heads, self.d_k).to(node_inp_i.device) 70 | 71 | for source_type in range(self.num_types): 72 | sb = (node_type_j == int(source_type)) 73 | k_linear = self.k_linears[source_type] 74 | v_linear = self.v_linears[source_type] 75 | for target_type in range(self.num_types): 76 | tb = (node_type_i == int(target_type)) & sb 77 | q_linear = self.q_linears[target_type] 78 | for relation_type in range(self.num_relations): 79 | ''' 80 | idx is all the edges with meta relation 81 | ''' 82 | idx = (edge_type == int(relation_type)) & tb 83 | if idx.sum() == 0: 84 | continue 85 | ''' 86 | Get the corresponding input node representations by idx. 87 | Add tempotal encoding to source representation (j) 88 | ''' 89 | target_node_vec = node_inp_i[idx] 90 | source_node_vec = node_inp_j[idx] 91 | if self.use_RTE: 92 | source_node_vec = self.emb(source_node_vec, edge_time[idx]) 93 | ''' 94 | Step 1: Heterogeneous Mutual Attention 95 | ''' 96 | q_mat = q_linear(target_node_vec).view(-1, self.n_heads, self.d_k) 97 | k_mat = k_linear(source_node_vec).view(-1, self.n_heads, self.d_k) 98 | k_mat = torch.bmm(k_mat.transpose(1,0), self.relation_att[relation_type]).transpose(1,0) 99 | res_att[idx] = (q_mat * k_mat).sum(dim=-1) * self.relation_pri[relation_type] / self.sqrt_dk 100 | ''' 101 | Step 2: Heterogeneous Message Passing 102 | ''' 103 | v_mat = v_linear(source_node_vec).view(-1, self.n_heads, self.d_k) 104 | res_msg[idx] = torch.bmm(v_mat.transpose(1,0), self.relation_msg[relation_type]).transpose(1,0) 105 | ''' 106 | Softmax based on target node's id (edge_index_i). Store attention value in self.att for later visualization. 107 | ''' 108 | self.att = softmax(res_att, edge_index_i) 109 | res = res_msg * self.att.view(-1, self.n_heads, 1) 110 | del res_att, res_msg 111 | return res.view(-1, self.out_dim) 112 | 113 | 114 | def update(self, aggr_out, node_inp, node_type): 115 | ''' 116 | Step 3: Target-specific Aggregation 117 | x = W[node_type] * gelu(Agg(x)) + x 118 | ''' 119 | aggr_out = F.gelu(aggr_out) 120 | res = torch.zeros(aggr_out.size(0), self.out_dim).to(node_inp.device) 121 | for target_type in range(self.num_types): 122 | idx = (node_type == int(target_type)) 123 | if idx.sum() == 0: 124 | continue 125 | trans_out = self.a_linears[target_type](aggr_out[idx]) 126 | ''' 127 | Add skip connection with learnable weight self.skip[t_id] 128 | ''' 129 | alpha = torch.sigmoid(self.skip[target_type]) 130 | if self.use_norm: 131 | res[idx] = self.norms[target_type](trans_out * alpha + node_inp[idx] * (1 - alpha)) 132 | else: 133 | res[idx] = trans_out * alpha + node_inp[idx] * (1 - alpha) 134 | return self.drop(res) 135 | 136 | def __repr__(self): 137 | return '{}(in_dim={}, out_dim={}, num_types={}, num_types={})'.format( 138 | self.__class__.__name__, self.in_dim, self.out_dim, 139 | self.num_types, self.num_relations) 140 | 141 | 142 | class RelTemporalEncoding(nn.Module): 143 | ''' 144 | Implement the Temporal Encoding (Sinusoid) function. 145 | ''' 146 | def __init__(self, n_hid, max_len = 240, dropout = 0.2): 147 | super(RelTemporalEncoding, self).__init__() 148 | self.drop = nn.Dropout(dropout) 149 | position = torch.arange(0., max_len).unsqueeze(1) 150 | div_term = 1 / (10000 ** (torch.arange(0., n_hid * 2, 2.)) / n_hid / 2) 151 | self.emb = nn.Embedding(max_len, n_hid * 2) 152 | self.emb.weight.data[:, 0::2] = torch.sin(position * div_term) / math.sqrt(n_hid) 153 | self.emb.weight.data[:, 1::2] = torch.cos(position * div_term) / math.sqrt(n_hid) 154 | self.emb.requires_grad = False 155 | self.lin = nn.Linear(n_hid * 2, n_hid) 156 | def forward(self, x, t): 157 | return x + self.lin(self.drop(self.emb(t))) 158 | 159 | 160 | 161 | class GeneralConv(nn.Module): 162 | def __init__(self, conv_name, in_hid, out_hid, num_types, num_relations, n_heads, dropout, use_norm = True, use_RTE = True): 163 | super(GeneralConv, self).__init__() 164 | self.conv_name = conv_name 165 | if self.conv_name == 'hgt': 166 | self.base_conv = HGTConv(in_hid, out_hid, num_types, num_relations, n_heads, dropout, use_norm, use_RTE) 167 | elif self.conv_name == 'gcn': 168 | self.base_conv = GCNConv(in_hid, out_hid) 169 | elif self.conv_name == 'gat': 170 | self.base_conv = GATConv(in_hid, out_hid // n_heads, heads=n_heads) 171 | def forward(self, meta_xs, node_type, edge_index, edge_type, edge_time): 172 | if self.conv_name == 'hgt': 173 | return self.base_conv(meta_xs, node_type, edge_index, edge_type, edge_time) 174 | elif self.conv_name == 'gcn': 175 | return self.base_conv(meta_xs, edge_index) 176 | elif self.conv_name == 'gat': 177 | return self.base_conv(meta_xs, edge_index) 178 | 179 | -------------------------------------------------------------------------------- /GPT_GNN/data.py: -------------------------------------------------------------------------------- 1 | import json, os 2 | import math, copy, time 3 | import numpy as np 4 | from collections import defaultdict 5 | import pandas as pd 6 | 7 | import math 8 | from tqdm import tqdm 9 | 10 | import seaborn as sb 11 | import matplotlib.pyplot as plt 12 | import matplotlib.cm as cm 13 | 14 | from .utils import * 15 | 16 | import dill 17 | from functools import partial 18 | import multiprocessing as mp 19 | 20 | class Graph(): 21 | def __init__(self): 22 | super(Graph, self).__init__() 23 | ''' 24 | node_forward and bacward are only used when building the data. 25 | Afterwards will be transformed into node_feature by DataFrame 26 | 27 | node_forward: name -> node_id 28 | node_bacward: node_id -> feature_dict 29 | node_feature: a DataFrame containing all features 30 | ''' 31 | self.node_forward = defaultdict(lambda: {}) 32 | self.node_bacward = defaultdict(lambda: []) 33 | self.node_feature = defaultdict(lambda: []) 34 | 35 | ''' 36 | edge_list: index the adjacancy matrix (time) by 37 | 38 | ''' 39 | self.edge_list = defaultdict( #target_type 40 | lambda: defaultdict( #source_type 41 | lambda: defaultdict( #relation_type 42 | lambda: defaultdict( #target_id 43 | lambda: defaultdict( #source_id( 44 | lambda: int # time 45 | ))))) 46 | self.times = {} 47 | def add_node(self, node): 48 | nfl = self.node_forward[node['type']] 49 | if node['id'] not in nfl: 50 | self.node_bacward[node['type']] += [node] 51 | ser = len(nfl) 52 | nfl[node['id']] = ser 53 | return ser 54 | return nfl[node['id']] 55 | def add_edge(self, source_node, target_node, time = None, relation_type = None, directed = True): 56 | edge = [self.add_node(source_node), self.add_node(target_node)] 57 | ''' 58 | Add bi-directional edges with different relation type 59 | ''' 60 | self.edge_list[target_node['type']][source_node['type']][relation_type][edge[1]][edge[0]] = time 61 | if directed: 62 | self.edge_list[source_node['type']][target_node['type']]['rev_' + relation_type][edge[0]][edge[1]] = time 63 | else: 64 | self.edge_list[source_node['type']][target_node['type']][relation_type][edge[0]][edge[1]] = time 65 | self.times[time] = True 66 | 67 | def update_node(self, node): 68 | nbl = self.node_bacward[node['type']] 69 | ser = self.add_node(node) 70 | for k in node: 71 | if k not in nbl[ser]: 72 | nbl[ser][k] = node[k] 73 | 74 | def get_meta_graph(self): 75 | types = self.get_types() 76 | metas = [] 77 | for target_type in self.edge_list: 78 | for source_type in self.edge_list[target_type]: 79 | for r_type in self.edge_list[target_type][source_type]: 80 | metas += [(target_type, source_type, r_type)] 81 | return metas 82 | 83 | def get_types(self): 84 | return list(self.node_feature.keys()) 85 | 86 | 87 | 88 | def sample_subgraph(graph, time_range, sampled_depth = 2, sampled_number = 8, inp = None, feature_extractor = feature_OAG): 89 | ''' 90 | Sample Sub-Graph based on the connection of other nodes with currently sampled nodes 91 | We maintain budgets for each node type, indexed by . 92 | Currently sampled nodes are stored in layer_data. 93 | After nodes are sampled, we construct the sampled adjacancy matrix. 94 | ''' 95 | layer_data = defaultdict( #target_type 96 | lambda: {} # {target_id: [ser, time]} 97 | ) 98 | budget = defaultdict( #source_type 99 | lambda: defaultdict( #source_id 100 | lambda: [0., 0] #[sampled_score, time] 101 | )) 102 | new_layer_adj = defaultdict( #target_type 103 | lambda: defaultdict( #source_type 104 | lambda: defaultdict( #relation_type 105 | lambda: [] #[target_id, source_id] 106 | ))) 107 | ''' 108 | For each node being sampled, we find out all its neighborhood, 109 | adding the degree count of these nodes in the budget. 110 | Note that there exist some nodes that have many neighborhoods 111 | (such as fields, venues), for those case, we only consider 112 | ''' 113 | def add_budget(te, target_id, target_time, layer_data, budget): 114 | for source_type in te: 115 | tes = te[source_type] 116 | for relation_type in tes: 117 | if relation_type == 'self' or target_id not in tes[relation_type]: 118 | continue 119 | adl = tes[relation_type][target_id] 120 | if len(adl) < sampled_number: 121 | sampled_ids = list(adl.keys()) 122 | else: 123 | sampled_ids = np.random.choice(list(adl.keys()), sampled_number, replace = False) 124 | for source_id in sampled_ids: 125 | source_time = adl[source_id] 126 | if source_time == None: 127 | source_time = target_time 128 | if source_time > np.max(list(time_range.keys())) or source_id in layer_data[source_type]: 129 | continue 130 | budget[source_type][source_id][0] += 1. / len(sampled_ids) 131 | budget[source_type][source_id][1] = source_time 132 | 133 | ''' 134 | First adding the sampled nodes then updating budget. 135 | ''' 136 | for _type in inp: 137 | for _id, _time in inp[_type]: 138 | layer_data[_type][_id] = [len(layer_data[_type]), _time] 139 | for _type in inp: 140 | te = graph.edge_list[_type] 141 | for _id, _time in inp[_type]: 142 | add_budget(te, _id, _time, layer_data, budget) 143 | ''' 144 | We recursively expand the sampled graph by sampled_depth. 145 | Each time we sample a fixed number of nodes for each budget, 146 | based on the accumulated degree. 147 | ''' 148 | for layer in range(sampled_depth): 149 | sts = list(budget.keys()) 150 | for source_type in sts: 151 | te = graph.edge_list[source_type] 152 | keys = np.array(list(budget[source_type].keys())) 153 | if sampled_number > len(keys): 154 | ''' 155 | Directly sample all the nodes 156 | ''' 157 | sampled_ids = np.arange(len(keys)) 158 | else: 159 | ''' 160 | Sample based on accumulated degree 161 | ''' 162 | score = np.array(list(budget[source_type].values()))[:,0] ** 2 163 | score = score / np.sum(score) 164 | sampled_ids = np.random.choice(len(score), sampled_number, p = score, replace = False) 165 | sampled_keys = keys[sampled_ids] 166 | ''' 167 | First adding the sampled nodes then updating budget. 168 | ''' 169 | for k in sampled_keys: 170 | layer_data[source_type][k] = [len(layer_data[source_type]), budget[source_type][k][1]] 171 | for k in sampled_keys: 172 | add_budget(te, k, budget[source_type][k][1], layer_data, budget) 173 | budget[source_type].pop(k) 174 | ''' 175 | Prepare feature, time and adjacency matrix for the sampled graph 176 | ''' 177 | feature, times, indxs, texts = feature_extractor(layer_data, graph) 178 | 179 | edge_list = defaultdict( #target_type 180 | lambda: defaultdict( #source_type 181 | lambda: defaultdict( #relation_type 182 | lambda: [] # [target_id, source_id] 183 | ))) 184 | for _type in layer_data: 185 | for _key in layer_data[_type]: 186 | _ser = layer_data[_type][_key][0] 187 | edge_list[_type][_type]['self'] += [[_ser, _ser]] 188 | ''' 189 | Reconstruct sampled adjacancy matrix by checking whether each 190 | link exist in the original graph 191 | ''' 192 | for target_type in graph.edge_list: 193 | te = graph.edge_list[target_type] 194 | tld = layer_data[target_type] 195 | for source_type in te: 196 | tes = te[source_type] 197 | sld = layer_data[source_type] 198 | for relation_type in tes: 199 | tesr = tes[relation_type] 200 | for target_key in tld: 201 | if target_key not in tesr: 202 | continue 203 | target_ser = tld[target_key][0] 204 | for source_key in tesr[target_key]: 205 | ''' 206 | Check whether each link (target_id, source_id) exist in original adjacancy matrix 207 | ''' 208 | if source_key in sld: 209 | source_ser = sld[source_key][0] 210 | edge_list[target_type][source_type][relation_type] += [[target_ser, source_ser]] 211 | return feature, times, edge_list, indxs, texts 212 | 213 | def to_torch(feature, time, edge_list, graph): 214 | ''' 215 | Transform a sampled sub-graph into pytorch Tensor 216 | node_dict: {node_type: } node_number is used to trace back the nodes in original graph. 217 | edge_dict: {edge_type: edge_type_ID} 218 | ''' 219 | node_dict = {} 220 | node_feature = [] 221 | node_type = [] 222 | node_time = [] 223 | edge_index = [] 224 | edge_type = [] 225 | edge_time = [] 226 | 227 | node_num = 0 228 | types = graph.get_types() 229 | for t in types: 230 | node_dict[t] = [node_num, len(node_dict)] 231 | node_num += len(feature[t]) 232 | 233 | if 'fake_paper' in feature: 234 | node_dict['fake_paper'] = [node_num, node_dict['paper'][1]] 235 | node_num += len(feature['fake_paper']) 236 | types += ['fake_paper'] 237 | 238 | for t in types: 239 | node_feature += list(feature[t]) 240 | node_time += list(time[t]) 241 | node_type += [node_dict[t][1] for _ in range(len(feature[t]))] 242 | 243 | edge_dict = {e[2]: i for i, e in enumerate(graph.get_meta_graph())} 244 | edge_dict['self'] = len(edge_dict) 245 | 246 | for target_type in edge_list: 247 | for source_type in edge_list[target_type]: 248 | for relation_type in edge_list[target_type][source_type]: 249 | for ii, (ti, si) in enumerate(edge_list[target_type][source_type][relation_type]): 250 | tid, sid = ti + node_dict[target_type][0], si + node_dict[source_type][0] 251 | edge_index += [[sid, tid]] 252 | edge_type += [edge_dict[relation_type]] 253 | ''' 254 | Our time ranges from 1900 - 2020, largest span is 120. 255 | ''' 256 | edge_time += [node_time[tid] - node_time[sid] + 120] 257 | node_feature = torch.FloatTensor(node_feature) 258 | node_type = torch.LongTensor(node_type) 259 | edge_time = torch.LongTensor(edge_time) 260 | edge_index = torch.LongTensor(edge_index).t() 261 | edge_type = torch.LongTensor(edge_type) 262 | return node_feature, node_type, edge_time, edge_index, edge_type, node_dict, edge_dict 263 | 264 | -------------------------------------------------------------------------------- /GPT_GNN/model.py: -------------------------------------------------------------------------------- 1 | from .conv import * 2 | import numpy as np 3 | from gensim.parsing.preprocessing import * 4 | 5 | 6 | class GPT_GNN(nn.Module): 7 | def __init__(self, gnn, rem_edge_list, attr_decoder, types, neg_samp_num, device, neg_queue_size = 0): 8 | super(GPT_GNN, self).__init__() 9 | self.types = types 10 | self.gnn = gnn 11 | self.params = nn.ModuleList() 12 | self.neg_queue_size = neg_queue_size 13 | self.link_dec_dict = {} 14 | self.neg_queue = {} 15 | for source_type in rem_edge_list: 16 | self.link_dec_dict[source_type] = {} 17 | self.neg_queue[source_type] = {} 18 | for relation_type in rem_edge_list[source_type]: 19 | print(source_type, relation_type) 20 | matcher = Matcher(gnn.n_hid, gnn.n_hid) 21 | self.neg_queue[source_type][relation_type] = torch.FloatTensor([]).to(device) 22 | self.link_dec_dict[source_type][relation_type] = matcher 23 | self.params.append(matcher) 24 | self.attr_decoder = attr_decoder 25 | self.init_emb = nn.Parameter(torch.randn(gnn.in_dim)) 26 | self.ce = nn.CrossEntropyLoss(reduction = 'none') 27 | self.neg_samp_num = neg_samp_num 28 | 29 | def neg_sample(self, souce_node_list, pos_node_list): 30 | np.random.shuffle(souce_node_list) 31 | neg_nodes = [] 32 | keys = {key : True for key in pos_node_list} 33 | tot = 0 34 | for node_id in souce_node_list: 35 | if node_id not in keys: 36 | neg_nodes += [node_id] 37 | tot += 1 38 | if tot == self.neg_samp_num: 39 | break 40 | return neg_nodes 41 | 42 | def forward(self, node_feature, node_type, edge_time, edge_index, edge_type): 43 | return self.gnn(node_feature, node_type, edge_time, edge_index, edge_type) 44 | def link_loss(self, node_emb, rem_edge_list, ori_edge_list, node_dict, target_type, use_queue = True, update_queue = False): 45 | losses = 0 46 | ress = [] 47 | for source_type in rem_edge_list: 48 | if source_type not in self.link_dec_dict: 49 | continue 50 | for relation_type in rem_edge_list[source_type]: 51 | if relation_type not in self.link_dec_dict[source_type]: 52 | continue 53 | rem_edges = rem_edge_list[source_type][relation_type] 54 | if len(rem_edges) <= 8: 55 | continue 56 | ori_edges = ori_edge_list[source_type][relation_type] 57 | matcher = self.link_dec_dict[source_type][relation_type] 58 | 59 | target_ids, positive_source_ids = rem_edges[:,0].reshape(-1, 1), rem_edges[:,1].reshape(-1, 1) 60 | n_nodes = len(target_ids) 61 | source_node_ids = np.unique(ori_edges[:, 1]) 62 | 63 | negative_source_ids = [self.neg_sample(source_node_ids, \ 64 | ori_edges[ori_edges[:, 0] == t_id][:, 1].tolist()) for t_id in target_ids] 65 | sn = min([len(neg_ids) for neg_ids in negative_source_ids]) 66 | 67 | negative_source_ids = [neg_ids[:sn] for neg_ids in negative_source_ids] 68 | 69 | source_ids = torch.LongTensor(np.concatenate((positive_source_ids, negative_source_ids), axis=-1) + node_dict[source_type][0]) 70 | emb = node_emb[source_ids] 71 | 72 | if use_queue and len(self.neg_queue[source_type][relation_type]) // n_nodes > 0: 73 | tmp = self.neg_queue[source_type][relation_type] 74 | stx = len(tmp) // n_nodes 75 | tmp = tmp[: stx * n_nodes].reshape(n_nodes, stx, -1) 76 | rep_size = sn + 1 + stx 77 | source_emb = torch.cat([emb, tmp], dim=1) 78 | source_emb = source_emb.reshape(n_nodes * rep_size, -1) 79 | else: 80 | rep_size = sn + 1 81 | source_emb = emb.reshape(source_ids.shape[0] * rep_size, -1) 82 | 83 | target_ids = target_ids.repeat(rep_size, 1) + node_dict[target_type][0] 84 | target_emb = node_emb[target_ids.reshape(-1)] 85 | res = matcher.forward(target_emb, source_emb) 86 | res = res.reshape(n_nodes, rep_size) 87 | ress += [res.detach()] 88 | losses += F.log_softmax(res, dim=-1)[:,0].mean() 89 | if update_queue and 'L1' not in relation_type and 'L2' not in relation_type: 90 | tmp = self.neg_queue[source_type][relation_type] 91 | self.neg_queue[source_type][relation_type] = \ 92 | torch.cat([node_emb[source_node_ids].detach(), tmp], dim=0)[:int(self.neg_queue_size * n_nodes)] 93 | return -losses / len(ress), ress 94 | 95 | 96 | def text_loss(self, reps, texts, w2v_model, device): 97 | def parse_text(texts, w2v_model, device): 98 | idxs = [] 99 | pad = w2v_model.wv.vocab['eos'].index 100 | for text in texts: 101 | idx = [] 102 | for word in ['bos'] + preprocess_string(text) + ['eos']: 103 | if word in w2v_model.wv.vocab: 104 | idx += [w2v_model.wv.vocab[word].index] 105 | idxs += [idx] 106 | mxl = np.max([len(s) for s in idxs]) + 1 107 | inp_idxs = [] 108 | out_idxs = [] 109 | masks = [] 110 | for i, idx in enumerate(idxs): 111 | inp_idxs += [idx + [pad for _ in range(mxl - len(idx) - 1)]] 112 | out_idxs += [idx[1:] + [pad for _ in range(mxl - len(idx))]] 113 | masks += [[1 for _ in range(len(idx))] + [0 for _ in range(mxl - len(idx) - 1)]] 114 | return torch.LongTensor(inp_idxs).transpose(0, 1).to(device), \ 115 | torch.LongTensor(out_idxs).transpose(0, 1).to(device), torch.BoolTensor(masks).transpose(0, 1).to(device) 116 | inp_idxs, out_idxs, masks = parse_text(texts, w2v_model, device) 117 | pred_prob = self.attr_decoder(inp_idxs, reps.repeat(inp_idxs.shape[0], 1, 1)) 118 | return self.ce(pred_prob[masks], out_idxs[masks]).mean() 119 | 120 | def feat_loss(self, reps, out): 121 | return -self.attr_decoder(reps, out).mean() 122 | 123 | 124 | class Classifier(nn.Module): 125 | def __init__(self, n_hid, n_out): 126 | super(Classifier, self).__init__() 127 | self.n_hid = n_hid 128 | self.n_out = n_out 129 | self.linear = nn.Linear(n_hid, n_out) 130 | def forward(self, x): 131 | tx = self.linear(x) 132 | return torch.log_softmax(tx.squeeze(), dim=-1) 133 | def __repr__(self): 134 | return '{}(n_hid={}, n_out={})'.format( 135 | self.__class__.__name__, self.n_hid, self.n_out) 136 | 137 | 138 | class Matcher(nn.Module): 139 | ''' 140 | Matching between a pair of nodes to conduct link prediction. 141 | Use multi-head attention as matching model. 142 | ''' 143 | 144 | def __init__(self, n_hid, n_out, temperature = 0.1): 145 | super(Matcher, self).__init__() 146 | self.n_hid = n_hid 147 | self.linear = nn.Linear(n_hid, n_out) 148 | self.sqrt_hd = math.sqrt(n_out) 149 | self.drop = nn.Dropout(0.2) 150 | self.cosine = nn.CosineSimilarity(dim=1) 151 | self.cache = None 152 | self.temperature = temperature 153 | def forward(self, x, ty, use_norm = True): 154 | tx = self.drop(self.linear(x)) 155 | if use_norm: 156 | return self.cosine(tx, ty) / self.temperature 157 | else: 158 | return (tx * ty).sum(dim=-1) / self.sqrt_hd 159 | def __repr__(self): 160 | return '{}(n_hid={})'.format( 161 | self.__class__.__name__, self.n_hid) 162 | 163 | 164 | class GNN(nn.Module): 165 | def __init__(self, in_dim, n_hid, num_types, num_relations, n_heads, n_layers, dropout = 0.2, conv_name = 'hgt', prev_norm = False, last_norm = False, use_RTE = True): 166 | super(GNN, self).__init__() 167 | self.gcs = nn.ModuleList() 168 | self.num_types = num_types 169 | self.in_dim = in_dim 170 | self.n_hid = n_hid 171 | self.adapt_ws = nn.ModuleList() 172 | self.drop = nn.Dropout(dropout) 173 | for t in range(num_types): 174 | self.adapt_ws.append(nn.Linear(in_dim, n_hid)) 175 | for l in range(n_layers - 1): 176 | self.gcs.append(GeneralConv(conv_name, n_hid, n_hid, num_types, num_relations, n_heads, dropout, use_norm = prev_norm, use_RTE = use_RTE)) 177 | self.gcs.append(GeneralConv(conv_name, n_hid, n_hid, num_types, num_relations, n_heads, dropout, use_norm = last_norm, use_RTE = use_RTE)) 178 | 179 | def forward(self, node_feature, node_type, edge_time, edge_index, edge_type): 180 | res = torch.zeros(node_feature.size(0), self.n_hid).to(node_feature.device) 181 | for t_id in range(self.num_types): 182 | idx = (node_type == int(t_id)) 183 | if idx.sum() == 0: 184 | continue 185 | res[idx] = torch.tanh(self.adapt_ws[t_id](node_feature[idx])) 186 | meta_xs = self.drop(res) 187 | del res 188 | for gc in self.gcs: 189 | meta_xs = gc(meta_xs, node_type, edge_index, edge_type, edge_time) 190 | return meta_xs 191 | 192 | 193 | class RNNModel(nn.Module): 194 | """Container module with an encoder, a recurrent module, and a decoder.""" 195 | def __init__(self, n_word, ninp, nhid, nlayers, dropout=0.2): 196 | super(RNNModel, self).__init__() 197 | self.drop = nn.Dropout(dropout) 198 | self.rnn = nn.LSTM(nhid, nhid, nlayers) 199 | self.encoder = nn.Embedding(n_word, nhid) 200 | self.decoder = nn.Linear(nhid, n_word) 201 | self.adp = nn.Linear(ninp + nhid, nhid) 202 | def forward(self, inp, hidden = None): 203 | emb = self.encoder(inp) 204 | if hidden is not None: 205 | emb = torch.cat((emb, hidden), dim=-1) 206 | emb = F.gelu(self.adp(emb)) 207 | output, _ = self.rnn(emb) 208 | decoded = self.decoder(self.drop(output)) 209 | return decoded 210 | def from_w2v(self, w2v): 211 | initrange = 0.1 212 | self.encoder.weight.data = w2v 213 | self.decoder.weight = self.encoder.weight 214 | 215 | self.encoder.weight.requires_grad = False 216 | self.decoder.weight.requires_grad = False 217 | -------------------------------------------------------------------------------- /GPT_GNN/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.sparse as sp 3 | import torch 4 | 5 | 6 | def dcg_at_k(r, k): 7 | r = np.asfarray(r)[:k] 8 | if r.size: 9 | return r[0] + np.sum(r[1:] / np.log2(np.arange(2, r.size + 1))) 10 | return 0. 11 | 12 | def ndcg_at_k(r, k): 13 | dcg_max = dcg_at_k(sorted(r, reverse=True), k) 14 | if not dcg_max: 15 | return 0. 16 | return dcg_at_k(r, k) / dcg_max 17 | 18 | 19 | def mean_reciprocal_rank(rs): 20 | rs = (np.asarray(r).nonzero()[0] for r in rs) 21 | return [1. / (r[0] + 1) if r.size else 0. for r in rs] 22 | 23 | 24 | def normalize(mx): 25 | """Row-normalize sparse matrix""" 26 | rowsum = np.array(mx.sum(1)) 27 | r_inv = np.power(rowsum, -1).flatten() 28 | r_inv[np.isinf(r_inv)] = 0. 29 | r_mat_inv = sp.diags(r_inv) 30 | mx = r_mat_inv.dot(mx) 31 | return mx 32 | 33 | 34 | def sparse_mx_to_torch_sparse_tensor(sparse_mx): 35 | """Convert a scipy sparse matrix to a torch sparse tensor.""" 36 | sparse_mx = sparse_mx.tocoo().astype(np.float32) 37 | indices = torch.from_numpy( 38 | np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64)) 39 | values = torch.from_numpy(sparse_mx.data) 40 | shape = torch.Size(sparse_mx.shape) 41 | return torch.sparse.FloatTensor(indices, values, shape) 42 | 43 | def randint(): 44 | return np.random.randint(2**32 - 1) 45 | 46 | def feature_OAG(layer_data, graph): 47 | feature = {} 48 | times = {} 49 | indxs = {} 50 | texts = [] 51 | for _type in layer_data: 52 | if len(layer_data[_type]) == 0: 53 | continue 54 | idxs = np.array(list(layer_data[_type].keys())) 55 | tims = np.array(list(layer_data[_type].values()))[:,1] 56 | 57 | if 'node_emb' in graph.node_feature[_type]: 58 | feature[_type] = np.array(list(graph.node_feature[_type].loc[idxs, 'node_emb']), dtype=np.float) 59 | else: 60 | feature[_type] = np.zeros([len(idxs), 400]) 61 | feature[_type] = np.concatenate((feature[_type], list(graph.node_feature[_type].loc[idxs, 'emb']),\ 62 | np.log10(np.array(list(graph.node_feature[_type].loc[idxs, 'citation'])).reshape(-1, 1) + 0.01)), axis=1) 63 | 64 | times[_type] = tims 65 | indxs[_type] = idxs 66 | 67 | if _type == 'paper': 68 | attr = np.array(list(graph.node_feature[_type].loc[idxs, 'title']), dtype=np.str) 69 | return feature, times, indxs, attr 70 | 71 | def feature_reddit(layer_data, graph): 72 | feature = {} 73 | times = {} 74 | indxs = {} 75 | texts = [] 76 | for _type in layer_data: 77 | if len(layer_data[_type]) == 0: 78 | continue 79 | idxs = np.array(list(layer_data[_type].keys())) 80 | tims = np.array(list(layer_data[_type].values()))[:,1] 81 | 82 | feature[_type] = np.array(list(graph.node_feature[_type].loc[idxs, 'emb']), dtype=np.float) 83 | times[_type] = tims 84 | indxs[_type] = idxs 85 | 86 | if _type == 'def': 87 | attr = feature[_type] 88 | return feature, times, indxs, attr -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License 2 | 3 | Copyright (c) 2020 acbull 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in 13 | all copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 21 | THE SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # GPT-GNN: Generative Pre-Training of Graph Neural Networks 2 | 3 |

4 | 5 |
6 |
7 |

8 | 9 | 10 | GPT-GNN is a pre-training framework to initialize GNNs by generative pre-training. It can be applied to large-scale and heterogensous graphs. 11 | 12 | You can see our KDD 2020 paper [“**Generative Pre-Training of Graph Neural Networks**”](https://arxiv.org/abs/2006.15437) for more details. 13 | 14 | 15 | ## Overview 16 | The key package is GPT_GNN, which contains the the high-level GPT-GNN pretraining framework, base GNN models, and base graph structure and data loader. 17 | 18 | To illustrate how to apply the GPT_GNN framework for arbitrary graphs, we provide examples of pre-training on both hetergeneous (OAG) and homogeneous graphs (reddit). Both of them are of large-scale. 19 | 20 | Within each `example_*` package, there is a `pretrain_*.py` file for pre-training a GNN on the given graph, and also multiple `finetune_*.py` files for training and validating on downstream tasks. 21 | 22 | ## DataSet 23 | For **Open Academic Graph (OAG)**, we provide a heterogeneous graph containing highly-cited CS papers (8.1G) spanning from 1900-2020. You can download the preprocessed graph via [this link](https://drive.google.com/open?id=1a85skqsMBwnJ151QpurLFSa9o2ymc_rq). We split the data by their time: Pre-training ( t < 2014 ); Training ( 2014 <= t < 2017); Validation ( t = 2017 ); Testing ( 2018 <= t ). As we use the raw-text as attribute generation task for OAG, we provide a pre-trained word2vec model via [this link](https://drive.google.com/file/d/1ArdaMlPKVqdRGyiw4YSdUOV6CeFb2AmD/view?usp=sharing). 24 | 25 | If you want to directly process from raw data, you can download via [this link](https://drive.google.com/open?id=1yDdVaartOCOSsQlUZs8cJcAUhmvRiBSz). After downloading it, run `preprocess_OAG.py` to extract features and store them in our data structure. 26 | 27 | For **Reddit**, we simply download the preprocessed graph using pyG.datasets API, and then turn it into our own data structure using `preprocess_reddit.py`. We randomly split the data into different sets. 28 | 29 | ## Setup 30 | 31 | This implementation is based on pytorch_geometric. To run the code, you need the following dependencies: 32 | 33 | - [Pytorch 1.3.0](https://pytorch.org/) 34 | - [pytorch_geometric 1.3.2](https://pytorch-geometric.readthedocs.io/) 35 | - torch-cluster==1.4.5 36 | - torch-scatter==1.3.2 37 | - torch-sparse==0.4.3 38 | - [gensim](https://github.com/RaRe-Technologies/gensim) 39 | - [sklearn](https://github.com/scikit-learn/scikit-learn) 40 | - [tqdm](https://github.com/tqdm/tqdm) 41 | - [dill](https://github.com/uqfoundation/dill) 42 | - [pandas](https://github.com/pandas-dev/pandas) 43 | 44 | You can simply run ```pip install -r requirements.txt``` to install all the necessary packages. 45 | 46 | ## Usage 47 | We first introduce the arguments to control hyperparameters. There are mainly three types of arguments, for pre-training; for dataset; for model and optimization. 48 | 49 | For pre-training, we provide arguments to control different modules for attribute and edge generation tasks: 50 | ``` 51 | --attr_ratio FLOAT The ratio (0~1) of attribute generation loss . Default is 0.5. 52 | --attr_type STR type of attribute decoder ['text' or 'vec'] Default is 'vec' 53 | --neg_samp_num BOOL Whether to use layer-norm on the last layer. Default is False. 54 | --queue_size INT Max size of adaptive embedding queue. Default is 256. 55 | ``` 56 | 57 | For datasets, we provide arguments to control mini-batch sampling: 58 | ``` 59 | --data_dir STR The address of preprocessed graph. 60 | --pretrain_model_dir STR The address for storing the pre-trained models. 61 | --sample_depth INT How many layers within a mini-batch subgraph Default is 6. 62 | --sample_width INT How many nodes to be sampled per layer per type Default is 128. 63 | ``` 64 | 65 | For both pre-training and fine-tuning, we provide arguments to control model and optimizer hyperparameters. We highlight some key arguments below: 66 | 67 | ``` 68 | --conv_name STR Name of GNN filter (model) Default is hgt. 69 | --scheduler STR Name of learning rate scheduler Default is cycle (for pretrain) and cosine (for fine-tuning) 70 | --n_hid INT Number of hidden dimension Default is 400. 71 | --n_layers INT Number of GNN layers Default is 3. 72 | --prev_norm BOOL Whether to use layer-norm on previous layers. Default is False. 73 | --last_norm BOOL Whether to use layer-norm on the last layer. Default is False. 74 | --max_lr FLOAT Maximum learning rate. Default is 1e-3 (for pretrain) and 5e-4 (for fine-tuning). 75 | ``` 76 | 77 | The following commands pretrain a 3-layer HGT over OAG-CS: 78 | ```bash 79 | python pretrain_OAG.py --attr_type text --conv_name hgt --n_layers 3 --pretrain_model_dir /datadrive/models/gta_all_cs3 80 | ``` 81 | 82 |

83 | 84 |

85 | 86 | The following commands use the pre-trained model as initialization and finetune on the paper-field classification task using 10% of training and validation data: 87 | ```bash 88 | python finetune_OAG_PF.py --use_pretrain --pretrain_model_dir /datadrive/models/gta_all_cs3 --n_layer 3 --data_percentage 0.1 89 | ``` 90 | 91 | 92 | ## Pre-trained Models 93 | 94 | 1. The 3-layer HGT model pre-trained over OAG-CS under Time-Transfer Setting via [this link](https://drive.google.com/file/d/1OyIRfpNXjaD0TiRF-_Upfl5hix3is5ca/view?usp=sharing) 95 | 2. The 3-layer HGT model pre-trained over Reddit via [this link](https://drive.google.com/file/d/1Ja4PJT2bkFH0qgoWXjGBjByIFPco4h-S/view?usp=sharing) 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | ### Citation 115 | 116 | Please consider citing the following paper when using our code for your application. 117 | 118 | ```bibtex 119 | @inproceedings{gpt_gnn, 120 | title={GPT-GNN: Generative Pre-Training of Graph Neural Networks}, 121 | author={Ziniu Hu and Yuxiao Dong and Kuansan Wang and Kai-Wei Chang and Yizhou Sun}, 122 | booktitle={Proceedings of the 26th ACM SIGKDD Conference on Knowledge Discovery and Data Mining}, 123 | year={2020} 124 | } 125 | ``` 126 | 127 | 128 | This implementation is mainly based on [pyHGT](https://github.com/acbull/pyHGT) API. 129 | -------------------------------------------------------------------------------- /example_OAG/GPT_GNN/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /example_OAG/GPT_GNN/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/acbull/GPT-GNN/f26e13c69ddc8a3f2580cb16d0b9a1c73d89f4bc/example_OAG/GPT_GNN/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /example_OAG/GPT_GNN/__pycache__/conv.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/acbull/GPT-GNN/f26e13c69ddc8a3f2580cb16d0b9a1c73d89f4bc/example_OAG/GPT_GNN/__pycache__/conv.cpython-37.pyc -------------------------------------------------------------------------------- /example_OAG/GPT_GNN/__pycache__/data.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/acbull/GPT-GNN/f26e13c69ddc8a3f2580cb16d0b9a1c73d89f4bc/example_OAG/GPT_GNN/__pycache__/data.cpython-37.pyc -------------------------------------------------------------------------------- /example_OAG/GPT_GNN/__pycache__/model.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/acbull/GPT-GNN/f26e13c69ddc8a3f2580cb16d0b9a1c73d89f4bc/example_OAG/GPT_GNN/__pycache__/model.cpython-37.pyc -------------------------------------------------------------------------------- /example_OAG/GPT_GNN/__pycache__/utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/acbull/GPT-GNN/f26e13c69ddc8a3f2580cb16d0b9a1c73d89f4bc/example_OAG/GPT_GNN/__pycache__/utils.cpython-37.pyc -------------------------------------------------------------------------------- /example_OAG/GPT_GNN/conv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | from torch_geometric.nn import GCNConv, GATConv, RGCNConv 6 | from torch_geometric.nn.conv import MessagePassing 7 | from torch_geometric.nn.inits import glorot, uniform 8 | from torch_geometric.utils import softmax 9 | import math 10 | 11 | class HGTConv(MessagePassing): 12 | def __init__(self, in_dim, out_dim, num_types, num_relations, n_heads, dropout = 0.2, use_norm = True, use_RTE = True, **kwargs): 13 | super(HGTConv, self).__init__(node_dim=0, aggr='add', **kwargs) 14 | 15 | self.in_dim = in_dim 16 | self.out_dim = out_dim 17 | self.num_types = num_types 18 | self.num_relations = num_relations 19 | self.total_rel = num_types * num_relations * num_types 20 | self.n_heads = n_heads 21 | self.d_k = out_dim // n_heads 22 | self.sqrt_dk = math.sqrt(self.d_k) 23 | self.use_norm = use_norm 24 | self.att = None 25 | 26 | 27 | self.k_linears = nn.ModuleList() 28 | self.q_linears = nn.ModuleList() 29 | self.v_linears = nn.ModuleList() 30 | self.a_linears = nn.ModuleList() 31 | self.norms = nn.ModuleList() 32 | 33 | for t in range(num_types): 34 | self.k_linears.append(nn.Linear(in_dim, out_dim)) 35 | self.q_linears.append(nn.Linear(in_dim, out_dim)) 36 | self.v_linears.append(nn.Linear(in_dim, out_dim)) 37 | self.a_linears.append(nn.Linear(out_dim, out_dim)) 38 | if use_norm: 39 | self.norms.append(nn.LayerNorm(out_dim)) 40 | ''' 41 | TODO: make relation_pri smaller, as not all pair exist in meta relation list. 42 | ''' 43 | self.relation_pri = nn.Parameter(torch.ones(num_relations, self.n_heads)) 44 | self.relation_att = nn.Parameter(torch.Tensor(num_relations, n_heads, self.d_k, self.d_k)) 45 | self.relation_msg = nn.Parameter(torch.Tensor(num_relations, n_heads, self.d_k, self.d_k)) 46 | self.skip = nn.Parameter(torch.ones(num_types)) 47 | self.drop = nn.Dropout(dropout) 48 | self.emb = RelTemporalEncoding(in_dim) 49 | 50 | glorot(self.relation_att) 51 | glorot(self.relation_msg) 52 | 53 | def forward(self, node_inp, node_type, edge_index, edge_type, edge_time): 54 | return self.propagate(edge_index, node_inp=node_inp, node_type=node_type, \ 55 | edge_type=edge_type, edge_time = edge_time) 56 | 57 | def message(self, edge_index_i, node_inp_i, node_inp_j, node_type_i, node_type_j, edge_type, edge_time): 58 | ''' 59 | j: source, i: target; 60 | ''' 61 | data_size = edge_index_i.size(0) 62 | ''' 63 | Create Attention and Message tensor beforehand. 64 | ''' 65 | res_att = torch.zeros(data_size, self.n_heads).to(node_inp_i.device) 66 | res_msg = torch.zeros(data_size, self.n_heads, self.d_k).to(node_inp_i.device) 67 | 68 | for source_type in range(self.num_types): 69 | sb = (node_type_j == int(source_type)) 70 | k_linear = self.k_linears[source_type] 71 | v_linear = self.v_linears[source_type] 72 | for target_type in range(self.num_types): 73 | tb = (node_type_i == int(target_type)) & sb 74 | q_linear = self.q_linears[target_type] 75 | for relation_type in range(self.num_relations): 76 | ''' 77 | idx is all the edges with meta relation 78 | ''' 79 | idx = (edge_type == int(relation_type)) & tb 80 | if idx.sum() == 0: 81 | continue 82 | ''' 83 | Get the corresponding input node representations by idx. 84 | Add tempotal encoding to source representation (j) 85 | ''' 86 | target_node_vec = node_inp_i[idx] 87 | source_node_vec = self.emb(node_inp_j[idx], edge_time[idx]) 88 | 89 | ''' 90 | Step 1: Heterogeneous Mutual Attention 91 | ''' 92 | q_mat = q_linear(target_node_vec).view(-1, self.n_heads, self.d_k) 93 | k_mat = k_linear(source_node_vec).view(-1, self.n_heads, self.d_k) 94 | k_mat = torch.bmm(k_mat.transpose(1,0), self.relation_att[relation_type]).transpose(1,0) 95 | res_att[idx] = (q_mat * k_mat).sum(dim=-1) * self.relation_pri[relation_type] / self.sqrt_dk 96 | ''' 97 | Step 2: Heterogeneous Message Passing 98 | ''' 99 | v_mat = v_linear(source_node_vec).view(-1, self.n_heads, self.d_k) 100 | res_msg[idx] = torch.bmm(v_mat.transpose(1,0), self.relation_msg[relation_type]).transpose(1,0) 101 | ''' 102 | Softmax based on target node's id (edge_index_i). Store attention value in self.att for later visualization. 103 | ''' 104 | self.att = softmax(res_att, edge_index_i) 105 | res = res_msg * self.att.view(-1, self.n_heads, 1) 106 | del res_att, res_msg 107 | return res.view(-1, self.out_dim) 108 | 109 | 110 | def update(self, aggr_out, node_inp, node_type): 111 | ''' 112 | Step 3: Target-specific Aggregation 113 | x = W[node_type] * gelu(Agg(x)) + x 114 | ''' 115 | aggr_out = F.gelu(aggr_out) 116 | res = torch.zeros(aggr_out.size(0), self.out_dim).to(node_inp.device) 117 | for target_type in range(self.num_types): 118 | idx = (node_type == int(target_type)) 119 | if idx.sum() == 0: 120 | continue 121 | trans_out = self.a_linears[target_type](aggr_out[idx]) 122 | ''' 123 | Add skip connection with learnable weight self.skip[t_id] 124 | ''' 125 | alpha = torch.sigmoid(self.skip[target_type]) 126 | if self.use_norm: 127 | res[idx] = self.norms[target_type](trans_out * alpha + node_inp[idx] * (1 - alpha)) 128 | else: 129 | res[idx] = trans_out * alpha + node_inp[idx] * (1 - alpha) 130 | return self.drop(res) 131 | 132 | def __repr__(self): 133 | return '{}(in_dim={}, out_dim={}, num_types={}, num_types={})'.format( 134 | self.__class__.__name__, self.in_dim, self.out_dim, 135 | self.num_types, self.num_relations) 136 | 137 | 138 | class RelTemporalEncoding(nn.Module): 139 | ''' 140 | Implement the Temporal Encoding (Sinusoid) function. 141 | ''' 142 | def __init__(self, n_hid, max_len = 240, dropout = 0.2): 143 | super(RelTemporalEncoding, self).__init__() 144 | self.drop = nn.Dropout(dropout) 145 | position = torch.arange(0., max_len).unsqueeze(1) 146 | div_term = 1 / (10000 ** (torch.arange(0., n_hid * 2, 2.)) / n_hid / 2) 147 | self.emb = nn.Embedding(max_len, n_hid * 2) 148 | self.emb.weight.data[:, 0::2] = torch.sin(position * div_term) / math.sqrt(n_hid) 149 | self.emb.weight.data[:, 1::2] = torch.cos(position * div_term) / math.sqrt(n_hid) 150 | self.emb.requires_grad = False 151 | self.lin = nn.Linear(n_hid * 2, n_hid) 152 | def forward(self, x, t): 153 | return x + self.lin(self.drop(self.emb(t))) 154 | 155 | 156 | 157 | class GeneralConv(nn.Module): 158 | def __init__(self, conv_name, in_hid, out_hid, num_types, num_relations, n_heads, dropout, use_norm = True, use_RTE = True): 159 | super(GeneralConv, self).__init__() 160 | self.conv_name = conv_name 161 | if self.conv_name == 'hgt': 162 | self.base_conv = HGTConv(in_hid, out_hid, num_types, num_relations, n_heads, dropout, use_norm, use_RTE) 163 | elif self.conv_name == 'gcn': 164 | self.base_conv = GCNConv(in_hid, out_hid) 165 | elif self.conv_name == 'gat': 166 | self.base_conv = GATConv(in_hid, out_hid // n_heads, heads=n_heads) 167 | elif self.conv_name == 'rgcn': 168 | self.base_conv = RGCNConv(in_hid, out_hid, num_relations) 169 | def forward(self, meta_xs, node_type, edge_index, edge_type, edge_time): 170 | if self.conv_name == 'hgt': 171 | return self.base_conv(meta_xs, node_type, edge_index, edge_type, edge_time) 172 | elif self.conv_name == 'gcn': 173 | return self.base_conv(meta_xs, edge_index) 174 | elif self.conv_name == 'gat': 175 | return self.base_conv(meta_xs, edge_index) 176 | elif self.conv_name == 'rgcn': 177 | return self.base_conv(meta_xs, edge_index, edge_type) 178 | -------------------------------------------------------------------------------- /example_OAG/GPT_GNN/data.py: -------------------------------------------------------------------------------- 1 | import json, os 2 | import math, copy, time 3 | import numpy as np 4 | from collections import defaultdict 5 | import pandas as pd 6 | 7 | import math 8 | from tqdm import tqdm 9 | 10 | import seaborn as sb 11 | import matplotlib.pyplot as plt 12 | import matplotlib.cm as cm 13 | 14 | from .utils import * 15 | 16 | import dill 17 | from functools import partial 18 | import multiprocessing as mp 19 | 20 | class Graph(): 21 | def __init__(self): 22 | super(Graph, self).__init__() 23 | ''' 24 | node_forward and bacward are only used when building the data. 25 | Afterwards will be transformed into node_feature by DataFrame 26 | 27 | node_forward: name -> node_id 28 | node_bacward: node_id -> feature_dict 29 | node_feature: a DataFrame containing all features 30 | ''' 31 | self.node_forward = defaultdict(lambda: {}) 32 | self.node_bacward = defaultdict(lambda: []) 33 | self.node_feature = defaultdict(lambda: []) 34 | 35 | ''' 36 | edge_list: index the adjacancy matrix (time) by 37 | 38 | ''' 39 | self.edge_list = defaultdict( #target_type 40 | lambda: defaultdict( #source_type 41 | lambda: defaultdict( #relation_type 42 | lambda: defaultdict( #target_id 43 | lambda: defaultdict( #source_id( 44 | lambda: int # time 45 | ))))) 46 | self.times = {} 47 | def add_node(self, node): 48 | nfl = self.node_forward[node['type']] 49 | if node['id'] not in nfl: 50 | self.node_bacward[node['type']] += [node] 51 | ser = len(nfl) 52 | nfl[node['id']] = ser 53 | return ser 54 | return nfl[node['id']] 55 | def add_edge(self, source_node, target_node, time = None, relation_type = None, directed = True): 56 | edge = [self.add_node(source_node), self.add_node(target_node)] 57 | ''' 58 | Add bi-directional edges with different relation type 59 | ''' 60 | self.edge_list[target_node['type']][source_node['type']][relation_type][edge[1]][edge[0]] = time 61 | if directed: 62 | self.edge_list[source_node['type']][target_node['type']]['rev_' + relation_type][edge[0]][edge[1]] = time 63 | else: 64 | self.edge_list[source_node['type']][target_node['type']][relation_type][edge[0]][edge[1]] = time 65 | self.times[time] = True 66 | 67 | def update_node(self, node): 68 | nbl = self.node_bacward[node['type']] 69 | ser = self.add_node(node) 70 | for k in node: 71 | if k not in nbl[ser]: 72 | nbl[ser][k] = node[k] 73 | 74 | def get_meta_graph(self): 75 | types = self.get_types() 76 | metas = [] 77 | for target_type in self.edge_list: 78 | for source_type in self.edge_list[target_type]: 79 | for r_type in self.edge_list[target_type][source_type]: 80 | metas += [(target_type, source_type, r_type)] 81 | return metas 82 | 83 | def get_types(self): 84 | return list(self.node_feature.keys()) 85 | 86 | 87 | 88 | def sample_subgraph(graph, time_range, sampled_depth = 2, sampled_number = 8, inp = None, feature_extractor = feature_OAG): 89 | ''' 90 | Sample Sub-Graph based on the connection of other nodes with currently sampled nodes 91 | We maintain budgets for each node type, indexed by . 92 | Currently sampled nodes are stored in layer_data. 93 | After nodes are sampled, we construct the sampled adjacancy matrix. 94 | ''' 95 | layer_data = defaultdict( #target_type 96 | lambda: {} # {target_id: [ser, time]} 97 | ) 98 | budget = defaultdict( #source_type 99 | lambda: defaultdict( #source_id 100 | lambda: [0., 0] #[sampled_score, time] 101 | )) 102 | new_layer_adj = defaultdict( #target_type 103 | lambda: defaultdict( #source_type 104 | lambda: defaultdict( #relation_type 105 | lambda: [] #[target_id, source_id] 106 | ))) 107 | ''' 108 | For each node being sampled, we find out all its neighborhood, 109 | adding the degree count of these nodes in the budget. 110 | Note that there exist some nodes that have many neighborhoods 111 | (such as fields, venues), for those case, we only consider 112 | ''' 113 | def add_budget(te, target_id, target_time, layer_data, budget): 114 | for source_type in te: 115 | tes = te[source_type] 116 | for relation_type in tes: 117 | if relation_type == 'self' or target_id not in tes[relation_type]: 118 | continue 119 | adl = tes[relation_type][target_id] 120 | if len(adl) < sampled_number: 121 | sampled_ids = list(adl.keys()) 122 | else: 123 | sampled_ids = np.random.choice(list(adl.keys()), sampled_number, replace = False) 124 | for source_id in sampled_ids: 125 | source_time = adl[source_id] 126 | if source_time == None: 127 | source_time = target_time 128 | if source_time > np.max(list(time_range.keys())) or source_id in layer_data[source_type]: 129 | continue 130 | budget[source_type][source_id][0] += 1. / len(sampled_ids) 131 | budget[source_type][source_id][1] = source_time 132 | 133 | ''' 134 | First adding the sampled nodes then updating budget. 135 | ''' 136 | for _type in inp: 137 | for _id, _time in inp[_type]: 138 | layer_data[_type][_id] = [len(layer_data[_type]), _time] 139 | for _type in inp: 140 | te = graph.edge_list[_type] 141 | for _id, _time in inp[_type]: 142 | add_budget(te, _id, _time, layer_data, budget) 143 | ''' 144 | We recursively expand the sampled graph by sampled_depth. 145 | Each time we sample a fixed number of nodes for each budget, 146 | based on the accumulated degree. 147 | ''' 148 | for layer in range(sampled_depth): 149 | sts = list(budget.keys()) 150 | for source_type in sts: 151 | te = graph.edge_list[source_type] 152 | keys = np.array(list(budget[source_type].keys())) 153 | if sampled_number > len(keys): 154 | ''' 155 | Directly sample all the nodes 156 | ''' 157 | sampled_ids = np.arange(len(keys)) 158 | else: 159 | ''' 160 | Sample based on accumulated degree 161 | ''' 162 | score = np.array(list(budget[source_type].values()))[:,0] ** 2 163 | score = score / np.sum(score) 164 | sampled_ids = np.random.choice(len(score), sampled_number, p = score, replace = False) 165 | sampled_keys = keys[sampled_ids] 166 | ''' 167 | First adding the sampled nodes then updating budget. 168 | ''' 169 | for k in sampled_keys: 170 | layer_data[source_type][k] = [len(layer_data[source_type]), budget[source_type][k][1]] 171 | for k in sampled_keys: 172 | add_budget(te, k, budget[source_type][k][1], layer_data, budget) 173 | budget[source_type].pop(k) 174 | ''' 175 | Prepare feature, time and adjacency matrix for the sampled graph 176 | ''' 177 | feature, times, indxs, texts = feature_extractor(layer_data, graph) 178 | 179 | edge_list = defaultdict( #target_type 180 | lambda: defaultdict( #source_type 181 | lambda: defaultdict( #relation_type 182 | lambda: [] # [target_id, source_id] 183 | ))) 184 | for _type in layer_data: 185 | for _key in layer_data[_type]: 186 | _ser = layer_data[_type][_key][0] 187 | edge_list[_type][_type]['self'] += [[_ser, _ser]] 188 | ''' 189 | Reconstruct sampled adjacancy matrix by checking whether each 190 | link exist in the original graph 191 | ''' 192 | for target_type in graph.edge_list: 193 | te = graph.edge_list[target_type] 194 | tld = layer_data[target_type] 195 | for source_type in te: 196 | tes = te[source_type] 197 | sld = layer_data[source_type] 198 | for relation_type in tes: 199 | tesr = tes[relation_type] 200 | for target_key in tld: 201 | if target_key not in tesr: 202 | continue 203 | target_ser = tld[target_key][0] 204 | for source_key in tesr[target_key]: 205 | ''' 206 | Check whether each link (target_id, source_id) exist in original adjacancy matrix 207 | ''' 208 | if source_key in sld: 209 | source_ser = sld[source_key][0] 210 | edge_list[target_type][source_type][relation_type] += [[target_ser, source_ser]] 211 | return feature, times, edge_list, indxs, texts 212 | 213 | def to_torch(feature, time, edge_list, graph): 214 | ''' 215 | Transform a sampled sub-graph into pytorch Tensor 216 | node_dict: {node_type: } node_number is used to trace back the nodes in original graph. 217 | edge_dict: {edge_type: edge_type_ID} 218 | ''' 219 | node_dict = {} 220 | node_feature = [] 221 | node_type = [] 222 | node_time = [] 223 | edge_index = [] 224 | edge_type = [] 225 | edge_time = [] 226 | 227 | node_num = 0 228 | types = graph.get_types() 229 | for t in types: 230 | node_dict[t] = [node_num, len(node_dict)] 231 | node_num += len(feature[t]) 232 | 233 | if 'fake_paper' in feature: 234 | node_dict['fake_paper'] = [node_num, node_dict['paper'][1]] 235 | node_num += len(feature['fake_paper']) 236 | types += ['fake_paper'] 237 | 238 | for t in types: 239 | node_feature += list(feature[t]) 240 | node_time += list(time[t]) 241 | node_type += [node_dict[t][1] for _ in range(len(feature[t]))] 242 | 243 | edge_dict = {e[2]: i for i, e in enumerate(graph.get_meta_graph())} 244 | edge_dict['self'] = len(edge_dict) 245 | 246 | for target_type in edge_list: 247 | for source_type in edge_list[target_type]: 248 | for relation_type in edge_list[target_type][source_type]: 249 | for ii, (ti, si) in enumerate(edge_list[target_type][source_type][relation_type]): 250 | tid, sid = ti + node_dict[target_type][0], si + node_dict[source_type][0] 251 | edge_index += [[sid, tid]] 252 | edge_type += [edge_dict[relation_type]] 253 | ''' 254 | Our time ranges from 1900 - 2020, largest span is 120. 255 | ''' 256 | edge_time += [node_time[tid] - node_time[sid] + 120] 257 | node_feature = torch.FloatTensor(node_feature) 258 | node_type = torch.LongTensor(node_type) 259 | edge_time = torch.LongTensor(edge_time) 260 | edge_index = torch.LongTensor(edge_index).t() 261 | edge_type = torch.LongTensor(edge_type) 262 | return node_feature, node_type, edge_time, edge_index, edge_type, node_dict, edge_dict 263 | 264 | 265 | class RenameUnpickler(dill.Unpickler): 266 | def find_class(self, module, name): 267 | renamed_module = module 268 | if module == "pyHGT.data" or module == 'data': 269 | renamed_module = "GPT_GNN.data" 270 | return super(RenameUnpickler, self).find_class(renamed_module, name) 271 | 272 | 273 | def renamed_load(file_obj): 274 | return RenameUnpickler(file_obj).load() 275 | -------------------------------------------------------------------------------- /example_OAG/GPT_GNN/model.py: -------------------------------------------------------------------------------- 1 | from .conv import * 2 | import numpy as np 3 | from gensim.parsing.preprocessing import * 4 | 5 | 6 | class GPT_GNN(nn.Module): 7 | def __init__(self, gnn, rem_edge_list, attr_decoder, types, neg_samp_num, device, neg_queue_size = 0): 8 | super(GPT_GNN, self).__init__() 9 | if gnn is None: 10 | return 11 | self.types = types 12 | self.gnn = gnn 13 | self.params = nn.ModuleList() 14 | self.neg_queue_size = neg_queue_size 15 | self.link_dec_dict = {} 16 | self.neg_queue = {} 17 | for source_type in rem_edge_list: 18 | self.link_dec_dict[source_type] = {} 19 | self.neg_queue[source_type] = {} 20 | for relation_type in rem_edge_list[source_type]: 21 | print(source_type, relation_type) 22 | matcher = Matcher(gnn.n_hid, gnn.n_hid) 23 | self.neg_queue[source_type][relation_type] = torch.FloatTensor([]).to(device) 24 | self.link_dec_dict[source_type][relation_type] = matcher 25 | self.params.append(matcher) 26 | self.attr_decoder = attr_decoder 27 | self.init_emb = nn.Parameter(torch.randn(gnn.in_dim)) 28 | self.ce = nn.CrossEntropyLoss(reduction = 'none') 29 | self.neg_samp_num = neg_samp_num 30 | 31 | def neg_sample(self, souce_node_list, pos_node_list): 32 | np.random.shuffle(souce_node_list) 33 | neg_nodes = [] 34 | keys = {key : True for key in pos_node_list} 35 | tot = 0 36 | for node_id in souce_node_list: 37 | if node_id not in keys: 38 | neg_nodes += [node_id] 39 | tot += 1 40 | if tot == self.neg_samp_num: 41 | break 42 | return neg_nodes 43 | 44 | def forward(self, node_feature, node_type, edge_time, edge_index, edge_type): 45 | return self.gnn(node_feature, node_type, edge_time, edge_index, edge_type) 46 | def link_loss(self, node_emb, rem_edge_list, ori_edge_list, node_dict, target_type, use_queue = True, update_queue = False): 47 | losses = 0 48 | ress = [] 49 | for source_type in rem_edge_list: 50 | if source_type not in self.link_dec_dict: 51 | continue 52 | for relation_type in rem_edge_list[source_type]: 53 | if relation_type not in self.link_dec_dict[source_type]: 54 | continue 55 | rem_edges = rem_edge_list[source_type][relation_type] 56 | if len(rem_edges) <= 8: 57 | continue 58 | ori_edges = ori_edge_list[source_type][relation_type] 59 | matcher = self.link_dec_dict[source_type][relation_type] 60 | 61 | target_ids, positive_source_ids = rem_edges[:,0].reshape(-1, 1), rem_edges[:,1].reshape(-1, 1) 62 | n_nodes = len(target_ids) 63 | source_node_ids = np.unique(ori_edges[:, 1]) 64 | 65 | negative_source_ids = [self.neg_sample(source_node_ids, \ 66 | ori_edges[ori_edges[:, 0] == t_id][:, 1].tolist()) for t_id in target_ids] 67 | sn = min([len(neg_ids) for neg_ids in negative_source_ids]) 68 | 69 | negative_source_ids = [neg_ids[:sn] for neg_ids in negative_source_ids] 70 | 71 | source_ids = torch.LongTensor(np.concatenate((positive_source_ids, negative_source_ids), axis=-1) + node_dict[source_type][0]) 72 | emb = node_emb[source_ids] 73 | 74 | if use_queue and len(self.neg_queue[source_type][relation_type]) // n_nodes > 0: 75 | tmp = self.neg_queue[source_type][relation_type] 76 | stx = len(tmp) // n_nodes 77 | tmp = tmp[: stx * n_nodes].reshape(n_nodes, stx, -1) 78 | rep_size = sn + 1 + stx 79 | source_emb = torch.cat([emb, tmp], dim=1) 80 | source_emb = source_emb.reshape(n_nodes * rep_size, -1) 81 | else: 82 | rep_size = sn + 1 83 | source_emb = emb.reshape(source_ids.shape[0] * rep_size, -1) 84 | 85 | target_ids = target_ids.repeat(rep_size, 1) + node_dict[target_type][0] 86 | target_emb = node_emb[target_ids.reshape(-1)] 87 | res = matcher.forward(target_emb, source_emb) 88 | res = res.reshape(n_nodes, rep_size) 89 | ress += [res.detach()] 90 | losses += F.log_softmax(res, dim=-1)[:,0].mean() 91 | if update_queue and 'L1' not in relation_type and 'L2' not in relation_type: 92 | tmp = self.neg_queue[source_type][relation_type] 93 | self.neg_queue[source_type][relation_type] = \ 94 | torch.cat([node_emb[source_node_ids].detach(), tmp], dim=0)[:int(self.neg_queue_size * n_nodes)] 95 | return -losses / len(ress), ress 96 | 97 | 98 | def text_loss(self, reps, texts, w2v_model, device): 99 | def parse_text(texts, w2v_model, device): 100 | idxs = [] 101 | pad = w2v_model.wv.vocab['eos'].index 102 | for text in texts: 103 | idx = [] 104 | for word in ['bos'] + preprocess_string(text) + ['eos']: 105 | if word in w2v_model.wv.vocab: 106 | idx += [w2v_model.wv.vocab[word].index] 107 | idxs += [idx] 108 | mxl = np.max([len(s) for s in idxs]) + 1 109 | inp_idxs = [] 110 | out_idxs = [] 111 | masks = [] 112 | for i, idx in enumerate(idxs): 113 | inp_idxs += [idx + [pad for _ in range(mxl - len(idx) - 1)]] 114 | out_idxs += [idx[1:] + [pad for _ in range(mxl - len(idx))]] 115 | masks += [[1 for _ in range(len(idx))] + [0 for _ in range(mxl - len(idx) - 1)]] 116 | return torch.LongTensor(inp_idxs).transpose(0, 1).to(device), \ 117 | torch.LongTensor(out_idxs).transpose(0, 1).to(device), torch.BoolTensor(masks).transpose(0, 1).to(device) 118 | inp_idxs, out_idxs, masks = parse_text(texts, w2v_model, device) 119 | pred_prob = self.attr_decoder(inp_idxs, reps.repeat(inp_idxs.shape[0], 1, 1)) 120 | return self.ce(pred_prob[masks], out_idxs[masks]).mean() 121 | 122 | def feat_loss(self, reps, out): 123 | return -self.attr_decoder(reps, out).mean() 124 | 125 | 126 | class Classifier(nn.Module): 127 | def __init__(self, n_hid, n_out): 128 | super(Classifier, self).__init__() 129 | self.n_hid = n_hid 130 | self.n_out = n_out 131 | self.linear = nn.Linear(n_hid, n_out) 132 | def forward(self, x): 133 | tx = self.linear(x) 134 | return torch.log_softmax(tx.squeeze(), dim=-1) 135 | def __repr__(self): 136 | return '{}(n_hid={}, n_out={})'.format( 137 | self.__class__.__name__, self.n_hid, self.n_out) 138 | 139 | 140 | class Matcher(nn.Module): 141 | ''' 142 | Matching between a pair of nodes to conduct link prediction. 143 | Use multi-head attention as matching model. 144 | ''' 145 | 146 | def __init__(self, n_hid, n_out, temperature = 0.1): 147 | super(Matcher, self).__init__() 148 | self.n_hid = n_hid 149 | self.linear = nn.Linear(n_hid, n_out) 150 | self.sqrt_hd = math.sqrt(n_out) 151 | self.drop = nn.Dropout(0.2) 152 | self.cosine = nn.CosineSimilarity(dim=1) 153 | self.cache = None 154 | self.temperature = temperature 155 | def forward(self, x, ty, use_norm = True): 156 | tx = self.drop(self.linear(x)) 157 | if use_norm: 158 | return self.cosine(tx, ty) / self.temperature 159 | else: 160 | return (tx * ty).sum(dim=-1) / self.sqrt_hd 161 | def __repr__(self): 162 | return '{}(n_hid={})'.format( 163 | self.__class__.__name__, self.n_hid) 164 | 165 | 166 | class GNN(nn.Module): 167 | def __init__(self, in_dim, n_hid, num_types, num_relations, n_heads, n_layers, dropout = 0.2, conv_name = 'hgt', prev_norm = False, last_norm = False, use_RTE = True): 168 | super(GNN, self).__init__() 169 | self.gcs = nn.ModuleList() 170 | self.num_types = num_types 171 | self.in_dim = in_dim 172 | self.n_hid = n_hid 173 | self.adapt_ws = nn.ModuleList() 174 | self.drop = nn.Dropout(dropout) 175 | for t in range(num_types): 176 | self.adapt_ws.append(nn.Linear(in_dim, n_hid)) 177 | for l in range(n_layers - 1): 178 | self.gcs.append(GeneralConv(conv_name, n_hid, n_hid, num_types, num_relations, n_heads, dropout, use_norm = prev_norm, use_RTE = use_RTE)) 179 | self.gcs.append(GeneralConv(conv_name, n_hid, n_hid, num_types, num_relations, n_heads, dropout, use_norm = last_norm, use_RTE = use_RTE)) 180 | 181 | def forward(self, node_feature, node_type, edge_time, edge_index, edge_type): 182 | res = torch.zeros(node_feature.size(0), self.n_hid).to(node_feature.device) 183 | for t_id in range(self.num_types): 184 | idx = (node_type == int(t_id)) 185 | if idx.sum() == 0: 186 | continue 187 | res[idx] = torch.tanh(self.adapt_ws[t_id](node_feature[idx])) 188 | meta_xs = self.drop(res) 189 | del res 190 | for gc in self.gcs: 191 | meta_xs = gc(meta_xs, node_type, edge_index, edge_type, edge_time) 192 | return meta_xs 193 | 194 | 195 | class RNNModel(nn.Module): 196 | """Container module with an encoder, a recurrent module, and a decoder.""" 197 | def __init__(self, n_word, ninp, nhid, nlayers, dropout=0.2): 198 | super(RNNModel, self).__init__() 199 | self.drop = nn.Dropout(dropout) 200 | self.rnn = nn.LSTM(nhid, nhid, nlayers) 201 | self.encoder = nn.Embedding(n_word, nhid) 202 | self.decoder = nn.Linear(nhid, n_word) 203 | self.adp = nn.Linear(ninp + nhid, nhid) 204 | def forward(self, inp, hidden = None): 205 | emb = self.encoder(inp) 206 | if hidden is not None: 207 | emb = torch.cat((emb, hidden), dim=-1) 208 | emb = F.gelu(self.adp(emb)) 209 | output, _ = self.rnn(emb) 210 | decoded = self.decoder(self.drop(output)) 211 | return decoded 212 | def from_w2v(self, w2v): 213 | initrange = 0.1 214 | self.encoder.weight.data = w2v 215 | self.decoder.weight = self.encoder.weight 216 | 217 | self.encoder.weight.requires_grad = False 218 | self.decoder.weight.requires_grad = False 219 | -------------------------------------------------------------------------------- /example_OAG/GPT_GNN/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.sparse as sp 3 | import torch 4 | from texttable import Texttable 5 | from collections import OrderedDict 6 | 7 | def args_print(args): 8 | _dict = vars(args) 9 | t = Texttable() 10 | t.add_row(["Parameter", "Value"]) 11 | for k in _dict: 12 | t.add_row([k, _dict[k]]) 13 | print(t.draw()) 14 | 15 | def dcg_at_k(r, k): 16 | r = np.asfarray(r)[:k] 17 | if r.size: 18 | return r[0] + np.sum(r[1:] / np.log2(np.arange(2, r.size + 1))) 19 | return 0. 20 | 21 | def ndcg_at_k(r, k): 22 | dcg_max = dcg_at_k(sorted(r, reverse=True), k) 23 | if not dcg_max: 24 | return 0. 25 | return dcg_at_k(r, k) / dcg_max 26 | 27 | 28 | def mean_reciprocal_rank(rs): 29 | rs = (np.asarray(r).nonzero()[0] for r in rs) 30 | return [1. / (r[0] + 1) if r.size else 0. for r in rs] 31 | 32 | 33 | def normalize(mx): 34 | """Row-normalize sparse matrix""" 35 | rowsum = np.array(mx.sum(1)) 36 | r_inv = np.power(rowsum, -1).flatten() 37 | r_inv[np.isinf(r_inv)] = 0. 38 | r_mat_inv = sp.diags(r_inv) 39 | mx = r_mat_inv.dot(mx) 40 | return mx 41 | 42 | 43 | def sparse_mx_to_torch_sparse_tensor(sparse_mx): 44 | """Convert a scipy sparse matrix to a torch sparse tensor.""" 45 | sparse_mx = sparse_mx.tocoo().astype(np.float32) 46 | indices = torch.from_numpy( 47 | np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64)) 48 | values = torch.from_numpy(sparse_mx.data) 49 | shape = torch.Size(sparse_mx.shape) 50 | return torch.sparse.FloatTensor(indices, values, shape) 51 | 52 | def randint(): 53 | return np.random.randint(2**32 - 1) 54 | 55 | def feature_OAG(layer_data, graph): 56 | feature = {} 57 | times = {} 58 | indxs = {} 59 | texts = [] 60 | for _type in layer_data: 61 | if len(layer_data[_type]) == 0: 62 | continue 63 | idxs = np.array(list(layer_data[_type].keys())) 64 | tims = np.array(list(layer_data[_type].values()))[:,1] 65 | 66 | if 'node_emb' in graph.node_feature[_type]: 67 | feature[_type] = np.array(list(graph.node_feature[_type].loc[idxs, 'node_emb']), dtype=np.float) 68 | else: 69 | feature[_type] = np.zeros([len(idxs), 400]) 70 | feature[_type] = np.concatenate((feature[_type], list(graph.node_feature[_type].loc[idxs, 'emb']),\ 71 | np.log10(np.array(list(graph.node_feature[_type].loc[idxs, 'citation'])).reshape(-1, 1) + 0.01)), axis=1) 72 | 73 | times[_type] = tims 74 | indxs[_type] = idxs 75 | 76 | if _type == 'paper': 77 | attr = np.array(list(graph.node_feature[_type].loc[idxs, 'title']), dtype=np.str) 78 | return feature, times, indxs, attr 79 | 80 | def feature_reddit(layer_data, graph): 81 | feature = {} 82 | times = {} 83 | indxs = {} 84 | texts = [] 85 | for _type in layer_data: 86 | if len(layer_data[_type]) == 0: 87 | continue 88 | idxs = np.array(list(layer_data[_type].keys())) 89 | tims = np.array(list(layer_data[_type].values()))[:,1] 90 | 91 | feature[_type] = np.array(list(graph.node_feature[_type].loc[idxs, 'emb']), dtype=np.float) 92 | times[_type] = tims 93 | indxs[_type] = idxs 94 | 95 | if _type == 'def': 96 | attr = feature[_type] 97 | return feature, times, indxs, attr 98 | 99 | def load_gnn(_dict): 100 | out_dict = {} 101 | for key in _dict: 102 | if 'gnn' in key: 103 | out_dict[key[4:]] = _dict[key] 104 | return OrderedDict(out_dict) -------------------------------------------------------------------------------- /example_OAG/finetune_OAG_PF.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from GPT_GNN.data import * 3 | from GPT_GNN.model import * 4 | from warnings import filterwarnings 5 | filterwarnings("ignore") 6 | 7 | import argparse 8 | 9 | parser = argparse.ArgumentParser(description='Fine-Tuning on OAG Paper-Field (L2) classification task') 10 | 11 | ''' 12 | Dataset arguments 13 | ''' 14 | parser.add_argument('--data_dir', type=str, default='/datadrive/dataset', 15 | help='The address of preprocessed graph.') 16 | parser.add_argument('--use_pretrain', help='Whether to use pre-trained model', action='store_true') 17 | parser.add_argument('--pretrain_model_dir', type=str, default='/datadrive/models/gpt_all_cs', 18 | help='The address for pretrained model.') 19 | parser.add_argument('--model_dir', type=str, default='/datadrive/models', 20 | help='The address for storing the models and optimization results.') 21 | parser.add_argument('--task_name', type=str, default='PF', 22 | help='The name of the stored models and optimization results.') 23 | parser.add_argument('--cuda', type=int, default=2, 24 | help='Avaiable GPU ID') 25 | parser.add_argument('--domain', type=str, default='_CS', 26 | help='CS, Medicion or All: _CS or _Med or (empty)') 27 | parser.add_argument('--sample_depth', type=int, default=6, 28 | help='How many numbers to sample the graph') 29 | parser.add_argument('--sample_width', type=int, default=128, 30 | help='How many nodes to be sampled per layer per type') 31 | 32 | ''' 33 | Model arguments 34 | ''' 35 | parser.add_argument('--conv_name', type=str, default='hgt', 36 | choices=['hgt', 'gcn', 'gat', 'rgcn', 'han', 'hetgnn'], 37 | help='The name of GNN filter. By default is Heterogeneous Graph Transformer (hgt)') 38 | parser.add_argument('--n_hid', type=int, default=400, 39 | help='Number of hidden dimension') 40 | parser.add_argument('--n_heads', type=int, default=8, 41 | help='Number of attention head') 42 | parser.add_argument('--n_layers', type=int, default=3, 43 | help='Number of GNN layers') 44 | parser.add_argument('--prev_norm', help='Whether to add layer-norm on the previous layers', action='store_true') 45 | parser.add_argument('--last_norm', help='Whether to add layer-norm on the last layers', action='store_true') 46 | parser.add_argument('--dropout', type=int, default=0.2, 47 | help='Dropout ratio') 48 | 49 | 50 | ''' 51 | Optimization arguments 52 | ''' 53 | parser.add_argument('--optimizer', type=str, default='adamw', 54 | choices=['adamw', 'adam', 'sgd', 'adagrad'], 55 | help='optimizer to use.') 56 | parser.add_argument('--scheduler', type=str, default='cycle', 57 | help='Name of learning rate scheduler.' , choices=['cycle', 'cosine']) 58 | parser.add_argument('--data_percentage', type=int, default=0.1, 59 | help='Percentage of training and validation data to use') 60 | parser.add_argument('--n_epoch', type=int, default=50, 61 | help='Number of epoch to run') 62 | parser.add_argument('--n_pool', type=int, default=8, 63 | help='Number of process to sample subgraph') 64 | parser.add_argument('--n_batch', type=int, default=16, 65 | help='Number of batch (sampled graphs) for each epoch') 66 | parser.add_argument('--batch_size', type=int, default=256, 67 | help='Number of output nodes for training') 68 | parser.add_argument('--clip', type=int, default=0.5, 69 | help='Gradient Norm Clipping') 70 | 71 | args = parser.parse_args() 72 | args_print(args) 73 | 74 | if args.cuda != -1: 75 | device = torch.device("cuda:" + str(args.cuda)) 76 | else: 77 | device = torch.device("cpu") 78 | 79 | print('Start Loading Graph Data...') 80 | graph = renamed_load(open(os.path.join(args.data_dir, 'graph%s.pk' % args.domain), 'rb')) 81 | print('Finish Loading Graph Data!') 82 | 83 | target_type = 'paper' 84 | 85 | types = graph.get_types() 86 | ''' 87 | cand_list stores all the L2 fields, which is the classification domain. 88 | ''' 89 | cand_list = list(graph.edge_list['field']['paper']['PF_in_L2'].keys()) 90 | ''' 91 | Use KL Divergence here, since each paper can be associated with multiple fields. 92 | Thus this task is a multi-label classification. 93 | ''' 94 | criterion = nn.KLDivLoss(reduction='batchmean') 95 | def node_classification_sample(seed, pairs, time_range): 96 | ''' 97 | sub-graph sampling and label preparation for node classification: 98 | (1) Sample batch_size number of output nodes (papers), get their time. 99 | ''' 100 | np.random.seed(seed) 101 | target_ids = np.random.choice(list(pairs.keys()), args.batch_size, replace = False) 102 | target_info = [] 103 | for target_id in target_ids: 104 | _, _time = pairs[target_id] 105 | target_info += [[target_id, _time]] 106 | ''' 107 | (2) Based on the seed nodes, sample a subgraph with 'sampled_depth' and 'sampled_number' 108 | ''' 109 | feature, times, edge_list, _, _ = sample_subgraph(graph, time_range, \ 110 | inp = {'paper': np.array(target_info)}, \ 111 | sampled_depth = args.sample_depth, sampled_number = args.sample_width) 112 | 113 | ''' 114 | (3) Mask out the edge between the output target nodes (paper) with output source nodes (L2 field) 115 | ''' 116 | masked_edge_list = [] 117 | for i in edge_list['paper']['field']['rev_PF_in_L2']: 118 | if i[0] >= args.batch_size: 119 | masked_edge_list += [i] 120 | edge_list['paper']['field']['rev_PF_in_L2'] = masked_edge_list 121 | 122 | masked_edge_list = [] 123 | for i in edge_list['field']['paper']['PF_in_L2']: 124 | if i[1] >= args.batch_size: 125 | masked_edge_list += [i] 126 | edge_list['field']['paper']['PF_in_L2'] = masked_edge_list 127 | ''' 128 | (4) Transform the subgraph into torch Tensor (edge_index is in format of pytorch_geometric) 129 | ''' 130 | node_feature, node_type, edge_time, edge_index, edge_type, node_dict, edge_dict = \ 131 | to_torch(feature, times, edge_list, graph) 132 | ''' 133 | (5) Prepare the labels for each output target node (paper), and their index in sampled graph. 134 | (node_dict[type][0] stores the start index of a specific type of nodes) 135 | ''' 136 | ylabel = np.zeros([args.batch_size, len(cand_list)]) 137 | for x_id, target_id in enumerate(target_ids): 138 | if target_id not in pairs: 139 | print('error 1' + str(target_id)) 140 | for source_id in pairs[target_id][0]: 141 | if source_id not in cand_list: 142 | print('error 2' + str(target_id)) 143 | ylabel[x_id][cand_list.index(source_id)] = 1 144 | 145 | ylabel /= ylabel.sum(axis=1).reshape(-1, 1) 146 | x_ids = np.arange(args.batch_size) + node_dict['paper'][0] 147 | return node_feature, node_type, edge_time, edge_index, edge_type, x_ids, ylabel 148 | 149 | def prepare_data(pool): 150 | ''' 151 | Sampled and prepare training and validation data using multi-process parallization. 152 | ''' 153 | jobs = [] 154 | for batch_id in np.arange(args.n_batch): 155 | p = pool.apply_async(node_classification_sample, args=(randint(), \ 156 | sel_train_pairs, train_range)) 157 | jobs.append(p) 158 | p = pool.apply_async(node_classification_sample, args=(randint(), \ 159 | sel_valid_pairs, valid_range)) 160 | jobs.append(p) 161 | return jobs 162 | 163 | pre_range = {t: True for t in graph.times if t != None and t < 2014} 164 | train_range = {t: True for t in graph.times if t != None and t >= 2014 and t <= 2016} 165 | valid_range = {t: True for t in graph.times if t != None and t > 2016 and t <= 2017} 166 | test_range = {t: True for t in graph.times if t != None and t > 2017} 167 | 168 | 169 | train_pairs = {} 170 | valid_pairs = {} 171 | test_pairs = {} 172 | ''' 173 | Prepare all the souce nodes (L2 field) associated with each target node (paper) as dict 174 | ''' 175 | for target_id in graph.edge_list['paper']['field']['rev_PF_in_L2']: 176 | for source_id in graph.edge_list['paper']['field']['rev_PF_in_L2'][target_id]: 177 | _time = graph.edge_list['paper']['field']['rev_PF_in_L2'][target_id][source_id] 178 | if _time in train_range: 179 | if target_id not in train_pairs: 180 | train_pairs[target_id] = [[], _time] 181 | train_pairs[target_id][0] += [source_id] 182 | elif _time in valid_range: 183 | if target_id not in valid_pairs: 184 | valid_pairs[target_id] = [[], _time] 185 | valid_pairs[target_id][0] += [source_id] 186 | else: 187 | if target_id not in test_pairs: 188 | test_pairs[target_id] = [[], _time] 189 | test_pairs[target_id][0] += [source_id] 190 | 191 | 192 | np.random.seed(43) 193 | ''' 194 | Only train and valid with a certain percentage of data, if necessary. 195 | ''' 196 | sel_train_pairs = {p : train_pairs[p] for p in np.random.choice(list(train_pairs.keys()), int(len(train_pairs) * args.data_percentage), replace = False)} 197 | sel_valid_pairs = {p : valid_pairs[p] for p in np.random.choice(list(valid_pairs.keys()), int(len(valid_pairs) * args.data_percentage), replace = False)} 198 | 199 | 200 | 201 | ''' 202 | Initialize GNN (model is specified by conv_name) and Classifier 203 | ''' 204 | gnn = GNN(conv_name = args.conv_name, in_dim = len(graph.node_feature[target_type]['emb'].values[0]) + 401, n_hid = args.n_hid, \ 205 | n_heads = args.n_heads, n_layers = args.n_layers, dropout = args.dropout, num_types = len(types), \ 206 | num_relations = len(graph.get_meta_graph()) + 1, prev_norm = args.prev_norm, last_norm = args.last_norm) 207 | if args.use_pretrain: 208 | gnn.load_state_dict(load_gnn(torch.load(args.pretrain_model_dir)), strict = False) 209 | print('Load Pre-trained Model from (%s)' % args.pretrain_model_dir) 210 | classifier = Classifier(args.n_hid, len(cand_list)) 211 | 212 | model = nn.Sequential(gnn, classifier).to(device) 213 | 214 | 215 | optimizer = torch.optim.AdamW(model.parameters(), lr = 5e-4) 216 | 217 | stats = [] 218 | res = [] 219 | best_val = 0 220 | train_step = 0 221 | 222 | pool = mp.Pool(args.n_pool) 223 | st = time.time() 224 | jobs = prepare_data(pool) 225 | 226 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 500, eta_min=1e-6) 227 | 228 | for epoch in np.arange(args.n_epoch) + 1: 229 | ''' 230 | Prepare Training and Validation Data 231 | ''' 232 | train_data = [job.get() for job in jobs[:-1]] 233 | valid_data = jobs[-1].get() 234 | pool.close() 235 | pool.join() 236 | ''' 237 | After the data is collected, close the pool and then reopen it. 238 | ''' 239 | pool = mp.Pool(args.n_pool) 240 | jobs = prepare_data(pool) 241 | et = time.time() 242 | print('Data Preparation: %.1fs' % (et - st)) 243 | 244 | ''' 245 | Train (2014 <= time <= 2016) 246 | ''' 247 | model.train() 248 | train_losses = [] 249 | for node_feature, node_type, edge_time, edge_index, edge_type, x_ids, ylabel in train_data: 250 | node_rep = gnn.forward(node_feature.to(device), node_type.to(device), \ 251 | edge_time.to(device), edge_index.to(device), edge_type.to(device)) 252 | res = classifier.forward(node_rep[x_ids]) 253 | loss = criterion(res, torch.FloatTensor(ylabel).to(device)) 254 | 255 | optimizer.zero_grad() 256 | loss.backward() 257 | 258 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip) 259 | optimizer.step() 260 | 261 | train_losses += [loss.cpu().detach().tolist()] 262 | train_step += 1 263 | scheduler.step(train_step) 264 | del res, loss 265 | ''' 266 | Valid (2017 <= time <= 2017) 267 | ''' 268 | model.eval() 269 | with torch.no_grad(): 270 | node_feature, node_type, edge_time, edge_index, edge_type, x_ids, ylabel = valid_data 271 | node_rep = gnn.forward(node_feature.to(device), node_type.to(device), \ 272 | edge_time.to(device), edge_index.to(device), edge_type.to(device)) 273 | res = classifier.forward(node_rep[x_ids]) 274 | loss = criterion(res, torch.FloatTensor(ylabel).to(device)) 275 | 276 | ''' 277 | Calculate Valid NDCG. Update the best model based on highest NDCG score. 278 | ''' 279 | valid_res = [] 280 | for ai, bi in zip(ylabel, res.argsort(descending = True)): 281 | valid_res += [ai[bi.cpu().numpy()]] 282 | valid_ndcg = np.average([ndcg_at_k(resi, len(resi)) for resi in valid_res]) 283 | if valid_ndcg > best_val: 284 | best_val = valid_ndcg 285 | torch.save(model, os.path.join(args.model_dir, args.task_name + '_' + args.conv_name)) 286 | print('UPDATE!!!') 287 | 288 | st = time.time() 289 | print(("Epoch: %d (%.1fs) LR: %.5f Train Loss: %.2f Valid Loss: %.2f Valid NDCG: %.4f") % \ 290 | (epoch, (st-et), optimizer.param_groups[0]['lr'], np.average(train_losses), \ 291 | loss.cpu().detach().tolist(), valid_ndcg)) 292 | stats += [[np.average(train_losses), loss.cpu().detach().tolist()]] 293 | del res, loss 294 | del train_data, valid_data 295 | 296 | 297 | ''' 298 | Evaluate the trained model via test set (time >= 2018) 299 | ''' 300 | 301 | 302 | best_model = torch.load(os.path.join(args.model_dir, args.task_name + '_' + args.conv_name)) 303 | best_model.eval() 304 | gnn, classifier = best_model 305 | with torch.no_grad(): 306 | test_res = [] 307 | for _ in range(10): 308 | node_feature, node_type, edge_time, edge_index, edge_type, x_ids, ylabel = \ 309 | node_classification_sample(randint(), test_pairs, test_range) 310 | paper_rep = gnn.forward(node_feature.to(device), node_type.to(device), \ 311 | edge_time.to(device), edge_index.to(device), edge_type.to(device))[x_ids] 312 | res = classifier.forward(paper_rep) 313 | for ai, bi in zip(ylabel, res.argsort(descending = True)): 314 | test_res += [ai[bi.cpu().numpy()]] 315 | test_ndcg = [ndcg_at_k(resi, len(resi)) for resi in test_res] 316 | print('Best Test NDCG: %.4f' % np.average(test_ndcg)) 317 | test_mrr = mean_reciprocal_rank(test_res) 318 | print('Best Test MRR: %.4f' % np.average(test_mrr)) 319 | -------------------------------------------------------------------------------- /example_OAG/finetune_OAG_PV.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from GPT_GNN.data import * 3 | from GPT_GNN.model import * 4 | from warnings import filterwarnings 5 | filterwarnings("ignore") 6 | 7 | import argparse 8 | 9 | parser = argparse.ArgumentParser(description='Fine-Tuning on Paper-Venue (Journal) classification task') 10 | 11 | ''' 12 | Dataset arguments 13 | ''' 14 | parser.add_argument('--data_dir', type=str, default='/datadrive/dataset', 15 | help='The address of preprocessed graph.') 16 | parser.add_argument('--use_pretrain', help='Whether to use pre-trained model', action='store_true') 17 | parser.add_argument('--pretrain_model_dir', type=str, default='/datadrive/models/gpt_all_cs', 18 | help='The address for pretrained model.') 19 | parser.add_argument('--model_dir', type=str, default='/datadrive/models', 20 | help='The address for storing the models and optimization results.') 21 | parser.add_argument('--task_name', type=str, default='PV', 22 | help='The name of the stored models and optimization results.') 23 | parser.add_argument('--cuda', type=int, default=2, 24 | help='Avaiable GPU ID') 25 | parser.add_argument('--domain', type=str, default='_CS', 26 | help='CS, Medicion or All: _CS or _Med or (empty)') 27 | parser.add_argument('--sample_depth', type=int, default=6, 28 | help='How many numbers to sample the graph') 29 | parser.add_argument('--sample_width', type=int, default=128, 30 | help='How many nodes to be sampled per layer per type') 31 | 32 | ''' 33 | Model arguments 34 | ''' 35 | parser.add_argument('--conv_name', type=str, default='hgt', 36 | choices=['hgt', 'gcn', 'gat', 'rgcn', 'han', 'hetgnn'], 37 | help='The name of GNN filter. By default is Heterogeneous Graph Transformer (hgt)') 38 | parser.add_argument('--n_hid', type=int, default=400, 39 | help='Number of hidden dimension') 40 | parser.add_argument('--n_heads', type=int, default=8, 41 | help='Number of attention head') 42 | parser.add_argument('--n_layers', type=int, default=3, 43 | help='Number of GNN layers') 44 | parser.add_argument('--prev_norm', help='Whether to add layer-norm on the previous layers', action='store_true') 45 | parser.add_argument('--last_norm', help='Whether to add layer-norm on the last layers', action='store_true') 46 | parser.add_argument('--dropout', type=int, default=0.2, 47 | help='Dropout ratio') 48 | 49 | ''' 50 | Optimization arguments 51 | ''' 52 | parser.add_argument('--optimizer', type=str, default='adamw', 53 | choices=['adamw', 'adam', 'sgd', 'adagrad'], 54 | help='optimizer to use.') 55 | parser.add_argument('--scheduler', type=str, default='cycle', 56 | help='Name of learning rate scheduler.' , choices=['cycle', 'cosine']) 57 | parser.add_argument('--data_percentage', type=int, default=0.1, 58 | help='Percentage of training and validation data to use') 59 | parser.add_argument('--n_epoch', type=int, default=50, 60 | help='Number of epoch to run') 61 | parser.add_argument('--n_pool', type=int, default=8, 62 | help='Number of process to sample subgraph') 63 | parser.add_argument('--n_batch', type=int, default=16, 64 | help='Number of batch (sampled graphs) for each epoch') 65 | parser.add_argument('--batch_size', type=int, default=256, 66 | help='Number of output nodes for training') 67 | parser.add_argument('--clip', type=int, default=0.5, 68 | help='Gradient Norm Clipping') 69 | 70 | args = parser.parse_args() 71 | args_print(args) 72 | 73 | if args.cuda != -1: 74 | device = torch.device("cuda:" + str(args.cuda)) 75 | else: 76 | device = torch.device("cpu") 77 | 78 | print('Start Loading Graph Data...') 79 | graph = renamed_load(open(os.path.join(args.data_dir, 'graph%s.pk' % args.domain), 'rb')) 80 | print('Finish Loading Graph Data!') 81 | 82 | target_type = 'paper' 83 | 84 | types = graph.get_types() 85 | ''' 86 | cand_list stores all the Journal, which is the classification domain. 87 | ''' 88 | cand_list = list(graph.edge_list['venue']['paper']['PV_Journal'].keys()) 89 | ''' 90 | Use CrossEntropy (log-softmax + NLL) here, since each paper can be associated with one venue. 91 | ''' 92 | criterion = nn.NLLLoss() 93 | 94 | def node_classification_sample(seed, pairs, time_range): 95 | ''' 96 | sub-graph sampling and label preparation for node classification: 97 | (1) Sample batch_size number of output nodes (papers) and their time. 98 | ''' 99 | np.random.seed(seed) 100 | target_ids = np.random.choice(list(pairs.keys()), args.batch_size, replace = False) 101 | target_info = [] 102 | for target_id in target_ids: 103 | _, _time = pairs[target_id] 104 | target_info += [[target_id, _time]] 105 | 106 | ''' 107 | (2) Based on the seed nodes, sample a subgraph with 'sampled_depth' and 'sampled_number' 108 | ''' 109 | feature, times, edge_list, _, _ = sample_subgraph(graph, time_range, \ 110 | inp = {'paper': np.array(target_info)}, \ 111 | sampled_depth = args.sample_depth, sampled_number = args.sample_width) 112 | 113 | 114 | ''' 115 | (3) Mask out the edge between the output target nodes (paper) with output source nodes (Journal) 116 | ''' 117 | masked_edge_list = [] 118 | for i in edge_list['paper']['venue']['rev_PV_Journal']: 119 | if i[0] >= args.batch_size: 120 | masked_edge_list += [i] 121 | edge_list['paper']['venue']['rev_PV_Journal'] = masked_edge_list 122 | 123 | masked_edge_list = [] 124 | for i in edge_list['venue']['paper']['PV_Journal']: 125 | if i[1] >= args.batch_size: 126 | masked_edge_list += [i] 127 | edge_list['venue']['paper']['PV_Journal'] = masked_edge_list 128 | 129 | ''' 130 | (4) Transform the subgraph into torch Tensor (edge_index is in format of pytorch_geometric) 131 | ''' 132 | node_feature, node_type, edge_time, edge_index, edge_type, node_dict, edge_dict = \ 133 | to_torch(feature, times, edge_list, graph) 134 | ''' 135 | (5) Prepare the labels for each output target node (paper), and their index in sampled graph. 136 | (node_dict[type][0] stores the start index of a specific type of nodes) 137 | ''' 138 | ylabel = torch.zeros(args.batch_size, dtype = torch.long) 139 | for x_id, target_id in enumerate(target_ids): 140 | ylabel[x_id] = cand_list.index(pairs[target_id][0]) 141 | x_ids = np.arange(args.batch_size) + node_dict['paper'][0] 142 | return node_feature, node_type, edge_time, edge_index, edge_type, x_ids, ylabel 143 | 144 | def prepare_data(pool): 145 | ''' 146 | Sampled and prepare training and validation data using multi-process parallization. 147 | ''' 148 | jobs = [] 149 | for batch_id in np.arange(args.n_batch): 150 | p = pool.apply_async(node_classification_sample, args=(randint(), \ 151 | sel_train_pairs, train_range)) 152 | jobs.append(p) 153 | p = pool.apply_async(node_classification_sample, args=(randint(), \ 154 | sel_valid_pairs, valid_range)) 155 | jobs.append(p) 156 | return jobs 157 | 158 | 159 | train_pairs = {} 160 | valid_pairs = {} 161 | test_pairs = {} 162 | ''' 163 | Prepare all the souce nodes (Journal) associated with each target node (paper) as dict 164 | ''' 165 | for target_id in graph.edge_list['paper']['venue']['rev_PV_Journal']: 166 | for source_id in graph.edge_list['paper']['venue']['rev_PV_Journal'][target_id]: 167 | _time = graph.edge_list['paper']['venue']['rev_PV_Journal'][target_id][source_id] 168 | if _time in train_range: 169 | if target_id not in train_pairs: 170 | train_pairs[target_id] = [source_id, _time] 171 | elif _time in valid_range: 172 | if target_id not in valid_pairs: 173 | valid_pairs[target_id] = [source_id, _time] 174 | else: 175 | if target_id not in test_pairs: 176 | test_pairs[target_id] = [source_id, _time] 177 | 178 | 179 | np.random.seed(43) 180 | ''' 181 | Only train and valid with a certain percentage of data, if necessary. 182 | ''' 183 | sel_train_pairs = {p : train_pairs[p] for p in np.random.choice(list(train_pairs.keys()), int(len(train_pairs) * args.data_percentage), replace = False)} 184 | sel_valid_pairs = {p : valid_pairs[p] for p in np.random.choice(list(valid_pairs.keys()), int(len(valid_pairs) * args.data_percentage), replace = False)} 185 | 186 | 187 | 188 | 189 | ''' 190 | Initialize GNN (model is specified by conv_name) and Classifier 191 | ''' 192 | gnn = GNN(conv_name = args.conv_name, in_dim = len(graph.node_feature[target_type]['emb'].values[0]) + 401, n_hid = args.n_hid, \ 193 | n_heads = args.n_heads, n_layers = args.n_layers, dropout = args.dropout, num_types = len(types), \ 194 | num_relations = len(graph.get_meta_graph()) + 1, prev_norm = args.prev_norm, last_norm = args.last_norm) 195 | if args.use_pretrain: 196 | gnn.load_state_dict(load_gnn(torch.load(args.pretrain_model_dir)), strict = False) 197 | print('Load Pre-trained Model from (%s)' % args.pretrain_model_dir) 198 | classifier = Classifier(args.n_hid, len(cand_list)).to(device) 199 | 200 | model = nn.Sequential(gnn, classifier) 201 | 202 | 203 | optimizer = torch.optim.AdamW(model.parameters(), lr = 5e-4) 204 | 205 | stats = [] 206 | res = [] 207 | best_val = 0 208 | train_step = 0 209 | 210 | pool = mp.Pool(args.n_pool) 211 | st = time.time() 212 | jobs = prepare_data(pool) 213 | 214 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 500, eta_min=1e-6) 215 | 216 | for epoch in np.arange(args.n_epoch) + 1: 217 | ''' 218 | Prepare Training and Validation Data 219 | ''' 220 | train_data = [job.get() for job in jobs[:-1]] 221 | valid_data = jobs[-1].get() 222 | pool.close() 223 | pool.join() 224 | ''' 225 | After the data is collected, close the pool and then reopen it. 226 | ''' 227 | pool = mp.Pool(args.n_pool) 228 | jobs = prepare_data(pool) 229 | et = time.time() 230 | print('Data Preparation: %.1fs' % (et - st)) 231 | 232 | ''' 233 | Train (2014 <= time <= 2016) 234 | ''' 235 | model.train() 236 | train_losses = [] 237 | torch.cuda.empty_cache() 238 | for node_feature, node_type, edge_time, edge_index, edge_type, x_ids, ylabel in train_data: 239 | node_rep = gnn.forward(node_feature.to(device), node_type.to(device), \ 240 | edge_time.to(device), edge_index.to(device), edge_type.to(device)) 241 | res = classifier.forward(node_rep[x_ids]) 242 | loss = criterion(res, ylabel.to(device)) 243 | 244 | optimizer.zero_grad() 245 | torch.cuda.empty_cache() 246 | loss.backward() 247 | 248 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip) 249 | optimizer.step() 250 | 251 | train_losses += [loss.cpu().detach().tolist()] 252 | train_step += 1 253 | scheduler.step(train_step) 254 | del res, loss 255 | ''' 256 | Valid (2017 <= time <= 2017) 257 | ''' 258 | model.eval() 259 | with torch.no_grad(): 260 | node_feature, node_type, edge_time, edge_index, edge_type, x_ids, ylabel = valid_data 261 | node_rep = gnn.forward(node_feature.to(device), node_type.to(device), \ 262 | edge_time.to(device), edge_index.to(device), edge_type.to(device)) 263 | res = classifier.forward(node_rep[x_ids]) 264 | loss = criterion(res, ylabel.to(device)) 265 | 266 | ''' 267 | Calculate Valid NDCG. Update the best model based on highest NDCG score. 268 | ''' 269 | valid_res = [] 270 | for ai, bi in zip(ylabel, res.argsort(descending = True)): 271 | valid_res += [(bi == ai).int().tolist()] 272 | valid_ndcg = np.average([ndcg_at_k(resi, len(resi)) for resi in valid_res]) 273 | 274 | if valid_ndcg > best_val: 275 | best_val = valid_ndcg 276 | torch.save(model, os.path.join(args.model_dir, args.task_name + '_' + args.conv_name)) 277 | print('UPDATE!!!') 278 | 279 | st = time.time() 280 | print(("Epoch: %d (%.1fs) LR: %.5f Train Loss: %.2f Valid Loss: %.2f Valid NDCG: %.4f") % \ 281 | (epoch, (st-et), optimizer.param_groups[0]['lr'], np.average(train_losses), \ 282 | loss.cpu().detach().tolist(), valid_ndcg)) 283 | stats += [[np.average(train_losses), loss.cpu().detach().tolist()]] 284 | del res, loss 285 | del train_data, valid_data 286 | 287 | 288 | ''' 289 | Evaluate the trained model via test set (time >= 2018) 290 | ''' 291 | 292 | best_model = torch.load(os.path.join(args.model_dir, args.task_name + '_' + args.conv_name)) 293 | best_model.eval() 294 | gnn, classifier = best_model 295 | with torch.no_grad(): 296 | test_res = [] 297 | for _ in range(10): 298 | node_feature, node_type, edge_time, edge_index, edge_type, x_ids, ylabel = \ 299 | node_classification_sample(randint(), test_pairs, test_range) 300 | paper_rep = gnn.forward(node_feature.to(device), node_type.to(device), \ 301 | edge_time.to(device), edge_index.to(device), edge_type.to(device))[x_ids] 302 | res = classifier.forward(paper_rep) 303 | for ai, bi in zip(ylabel, res.argsort(descending = True)): 304 | test_res += [(bi == ai).int().tolist()] 305 | test_ndcg = [ndcg_at_k(resi, len(resi)) for resi in test_res] 306 | print('Best Test NDCG: %.4f' % np.average(test_ndcg)) 307 | test_mrr = mean_reciprocal_rank(test_res) 308 | print('Best Test MRR: %.4f' % np.average(test_mrr)) 309 | -------------------------------------------------------------------------------- /example_OAG/preprocess_OAG.py: -------------------------------------------------------------------------------- 1 | from transformers import * 2 | 3 | from data import * 4 | import gensim 5 | from gensim.models import Word2Vec 6 | from tqdm import tqdm 7 | # from tqdm import tqdm_notebook as tqdm # Comment this line if using jupyter notebook 8 | 9 | 10 | import argparse 11 | 12 | parser = argparse.ArgumentParser(description='Preprocess OAG (CS/Med/All) Data') 13 | 14 | ''' 15 | Dataset arguments 16 | ''' 17 | parser.add_argument('--input_dir', type=str, default='./data/oag_raw', 18 | help='The address to store the original data directory.') 19 | parser.add_argument('--output_dir', type=str, default='./data/oag_output', 20 | help='The address to output the preprocessed graph.') 21 | parser.add_argument('--cuda', type=int, default=0, 22 | help='Avaiable GPU ID') 23 | parser.add_argument('--domain', type=str, default='_CS', 24 | help='CS, Medical or All: _CS or _Med or (empty)') 25 | parser.add_argument('--citation_bar', type=int, default=1, 26 | help='Only consider papers with citation larger than (2020 - year) * citation_bar') 27 | 28 | args = parser.parse_args() 29 | 30 | 31 | test_time_bar = 2016 32 | 33 | cite_dict = defaultdict(lambda: 0) 34 | with open(args.input_dir + '/PR%s_20190919.tsv' % args.domain) as fin: 35 | fin.readline() 36 | for l in tqdm(fin, total = sum(1 for line in open(args.input_dir + '/PR%s_20190919.tsv' % args.domain))): 37 | l = l[:-1].split('\t') 38 | cite_dict[l[1]] += 1 39 | 40 | 41 | pfl = defaultdict(lambda: {}) 42 | with open(args.input_dir + '/Papers%s_20190919.tsv' % args.domain) as fin: 43 | fin.readline() 44 | for l in tqdm(fin, total = sum(1 for line in open(args.input_dir + '/Papers%s_20190919.tsv' % args.domain))): 45 | l = l[:-1].split('\t') 46 | bound = min(2020 - int(l[1]), 20) * args.citation_bar 47 | if cite_dict[l[0]] < bound or l[0] == '' or l[1] == '' or l[2] == '' or l[3] == '' and l[4] == '' or int(l[1]) < 1900: 48 | continue 49 | pi = {'id': l[0], 'title': l[2], 'type': 'paper', 'time': int(l[1])} 50 | pfl[l[0]] = pi 51 | 52 | 53 | if args.cuda != -1: 54 | device = torch.device("cuda:" + str(args.cuda)) 55 | else: 56 | device = torch.device("cpu") 57 | 58 | tokenizer = XLNetTokenizer.from_pretrained('xlnet-base-cased') 59 | model = XLNetModel.from_pretrained('xlnet-base-cased', 60 | output_hidden_states=True, 61 | output_attentions=True).to(device) 62 | 63 | 64 | with open(args.input_dir + '/PAb%s_20190919.tsv' % args.domain) as fin: 65 | fin.readline() 66 | for l in tqdm(fin, total = sum(1 for line in open(args.input_dir + '/PAb%s_20190919.tsv' % args.domain, 'r'))): 67 | try: 68 | l = l.split('\t') 69 | if l[0] in pfl: 70 | input_ids = torch.tensor([tokenizer.encode(pfl[l[0]]['title'])]).to(device)[:, :64] 71 | if len(input_ids[0]) < 4: 72 | continue 73 | all_hidden_states, all_attentions = model(input_ids)[-2:] 74 | rep = (all_hidden_states[-2][0] * all_attentions[-2][0].mean(dim=0).mean(dim=0).view(-1, 1)).sum(dim=0) 75 | pfl[l[0]]['emb'] = rep.tolist() 76 | except Exception as e: 77 | print(e) 78 | 79 | 80 | 81 | vfi_ids = {} 82 | with open(args.input_dir + '/vfi_vector.tsv') as fin: 83 | for l in tqdm(fin, total = sum(1 for line in open(args.input_dir + '/vfi_vector.tsv'))): 84 | l = l[:-1].split('\t') 85 | vfi_ids[l[0]] = True 86 | 87 | 88 | graph = Graph() 89 | rem = [] 90 | with open(args.input_dir + '/Papers%s_20190919.tsv' % args.domain) as fin: 91 | fin.readline() 92 | for l in tqdm(fin, total = sum(1 for line in open(args.input_dir + '/Papers%s_20190919.tsv' % args.domain, 'r'))): 93 | l = l[:-1].split('\t') 94 | if l[0] not in pfl or l[4] != 'en' or 'emb' not in pfl[l[0]] or l[3] not in vfi_ids: 95 | continue 96 | rem += [l[0]] 97 | vi = {'id': l[3], 'type': 'venue', 'attr': l[-2]} 98 | graph.add_edge(pfl[l[0]], vi, time = int(l[1]), relation_type = 'PV_' + l[-2]) 99 | pfl = {i: pfl[i] for i in rem} 100 | print(len(pfl)) 101 | 102 | 103 | with open(args.input_dir + '/PR%s_20190919.tsv' % args.domain) as fin: 104 | fin.readline() 105 | for l in tqdm(fin, total = sum(1 for line in open(args.input_dir + '/PR%s_20190919.tsv' % args.domain))): 106 | l = l[:-1].split('\t') 107 | if l[0] in pfl and l[1] in pfl: 108 | p1 = pfl[l[0]] 109 | p2 = pfl[l[1]] 110 | if p1['time'] >= p2['time']: 111 | graph.add_edge(p1, p2, time = p1['time'], relation_type = 'PP_cite') 112 | 113 | 114 | 115 | ffl = {} 116 | with open(args.input_dir + '/PF%s_20190919.tsv' % args.domain) as fin: 117 | fin.readline() 118 | for l in tqdm(fin, total = sum(1 for line in open(args.input_dir + '/PF%s_20190919.tsv' % args.domain))): 119 | l = l[:-1].split('\t') 120 | if l[0] in pfl and l[1] in vfi_ids: 121 | ffl[l[1]] = True 122 | 123 | 124 | 125 | 126 | with open(args.input_dir + '/FHierarchy_20190919.tsv') as fin: 127 | fin.readline() 128 | for l in tqdm(fin, total = sum(1 for line in open(args.input_dir + '/FHierarchy_20190919.tsv'))): 129 | l = l[:-1].split('\t') 130 | if l[0] in ffl and l[1] in ffl: 131 | fi = {'id': l[0], 'type': 'field', 'attr': l[2]} 132 | fj = {'id': l[1], 'type': 'field', 'attr': l[3]} 133 | graph.add_edge(fi, fj, relation_type = 'FF_in') 134 | ffl[l[0]] = fi 135 | ffl[l[1]] = fj 136 | 137 | 138 | 139 | 140 | with open(args.input_dir + '/PF%s_20190919.tsv' % args.domain) as fin: 141 | fin.readline() 142 | for l in tqdm(fin, total = sum(1 for line in open(args.input_dir + '/PF%s_20190919.tsv' % args.domain))): 143 | l = l[:-1].split('\t') 144 | if l[0] in pfl and l[1] in ffl and type(ffl[l[1]]) == dict: 145 | pi = pfl[l[0]] 146 | fi = ffl[l[1]] 147 | graph.add_edge(pi, fi, time = pi['time'], relation_type = 'PF_in_' + fi['attr']) 148 | 149 | 150 | 151 | 152 | coa = defaultdict(lambda: {}) 153 | with open(args.input_dir + '/PAuAf%s_20190919.tsv' % args.domain) as fin: 154 | fin.readline() 155 | for l in tqdm(fin, total = sum(1 for line in open(args.input_dir + '/PAuAf%s_20190919.tsv' % args.domain))): 156 | l = l[:-1].split('\t') 157 | if l[0] in pfl and l[2] in vfi_ids: 158 | pi = pfl[l[0]] 159 | ai = {'id': l[1], 'type': 'author'} 160 | fi = {'id': l[2], 'type': 'affiliation'} 161 | coa[l[0]][int(l[-1])] = ai 162 | graph.add_edge(ai, fi, relation_type = 'in') 163 | 164 | for pid in tqdm(coa): 165 | pi = pfl[pid] 166 | max_seq = max(coa[pid].keys()) 167 | for seq_i in coa[pid]: 168 | ai = coa[pid][seq_i] 169 | if seq_i == 1: 170 | graph.add_edge(ai, pi, time = pi['time'], relation_type = 'AP_write_first') 171 | elif seq_i == max_seq: 172 | graph.add_edge(ai, pi, time = pi['time'], relation_type = 'AP_write_last') 173 | else: 174 | graph.add_edge(ai, pi, time = pi['time'], relation_type = 'AP_write_other') 175 | 176 | 177 | 178 | 179 | with open(args.input_dir + '/vfi_vector.tsv') as fin: 180 | for l in tqdm(fin, total = sum(1 for line in open(args.input_dir + '/vfi_vector.tsv'))): 181 | l = l[:-1].split('\t') 182 | ser = l[0] 183 | for idx in ['venue', 'field', 'affiliation']: 184 | if ser in graph.node_forward[idx]: 185 | graph.node_bacward[idx][graph.node_forward[idx][ser]]['node_emb'] = np.array(l[1].split(' ')) 186 | 187 | 188 | 189 | 190 | with open(args.input_dir + '/SeqName%s_20190919.tsv' % args.domain) as fin: 191 | for l in tqdm(fin, total = sum(1 for line in open(args.input_dir + '/SeqName%s_20190919.tsv' % args.domain))): 192 | l = l[:-1].split('\t') 193 | key = l[2] 194 | if key in ['conference', 'journal', 'repository', 'patent']: 195 | key = 'venue' 196 | if key == 'fos': 197 | key = 'field' 198 | if l[0] in graph.node_forward[key]: 199 | s = graph.node_bacward[key][graph.node_forward[key][l[0]]] 200 | s['name'] = l[1] 201 | 202 | ''' 203 | Calculate the total citation information as node attributes. 204 | ''' 205 | 206 | for idx, pi in enumerate(graph.node_bacward['paper']): 207 | pi['citation'] = len(graph.edge_list['paper']['paper']['PP_cite'][idx]) 208 | for idx, ai in enumerate(graph.node_bacward['author']): 209 | citation = 0 210 | for rel in graph.edge_list['author']['paper'].keys(): 211 | for pid in graph.edge_list['author']['paper'][rel][idx]: 212 | citation += graph.node_bacward['paper'][pid]['citation'] 213 | ai['citation'] = citation 214 | for idx, fi in enumerate(graph.node_bacward['affiliation']): 215 | citation = 0 216 | for aid in graph.edge_list['affiliation']['author']['in'][idx]: 217 | citation += graph.node_bacward['author'][aid]['citation'] 218 | fi['citation'] = citation 219 | for idx, vi in enumerate(graph.node_bacward['venue']): 220 | citation = 0 221 | for rel in graph.edge_list['venue']['paper'].keys(): 222 | for pid in graph.edge_list['venue']['paper'][rel][idx]: 223 | citation += graph.node_bacward['paper'][pid]['citation'] 224 | vi['citation'] = citation 225 | for idx, fi in enumerate(graph.node_bacward['field']): 226 | citation = 0 227 | for rel in graph.edge_list['field']['paper'].keys(): 228 | for pid in graph.edge_list['field']['paper'][rel][idx]: 229 | citation += graph.node_bacward['paper'][pid]['citation'] 230 | fi['citation'] = citation 231 | 232 | 233 | 234 | 235 | ''' 236 | Since only paper have w2v embedding, we simply propagate its 237 | feature to other nodes by averaging neighborhoods. 238 | Then, we construct the Dataframe for each node type. 239 | ''' 240 | d = pd.DataFrame(graph.node_bacward['paper']) 241 | graph.node_feature = {'paper': d} 242 | cv = np.array(list(d['emb'])) 243 | for _type in graph.node_bacward: 244 | if _type not in ['paper', 'affiliation']: 245 | d = pd.DataFrame(graph.node_bacward[_type]) 246 | i = [] 247 | for _rel in graph.edge_list[_type]['paper']: 248 | for t in graph.edge_list[_type]['paper'][_rel]: 249 | for s in graph.edge_list[_type]['paper'][_rel][t]: 250 | if graph.edge_list[_type]['paper'][_rel][t][s] <= test_time_bar: 251 | i += [[t, s]] 252 | if len(i) == 0: 253 | continue 254 | i = np.array(i).T 255 | v = np.ones(i.shape[1]) 256 | m = normalize(sp.coo_matrix((v, i), \ 257 | shape=(len(graph.node_bacward[_type]), len(graph.node_bacward['paper'])))) 258 | out = m.dot(cv) 259 | d['emb'] = list(out) 260 | graph.node_feature[_type] = d 261 | ''' 262 | Affiliation is not directly linked with Paper, so we average the author embedding. 263 | ''' 264 | cv = np.array(list(graph.node_feature['author']['emb'])) 265 | d = pd.DataFrame(graph.node_bacward['affiliation']) 266 | i = [] 267 | for _rel in graph.edge_list['affiliation']['author']: 268 | for j in graph.edge_list['affiliation']['author'][_rel]: 269 | for t in graph.edge_list['affiliation']['author'][_rel][j]: 270 | i += [[j, t]] 271 | i = np.array(i).T 272 | v = np.ones(i.shape[1]) 273 | m = normalize(sp.coo_matrix((v, i), \ 274 | shape=(len(graph.node_bacward['affiliation']), len(graph.node_bacward['author'])))) 275 | out = m.dot(cv) 276 | d['emb'] = list(out) 277 | graph.node_feature['affiliation'] = d 278 | 279 | 280 | edg = {} 281 | for k1 in graph.edge_list: 282 | if k1 not in edg: 283 | edg[k1] = {} 284 | for k2 in graph.edge_list[k1]: 285 | if k2 not in edg[k1]: 286 | edg[k1][k2] = {} 287 | for k3 in graph.edge_list[k1][k2]: 288 | if k3 not in edg[k1][k2]: 289 | edg[k1][k2][k3] = {} 290 | for e1 in graph.edge_list[k1][k2][k3]: 291 | if len(graph.edge_list[k1][k2][k3][e1]) == 0: 292 | continue 293 | edg[k1][k2][k3][e1] = {} 294 | for e2 in graph.edge_list[k1][k2][k3][e1]: 295 | edg[k1][k2][k3][e1][e2] = graph.edge_list[k1][k2][k3][e1][e2] 296 | print(k1, k2, k3, len(edg[k1][k2][k3])) 297 | graph.edge_list = edg 298 | 299 | 300 | del graph.node_bacward 301 | dill.dump(graph, open(args.output_dir + '/graph%s.pk' % args.domain, 'wb')) 302 | 303 | 304 | 305 | -------------------------------------------------------------------------------- /example_OAG/pretrain_OAG.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from GPT_GNN.data import * 3 | from GPT_GNN.model import * 4 | from warnings import filterwarnings 5 | filterwarnings("ignore") 6 | 7 | import argparse 8 | 9 | parser = argparse.ArgumentParser(description='Pre-training HGT on a given graph (heterogeneous / homogeneous)') 10 | 11 | ''' 12 | GPT-GNN arguments 13 | ''' 14 | parser.add_argument('--attr_ratio', type=float, default=0.5, 15 | help='Ratio of attr-loss against link-loss, range: [0-1]') 16 | parser.add_argument('--attr_type', type=str, default='text', 17 | choices=['text', 'vec'], 18 | help='The type of attribute decoder') 19 | parser.add_argument('--neg_samp_num', type=int, default=255, 20 | help='Maximum number of negative sample for each target node.') 21 | parser.add_argument('--queue_size', type=int, default=256, 22 | help='Max size of adaptive embedding queue.') 23 | parser.add_argument('--w2v_dir', type=str, default='/datadrive/dataset/w2v_all', 24 | help='The address of preprocessed graph.') 25 | 26 | ''' 27 | Dataset arguments 28 | ''' 29 | parser.add_argument('--data_dir', type=str, default='/datadrive/dataset/graph_CS.pk', 30 | help='The address of preprocessed graph.') 31 | parser.add_argument('--pretrain_model_dir', type=str, default='/datadrive/models/test', 32 | help='The address for storing the models and optimization results.') 33 | parser.add_argument('--cuda', type=int, default=2, 34 | help='Avaiable GPU ID') 35 | parser.add_argument('--sample_depth', type=int, default=6, 36 | help='How many layers within a mini-batch subgraph') 37 | parser.add_argument('--sample_width', type=int, default=128, 38 | help='How many nodes to be sampled per layer per type') 39 | 40 | ''' 41 | Model arguments 42 | ''' 43 | parser.add_argument('--conv_name', type=str, default='hgt', 44 | choices=['hgt', 'gcn', 'gat', 'rgcn', 'han', 'hetgnn'], 45 | help='The name of GNN filter. By default is Heterogeneous Graph Transformer (hgt)') 46 | parser.add_argument('--n_hid', type=int, default=400, 47 | help='Number of hidden dimension') 48 | parser.add_argument('--n_heads', type=int, default=8, 49 | help='Number of attention head') 50 | parser.add_argument('--n_layers', type=int, default=3, 51 | help='Number of GNN layers') 52 | parser.add_argument('--prev_norm', help='Whether to add layer-norm on the previous layers', action='store_true') 53 | parser.add_argument('--last_norm', help='Whether to add layer-norm on the last layers', action='store_true') 54 | parser.add_argument('--dropout', type=int, default=0.2, 55 | help='Dropout ratio') 56 | 57 | ''' 58 | Optimization arguments 59 | ''' 60 | parser.add_argument('--max_lr', type=float, default=1e-3, 61 | help='Maximum learning rate.') 62 | parser.add_argument('--scheduler', type=str, default='cycle', 63 | help='Name of learning rate scheduler.' , choices=['cycle', 'cosine']) 64 | parser.add_argument('--n_epoch', type=int, default=20, 65 | help='Number of epoch to run') 66 | parser.add_argument('--n_pool', type=int, default=8, 67 | help='Number of process to sample subgraph') 68 | parser.add_argument('--n_batch', type=int, default=32, 69 | help='Number of batch (sampled graphs) for each epoch') 70 | parser.add_argument('--batch_size', type=int, default=256, 71 | help='Number of output nodes for training') 72 | parser.add_argument('--clip', type=float, default=0.5, 73 | help='Gradient Norm Clipping') 74 | 75 | 76 | args = parser.parse_args() 77 | args_print(args) 78 | 79 | if args.cuda != -1: 80 | device = torch.device("cuda:" + str(args.cuda)) 81 | else: 82 | device = torch.device("cpu") 83 | 84 | print('Start Loading Graph Data...') 85 | graph = renamed_load(open(args.data_dir, 'rb')) 86 | print('Finish Loading Graph Data!') 87 | 88 | pre_range = {t: True for t in graph.times if t != None and t < 2014} 89 | train_range = {t: True for t in graph.times if t != None and t >= 2014 and t <= 2016} 90 | valid_range = {t: True for t in graph.times if t != None and t > 2016 and t <= 2017} 91 | test_range = {t: True for t in graph.times if t != None and t > 2017} 92 | 93 | pre_target_nodes = [] 94 | train_target_nodes = [] 95 | target_type = 'paper' 96 | rel_stop_list = ['self', 'rev_PF_in_L0', 'rev_PF_in_L5', 'rev_PV_Repository', 'rev_PV_Patent'] 97 | 98 | 99 | for p_id, _time in graph.node_feature[target_type]['time'].iteritems(): 100 | if _time in pre_range: 101 | pre_target_nodes += [[p_id, _time]] 102 | elif _time in train_range: 103 | train_target_nodes += [[p_id, _time]] 104 | pre_target_nodes = np.array(pre_target_nodes) 105 | train_target_nodes = np.array(train_target_nodes) 106 | 107 | 108 | def GPT_sample(seed, target_nodes, time_range, batch_size, feature_extractor): 109 | np.random.seed(seed) 110 | samp_target_nodes = target_nodes[np.random.choice(len(target_nodes), batch_size)] 111 | threshold = 0.5 112 | feature, times, edge_list, _, attr = sample_subgraph(graph, time_range, \ 113 | inp = {target_type: samp_target_nodes}, feature_extractor = feature_extractor, \ 114 | sampled_depth = args.sample_depth, sampled_number = args.sample_width) 115 | rem_edge_list = defaultdict( #source_type 116 | lambda: defaultdict( #relation_type 117 | lambda: [] # [target_id, source_id] 118 | )) 119 | 120 | ori_list = {} 121 | for source_type in edge_list[target_type]: 122 | ori_list[source_type] = {} 123 | for relation_type in edge_list[target_type][source_type]: 124 | ori_list[source_type][relation_type] = np.array(edge_list[target_type][source_type][relation_type]) 125 | el = [] 126 | for target_ser, source_ser in edge_list[target_type][source_type][relation_type]: 127 | if relation_type not in rel_stop_list and target_ser < batch_size and np.random.random() > threshold: 128 | rem_edge_list[source_type][relation_type] += [[target_ser, source_ser]] 129 | continue 130 | el += [[target_ser, source_ser]] 131 | el = np.array(el) 132 | edge_list[target_type][source_type][relation_type] = el 133 | 134 | if relation_type == 'self': 135 | continue 136 | else: 137 | if 'rev_' in relation_type: 138 | rev_relation = relation_type[4:] 139 | else: 140 | rev_relation = 'rev_' + relation_type 141 | edge_list[source_type]['paper'][rev_relation] = list(np.stack((el[:,1], el[:,0])).T) 142 | 143 | ''' 144 | Adding feature nodes: 145 | ''' 146 | n_target_nodes = len(feature[target_type]) 147 | feature[target_type] = np.concatenate((feature[target_type], np.zeros([batch_size, feature[target_type].shape[1]]))) 148 | times[target_type] = np.concatenate((times[target_type], times[target_type][:batch_size])) 149 | 150 | for source_type in edge_list[target_type]: 151 | for relation_type in edge_list[target_type][source_type]: 152 | el = [] 153 | for target_ser, source_ser in edge_list[target_type][source_type][relation_type]: 154 | if target_ser < batch_size: 155 | if relation_type == 'self': 156 | el += [[target_ser + n_target_nodes, target_ser + n_target_nodes]] 157 | else: 158 | el += [[target_ser + n_target_nodes, source_ser]] 159 | if len(el) > 0: 160 | edge_list[target_type][source_type][relation_type] = \ 161 | np.concatenate((edge_list[target_type][source_type][relation_type], el)) 162 | 163 | 164 | rem_edge_lists = {} 165 | for source_type in rem_edge_list: 166 | rem_edge_lists[source_type] = {} 167 | for relation_type in rem_edge_list[source_type]: 168 | rem_edge_lists[source_type][relation_type] = np.array(rem_edge_list[source_type][relation_type]) 169 | del rem_edge_list 170 | 171 | return to_torch(feature, times, edge_list, graph), rem_edge_lists, ori_list, \ 172 | attr[:batch_size], (n_target_nodes, n_target_nodes + batch_size) 173 | 174 | 175 | 176 | def prepare_data(pool): 177 | jobs = [] 178 | for _ in np.arange(args.n_batch - 1): 179 | jobs.append(pool.apply_async(GPT_sample, args=(randint(), pre_target_nodes, pre_range, args.batch_size, feature_OAG))) 180 | jobs.append(pool.apply_async(GPT_sample, args=(randint(), train_target_nodes, train_range, args.batch_size, feature_OAG))) 181 | return jobs 182 | 183 | 184 | pool = mp.Pool(args.n_pool) 185 | st = time.time() 186 | jobs = prepare_data(pool) 187 | repeat_num = int(len(pre_target_nodes) / args.batch_size // args.n_batch) 188 | 189 | 190 | data, rem_edge_list, ori_edge_list, _, _ = GPT_sample(randint(), pre_target_nodes, pre_range, args.batch_size, feature_OAG) 191 | node_feature, node_type, edge_time, edge_index, edge_type, node_dict, edge_dict = data 192 | types = graph.get_types() 193 | 194 | 195 | gnn = GNN(conv_name = args.conv_name, in_dim = len(graph.node_feature[target_type]['emb'].values[0]) + 401, n_hid = args.n_hid, \ 196 | n_heads = args.n_heads, n_layers = args.n_layers, dropout = args.dropout, num_types = len(types), \ 197 | num_relations = len(graph.get_meta_graph()) + 1, prev_norm = args.prev_norm, last_norm = args.last_norm) 198 | 199 | 200 | if args.attr_type == 'text': 201 | from gensim.models import Word2Vec 202 | w2v_model = Word2Vec.load(args.w2v_dir) 203 | n_tokens = len(w2v_model.wv.vocab) 204 | attr_decoder = RNNModel(n_word = n_tokens, ninp = gnn.n_hid, \ 205 | nhid = w2v_model.vector_size, nlayers = 2) 206 | attr_decoder.from_w2v(torch.FloatTensor(w2v_model.wv.vectors)) 207 | else: 208 | attr_decoder = Matcher(gnn.n_hid, gnn.in_dim) 209 | 210 | gpt_gnn = GPT_GNN(gnn = gnn, rem_edge_list = rem_edge_list, attr_decoder = attr_decoder, \ 211 | neg_queue_size = 0, types = types, neg_samp_num = args.neg_samp_num, device = device) 212 | gpt_gnn.init_emb.data = node_feature[node_type == node_dict[target_type][1]].mean(dim=0).detach() 213 | gpt_gnn = gpt_gnn.to(device) 214 | 215 | 216 | 217 | best_val = 100000 218 | train_step = 0 219 | stats = [] 220 | optimizer = torch.optim.AdamW(gpt_gnn.parameters(), weight_decay = 1e-2, eps=1e-06, lr = args.max_lr) 221 | 222 | if args.scheduler == 'cycle': 223 | scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, pct_start=0.02, anneal_strategy='linear', final_div_factor=100,\ 224 | max_lr = args.max_lr, total_steps = repeat_num * args.n_batch * args.n_epoch + 1) 225 | elif args.scheduler == 'cosine': 226 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, repeat_num * args.n_batch, eta_min=1e-6) 227 | 228 | print('Start Pretraining...') 229 | for epoch in np.arange(args.n_epoch) + 1: 230 | gpt_gnn.neg_queue_size = args.queue_size * epoch // args.n_epoch 231 | for batch in np.arange(repeat_num) + 1: 232 | train_data = [job.get() for job in jobs[:-1]] 233 | valid_data = jobs[-1].get() 234 | pool.close() 235 | pool.join() 236 | pool = mp.Pool(args.n_pool) 237 | jobs = prepare_data(pool) 238 | et = time.time() 239 | print('Data Preparation: %.1fs' % (et - st)) 240 | 241 | train_link_losses = [] 242 | train_attr_losses = [] 243 | gpt_gnn.train() 244 | for data, rem_edge_list, ori_edge_list, attr, (start_idx, end_idx) in train_data: 245 | node_feature, node_type, edge_time, edge_index, edge_type, node_dict, edge_dict = data 246 | node_feature = node_feature.detach() 247 | node_feature[start_idx : end_idx] = gpt_gnn.init_emb 248 | node_emb = gpt_gnn.gnn(node_feature.to(device), node_type.to(device), edge_time.to(device), \ 249 | edge_index.to(device), edge_type.to(device)) 250 | 251 | loss_link, _ = gpt_gnn.link_loss(node_emb, rem_edge_list, ori_edge_list, node_dict, target_type, use_queue = True, update_queue=True) 252 | if args.attr_type == 'text': 253 | loss_attr = gpt_gnn.text_loss(node_emb[start_idx : end_idx], attr, w2v_model, device) 254 | else: 255 | loss_attr = gpt_gnn.feat_loss(node_emb[start_idx : end_idx], torch.FloatTensor(attr).to(device)) 256 | 257 | 258 | loss = loss_link + loss_attr * args.attr_ratio 259 | 260 | 261 | optimizer.zero_grad() 262 | loss.backward() 263 | torch.nn.utils.clip_grad_norm_(gpt_gnn.parameters(), args.clip) 264 | optimizer.step() 265 | 266 | train_link_losses += [loss_link.item()] 267 | train_attr_losses += [loss_attr.item()] 268 | scheduler.step() 269 | ''' 270 | Valid 271 | ''' 272 | gpt_gnn.eval() 273 | with torch.no_grad(): 274 | data, rem_edge_list, ori_edge_list, attr, (start_idx, end_idx) = valid_data 275 | node_feature, node_type, edge_time, edge_index, edge_type, node_dict, edge_dict = data 276 | node_feature = node_feature.detach() 277 | node_feature[start_idx : end_idx] = gpt_gnn.init_emb 278 | node_emb = gpt_gnn.gnn(node_feature.to(device), node_type.to(device), edge_time.to(device), \ 279 | edge_index.to(device), edge_type.to(device)) 280 | loss_link, ress = gpt_gnn.link_loss(node_emb, rem_edge_list, ori_edge_list, node_dict, target_type, use_queue = False, update_queue=True) 281 | loss_link = loss_link.item() 282 | if args.attr_type == 'text': 283 | loss_attr = gpt_gnn.text_loss(node_emb[start_idx : end_idx], attr, w2v_model, device) 284 | else: 285 | loss_attr = gpt_gnn.feat_loss(node_emb[start_idx : end_idx], torch.FloatTensor(attr).to(device)) 286 | 287 | ndcgs = [] 288 | for i in ress: 289 | ai = np.zeros(len(i[0])) 290 | ai[0] = 1 291 | ndcgs += [ndcg_at_k(ai[j.cpu().numpy()], len(j)) for j in i.argsort(descending = True)] 292 | 293 | valid_loss = loss_link + loss_attr * args.attr_ratio 294 | st = time.time() 295 | print(("Epoch: %d, (%d / %d) %.1fs LR: %.5f Train Loss: (%.3f, %.3f) Valid Loss: (%.3f, %.3f) NDCG: %.3f Norm: %.3f queue: %d") % \ 296 | (epoch, batch, repeat_num, (st-et), optimizer.param_groups[0]['lr'], np.average(train_link_losses), np.average(train_attr_losses), \ 297 | loss_link, loss_attr, np.average(ndcgs), node_emb.norm(dim=1).mean(), gpt_gnn.neg_queue_size)) 298 | 299 | if valid_loss < best_val: 300 | best_val = valid_loss 301 | print('UPDATE!!!') 302 | torch.save(gpt_gnn.state_dict(), args.pretrain_model_dir) 303 | stats += [[np.average(train_link_losses), loss_link, loss_attr, valid_loss]] 304 | -------------------------------------------------------------------------------- /example_reddit/.ipynb_checkpoints/pretrain_reddit-checkpoint.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from GPT_GNN.data import * 3 | from GPT_GNN.model import * 4 | from warnings import filterwarnings 5 | filterwarnings("ignore") 6 | 7 | import argparse 8 | 9 | parser = argparse.ArgumentParser(description='Pre-training HGT on a given graph (heterogeneous / homogeneous)') 10 | 11 | ''' 12 | GPT-GNN arguments 13 | ''' 14 | parser.add_argument('--attr_ratio', type=float, default=0.5, 15 | help='Ratio of attr-loss against link-loss, range: [0-1]') 16 | parser.add_argument('--attr_type', type=str, default='vec', 17 | choices=['text', 'vec'], 18 | help='The type of attribute decoder') 19 | parser.add_argument('--neg_samp_num', type=int, default=255, 20 | help='Maximum number of negative sample for each target node.') 21 | parser.add_argument('--queue_size', type=int, default=256, 22 | help='Max size of adaptive embedding queue.') 23 | parser.add_argument('--w2v_dir', type=str, default='/datadrive/dataset/w2v_all', 24 | help='The address of preprocessed graph.') 25 | 26 | ''' 27 | Dataset arguments 28 | ''' 29 | parser.add_argument('--data_dir', type=str, default='/datadrive/dataset/graph_reddit.pk', 30 | help='The address of preprocessed graph.') 31 | parser.add_argument('--pretrain_model_dir', type=str, default='/datadrive/models/gpt_all_reddit', 32 | help='The address for storing the pre-trained models.') 33 | parser.add_argument('--cuda', type=int, default=1, 34 | help='Avaiable GPU ID') 35 | parser.add_argument('--sample_depth', type=int, default=6, 36 | help='How many layers within a mini-batch subgraph') 37 | parser.add_argument('--sample_width', type=int, default=128, 38 | help='How many nodes to be sampled per layer per type') 39 | 40 | ''' 41 | Model arguments 42 | ''' 43 | parser.add_argument('--conv_name', type=str, default='hgt', 44 | choices=['hgt', 'gcn', 'gat', 'rgcn', 'han', 'hetgnn'], 45 | help='The name of GNN filter. By default is Heterogeneous Graph Transformer (hgt)') 46 | parser.add_argument('--n_hid', type=int, default=400, 47 | help='Number of hidden dimension') 48 | parser.add_argument('--n_heads', type=int, default=8, 49 | help='Number of attention head') 50 | parser.add_argument('--n_layers', type=int, default=3, 51 | help='Number of GNN layers') 52 | parser.add_argument('--prev_norm', help='Whether to add layer-norm on the previous layers', action='store_true') 53 | parser.add_argument('--last_norm', help='Whether to add layer-norm on the last layers', action='store_true') 54 | parser.add_argument('--dropout', type=int, default=0.2, 55 | help='Dropout ratio') 56 | 57 | ''' 58 | Optimization arguments 59 | ''' 60 | parser.add_argument('--max_lr', type=float, default=1e-3, 61 | help='Maximum learning rate.') 62 | parser.add_argument('--scheduler', type=str, default='cycle', 63 | help='Name of learning rate scheduler.' , choices=['cycle', 'cosine']) 64 | parser.add_argument('--n_epoch', type=int, default=20, 65 | help='Number of epoch to run') 66 | parser.add_argument('--n_pool', type=int, default=8, 67 | help='Number of process to sample subgraph') 68 | parser.add_argument('--n_batch', type=int, default=32, 69 | help='Number of batch (sampled graphs) for each epoch') 70 | parser.add_argument('--batch_size', type=int, default=256, 71 | help='Number of output nodes for training') 72 | parser.add_argument('--clip', type=float, default=0.5, 73 | help='Gradient Norm Clipping') 74 | 75 | args = parser.parse_args() 76 | args_print(args) 77 | 78 | 79 | if args.cuda != -1: 80 | device = torch.device("cuda:" + str(args.cuda)) 81 | else: 82 | device = torch.device("cpu") 83 | 84 | 85 | print('Start Loading Graph Data...') 86 | graph_reddit = dill.load(open(args.data_dir, 'rb')) 87 | print('Finish Loading Graph Data!') 88 | 89 | target_type = 'def' 90 | rel_stop_list = ['self'] 91 | 92 | pre_target_nodes = graph_reddit.pre_target_nodes 93 | train_target_nodes = graph_reddit.train_target_nodes 94 | 95 | pre_target_nodes = np.concatenate([pre_target_nodes, np.ones(len(pre_target_nodes))]).reshape(2, -1).transpose() 96 | train_target_nodes = np.concatenate([train_target_nodes, np.ones(len(train_target_nodes))]).reshape(2, -1).transpose() 97 | 98 | 99 | def GPT_sample(seed, target_nodes, time_range, batch_size, feature_extractor): 100 | np.random.seed(seed) 101 | samp_target_nodes = target_nodes[np.random.choice(len(target_nodes), batch_size)] 102 | threshold = 0.5 103 | feature, times, edge_list, _, attr = sample_subgraph(graph, time_range, \ 104 | inp = {target_type: samp_target_nodes}, feature_extractor = feature_extractor, \ 105 | sampled_depth = args.sample_depth, sampled_number = args.sample_width) 106 | rem_edge_list = defaultdict( #source_type 107 | lambda: defaultdict( #relation_type 108 | lambda: [] # [target_id, source_id] 109 | )) 110 | 111 | ori_list = {} 112 | for source_type in edge_list[target_type]: 113 | ori_list[source_type] = {} 114 | for relation_type in edge_list[target_type][source_type]: 115 | ori_list[source_type][relation_type] = np.array(edge_list[target_type][source_type][relation_type]) 116 | el = [] 117 | for target_ser, source_ser in edge_list[target_type][source_type][relation_type]: 118 | if target_ser < source_ser: 119 | if relation_type not in rel_stop_list and target_ser < batch_size and \ 120 | np.random.random() > threshold: 121 | rem_edge_list[source_type][relation_type] += [[target_ser, source_ser]] 122 | continue 123 | el += [[target_ser, source_ser]] 124 | el += [[source_ser, target_ser]] 125 | el = np.array(el) 126 | edge_list[target_type][source_type][relation_type] = el 127 | 128 | if relation_type == 'self': 129 | continue 130 | 131 | ''' 132 | Adding feature nodes: 133 | ''' 134 | n_target_nodes = len(feature[target_type]) 135 | feature[target_type] = np.concatenate((feature[target_type], np.zeros([batch_size, feature[target_type].shape[1]]))) 136 | times[target_type] = np.concatenate((times[target_type], times[target_type][:batch_size])) 137 | 138 | for source_type in edge_list[target_type]: 139 | for relation_type in edge_list[target_type][source_type]: 140 | el = [] 141 | for target_ser, source_ser in edge_list[target_type][source_type][relation_type]: 142 | if target_ser < batch_size: 143 | if relation_type == 'self': 144 | el += [[target_ser + n_target_nodes, target_ser + n_target_nodes]] 145 | else: 146 | el += [[target_ser + n_target_nodes, source_ser]] 147 | if len(el) > 0: 148 | edge_list[target_type][source_type][relation_type] = \ 149 | np.concatenate((edge_list[target_type][source_type][relation_type], el)) 150 | 151 | 152 | rem_edge_lists = {} 153 | for source_type in rem_edge_list: 154 | rem_edge_lists[source_type] = {} 155 | for relation_type in rem_edge_list[source_type]: 156 | rem_edge_lists[source_type][relation_type] = np.array(rem_edge_list[source_type][relation_type]) 157 | del rem_edge_list 158 | 159 | return to_torch(feature, times, edge_list, graph), rem_edge_lists, ori_list, \ 160 | attr[:batch_size], (n_target_nodes, n_target_nodes + batch_size) 161 | 162 | 163 | 164 | 165 | def prepare_data(pool): 166 | jobs = [] 167 | for _ in np.arange(args.n_batch - 1): 168 | jobs.append(pool.apply_async(GPT_sample, args=(randint(), pre_target_nodes, {1: True}, args.batch_size, feature_reddit))) 169 | jobs.append(pool.apply_async(GPT_sample, args=(randint(), train_target_nodes, {1: True}, args.batch_size, feature_reddit))) 170 | return jobs 171 | 172 | 173 | pool = mp.Pool(args.n_pool) 174 | st = time.time() 175 | jobs = prepare_data(pool) 176 | repeat_num = int(len(pre_target_nodes) / args.batch_size // args.n_batch) 177 | 178 | 179 | data, rem_edge_list, ori_edge_list, _, _ = GPT_sample(randint(), pre_target_nodes, {1: True}, args.batch_size, feature_reddit) 180 | node_feature, node_type, edge_time, edge_index, edge_type, node_dict, edge_dict = data 181 | types = graph.get_types() 182 | 183 | 184 | gnn = GNN(conv_name = args.conv_name, in_dim = len(graph.node_feature[target_type]['emb'].values[0]), n_hid = args.n_hid, \ 185 | n_heads = args.n_heads, n_layers = args.n_layers, dropout = args.dropout, num_types = len(types), \ 186 | num_relations = len(graph.get_meta_graph()) + 1, prev_norm = args.prev_norm, last_norm = args.last_norm, use_RTE = False) 187 | 188 | if args.attr_type == 'text': 189 | from gensim.models import Word2Vec 190 | w2v_model = Word2Vec.load(args.w2v_dir) 191 | n_tokens = len(w2v_model.wv.vocab) 192 | attr_decoder = RNNModel(n_word = n_tokens, ninp = gnn.n_hid, \ 193 | nhid = w2v_model.vector_size, nlayers = 2) 194 | attr_decoder.from_w2v(torch.FloatTensor(w2v_model.wv.vectors)) 195 | else: 196 | attr_decoder = Matcher(gnn.n_hid, gnn.in_dim) 197 | 198 | gpt_gnn = GPT_GNN(gnn = gnn, rem_edge_list = rem_edge_list, attr_decoder = attr_decoder, \ 199 | types = types, neg_samp_num = args.neg_samp_num, device = device) 200 | gpt_gnn.init_emb.data = node_feature[node_type == node_dict[target_type][1]].mean(dim=0).detach() 201 | gpt_gnn = gpt_gnn.to(device) 202 | 203 | 204 | best_val = 100000 205 | train_step = 0 206 | stats = [] 207 | optimizer = torch.optim.AdamW(gpt_gnn.parameters(), weight_decay = 1e-2, eps=1e-06, lr = args.max_lr) 208 | 209 | if args.scheduler == 'cycle': 210 | scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, pct_start=0.02, anneal_strategy='linear', final_div_factor=100,\ 211 | max_lr = args.max_lr, total_steps = repeat_num * args.n_batch * args.n_epoch + 1) 212 | elif args.scheduler == 'cosine': 213 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, repeat_num * args.n_batch, eta_min=1e-6) 214 | 215 | print('Start Pretraining...') 216 | for epoch in np.arange(args.n_epoch) + 1: 217 | gpt_gnn.neg_queue_size = args.queue_size * epoch // args.n_epoch 218 | for batch in np.arange(repeat_num) + 1: 219 | train_data = [job.get() for job in jobs[:-1]] 220 | valid_data = jobs[-1].get() 221 | pool.close() 222 | pool.join() 223 | pool = mp.Pool(args.n_pool) 224 | jobs = prepare_data(pool) 225 | et = time.time() 226 | print('Data Preparation: %.1fs' % (et - st)) 227 | 228 | train_link_losses = [] 229 | train_attr_losses = [] 230 | gpt_gnn.train() 231 | for data, rem_edge_list, ori_edge_list, attr, (start_idx, end_idx) in train_data: 232 | node_feature, node_type, edge_time, edge_index, edge_type, node_dict, edge_dict = data 233 | node_feature = node_feature.detach() 234 | node_feature[start_idx : end_idx] = gpt_gnn.init_emb 235 | node_emb = gpt_gnn.gnn(node_feature.to(device), node_type.to(device), edge_time.to(device), \ 236 | edge_index.to(device), edge_type.to(device)) 237 | 238 | loss_link, _ = gpt_gnn.link_loss(node_emb, rem_edge_list, ori_edge_list, node_dict, target_type, use_queue = True, update_queue=True) 239 | if args.attr_type == 'text': 240 | loss_attr = gpt_gnn.text_loss(node_emb[start_idx : end_idx], attr, w2v_model, device) 241 | else: 242 | loss_attr = gpt_gnn.feat_loss(node_emb[start_idx : end_idx], torch.FloatTensor(attr).to(device)) 243 | 244 | 245 | loss = loss_link * (1 - args.attr_ratio) + loss_attr * args.attr_ratio 246 | 247 | 248 | optimizer.zero_grad() 249 | loss.backward() 250 | torch.nn.utils.clip_grad_norm_(gpt_gnn.parameters(), args.clip) 251 | optimizer.step() 252 | 253 | train_link_losses += [loss_link.item()] 254 | train_attr_losses += [loss_attr.item()] 255 | scheduler.step() 256 | ''' 257 | Valid 258 | ''' 259 | gpt_gnn.eval() 260 | with torch.no_grad(): 261 | data, rem_edge_list, ori_edge_list, attr, (start_idx, end_idx) = valid_data 262 | node_feature, node_type, edge_time, edge_index, edge_type, node_dict, edge_dict = data 263 | node_feature = node_feature.detach() 264 | node_feature[start_idx : end_idx] = gpt_gnn.init_emb 265 | node_emb = gpt_gnn.gnn(node_feature.to(device), node_type.to(device), edge_time.to(device), \ 266 | edge_index.to(device), edge_type.to(device)) 267 | loss_link, ress = gpt_gnn.link_loss(node_emb, rem_edge_list, ori_edge_list, node_dict, target_type, use_queue = False, update_queue=True) 268 | loss_link = loss_link.item() 269 | if args.attr_type == 'text': 270 | loss_attr = gpt_gnn.text_loss(node_emb[start_idx : end_idx], attr, w2v_model, device) 271 | else: 272 | loss_attr = gpt_gnn.feat_loss(node_emb[start_idx : end_idx], torch.FloatTensor(attr).to(device)) 273 | 274 | ndcgs = [] 275 | for i in ress: 276 | ai = np.zeros(len(i[0])) 277 | ai[0] = 1 278 | ndcgs += [ndcg_at_k(ai[j.cpu().numpy()], len(j)) for j in i.argsort(descending = True)] 279 | 280 | valid_loss = loss_link * (1 - args.attr_ratio) + loss_attr * args.attr_ratio 281 | st = time.time() 282 | print(("Epoch: %d, (%d / %d) %.1fs LR: %.5f Train Loss: (%.3f, %.3f) Valid Loss: (%.3f, %.3f) NDCG: %.3f Norm: %.3f queue: %d") % \ 283 | (epoch, batch, repeat_num, (st-et), optimizer.param_groups[0]['lr'], np.average(train_link_losses), np.average(train_attr_losses), \ 284 | loss_link, loss_attr, np.average(ndcgs), node_emb.norm(dim=1).mean(), gpt_gnn.neg_queue_size)) 285 | 286 | if valid_loss < best_val: 287 | best_val = valid_loss 288 | print('UPDATE!!!') 289 | torch.save(gpt_gnn.state_dict(), args.pretrain_model_dir) 290 | stats += [[np.average(train_link_losses), loss_link, loss_attr, valid_loss]] 291 | -------------------------------------------------------------------------------- /example_reddit/GPT_GNN/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /example_reddit/GPT_GNN/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/acbull/GPT-GNN/f26e13c69ddc8a3f2580cb16d0b9a1c73d89f4bc/example_reddit/GPT_GNN/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /example_reddit/GPT_GNN/__pycache__/conv.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/acbull/GPT-GNN/f26e13c69ddc8a3f2580cb16d0b9a1c73d89f4bc/example_reddit/GPT_GNN/__pycache__/conv.cpython-37.pyc -------------------------------------------------------------------------------- /example_reddit/GPT_GNN/__pycache__/data.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/acbull/GPT-GNN/f26e13c69ddc8a3f2580cb16d0b9a1c73d89f4bc/example_reddit/GPT_GNN/__pycache__/data.cpython-37.pyc -------------------------------------------------------------------------------- /example_reddit/GPT_GNN/__pycache__/model.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/acbull/GPT-GNN/f26e13c69ddc8a3f2580cb16d0b9a1c73d89f4bc/example_reddit/GPT_GNN/__pycache__/model.cpython-37.pyc -------------------------------------------------------------------------------- /example_reddit/GPT_GNN/__pycache__/utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/acbull/GPT-GNN/f26e13c69ddc8a3f2580cb16d0b9a1c73d89f4bc/example_reddit/GPT_GNN/__pycache__/utils.cpython-37.pyc -------------------------------------------------------------------------------- /example_reddit/GPT_GNN/conv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | from torch_geometric.nn import GCNConv, GATConv 6 | from torch_geometric.nn.conv import MessagePassing 7 | from torch_geometric.nn.inits import glorot, uniform 8 | from torch_geometric.utils import softmax 9 | import math 10 | 11 | class HGTConv(MessagePassing): 12 | def __init__(self, in_dim, out_dim, num_types, num_relations, n_heads, dropout = 0.2, use_norm = True, use_RTE = True, **kwargs): 13 | super(HGTConv, self).__init__(node_dim=0, aggr='add', **kwargs) 14 | 15 | self.in_dim = in_dim 16 | self.out_dim = out_dim 17 | self.num_types = num_types 18 | self.num_relations = num_relations 19 | self.total_rel = num_types * num_relations * num_types 20 | self.n_heads = n_heads 21 | self.d_k = out_dim // n_heads 22 | self.sqrt_dk = math.sqrt(self.d_k) 23 | self.use_norm = use_norm 24 | self.att = None 25 | 26 | 27 | self.k_linears = nn.ModuleList() 28 | self.q_linears = nn.ModuleList() 29 | self.v_linears = nn.ModuleList() 30 | self.a_linears = nn.ModuleList() 31 | self.norms = nn.ModuleList() 32 | 33 | for t in range(num_types): 34 | self.k_linears.append(nn.Linear(in_dim, out_dim)) 35 | self.q_linears.append(nn.Linear(in_dim, out_dim)) 36 | self.v_linears.append(nn.Linear(in_dim, out_dim)) 37 | self.a_linears.append(nn.Linear(out_dim, out_dim)) 38 | if use_norm: 39 | self.norms.append(nn.LayerNorm(out_dim)) 40 | ''' 41 | TODO: make relation_pri smaller, as not all pair exist in meta relation list. 42 | ''' 43 | self.relation_pri = nn.Parameter(torch.ones(num_relations, self.n_heads)) 44 | self.relation_att = nn.Parameter(torch.Tensor(num_relations, n_heads, self.d_k, self.d_k)) 45 | self.relation_msg = nn.Parameter(torch.Tensor(num_relations, n_heads, self.d_k, self.d_k)) 46 | self.skip = nn.Parameter(torch.ones(num_types)) 47 | self.drop = nn.Dropout(dropout) 48 | self.emb = RelTemporalEncoding(in_dim) 49 | 50 | glorot(self.relation_att) 51 | glorot(self.relation_msg) 52 | 53 | def forward(self, node_inp, node_type, edge_index, edge_type, edge_time): 54 | return self.propagate(edge_index, node_inp=node_inp, node_type=node_type, \ 55 | edge_type=edge_type, edge_time = edge_time) 56 | 57 | def message(self, edge_index_i, node_inp_i, node_inp_j, node_type_i, node_type_j, edge_type, edge_time): 58 | ''' 59 | j: source, i: target; 60 | ''' 61 | data_size = edge_index_i.size(0) 62 | ''' 63 | Create Attention and Message tensor beforehand. 64 | ''' 65 | res_att = torch.zeros(data_size, self.n_heads).to(node_inp_i.device) 66 | res_msg = torch.zeros(data_size, self.n_heads, self.d_k).to(node_inp_i.device) 67 | 68 | for source_type in range(self.num_types): 69 | sb = (node_type_j == int(source_type)) 70 | k_linear = self.k_linears[source_type] 71 | v_linear = self.v_linears[source_type] 72 | for target_type in range(self.num_types): 73 | tb = (node_type_i == int(target_type)) & sb 74 | q_linear = self.q_linears[target_type] 75 | for relation_type in range(self.num_relations): 76 | ''' 77 | idx is all the edges with meta relation 78 | ''' 79 | idx = (edge_type == int(relation_type)) & tb 80 | if idx.sum() == 0: 81 | continue 82 | ''' 83 | Get the corresponding input node representations by idx. 84 | Add tempotal encoding to source representation (j) 85 | ''' 86 | target_node_vec = node_inp_i[idx] 87 | source_node_vec = self.emb(node_inp_j[idx], edge_time[idx]) 88 | 89 | ''' 90 | Step 1: Heterogeneous Mutual Attention 91 | ''' 92 | q_mat = q_linear(target_node_vec).view(-1, self.n_heads, self.d_k) 93 | k_mat = k_linear(source_node_vec).view(-1, self.n_heads, self.d_k) 94 | k_mat = torch.bmm(k_mat.transpose(1,0), self.relation_att[relation_type]).transpose(1,0) 95 | res_att[idx] = (q_mat * k_mat).sum(dim=-1) * self.relation_pri[relation_type] / self.sqrt_dk 96 | ''' 97 | Step 2: Heterogeneous Message Passing 98 | ''' 99 | v_mat = v_linear(source_node_vec).view(-1, self.n_heads, self.d_k) 100 | res_msg[idx] = torch.bmm(v_mat.transpose(1,0), self.relation_msg[relation_type]).transpose(1,0) 101 | ''' 102 | Softmax based on target node's id (edge_index_i). Store attention value in self.att for later visualization. 103 | ''' 104 | self.att = softmax(res_att, edge_index_i) 105 | res = res_msg * self.att.view(-1, self.n_heads, 1) 106 | del res_att, res_msg 107 | return res.view(-1, self.out_dim) 108 | 109 | 110 | def update(self, aggr_out, node_inp, node_type): 111 | ''' 112 | Step 3: Target-specific Aggregation 113 | x = W[node_type] * gelu(Agg(x)) + x 114 | ''' 115 | aggr_out = F.gelu(aggr_out) 116 | res = torch.zeros(aggr_out.size(0), self.out_dim).to(node_inp.device) 117 | for target_type in range(self.num_types): 118 | idx = (node_type == int(target_type)) 119 | if idx.sum() == 0: 120 | continue 121 | trans_out = self.a_linears[target_type](aggr_out[idx]) 122 | ''' 123 | Add skip connection with learnable weight self.skip[t_id] 124 | ''' 125 | alpha = torch.sigmoid(self.skip[target_type]) 126 | if self.use_norm: 127 | res[idx] = self.norms[target_type](trans_out * alpha + node_inp[idx] * (1 - alpha)) 128 | else: 129 | res[idx] = trans_out * alpha + node_inp[idx] * (1 - alpha) 130 | return self.drop(res) 131 | 132 | def __repr__(self): 133 | return '{}(in_dim={}, out_dim={}, num_types={}, num_types={})'.format( 134 | self.__class__.__name__, self.in_dim, self.out_dim, 135 | self.num_types, self.num_relations) 136 | 137 | 138 | class RelTemporalEncoding(nn.Module): 139 | ''' 140 | Implement the Temporal Encoding (Sinusoid) function. 141 | ''' 142 | def __init__(self, n_hid, max_len = 240, dropout = 0.2): 143 | super(RelTemporalEncoding, self).__init__() 144 | self.drop = nn.Dropout(dropout) 145 | position = torch.arange(0., max_len).unsqueeze(1) 146 | div_term = 1 / (10000 ** (torch.arange(0., n_hid * 2, 2.)) / n_hid / 2) 147 | self.emb = nn.Embedding(max_len, n_hid * 2) 148 | self.emb.weight.data[:, 0::2] = torch.sin(position * div_term) / math.sqrt(n_hid) 149 | self.emb.weight.data[:, 1::2] = torch.cos(position * div_term) / math.sqrt(n_hid) 150 | self.emb.requires_grad = False 151 | self.lin = nn.Linear(n_hid * 2, n_hid) 152 | def forward(self, x, t): 153 | return x + self.lin(self.drop(self.emb(t))) 154 | 155 | 156 | 157 | class GeneralConv(nn.Module): 158 | def __init__(self, conv_name, in_hid, out_hid, num_types, num_relations, n_heads, dropout, use_norm = True, use_RTE = True): 159 | super(GeneralConv, self).__init__() 160 | self.conv_name = conv_name 161 | if self.conv_name == 'hgt': 162 | self.base_conv = HGTConv(in_hid, out_hid, num_types, num_relations, n_heads, dropout, use_norm, use_RTE) 163 | elif self.conv_name == 'gcn': 164 | self.base_conv = GCNConv(in_hid, out_hid) 165 | elif self.conv_name == 'gat': 166 | self.base_conv = GATConv(in_hid, out_hid // n_heads, heads=n_heads) 167 | def forward(self, meta_xs, node_type, edge_index, edge_type, edge_time): 168 | if self.conv_name == 'hgt': 169 | return self.base_conv(meta_xs, node_type, edge_index, edge_type, edge_time) 170 | elif self.conv_name == 'gcn': 171 | return self.base_conv(meta_xs, edge_index) 172 | elif self.conv_name == 'gat': 173 | return self.base_conv(meta_xs, edge_index) 174 | 175 | 176 | -------------------------------------------------------------------------------- /example_reddit/GPT_GNN/data.py: -------------------------------------------------------------------------------- 1 | import json, os 2 | import math, copy, time 3 | import numpy as np 4 | from collections import defaultdict 5 | import pandas as pd 6 | 7 | import math 8 | from tqdm import tqdm 9 | 10 | import seaborn as sb 11 | import matplotlib.pyplot as plt 12 | import matplotlib.cm as cm 13 | 14 | from .utils import * 15 | 16 | import dill 17 | from functools import partial 18 | import multiprocessing as mp 19 | 20 | class Graph(): 21 | def __init__(self): 22 | super(Graph, self).__init__() 23 | ''' 24 | node_forward and bacward are only used when building the data. 25 | Afterwards will be transformed into node_feature by DataFrame 26 | 27 | node_forward: name -> node_id 28 | node_bacward: node_id -> feature_dict 29 | node_feature: a DataFrame containing all features 30 | ''' 31 | self.node_forward = defaultdict(lambda: {}) 32 | self.node_bacward = defaultdict(lambda: []) 33 | self.node_feature = defaultdict(lambda: []) 34 | 35 | ''' 36 | edge_list: index the adjacancy matrix (time) by 37 | 38 | ''' 39 | self.edge_list = defaultdict( #target_type 40 | lambda: defaultdict( #source_type 41 | lambda: defaultdict( #relation_type 42 | lambda: defaultdict( #target_id 43 | lambda: defaultdict( #source_id( 44 | lambda: int # time 45 | ))))) 46 | self.times = {} 47 | def add_node(self, node): 48 | nfl = self.node_forward[node['type']] 49 | if node['id'] not in nfl: 50 | self.node_bacward[node['type']] += [node] 51 | ser = len(nfl) 52 | nfl[node['id']] = ser 53 | return ser 54 | return nfl[node['id']] 55 | def add_edge(self, source_node, target_node, time = None, relation_type = None, directed = True): 56 | edge = [self.add_node(source_node), self.add_node(target_node)] 57 | ''' 58 | Add bi-directional edges with different relation type 59 | ''' 60 | self.edge_list[target_node['type']][source_node['type']][relation_type][edge[1]][edge[0]] = time 61 | if directed: 62 | self.edge_list[source_node['type']][target_node['type']]['rev_' + relation_type][edge[0]][edge[1]] = time 63 | else: 64 | self.edge_list[source_node['type']][target_node['type']][relation_type][edge[0]][edge[1]] = time 65 | self.times[time] = True 66 | 67 | def update_node(self, node): 68 | nbl = self.node_bacward[node['type']] 69 | ser = self.add_node(node) 70 | for k in node: 71 | if k not in nbl[ser]: 72 | nbl[ser][k] = node[k] 73 | 74 | def get_meta_graph(self): 75 | types = self.get_types() 76 | metas = [] 77 | for target_type in self.edge_list: 78 | for source_type in self.edge_list[target_type]: 79 | for r_type in self.edge_list[target_type][source_type]: 80 | metas += [(target_type, source_type, r_type)] 81 | return metas 82 | 83 | def get_types(self): 84 | return list(self.node_feature.keys()) 85 | 86 | 87 | 88 | def sample_subgraph(graph, time_range, sampled_depth = 2, sampled_number = 8, inp = None, feature_extractor = feature_OAG): 89 | ''' 90 | Sample Sub-Graph based on the connection of other nodes with currently sampled nodes 91 | We maintain budgets for each node type, indexed by . 92 | Currently sampled nodes are stored in layer_data. 93 | After nodes are sampled, we construct the sampled adjacancy matrix. 94 | ''' 95 | layer_data = defaultdict( #target_type 96 | lambda: {} # {target_id: [ser, time]} 97 | ) 98 | budget = defaultdict( #source_type 99 | lambda: defaultdict( #source_id 100 | lambda: [0., 0] #[sampled_score, time] 101 | )) 102 | new_layer_adj = defaultdict( #target_type 103 | lambda: defaultdict( #source_type 104 | lambda: defaultdict( #relation_type 105 | lambda: [] #[target_id, source_id] 106 | ))) 107 | ''' 108 | For each node being sampled, we find out all its neighborhood, 109 | adding the degree count of these nodes in the budget. 110 | Note that there exist some nodes that have many neighborhoods 111 | (such as fields, venues), for those case, we only consider 112 | ''' 113 | def add_budget(te, target_id, target_time, layer_data, budget): 114 | for source_type in te: 115 | tes = te[source_type] 116 | for relation_type in tes: 117 | if relation_type == 'self' or target_id not in tes[relation_type]: 118 | continue 119 | adl = tes[relation_type][target_id] 120 | if len(adl) < sampled_number: 121 | sampled_ids = list(adl.keys()) 122 | else: 123 | sampled_ids = np.random.choice(list(adl.keys()), sampled_number, replace = False) 124 | for source_id in sampled_ids: 125 | source_time = adl[source_id] 126 | if source_time == None: 127 | source_time = target_time 128 | if source_time > np.max(list(time_range.keys())) or source_id in layer_data[source_type]: 129 | continue 130 | budget[source_type][source_id][0] += 1. / len(sampled_ids) 131 | budget[source_type][source_id][1] = source_time 132 | 133 | ''' 134 | First adding the sampled nodes then updating budget. 135 | ''' 136 | for _type in inp: 137 | for _id, _time in inp[_type]: 138 | layer_data[_type][_id] = [len(layer_data[_type]), _time] 139 | for _type in inp: 140 | te = graph.edge_list[_type] 141 | for _id, _time in inp[_type]: 142 | add_budget(te, _id, _time, layer_data, budget) 143 | ''' 144 | We recursively expand the sampled graph by sampled_depth. 145 | Each time we sample a fixed number of nodes for each budget, 146 | based on the accumulated degree. 147 | ''' 148 | for layer in range(sampled_depth): 149 | sts = list(budget.keys()) 150 | for source_type in sts: 151 | te = graph.edge_list[source_type] 152 | keys = np.array(list(budget[source_type].keys())) 153 | if sampled_number > len(keys): 154 | ''' 155 | Directly sample all the nodes 156 | ''' 157 | sampled_ids = np.arange(len(keys)) 158 | else: 159 | ''' 160 | Sample based on accumulated degree 161 | ''' 162 | score = np.array(list(budget[source_type].values()))[:,0] ** 2 163 | score = score / np.sum(score) 164 | sampled_ids = np.random.choice(len(score), sampled_number, p = score, replace = False) 165 | sampled_keys = keys[sampled_ids] 166 | ''' 167 | First adding the sampled nodes then updating budget. 168 | ''' 169 | for k in sampled_keys: 170 | layer_data[source_type][k] = [len(layer_data[source_type]), budget[source_type][k][1]] 171 | for k in sampled_keys: 172 | add_budget(te, k, budget[source_type][k][1], layer_data, budget) 173 | budget[source_type].pop(k) 174 | ''' 175 | Prepare feature, time and adjacency matrix for the sampled graph 176 | ''' 177 | feature, times, indxs, texts = feature_extractor(layer_data, graph) 178 | 179 | edge_list = defaultdict( #target_type 180 | lambda: defaultdict( #source_type 181 | lambda: defaultdict( #relation_type 182 | lambda: [] # [target_id, source_id] 183 | ))) 184 | for _type in layer_data: 185 | for _key in layer_data[_type]: 186 | _ser = layer_data[_type][_key][0] 187 | edge_list[_type][_type]['self'] += [[_ser, _ser]] 188 | ''' 189 | Reconstruct sampled adjacancy matrix by checking whether each 190 | link exist in the original graph 191 | ''' 192 | for target_type in graph.edge_list: 193 | te = graph.edge_list[target_type] 194 | tld = layer_data[target_type] 195 | for source_type in te: 196 | tes = te[source_type] 197 | sld = layer_data[source_type] 198 | for relation_type in tes: 199 | tesr = tes[relation_type] 200 | for target_key in tld: 201 | if target_key not in tesr: 202 | continue 203 | target_ser = tld[target_key][0] 204 | for source_key in tesr[target_key]: 205 | ''' 206 | Check whether each link (target_id, source_id) exist in original adjacancy matrix 207 | ''' 208 | if source_key in sld: 209 | source_ser = sld[source_key][0] 210 | edge_list[target_type][source_type][relation_type] += [[target_ser, source_ser]] 211 | return feature, times, edge_list, indxs, texts 212 | 213 | def to_torch(feature, time, edge_list, graph): 214 | ''' 215 | Transform a sampled sub-graph into pytorch Tensor 216 | node_dict: {node_type: } node_number is used to trace back the nodes in original graph. 217 | edge_dict: {edge_type: edge_type_ID} 218 | ''' 219 | node_dict = {} 220 | node_feature = [] 221 | node_type = [] 222 | node_time = [] 223 | edge_index = [] 224 | edge_type = [] 225 | edge_time = [] 226 | 227 | node_num = 0 228 | types = graph.get_types() 229 | for t in types: 230 | node_dict[t] = [node_num, len(node_dict)] 231 | node_num += len(feature[t]) 232 | 233 | if 'fake_paper' in feature: 234 | node_dict['fake_paper'] = [node_num, node_dict['paper'][1]] 235 | node_num += len(feature['fake_paper']) 236 | types += ['fake_paper'] 237 | 238 | for t in types: 239 | node_feature += list(feature[t]) 240 | node_time += list(time[t]) 241 | node_type += [node_dict[t][1] for _ in range(len(feature[t]))] 242 | 243 | edge_dict = {e[2]: i for i, e in enumerate(graph.get_meta_graph())} 244 | edge_dict['self'] = len(edge_dict) 245 | 246 | for target_type in edge_list: 247 | for source_type in edge_list[target_type]: 248 | for relation_type in edge_list[target_type][source_type]: 249 | for ii, (ti, si) in enumerate(edge_list[target_type][source_type][relation_type]): 250 | tid, sid = ti + node_dict[target_type][0], si + node_dict[source_type][0] 251 | edge_index += [[sid, tid]] 252 | edge_type += [edge_dict[relation_type]] 253 | ''' 254 | Our time ranges from 1900 - 2020, largest span is 120. 255 | ''' 256 | edge_time += [node_time[tid] - node_time[sid] + 120] 257 | node_feature = torch.FloatTensor(node_feature) 258 | node_type = torch.LongTensor(node_type) 259 | edge_time = torch.LongTensor(edge_time) 260 | edge_index = torch.LongTensor(edge_index).t() 261 | edge_type = torch.LongTensor(edge_type) 262 | return node_feature, node_type, edge_time, edge_index, edge_type, node_dict, edge_dict 263 | 264 | -------------------------------------------------------------------------------- /example_reddit/GPT_GNN/model.py: -------------------------------------------------------------------------------- 1 | from .conv import * 2 | import numpy as np 3 | from gensim.parsing.preprocessing import * 4 | 5 | 6 | class GPT_GNN(nn.Module): 7 | def __init__(self, gnn, rem_edge_list, attr_decoder, types, neg_samp_num, device, neg_queue_size = 0): 8 | super(GPT_GNN, self).__init__() 9 | self.types = types 10 | self.gnn = gnn 11 | self.params = nn.ModuleList() 12 | self.neg_queue_size = neg_queue_size 13 | self.link_dec_dict = {} 14 | self.neg_queue = {} 15 | for source_type in rem_edge_list: 16 | self.link_dec_dict[source_type] = {} 17 | self.neg_queue[source_type] = {} 18 | for relation_type in rem_edge_list[source_type]: 19 | print(source_type, relation_type) 20 | matcher = Matcher(gnn.n_hid, gnn.n_hid) 21 | self.neg_queue[source_type][relation_type] = torch.FloatTensor([]).to(device) 22 | self.link_dec_dict[source_type][relation_type] = matcher 23 | self.params.append(matcher) 24 | self.attr_decoder = attr_decoder 25 | self.init_emb = nn.Parameter(torch.randn(gnn.in_dim)) 26 | self.ce = nn.CrossEntropyLoss(reduction = 'none') 27 | self.neg_samp_num = neg_samp_num 28 | 29 | def neg_sample(self, souce_node_list, pos_node_list): 30 | np.random.shuffle(souce_node_list) 31 | neg_nodes = [] 32 | keys = {key : True for key in pos_node_list} 33 | tot = 0 34 | for node_id in souce_node_list: 35 | if node_id not in keys: 36 | neg_nodes += [node_id] 37 | tot += 1 38 | if tot == self.neg_samp_num: 39 | break 40 | return neg_nodes 41 | 42 | def forward(self, node_feature, node_type, edge_time, edge_index, edge_type): 43 | return self.gnn(node_feature, node_type, edge_time, edge_index, edge_type) 44 | def link_loss(self, node_emb, rem_edge_list, ori_edge_list, node_dict, target_type, use_queue = True, update_queue = False): 45 | losses = 0 46 | ress = [] 47 | for source_type in rem_edge_list: 48 | if source_type not in self.link_dec_dict: 49 | continue 50 | for relation_type in rem_edge_list[source_type]: 51 | if relation_type not in self.link_dec_dict[source_type]: 52 | continue 53 | rem_edges = rem_edge_list[source_type][relation_type] 54 | if len(rem_edges) <= 8: 55 | continue 56 | ori_edges = ori_edge_list[source_type][relation_type] 57 | matcher = self.link_dec_dict[source_type][relation_type] 58 | 59 | target_ids, positive_source_ids = rem_edges[:,0].reshape(-1, 1), rem_edges[:,1].reshape(-1, 1) 60 | n_nodes = len(target_ids) 61 | source_node_ids = np.unique(ori_edges[:, 1]) 62 | 63 | negative_source_ids = [self.neg_sample(source_node_ids, \ 64 | ori_edges[ori_edges[:, 0] == t_id][:, 1].tolist()) for t_id in target_ids] 65 | sn = min([len(neg_ids) for neg_ids in negative_source_ids]) 66 | 67 | negative_source_ids = [neg_ids[:sn] for neg_ids in negative_source_ids] 68 | 69 | source_ids = torch.LongTensor(np.concatenate((positive_source_ids, negative_source_ids), axis=-1) + node_dict[source_type][0]) 70 | emb = node_emb[source_ids] 71 | 72 | if use_queue and len(self.neg_queue[source_type][relation_type]) // n_nodes > 0: 73 | tmp = self.neg_queue[source_type][relation_type] 74 | stx = len(tmp) // n_nodes 75 | tmp = tmp[: stx * n_nodes].reshape(n_nodes, stx, -1) 76 | rep_size = sn + 1 + stx 77 | source_emb = torch.cat([emb, tmp], dim=1) 78 | source_emb = source_emb.reshape(n_nodes * rep_size, -1) 79 | else: 80 | rep_size = sn + 1 81 | source_emb = emb.reshape(source_ids.shape[0] * rep_size, -1) 82 | 83 | target_ids = target_ids.repeat(rep_size, 1) + node_dict[target_type][0] 84 | target_emb = node_emb[target_ids.reshape(-1)] 85 | res = matcher.forward(target_emb, source_emb) 86 | res = res.reshape(n_nodes, rep_size) 87 | ress += [res.detach()] 88 | losses += F.log_softmax(res, dim=-1)[:,0].mean() 89 | if update_queue and 'L1' not in relation_type and 'L2' not in relation_type: 90 | tmp = self.neg_queue[source_type][relation_type] 91 | self.neg_queue[source_type][relation_type] = \ 92 | torch.cat([node_emb[source_node_ids].detach(), tmp], dim=0)[:int(self.neg_queue_size * n_nodes)] 93 | return -losses / len(ress), ress 94 | 95 | 96 | def text_loss(self, reps, texts, w2v_model, device): 97 | def parse_text(texts, w2v_model, device): 98 | idxs = [] 99 | pad = w2v_model.wv.vocab['eos'].index 100 | for text in texts: 101 | idx = [] 102 | for word in ['bos'] + preprocess_string(text) + ['eos']: 103 | if word in w2v_model.wv.vocab: 104 | idx += [w2v_model.wv.vocab[word].index] 105 | idxs += [idx] 106 | mxl = np.max([len(s) for s in idxs]) + 1 107 | inp_idxs = [] 108 | out_idxs = [] 109 | masks = [] 110 | for i, idx in enumerate(idxs): 111 | inp_idxs += [idx + [pad for _ in range(mxl - len(idx) - 1)]] 112 | out_idxs += [idx[1:] + [pad for _ in range(mxl - len(idx))]] 113 | masks += [[1 for _ in range(len(idx))] + [0 for _ in range(mxl - len(idx) - 1)]] 114 | return torch.LongTensor(inp_idxs).transpose(0, 1).to(device), \ 115 | torch.LongTensor(out_idxs).transpose(0, 1).to(device), torch.BoolTensor(masks).transpose(0, 1).to(device) 116 | inp_idxs, out_idxs, masks = parse_text(texts, w2v_model, device) 117 | pred_prob = self.attr_decoder(inp_idxs, reps.repeat(inp_idxs.shape[0], 1, 1)) 118 | return self.ce(pred_prob[masks], out_idxs[masks]).mean() 119 | 120 | def feat_loss(self, reps, out): 121 | return -self.attr_decoder(reps, out).mean() 122 | 123 | 124 | class Classifier(nn.Module): 125 | def __init__(self, n_hid, n_out): 126 | super(Classifier, self).__init__() 127 | self.n_hid = n_hid 128 | self.n_out = n_out 129 | self.linear = nn.Linear(n_hid, n_out) 130 | def forward(self, x): 131 | tx = self.linear(x) 132 | return torch.log_softmax(tx.squeeze(), dim=-1) 133 | def __repr__(self): 134 | return '{}(n_hid={}, n_out={})'.format( 135 | self.__class__.__name__, self.n_hid, self.n_out) 136 | 137 | 138 | class Matcher(nn.Module): 139 | ''' 140 | Matching between a pair of nodes to conduct link prediction. 141 | Use multi-head attention as matching model. 142 | ''' 143 | 144 | def __init__(self, n_hid, n_out, temperature = 0.1): 145 | super(Matcher, self).__init__() 146 | self.n_hid = n_hid 147 | self.linear = nn.Linear(n_hid, n_out) 148 | self.sqrt_hd = math.sqrt(n_out) 149 | self.drop = nn.Dropout(0.2) 150 | self.cosine = nn.CosineSimilarity(dim=1) 151 | self.cache = None 152 | self.temperature = temperature 153 | def forward(self, x, ty, use_norm = True): 154 | tx = self.drop(self.linear(x)) 155 | if use_norm: 156 | return self.cosine(tx, ty) / self.temperature 157 | else: 158 | return (tx * ty).sum(dim=-1) / self.sqrt_hd 159 | def __repr__(self): 160 | return '{}(n_hid={})'.format( 161 | self.__class__.__name__, self.n_hid) 162 | 163 | 164 | class GNN(nn.Module): 165 | def __init__(self, in_dim, n_hid, num_types, num_relations, n_heads, n_layers, dropout = 0.2, conv_name = 'hgt', prev_norm = False, last_norm = False, use_RTE = True): 166 | super(GNN, self).__init__() 167 | self.gcs = nn.ModuleList() 168 | self.num_types = num_types 169 | self.in_dim = in_dim 170 | self.n_hid = n_hid 171 | self.adapt_ws = nn.ModuleList() 172 | self.drop = nn.Dropout(dropout) 173 | for t in range(num_types): 174 | self.adapt_ws.append(nn.Linear(in_dim, n_hid)) 175 | for l in range(n_layers - 1): 176 | self.gcs.append(GeneralConv(conv_name, n_hid, n_hid, num_types, num_relations, n_heads, dropout, use_norm = prev_norm, use_RTE = use_RTE)) 177 | self.gcs.append(GeneralConv(conv_name, n_hid, n_hid, num_types, num_relations, n_heads, dropout, use_norm = last_norm, use_RTE = use_RTE)) 178 | 179 | def forward(self, node_feature, node_type, edge_time, edge_index, edge_type): 180 | res = torch.zeros(node_feature.size(0), self.n_hid).to(node_feature.device) 181 | for t_id in range(self.num_types): 182 | idx = (node_type == int(t_id)) 183 | if idx.sum() == 0: 184 | continue 185 | res[idx] = torch.tanh(self.adapt_ws[t_id](node_feature[idx])) 186 | meta_xs = self.drop(res) 187 | del res 188 | for gc in self.gcs: 189 | meta_xs = gc(meta_xs, node_type, edge_index, edge_type, edge_time) 190 | return meta_xs 191 | 192 | 193 | class RNNModel(nn.Module): 194 | """Container module with an encoder, a recurrent module, and a decoder.""" 195 | def __init__(self, n_word, ninp, nhid, nlayers, dropout=0.2): 196 | super(RNNModel, self).__init__() 197 | self.drop = nn.Dropout(dropout) 198 | self.rnn = nn.LSTM(nhid, nhid, nlayers) 199 | self.encoder = nn.Embedding(n_word, nhid) 200 | self.decoder = nn.Linear(nhid, n_word) 201 | self.adp = nn.Linear(ninp + nhid, nhid) 202 | def forward(self, inp, hidden = None): 203 | emb = self.encoder(inp) 204 | if hidden is not None: 205 | emb = torch.cat((emb, hidden), dim=-1) 206 | emb = F.gelu(self.adp(emb)) 207 | output, _ = self.rnn(emb) 208 | decoded = self.decoder(self.drop(output)) 209 | return decoded 210 | def from_w2v(self, w2v): 211 | initrange = 0.1 212 | self.encoder.weight.data = w2v 213 | self.decoder.weight = self.encoder.weight 214 | 215 | self.encoder.weight.requires_grad = False 216 | self.decoder.weight.requires_grad = False 217 | -------------------------------------------------------------------------------- /example_reddit/GPT_GNN/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.sparse as sp 3 | import torch 4 | from texttable import Texttable 5 | 6 | def args_print(args): 7 | _dict = vars(args) 8 | t = Texttable() 9 | t.add_row(["Parameter", "Value"]) 10 | for k in _dict: 11 | t.add_row([k, _dict[k]]) 12 | print(t.draw()) 13 | 14 | def dcg_at_k(r, k): 15 | r = np.asfarray(r)[:k] 16 | if r.size: 17 | return r[0] + np.sum(r[1:] / np.log2(np.arange(2, r.size + 1))) 18 | return 0. 19 | 20 | def ndcg_at_k(r, k): 21 | dcg_max = dcg_at_k(sorted(r, reverse=True), k) 22 | if not dcg_max: 23 | return 0. 24 | return dcg_at_k(r, k) / dcg_max 25 | 26 | 27 | def mean_reciprocal_rank(rs): 28 | rs = (np.asarray(r).nonzero()[0] for r in rs) 29 | return [1. / (r[0] + 1) if r.size else 0. for r in rs] 30 | 31 | 32 | def normalize(mx): 33 | """Row-normalize sparse matrix""" 34 | rowsum = np.array(mx.sum(1)) 35 | r_inv = np.power(rowsum, -1).flatten() 36 | r_inv[np.isinf(r_inv)] = 0. 37 | r_mat_inv = sp.diags(r_inv) 38 | mx = r_mat_inv.dot(mx) 39 | return mx 40 | 41 | 42 | def sparse_mx_to_torch_sparse_tensor(sparse_mx): 43 | """Convert a scipy sparse matrix to a torch sparse tensor.""" 44 | sparse_mx = sparse_mx.tocoo().astype(np.float32) 45 | indices = torch.from_numpy( 46 | np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64)) 47 | values = torch.from_numpy(sparse_mx.data) 48 | shape = torch.Size(sparse_mx.shape) 49 | return torch.sparse.FloatTensor(indices, values, shape) 50 | 51 | def randint(): 52 | return np.random.randint(2**32 - 1) 53 | 54 | def feature_OAG(layer_data, graph): 55 | feature = {} 56 | times = {} 57 | indxs = {} 58 | texts = [] 59 | for _type in layer_data: 60 | if len(layer_data[_type]) == 0: 61 | continue 62 | idxs = np.array(list(layer_data[_type].keys())) 63 | tims = np.array(list(layer_data[_type].values()))[:,1] 64 | 65 | if 'node_emb' in graph.node_feature[_type]: 66 | feature[_type] = np.array(list(graph.node_feature[_type].loc[idxs, 'node_emb']), dtype=np.float) 67 | else: 68 | feature[_type] = np.zeros([len(idxs), 400]) 69 | feature[_type] = np.concatenate((feature[_type], list(graph.node_feature[_type].loc[idxs, 'emb']),\ 70 | np.log10(np.array(list(graph.node_feature[_type].loc[idxs, 'citation'])).reshape(-1, 1) + 0.01)), axis=1) 71 | 72 | times[_type] = tims 73 | indxs[_type] = idxs 74 | 75 | if _type == 'paper': 76 | attr = np.array(list(graph.node_feature[_type].loc[idxs, 'title']), dtype=np.str) 77 | return feature, times, indxs, attr 78 | 79 | def feature_reddit(layer_data, graph): 80 | feature = {} 81 | times = {} 82 | indxs = {} 83 | texts = [] 84 | for _type in layer_data: 85 | if len(layer_data[_type]) == 0: 86 | continue 87 | idxs = np.array(list(layer_data[_type].keys())) 88 | tims = np.array(list(layer_data[_type].values()))[:,1] 89 | 90 | feature[_type] = np.array(list(graph.node_feature[_type].loc[idxs, 'emb']), dtype=np.float) 91 | times[_type] = tims 92 | indxs[_type] = idxs 93 | 94 | if _type == 'def': 95 | attr = feature[_type] 96 | return feature, times, indxs, attr 97 | 98 | def load_gnn(_dict): 99 | out_dict = {} 100 | for key in _dict: 101 | if 'gnn' in key: 102 | out_dict[key[4:]] = _dict[key] 103 | return OrderedDict(out_dict) 104 | 105 | def load_gnn(_dict): 106 | out_dict = {} 107 | for key in _dict: 108 | if 'gnn' in key: 109 | out_dict[key[4:]] = _dict[key] 110 | return OrderedDict(out_dict) -------------------------------------------------------------------------------- /example_reddit/finetune_reddit.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from GPT_GNN.data import * 3 | from GPT_GNN.model import * 4 | from warnings import filterwarnings 5 | 6 | from sklearn.metrics import f1_score 7 | filterwarnings("ignore") 8 | 9 | import argparse 10 | 11 | parser = argparse.ArgumentParser(description='Fine-Tuning on Reddit classification task') 12 | 13 | ''' 14 | Dataset arguments 15 | ''' 16 | parser.add_argument('--data_dir', type=str, default='/datadrive/dataset', 17 | help='The address of preprocessed graph.') 18 | parser.add_argument('--use_pretrain', help='Whether to use pre-trained model', action='store_true') 19 | parser.add_argument('--pretrain_model_dir', type=str, default='/datadrive/models/gpt_all_cs', 20 | help='The address for pretrained model.') 21 | parser.add_argument('--model_dir', type=str, default='/datadrive/models/gpt_all_reddit', 22 | help='The address for storing the models and optimization results.') 23 | parser.add_argument('--task_name', type=str, default='reddit', 24 | help='The name of the stored models and optimization results.') 25 | parser.add_argument('--cuda', type=int, default=2, 26 | help='Avaiable GPU ID') 27 | parser.add_argument('--sample_depth', type=int, default=6, 28 | help='How many numbers to sample the graph') 29 | parser.add_argument('--sample_width', type=int, default=128, 30 | help='How many nodes to be sampled per layer per type') 31 | ''' 32 | Model arguments 33 | ''' 34 | parser.add_argument('--conv_name', type=str, default='hgt', 35 | choices=['hgt', 'gcn', 'gat', 'rgcn', 'han', 'hetgnn'], 36 | help='The name of GNN filter. By default is Heterogeneous Graph Transformer (hgt)') 37 | parser.add_argument('--n_hid', type=int, default=400, 38 | help='Number of hidden dimension') 39 | parser.add_argument('--n_heads', type=int, default=8, 40 | help='Number of attention head') 41 | parser.add_argument('--n_layers', type=int, default=3, 42 | help='Number of GNN layers') 43 | parser.add_argument('--prev_norm', help='Whether to add layer-norm on the previous layers', action='store_true') 44 | parser.add_argument('--last_norm', help='Whether to add layer-norm on the last layers', action='store_true') 45 | parser.add_argument('--dropout', type=int, default=0.2, 46 | help='Dropout ratio') 47 | 48 | 49 | ''' 50 | Optimization arguments 51 | ''' 52 | parser.add_argument('--optimizer', type=str, default='adamw', 53 | choices=['adamw', 'adam', 'sgd', 'adagrad'], 54 | help='optimizer to use.') 55 | parser.add_argument('--scheduler', type=str, default='cosine', 56 | help='Name of learning rate scheduler.' , choices=['cycle', 'cosine']) 57 | parser.add_argument('--data_percentage', type=int, default=0.1, 58 | help='Percentage of training and validation data to use') 59 | parser.add_argument('--n_epoch', type=int, default=50, 60 | help='Number of epoch to run') 61 | parser.add_argument('--n_pool', type=int, default=8, 62 | help='Number of process to sample subgraph') 63 | parser.add_argument('--n_batch', type=int, default=16, 64 | help='Number of batch (sampled graphs) for each epoch') 65 | parser.add_argument('--batch_size', type=int, default=256, 66 | help='Number of output nodes for training') 67 | parser.add_argument('--clip', type=int, default=0.5, 68 | help='Gradient Norm Clipping') 69 | 70 | args = parser.parse_args() 71 | args_print(args) 72 | 73 | if args.cuda != -1: 74 | device = torch.device("cuda:" + str(args.cuda)) 75 | else: 76 | device = torch.device("cpu") 77 | 78 | graph = dill.load(open(os.path.join(args.data_dir, 'graph_reddit.pk'), 'rb')) 79 | 80 | target_type = 'def' 81 | train_target_nodes = graph.train_target_nodes 82 | valid_target_nodes = graph.valid_target_nodes 83 | test_target_nodes = graph.test_target_nodes 84 | 85 | types = graph.get_types() 86 | criterion = nn.NLLLoss() 87 | 88 | def node_classification_sample(seed, nodes, time_range): 89 | ''' 90 | sub-graph sampling and label preparation for node classification: 91 | (1) Sample batch_size number of output nodes (papers) and their time. 92 | ''' 93 | np.random.seed(seed) 94 | samp_nodes = np.random.choice(nodes, args.batch_size, replace = False) 95 | feature, times, edge_list, _, texts = sample_subgraph(graph, time_range, \ 96 | inp = {target_type: np.concatenate([samp_nodes, np.ones(args.batch_size)]).reshape(2, -1).transpose()}, \ 97 | sampled_depth = args.sample_depth, sampled_number = args.sample_width, feature_extractor = feature_reddit) 98 | 99 | node_feature, node_type, edge_time, edge_index, edge_type, node_dict, edge_dict = \ 100 | to_torch(feature, times, edge_list, graph) 101 | 102 | x_ids = np.arange(args.batch_size) 103 | return node_feature, node_type, edge_time, edge_index, edge_type, x_ids, graph.y[samp_nodes] 104 | 105 | 106 | def prepare_data(pool): 107 | ''' 108 | Sampled and prepare training and validation data using multi-process parallization. 109 | ''' 110 | jobs = [] 111 | for batch_id in np.arange(args.n_batch): 112 | p = pool.apply_async(node_classification_sample, args=(randint(), train_target_nodes, {1: True})) 113 | jobs.append(p) 114 | p = pool.apply_async(node_classification_sample, args=(randint(), valid_target_nodes, {1: True})) 115 | jobs.append(p) 116 | return jobs 117 | 118 | stats = [] 119 | res = [] 120 | best_val = 0 121 | train_step = 0 122 | 123 | pool = mp.Pool(args.n_pool) 124 | st = time.time() 125 | jobs = prepare_data(pool) 126 | 127 | 128 | ''' 129 | Initialize GNN (model is specified by conv_name) and Classifier 130 | ''' 131 | gnn = GNN(conv_name = args.conv_name, in_dim = len(graph.node_feature[target_type]['emb'].values[0]), n_hid = args.n_hid, \ 132 | n_heads = args.n_heads, n_layers = args.n_layers, dropout = args.dropout, num_types = len(types), \ 133 | num_relations = len(graph.get_meta_graph()) + 1, prev_norm = args.prev_norm, last_norm = args.last_norm, use_RTE = False) 134 | if args.use_pretrain: 135 | gnn.load_state_dict(load_gnn(torch.load(args.pretrain_model_dir)), strict = False) 136 | print('Load Pre-trained Model from (%s)' % args.pretrain_model_dir) 137 | classifier = Classifier(args.n_hid, graph.y.max().item() + 1) 138 | 139 | model = nn.Sequential(gnn, classifier).to(device) 140 | 141 | 142 | optimizer = torch.optim.AdamW(model.parameters(), lr = 5e-4) 143 | 144 | 145 | 146 | if args.scheduler == 'cycle': 147 | scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, pct_start=0.02, anneal_strategy='linear', final_div_factor=100,\ 148 | max_lr = args.max_lr, total_steps = args.n_batch * args.n_epoch + 1) 149 | elif args.scheduler == 'cosine': 150 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 500, eta_min=1e-6) 151 | 152 | 153 | for epoch in np.arange(args.n_epoch) + 1: 154 | ''' 155 | Prepare Training and Validation Data 156 | ''' 157 | train_data = [job.get() for job in jobs[:-1]] 158 | valid_data = jobs[-1].get() 159 | pool.close() 160 | pool.join() 161 | ''' 162 | After the data is collected, close the pool and then reopen it. 163 | ''' 164 | pool = mp.Pool(args.n_pool) 165 | jobs = prepare_data(pool) 166 | et = time.time() 167 | print('Data Preparation: %.1fs' % (et - st)) 168 | 169 | ''' 170 | Train 171 | ''' 172 | model.train() 173 | train_losses = [] 174 | for node_feature, node_type, edge_time, edge_index, edge_type, x_ids, ylabel in train_data: 175 | node_rep = gnn.forward(node_feature.to(device), node_type.to(device), \ 176 | edge_time.to(device), edge_index.to(device), edge_type.to(device)) 177 | res = classifier.forward(node_rep[x_ids]) 178 | loss = criterion(res, ylabel.to(device)) 179 | 180 | optimizer.zero_grad() 181 | torch.cuda.empty_cache() 182 | loss.backward() 183 | 184 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip) 185 | optimizer.step() 186 | 187 | train_losses += [loss.cpu().detach().tolist()] 188 | train_step += 1 189 | scheduler.step(train_step) 190 | del res, loss 191 | ''' 192 | Valid 193 | ''' 194 | model.eval() 195 | with torch.no_grad(): 196 | node_feature, node_type, edge_time, edge_index, edge_type, x_ids, ylabel = valid_data 197 | node_rep = gnn.forward(node_feature.to(device), node_type.to(device), \ 198 | edge_time.to(device), edge_index.to(device), edge_type.to(device)) 199 | res = classifier.forward(node_rep[x_ids]) 200 | loss = criterion(res, ylabel.to(device)) 201 | 202 | ''' 203 | Calculate Valid F1. Update the best model based on highest F1 score. 204 | ''' 205 | valid_f1 = f1_score(res.argmax(dim=1).cpu().tolist(), ylabel.tolist(), average='micro') 206 | 207 | if valid_f1 > best_val: 208 | best_val = valid_f1 209 | torch.save(model, os.path.join(args.model_dir, args.task_name + '_' + args.conv_name)) 210 | print('UPDATE!!!') 211 | 212 | st = time.time() 213 | print(("Epoch: %d (%.1fs) LR: %.5f Train Loss: %.2f Valid Loss: %.2f Valid F1: %.4f") % \ 214 | (epoch, (st-et), optimizer.param_groups[0]['lr'], np.average(train_losses), \ 215 | loss.cpu().detach().tolist(), valid_f1)) 216 | stats += [[np.average(train_losses), loss.cpu().detach().tolist()]] 217 | del res, loss 218 | del train_data, valid_data 219 | 220 | 221 | 222 | best_model = torch.load(os.path.join(args.model_dir, args.task_name + '_' + args.conv_name)).to(device) 223 | best_model.eval() 224 | gnn, classifier = best_model 225 | with torch.no_grad(): 226 | test_res = [] 227 | for _ in range(10): 228 | node_feature, node_type, edge_time, edge_index, edge_type, x_ids, ylabel = \ 229 | node_classification_sample(randint(), test_target_nodes, {1: True}) 230 | paper_rep = gnn.forward(node_feature.to(device), node_type.to(device), \ 231 | edge_time.to(device), edge_index.to(device), edge_type.to(device))[x_ids] 232 | res = classifier.forward(paper_rep) 233 | test_f1 = f1_score(res.argmax(dim=1).cpu().tolist(), ylabel.tolist(), average='micro') 234 | test_res += [test_f1] 235 | print('Best Test F1: %.4f' % np.average(test_res)) 236 | -------------------------------------------------------------------------------- /example_reddit/preprocess_reddit.py: -------------------------------------------------------------------------------- 1 | from torch_geometric.datasets import Reddit 2 | from GPT_GNN.data import * 3 | 4 | dataset = Reddit(root='/datadrive/dataset') 5 | graph_reddit = Graph() 6 | el = defaultdict( #target_id 7 | lambda: defaultdict( #source_id( 8 | lambda: int # time 9 | )) 10 | for i, j in tqdm(dataset.data.edge_index.t()): 11 | el[i.item()][j.item()] = 1 12 | 13 | target_type = 'def' 14 | graph_reddit.edge_list['def']['def']['def'] = el 15 | n = list(el.keys()) 16 | degree = np.zeros(np.max(n)+1) 17 | for i in n: 18 | degree[i] = len(el[i]) 19 | x = np.concatenate((dataset.data.x.numpy(), np.log(degree).reshape(-1, 1)), axis=-1) 20 | graph_reddit.node_feature['def'] = pd.DataFrame({'emb': list(x)}) 21 | 22 | idx = np.arange(len(graph_reddit.node_feature[target_type])) 23 | np.random.seed(43) 24 | np.random.shuffle(idx) 25 | 26 | graph_reddit.pre_target_nodes = idx[ : int(len(idx) * 0.7)] 27 | graph_reddit.train_target_nodes = idx[int(len(idx) * 0.7) : int(len(idx) * 0.8)] 28 | graph_reddit.valid_target_nodes = idx[int(len(idx) * 0.8) : int(len(idx) * 0.9)] 29 | graph_reddit.test_target_nodes = idx[int(len(idx) * 0.9) : ] 30 | 31 | graph_reddit.y = dataset.data.y 32 | dill.dump(graph_reddit, open('/datadrive/dataset/graph_reddit.pk', 'wb')) 33 | -------------------------------------------------------------------------------- /example_reddit/pretrain_reddit.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from GPT_GNN.data import * 3 | from GPT_GNN.model import * 4 | from warnings import filterwarnings 5 | filterwarnings("ignore") 6 | 7 | import argparse 8 | 9 | parser = argparse.ArgumentParser(description='Pre-training HGT on a given graph (heterogeneous / homogeneous)') 10 | 11 | ''' 12 | GPT-GNN arguments 13 | ''' 14 | parser.add_argument('--attr_ratio', type=float, default=0.5, 15 | help='Ratio of attr-loss against link-loss, range: [0-1]') 16 | parser.add_argument('--attr_type', type=str, default='vec', 17 | choices=['text', 'vec'], 18 | help='The type of attribute decoder') 19 | parser.add_argument('--neg_samp_num', type=int, default=255, 20 | help='Maximum number of negative sample for each target node.') 21 | parser.add_argument('--queue_size', type=int, default=256, 22 | help='Max size of adaptive embedding queue.') 23 | parser.add_argument('--w2v_dir', type=str, default='/datadrive/dataset/w2v_all', 24 | help='The address of preprocessed graph.') 25 | 26 | ''' 27 | Dataset arguments 28 | ''' 29 | parser.add_argument('--data_dir', type=str, default='/datadrive/dataset/graph_reddit.pk', 30 | help='The address of preprocessed graph.') 31 | parser.add_argument('--pretrain_model_dir', type=str, default='/datadrive/models/gpt_all_reddit', 32 | help='The address for storing the pre-trained models.') 33 | parser.add_argument('--cuda', type=int, default=1, 34 | help='Avaiable GPU ID') 35 | parser.add_argument('--sample_depth', type=int, default=6, 36 | help='How many layers within a mini-batch subgraph') 37 | parser.add_argument('--sample_width', type=int, default=128, 38 | help='How many nodes to be sampled per layer per type') 39 | 40 | ''' 41 | Model arguments 42 | ''' 43 | parser.add_argument('--conv_name', type=str, default='hgt', 44 | choices=['hgt', 'gcn', 'gat', 'rgcn', 'han', 'hetgnn'], 45 | help='The name of GNN filter. By default is Heterogeneous Graph Transformer (hgt)') 46 | parser.add_argument('--n_hid', type=int, default=400, 47 | help='Number of hidden dimension') 48 | parser.add_argument('--n_heads', type=int, default=8, 49 | help='Number of attention head') 50 | parser.add_argument('--n_layers', type=int, default=3, 51 | help='Number of GNN layers') 52 | parser.add_argument('--prev_norm', help='Whether to add layer-norm on the previous layers', action='store_true') 53 | parser.add_argument('--last_norm', help='Whether to add layer-norm on the last layers', action='store_true') 54 | parser.add_argument('--dropout', type=int, default=0.2, 55 | help='Dropout ratio') 56 | 57 | ''' 58 | Optimization arguments 59 | ''' 60 | parser.add_argument('--max_lr', type=float, default=1e-3, 61 | help='Maximum learning rate.') 62 | parser.add_argument('--scheduler', type=str, default='cycle', 63 | help='Name of learning rate scheduler.' , choices=['cycle', 'cosine']) 64 | parser.add_argument('--n_epoch', type=int, default=20, 65 | help='Number of epoch to run') 66 | parser.add_argument('--n_pool', type=int, default=8, 67 | help='Number of process to sample subgraph') 68 | parser.add_argument('--n_batch', type=int, default=32, 69 | help='Number of batch (sampled graphs) for each epoch') 70 | parser.add_argument('--batch_size', type=int, default=256, 71 | help='Number of output nodes for training') 72 | parser.add_argument('--clip', type=float, default=0.5, 73 | help='Gradient Norm Clipping') 74 | 75 | args = parser.parse_args() 76 | args_print(args) 77 | 78 | 79 | if args.cuda != -1: 80 | device = torch.device("cuda:" + str(args.cuda)) 81 | else: 82 | device = torch.device("cpu") 83 | 84 | 85 | print('Start Loading Graph Data...') 86 | graph_reddit: Graph = dill.load(open(args.data_dir, 'rb')) 87 | print('Finish Loading Graph Data!') 88 | 89 | target_type = 'def' 90 | rel_stop_list = ['self'] 91 | 92 | pre_target_nodes = graph_reddit.pre_target_nodes 93 | train_target_nodes = graph_reddit.train_target_nodes 94 | 95 | pre_target_nodes = np.concatenate([pre_target_nodes, np.ones(len(pre_target_nodes))]).reshape(2, -1).transpose() 96 | train_target_nodes = np.concatenate([train_target_nodes, np.ones(len(train_target_nodes))]).reshape(2, -1).transpose() 97 | 98 | 99 | def GPT_sample(seed, target_nodes, time_range, batch_size, feature_extractor): 100 | np.random.seed(seed) 101 | samp_target_nodes = target_nodes[np.random.choice(len(target_nodes), batch_size)] 102 | threshold = 0.5 103 | feature, times, edge_list, _, attr = sample_subgraph(graph_reddit, time_range, \ 104 | inp = {target_type: samp_target_nodes}, feature_extractor = feature_extractor, \ 105 | sampled_depth = args.sample_depth, sampled_number = args.sample_width) 106 | rem_edge_list = defaultdict( #source_type 107 | lambda: defaultdict( #relation_type 108 | lambda: [] # [target_id, source_id] 109 | )) 110 | 111 | ori_list = {} 112 | for source_type in edge_list[target_type]: 113 | ori_list[source_type] = {} 114 | for relation_type in edge_list[target_type][source_type]: 115 | ori_list[source_type][relation_type] = np.array(edge_list[target_type][source_type][relation_type]) 116 | el = [] 117 | for target_ser, source_ser in edge_list[target_type][source_type][relation_type]: 118 | if target_ser < source_ser: 119 | if relation_type not in rel_stop_list and target_ser < batch_size and \ 120 | np.random.random() > threshold: 121 | rem_edge_list[source_type][relation_type] += [[target_ser, source_ser]] 122 | continue 123 | el += [[target_ser, source_ser]] 124 | el += [[source_ser, target_ser]] 125 | el = np.array(el) 126 | edge_list[target_type][source_type][relation_type] = el 127 | 128 | if relation_type == 'self': 129 | continue 130 | 131 | ''' 132 | Adding feature nodes: 133 | ''' 134 | n_target_nodes = len(feature[target_type]) 135 | feature[target_type] = np.concatenate((feature[target_type], np.zeros([batch_size, feature[target_type].shape[1]]))) 136 | times[target_type] = np.concatenate((times[target_type], times[target_type][:batch_size])) 137 | 138 | for source_type in edge_list[target_type]: 139 | for relation_type in edge_list[target_type][source_type]: 140 | el = [] 141 | for target_ser, source_ser in edge_list[target_type][source_type][relation_type]: 142 | if target_ser < batch_size: 143 | if relation_type == 'self': 144 | el += [[target_ser + n_target_nodes, target_ser + n_target_nodes]] 145 | else: 146 | el += [[target_ser + n_target_nodes, source_ser]] 147 | if len(el) > 0: 148 | edge_list[target_type][source_type][relation_type] = \ 149 | np.concatenate((edge_list[target_type][source_type][relation_type], el)) 150 | 151 | 152 | rem_edge_lists = {} 153 | for source_type in rem_edge_list: 154 | rem_edge_lists[source_type] = {} 155 | for relation_type in rem_edge_list[source_type]: 156 | rem_edge_lists[source_type][relation_type] = np.array(rem_edge_list[source_type][relation_type]) 157 | del rem_edge_list 158 | 159 | return to_torch(feature, times, edge_list, graph_reddit), rem_edge_lists, ori_list, \ 160 | attr[:batch_size], (n_target_nodes, n_target_nodes + batch_size) 161 | 162 | 163 | 164 | 165 | def prepare_data(pool): 166 | jobs = [] 167 | for _ in np.arange(args.n_batch - 1): 168 | jobs.append(pool.apply_async(GPT_sample, args=(randint(), pre_target_nodes, {1: True}, args.batch_size, feature_reddit))) 169 | jobs.append(pool.apply_async(GPT_sample, args=(randint(), train_target_nodes, {1: True}, args.batch_size, feature_reddit))) 170 | return jobs 171 | 172 | 173 | pool = mp.Pool(args.n_pool) 174 | st = time.time() 175 | jobs = prepare_data(pool) 176 | repeat_num = int(len(pre_target_nodes) / args.batch_size // args.n_batch) 177 | 178 | 179 | data, rem_edge_list, ori_edge_list, _, _ = GPT_sample(randint(), pre_target_nodes, {1: True}, args.batch_size, feature_reddit) 180 | node_feature, node_type, edge_time, edge_index, edge_type, node_dict, edge_dict = data 181 | types = graph_reddit.get_types() 182 | 183 | 184 | gnn = GNN(conv_name = args.conv_name, in_dim = len(graph_reddit.node_feature[target_type]['emb'].values[0]), n_hid = args.n_hid, \ 185 | n_heads = args.n_heads, n_layers = args.n_layers, dropout = args.dropout, num_types = len(types), \ 186 | num_relations = len(graph_reddit.get_meta_graph()) + 1, prev_norm = args.prev_norm, last_norm = args.last_norm, use_RTE = False) 187 | 188 | if args.attr_type == 'text': 189 | from gensim.models import Word2Vec 190 | w2v_model = Word2Vec.load(args.w2v_dir) 191 | n_tokens = len(w2v_model.wv.vocab) 192 | attr_decoder = RNNModel(n_word = n_tokens, ninp = gnn.n_hid, \ 193 | nhid = w2v_model.vector_size, nlayers = 2) 194 | attr_decoder.from_w2v(torch.FloatTensor(w2v_model.wv.vectors)) 195 | else: 196 | attr_decoder = Matcher(gnn.n_hid, gnn.in_dim) 197 | 198 | gpt_gnn = GPT_GNN(gnn = gnn, rem_edge_list = rem_edge_list, attr_decoder = attr_decoder, \ 199 | types = types, neg_samp_num = args.neg_samp_num, device = device) 200 | gpt_gnn.init_emb.data = node_feature[node_type == node_dict[target_type][1]].mean(dim=0).detach() 201 | gpt_gnn = gpt_gnn.to(device) 202 | 203 | 204 | best_val = 100000 205 | train_step = 0 206 | stats = [] 207 | optimizer = torch.optim.AdamW(gpt_gnn.parameters(), weight_decay = 1e-2, eps=1e-06, lr = args.max_lr) 208 | 209 | if args.scheduler == 'cycle': 210 | scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, pct_start=0.02, anneal_strategy='linear', final_div_factor=100,\ 211 | max_lr = args.max_lr, total_steps = repeat_num * args.n_batch * args.n_epoch + 1) 212 | elif args.scheduler == 'cosine': 213 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, repeat_num * args.n_batch, eta_min=1e-6) 214 | 215 | print('Start Pretraining...') 216 | for epoch in np.arange(args.n_epoch) + 1: 217 | gpt_gnn.neg_queue_size = args.queue_size * epoch // args.n_epoch 218 | for batch in np.arange(repeat_num) + 1: 219 | train_data = [job.get() for job in jobs[:-1]] 220 | valid_data = jobs[-1].get() 221 | pool.close() 222 | pool.join() 223 | pool = mp.Pool(args.n_pool) 224 | jobs = prepare_data(pool) 225 | et = time.time() 226 | print('Data Preparation: %.1fs' % (et - st)) 227 | 228 | train_link_losses = [] 229 | train_attr_losses = [] 230 | gpt_gnn.train() 231 | for data, rem_edge_list, ori_edge_list, attr, (start_idx, end_idx) in train_data: 232 | node_feature, node_type, edge_time, edge_index, edge_type, node_dict, edge_dict = data 233 | node_feature = node_feature.detach() 234 | node_feature[start_idx : end_idx] = gpt_gnn.init_emb 235 | node_emb = gpt_gnn.gnn(node_feature.to(device), node_type.to(device), edge_time.to(device), \ 236 | edge_index.to(device), edge_type.to(device)) 237 | 238 | loss_link, _ = gpt_gnn.link_loss(node_emb, rem_edge_list, ori_edge_list, node_dict, target_type, use_queue = True, update_queue=True) 239 | if args.attr_type == 'text': 240 | loss_attr = gpt_gnn.text_loss(node_emb[start_idx : end_idx], attr, w2v_model, device) 241 | else: 242 | loss_attr = gpt_gnn.feat_loss(node_emb[start_idx : end_idx], torch.FloatTensor(attr).to(device)) 243 | 244 | 245 | loss = loss_link * (1 - args.attr_ratio) + loss_attr * args.attr_ratio 246 | 247 | 248 | optimizer.zero_grad() 249 | loss.backward() 250 | torch.nn.utils.clip_grad_norm_(gpt_gnn.parameters(), args.clip) 251 | optimizer.step() 252 | 253 | train_link_losses += [loss_link.item()] 254 | train_attr_losses += [loss_attr.item()] 255 | scheduler.step() 256 | ''' 257 | Valid 258 | ''' 259 | gpt_gnn.eval() 260 | with torch.no_grad(): 261 | data, rem_edge_list, ori_edge_list, attr, (start_idx, end_idx) = valid_data 262 | node_feature, node_type, edge_time, edge_index, edge_type, node_dict, edge_dict = data 263 | node_feature = node_feature.detach() 264 | node_feature[start_idx : end_idx] = gpt_gnn.init_emb 265 | node_emb = gpt_gnn.gnn(node_feature.to(device), node_type.to(device), edge_time.to(device), \ 266 | edge_index.to(device), edge_type.to(device)) 267 | loss_link, ress = gpt_gnn.link_loss(node_emb, rem_edge_list, ori_edge_list, node_dict, target_type, use_queue = False, update_queue=True) 268 | loss_link = loss_link.item() 269 | if args.attr_type == 'text': 270 | loss_attr = gpt_gnn.text_loss(node_emb[start_idx : end_idx], attr, w2v_model, device) 271 | else: 272 | loss_attr = gpt_gnn.feat_loss(node_emb[start_idx : end_idx], torch.FloatTensor(attr).to(device)) 273 | 274 | ndcgs = [] 275 | for i in ress: 276 | ai = np.zeros(len(i[0])) 277 | ai[0] = 1 278 | ndcgs += [ndcg_at_k(ai[j.cpu().numpy()], len(j)) for j in i.argsort(descending = True)] 279 | 280 | valid_loss = loss_link * (1 - args.attr_ratio) + loss_attr * args.attr_ratio 281 | st = time.time() 282 | print(("Epoch: %d, (%d / %d) %.1fs LR: %.5f Train Loss: (%.3f, %.3f) Valid Loss: (%.3f, %.3f) NDCG: %.3f Norm: %.3f queue: %d") % \ 283 | (epoch, batch, repeat_num, (st-et), optimizer.param_groups[0]['lr'], np.average(train_link_losses), np.average(train_attr_losses), \ 284 | loss_link, loss_attr, np.average(ndcgs), node_emb.norm(dim=1).mean(), gpt_gnn.neg_queue_size)) 285 | 286 | if valid_loss < best_val: 287 | best_val = valid_loss 288 | print('UPDATE!!!') 289 | torch.save(gpt_gnn.state_dict(), args.pretrain_model_dir) 290 | stats += [[np.average(train_link_losses), loss_link, loss_attr, valid_loss]] 291 | -------------------------------------------------------------------------------- /images/gpt-intro.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/acbull/GPT-GNN/f26e13c69ddc8a3f2580cb16d0b9a1c73d89f4bc/images/gpt-intro.png -------------------------------------------------------------------------------- /images/pretrain_OAG.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/acbull/GPT-GNN/f26e13c69ddc8a3f2580cb16d0b9a1c73d89f4bc/images/pretrain_OAG.gif -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | dill==0.3.0 2 | numpy==1.22.0 3 | pandas==0.24.2 4 | torch==1.3.0 5 | torch-scatter==1.3.2 6 | torch-cluster==1.4.5 7 | torch-sparse==0.4.3 8 | torch-spline-conv==1.1.1 9 | torch-geometric==1.3.2 10 | torchvision==0.4.1 11 | tqdm==4.31.1 12 | seaborn==0.9.0 13 | matplotlib==3.0.3 14 | transformers==4.30.0 15 | --------------------------------------------------------------------------------