├── LICENSE ├── README.md ├── graphproppred ├── construct_two_hops.py ├── conv.py ├── gnn.py ├── main_pyg.py ├── moe.py ├── test.py └── utils.py ├── linkproppred └── ddi │ ├── gnn.py │ ├── logger.py │ ├── mf.py │ ├── mlp.py │ └── node2vec.py └── nodeproppred └── proteins ├── gnn.py ├── logger.py ├── mlp.py ├── moe.py └── node2vec.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 VITA 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Graph Mixture of Experts: Learning on Large-Scale Graphs with Explicit Diversity Modeling 2 | 3 | Official code for "Graph Mixture of Experts: Learning on Large-Scale Graphs with Explicit Diversity Modeling" in NeurIPS 2023. 4 | 5 | ## Introduction 6 | 7 | In this work, we propose the Graph Mixture of Experts (GMoE) model structure to enhance the ability of GNNs to accommodate the diversity of training graph structures, without incurring computational overheads at inference. 8 | 9 | ## How to run the code 10 | 11 | To train the GMoE model, run 12 | 13 | ``` 14 | python main_pyg.py --dataset $dataset -n $total_number_of_experts --n1 $number_of_one_hop_experts -k $number_of_selected_experts -d $feature_dimension --device 0 --gnn gcn-spmoe --coef 1 15 | ``` 16 | 17 | For example, on ogbg-molhiv dataset, run 18 | 19 | ``` 20 | python main_pyg.py --dataset ogbg-molhiv -n 8 --n1 4 -k 4 -d 150 --device 0 --gnn gcn-spmoe --coef 1 21 | ``` 22 | 23 | The test results for the best performing model on validation set will be recorded in the output files generated by the training code. 24 | 25 | 26 | ## Acknowledgement 27 | 28 | Our code is built upon the [official OGB code](https://github.com/snap-stanford/ogb/tree/master/examples). 29 | 30 | ## Citation 31 | 32 | ``` 33 | @inproceedings{wang2023gmoe, 34 | author = {Wang, Haotao and Jiang, Ziyu and You, Yuning and Han, Yan and Liu, Gaowen and Srinivasa, Jayanth and Kompella, Ramana Rao and Wang, Zhangyang}, 35 | title = {Graph Mixture of Experts: Learning on Large-Scale Graphs with Explicit Diversity Modeling}, 36 | booktitle = {NeurIPS}, 37 | year = {2023} 38 | } 39 | ``` 40 | -------------------------------------------------------------------------------- /graphproppred/construct_two_hops.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_geometric.loader import DataLoader 3 | import torch.optim as optim 4 | import torch.nn.functional as F 5 | from gnn import GNN 6 | 7 | from tqdm import tqdm 8 | import argparse 9 | import time 10 | import numpy as np 11 | import pickle 12 | 13 | ### importing OGB 14 | from ogb.graphproppred import PygGraphPropPredDataset, Evaluator 15 | 16 | def main(): 17 | # Training settings 18 | parser = argparse.ArgumentParser(description='GNN baselines on ogbgmol* data with Pytorch Geometrics') 19 | parser.add_argument('--dataset', type=str, default="ogbg-molhiv", choices=['ogbg-molhiv', 'ogbg-molpcba', 'ogbg-molmuv'], 20 | help='dataset name (default: ogbg-molhiv)') 21 | 22 | parser.add_argument('--feature', type=str, default="full", 23 | help='full feature or simple feature') 24 | args = parser.parse_args() 25 | 26 | ### automatic dataloading and splitting 27 | dataset = PygGraphPropPredDataset(name = args.dataset) 28 | 29 | if args.feature == 'full': 30 | pass 31 | elif args.feature == 'simple': 32 | print('using simple feature') 33 | # only retain the top two node/edge features 34 | dataset.data.x = dataset.data.x[:,:2] 35 | dataset.data.edge_attr = dataset.data.edge_attr[:,:2] 36 | 37 | import copy 38 | two_hop_dataset = copy.deepcopy(dataset) 39 | two_hop_edge_index = [] 40 | two_hop_edge_attr = [] 41 | two_hop_edge_slices = [0] 42 | for i in tqdm(range(len(dataset))): # loop though each graph 43 | graph = dataset[i] 44 | edge_index, edge_attr = graph.edge_index, graph.edge_attr # shape=(2, num_edges), shape=(num_edges,3) 45 | num_edges = graph.num_edges 46 | num_nodes = graph.num_nodes 47 | # construct hash table: 48 | hash_table = {} 49 | for j in range(num_edges): 50 | start_node_idx = edge_index[0,j].item() 51 | if start_node_idx in hash_table: 52 | hash_table[start_node_idx].append(j) 53 | else: 54 | hash_table[start_node_idx] = [j] 55 | for node_idx in range(num_nodes): 56 | if node_idx not in hash_table: # this is weird but some graphs has isolated nodes. 57 | continue 58 | for first_edge_idx in hash_table[node_idx]: 59 | first_edge = edge_index[:,first_edge_idx] 60 | first_edge_attr = edge_attr[first_edge_idx,:] 61 | hop_node_idx = first_edge[1].item() 62 | for second_edge_idx in hash_table[hop_node_idx]: 63 | second_edge = edge_index[:,second_edge_idx] 64 | if second_edge[1].item() == first_edge[0].item(): 65 | continue # we don't consider 1->2 and 2->1 as 1--two-hop-->1 66 | if second_edge[1].item() in hash_table[first_edge[0].item()]: # note: first_edge[0].item() == node_idx 67 | continue # we don't consider 1->2 as a two-hop path if there is a one-hop path between 1 & 2. 68 | second_edge_attr = edge_attr[second_edge_idx,:] 69 | two_hop_edge = [first_edge[0].item(), second_edge[1].item()] 70 | two_hope_edge_attr = torch.cat([first_edge_attr, second_edge_attr], dim=-1) 71 | two_hop_edge_index.append(two_hop_edge) 72 | two_hop_edge_attr.append(two_hope_edge_attr) 73 | 74 | two_hop_edge_slices.append(len(two_hop_edge_index)) 75 | 76 | two_hop_edge_index = torch.Tensor(two_hop_edge_index).T.long() 77 | two_hop_edge_attr = torch.stack(two_hop_edge_attr, dim=0) 78 | two_hop_edge_slices = torch.Tensor(two_hop_edge_slices) 79 | two_hop_dataset.edge_index, two_hop_dataset.edge_attr = two_hop_edge_index, two_hop_edge_attr 80 | two_hop_dataset.data['edge_index'], two_hop_dataset.data['edge_attr'] = two_hop_edge_index, two_hop_edge_attr 81 | two_hop_dataset.slices['edge_index'] = two_hop_edge_slices.long() 82 | two_hop_dataset.slices['edge_attr'] = two_hop_edge_slices.long() 83 | # two_hop_dataset.num_edge_features *= 2 # 3 -> 6, concatenating features of two edges. 84 | 85 | pickle.dump(two_hop_dataset, open('/home/haotao/GNN-MoE/mol/two_hop_%s_dataset.pkl' % args.dataset, 'wb')) 86 | 87 | if __name__ == "__main__": 88 | main() 89 | -------------------------------------------------------------------------------- /graphproppred/conv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_geometric.nn import MessagePassing 3 | import torch.nn.functional as F 4 | from torch_geometric.nn import global_mean_pool, global_add_pool 5 | from ogb.graphproppred.mol_encoder import AtomEncoder,BondEncoder 6 | from torch_geometric.utils import degree 7 | 8 | import math 9 | 10 | from ogb.utils.features import get_atom_feature_dims, get_bond_feature_dims 11 | full_bond_feature_dims = get_bond_feature_dims() 12 | class TwoHopBondEncoder(torch.nn.Module): 13 | 14 | def __init__(self, emb_dim): 15 | super(TwoHopBondEncoder, self).__init__() 16 | 17 | self.bond_embedding_list = torch.nn.ModuleList() 18 | 19 | for i, dim in enumerate(full_bond_feature_dims+full_bond_feature_dims): 20 | emb = torch.nn.Embedding(dim, emb_dim) 21 | torch.nn.init.xavier_uniform_(emb.weight.data) 22 | self.bond_embedding_list.append(emb) 23 | 24 | def forward(self, edge_attr): 25 | bond_embedding = 0 26 | for i in range(edge_attr.shape[1]): 27 | bond_embedding += self.bond_embedding_list[i](edge_attr[:,i]) 28 | 29 | return bond_embedding 30 | 31 | ### GIN convolution along the graph structure 32 | class GINConv(MessagePassing): 33 | def __init__(self, emb_dim, hop=1): 34 | ''' 35 | emb_dim (int): node embedding dimensionality 36 | ''' 37 | 38 | super(GINConv, self).__init__(aggr = "add") 39 | 40 | self.mlp = torch.nn.Sequential(torch.nn.Linear(emb_dim, 2*emb_dim), torch.nn.BatchNorm1d(2*emb_dim), torch.nn.ReLU(), torch.nn.Linear(2*emb_dim, emb_dim)) 41 | self.eps = torch.nn.Parameter(torch.Tensor([0])) 42 | 43 | if hop==1: 44 | self.bond_encoder = BondEncoder(emb_dim = emb_dim) 45 | elif hop==2: 46 | self.bond_encoder = TwoHopBondEncoder(emb_dim = emb_dim) 47 | else: 48 | raise Exception('Unimplemented hop %d' % hop) 49 | 50 | def forward(self, x, edge_index, edge_attr): 51 | edge_embedding = self.bond_encoder(edge_attr) 52 | out = self.mlp((1 + self.eps) *x + self.propagate(edge_index, x=x, edge_attr=edge_embedding)) 53 | 54 | return out 55 | 56 | def message(self, x_j, edge_attr): 57 | return F.relu(x_j + edge_attr) 58 | 59 | def update(self, aggr_out): 60 | return aggr_out 61 | 62 | ### GCN convolution along the graph structure 63 | class GCNConv(MessagePassing): 64 | def __init__(self, emb_dim, hop=1): 65 | super(GCNConv, self).__init__(aggr='add') 66 | 67 | self.linear = torch.nn.Linear(emb_dim, emb_dim) 68 | self.root_emb = torch.nn.Embedding(1, emb_dim) 69 | if hop==1: 70 | self.bond_encoder = BondEncoder(emb_dim = emb_dim) 71 | elif hop==2: 72 | self.bond_encoder = TwoHopBondEncoder(emb_dim = emb_dim) 73 | else: 74 | raise Exception('Unimplemented hop %d' % hop) 75 | 76 | def forward(self, x, edge_index, edge_attr): 77 | x = self.linear(x) 78 | edge_embedding = self.bond_encoder(edge_attr) 79 | 80 | row, col = edge_index 81 | 82 | #edge_weight = torch.ones((edge_index.size(1), ), device=edge_index.device) 83 | deg = degree(row, x.size(0), dtype = x.dtype) + 1 84 | deg_inv_sqrt = deg.pow(-0.5) 85 | deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0 86 | 87 | norm = deg_inv_sqrt[row] * deg_inv_sqrt[col] 88 | 89 | return self.propagate(edge_index, x=x, edge_attr = edge_embedding, norm=norm) + F.relu(x + self.root_emb.weight) * 1./deg.view(-1,1) 90 | 91 | def message(self, x_j, edge_attr, norm): 92 | return norm.view(-1, 1) * F.relu(x_j + edge_attr) 93 | 94 | def update(self, aggr_out): 95 | return aggr_out 96 | 97 | 98 | ### GNN to generate node embedding 99 | class GNN_node(torch.nn.Module): 100 | """ 101 | Output: 102 | node representations 103 | """ 104 | def __init__(self, num_layer, emb_dim, drop_ratio = 0.5, JK = "last", residual = False, gnn_type = 'gin', hop=1): 105 | ''' 106 | emb_dim (int): node embedding dimensionality 107 | num_layer (int): number of GNN message passing layers 108 | JK: Jumping knowledge refers to "Representation Learning on Graphs with Jumping Knowledge Networks" 109 | ''' 110 | 111 | super(GNN_node, self).__init__() 112 | self.num_layer = num_layer 113 | self.drop_ratio = drop_ratio 114 | self.JK = JK 115 | ### add residual connection or not 116 | self.residual = residual 117 | 118 | if self.num_layer < 2: 119 | raise ValueError("Number of GNN layers must be greater than 1.") 120 | 121 | self.atom_encoder = AtomEncoder(emb_dim) 122 | 123 | ###List of GNNs 124 | self.convs = torch.nn.ModuleList() 125 | self.batch_norms = torch.nn.ModuleList() 126 | 127 | for layer in range(num_layer): 128 | if gnn_type == 'gin': 129 | self.convs.append(GINConv(emb_dim)) 130 | elif gnn_type == 'gcn': 131 | self.convs.append(GCNConv(emb_dim, hop=hop)) 132 | else: 133 | raise ValueError('Undefined GNN type called {}'.format(gnn_type)) 134 | 135 | self.batch_norms.append(torch.nn.BatchNorm1d(emb_dim)) 136 | 137 | def forward(self, batched_data): 138 | x, edge_index, edge_attr, batch = batched_data.x, batched_data.edge_index, batched_data.edge_attr, batched_data.batch 139 | 140 | ### computing input node embedding 141 | 142 | h_list = [self.atom_encoder(x)] 143 | for layer in range(self.num_layer): 144 | 145 | h = self.convs[layer](h_list[layer], edge_index, edge_attr) 146 | h = self.batch_norms[layer](h) 147 | 148 | if layer == self.num_layer - 1: 149 | #remove relu for the last layer 150 | h = F.dropout(h, self.drop_ratio, training = self.training) 151 | else: 152 | h = F.dropout(F.relu(h), self.drop_ratio, training = self.training) 153 | 154 | if self.residual: 155 | h += h_list[layer] 156 | 157 | h_list.append(h) 158 | 159 | ### Different implementations of Jk-concat 160 | if self.JK == "last": 161 | node_representation = h_list[-1] 162 | elif self.JK == "sum": 163 | node_representation = 0 164 | for layer in range(self.num_layer + 1): 165 | node_representation += h_list[layer] 166 | 167 | return node_representation 168 | 169 | class GNN_MoE_node(torch.nn.Module): 170 | """ 171 | Output: 172 | node representations 173 | """ 174 | def __init__(self, num_layer, emb_dim, num_experts=3, drop_ratio = 0.5, JK = "last", residual = False, gnn_type = 'gin', num_experts_1hop=None): 175 | ''' 176 | emb_dim (int): node embedding dimensionality 177 | num_layer (int): number of GNN message passing layers 178 | JK: Jumping knowledge refers to "Representation Learning on Graphs with Jumping Knowledge Networks" 179 | ''' 180 | 181 | super(GNN_MoE_node, self).__init__() 182 | self.num_layer = num_layer 183 | self.num_experts = num_experts 184 | self.drop_ratio = drop_ratio 185 | self.JK = JK 186 | ### add residual connection or not 187 | self.residual = residual 188 | 189 | if not num_experts_1hop: 190 | self.num_experts_1hop = num_experts # by default, all experts are hop-1 experts. 191 | else: 192 | assert num_experts_1hop <= num_experts 193 | self.num_experts_1hop = num_experts_1hop 194 | 195 | if self.num_layer < 2: 196 | raise ValueError("Number of GNN layers must be greater than 1.") 197 | 198 | self.atom_encoder = AtomEncoder(emb_dim) 199 | 200 | ###List of GNNs 201 | self.convs = torch.nn.ModuleList() 202 | self.batch_norms = torch.nn.ModuleList() 203 | 204 | for layer in range(num_layer): 205 | convs_list = torch.nn.ModuleList() 206 | bn_list = torch.nn.ModuleList() 207 | for expert_idx in range(num_experts): 208 | if gnn_type == 'gin': 209 | convs_list.append(GINConv(emb_dim)) 210 | elif gnn_type == 'gcn': 211 | if expert_idx < self.num_experts_1hop: 212 | convs_list.append(GCNConv(emb_dim, hop=1)) 213 | else: 214 | convs_list.append(GCNConv(emb_dim, hop=2)) 215 | else: 216 | raise ValueError('Undefined GNN type called {}'.format(gnn_type)) 217 | 218 | bn_list.append(torch.nn.BatchNorm1d(emb_dim)) 219 | 220 | self.convs.append(convs_list) 221 | self.batch_norms.append(bn_list) 222 | 223 | # self.mix_fn = lambda h_expert_list: torch.mean(torch.stack(h_expert_list, dim=0), dim=0) 224 | 225 | def forward(self, batched_data): 226 | x, edge_index, edge_attr, batch = batched_data.x, batched_data.edge_index, batched_data.edge_attr, batched_data.batch 227 | # edge_index: shape=(2, N_batch) 228 | # edge_attr: shape=(N_batch, d_attr) 229 | 230 | ### computing input node embedding 231 | 232 | h_list = [self.atom_encoder(x)] 233 | for layer in range(self.num_layer): 234 | 235 | h_expert_list = [] 236 | for expert in range(self.num_experts): 237 | 238 | h = self.convs[layer][expert](h_list[layer], edge_index, edge_attr) # TODO: use different edge_index and edge_attr for each expert 239 | h = self.batch_norms[layer][expert](h) 240 | h_expert_list.append(h) 241 | 242 | h = torch.stack(h_expert_list, dim=0) # shape=[num_experts, num_nodes, d_features] 243 | h = torch.mean(h, dim=0) # shape=[num_nodes, d_features] 244 | 245 | if layer == self.num_layer - 1: 246 | #remove relu for the last layer 247 | h = F.dropout(h, self.drop_ratio, training = self.training) 248 | else: 249 | h = F.dropout(F.relu(h), self.drop_ratio, training = self.training) 250 | 251 | if self.residual: 252 | h += h_list[layer] 253 | 254 | h_list.append(h) 255 | 256 | ### Different implementations of Jk-concat 257 | if self.JK == "last": 258 | node_representation = h_list[-1] 259 | elif self.JK == "sum": 260 | node_representation = 0 261 | for layer in range(self.num_layer + 1): 262 | node_representation += h_list[layer] 263 | 264 | return node_representation 265 | 266 | from moe import MoE 267 | class GNN_SpMoE_node(torch.nn.Module): 268 | """ 269 | Output: 270 | node representations 271 | """ 272 | def __init__(self, num_layer, emb_dim, num_experts=3, drop_ratio = 0.5, JK = "last", residual = False, gnn_type = 'gcn', k=1, coef=1e-2, num_experts_1hop=None): 273 | ''' 274 | emb_dim (int): node embedding dimensionality 275 | num_layer (int): number of GNN message passing layers 276 | JK: Jumping knowledge refers to "Representation Learning on Graphs with Jumping Knowledge Networks" 277 | k: k value for top-k sparse gating. 278 | num_experts: total number of experts in each layer. 279 | num_experts_1hop: number of hop-1 experts in each layer. The first num_experts_1hop are hop-1 experts. The rest num_experts-num_experts_1hop are hop-2 experts. 280 | ''' 281 | 282 | super(GNN_SpMoE_node, self).__init__() 283 | self.num_layer = num_layer 284 | self.num_experts = num_experts 285 | self.k = k 286 | self.drop_ratio = drop_ratio 287 | self.JK = JK 288 | ### add residual connection or not 289 | self.residual = residual 290 | 291 | if not num_experts_1hop: 292 | self.num_experts_1hop = num_experts # by default, all experts are hop-1 experts. 293 | else: 294 | assert num_experts_1hop <= num_experts 295 | self.num_experts_1hop = num_experts_1hop 296 | 297 | if self.num_layer < 2: 298 | raise ValueError("Number of GNN layers must be greater than 1.") 299 | 300 | self.atom_encoder = AtomEncoder(emb_dim) 301 | 302 | ###List of GNNs 303 | self.ffns = torch.nn.ModuleList() 304 | 305 | for layer in range(num_layer): 306 | convs_list = torch.nn.ModuleList() 307 | bn_list = torch.nn.ModuleList() 308 | for expert_idx in range(num_experts): 309 | if gnn_type == 'gin': 310 | if expert_idx < self.num_experts_1hop: 311 | convs_list.append(GINConv(emb_dim, hop=1)) 312 | else: 313 | convs_list.append(GINConv(emb_dim, hop=2)) 314 | elif gnn_type == 'gcn': 315 | if expert_idx < self.num_experts_1hop: 316 | convs_list.append(GCNConv(emb_dim, hop=1)) 317 | else: 318 | convs_list.append(GCNConv(emb_dim, hop=2)) 319 | else: 320 | raise ValueError('Undefined GNN type called {}'.format(gnn_type)) 321 | 322 | bn_list.append(torch.nn.BatchNorm1d(emb_dim)) 323 | 324 | ffn = MoE(input_size=emb_dim, output_size=emb_dim, num_experts=num_experts, experts_conv=convs_list, experts_bn=bn_list, 325 | k=k, coef=coef, num_experts_1hop=self.num_experts_1hop) 326 | 327 | self.ffns.append(ffn) 328 | 329 | # self.mix_fn = lambda h_expert_list: torch.mean(torch.stack(h_expert_list, dim=0), dim=0) 330 | 331 | def forward(self, batched_data, batched_data_2hop=None): 332 | x, edge_index, edge_attr, batch = batched_data.x, batched_data.edge_index, batched_data.edge_attr, batched_data.batch 333 | if batched_data_2hop: 334 | x_2hop, edge_index_2hop, edge_attr_2hop, batch_2hop = batched_data_2hop.x, batched_data_2hop.edge_index, batched_data_2hop.edge_attr, batched_data_2hop.batch 335 | 336 | ### computing input node embedding 337 | 338 | h_list = [self.atom_encoder(x)] 339 | self.load_balance_loss = 0 # initialize load_balance_loss to 0 at the beginning of each forward pass. 340 | for layer in range(self.num_layer): 341 | 342 | if batched_data_2hop: 343 | h, _layer_load_balance_loss = self.ffns[layer](h_list[layer], edge_index, edge_attr, edge_index_2hop, edge_attr_2hop) 344 | else: 345 | h, _layer_load_balance_loss = self.ffns[layer](h_list[layer], edge_index, edge_attr, None, None) 346 | self.load_balance_loss += _layer_load_balance_loss 347 | 348 | if layer == self.num_layer - 1: 349 | #remove relu for the last layer 350 | h = F.dropout(h, self.drop_ratio, training = self.training) 351 | else: 352 | h = F.dropout(F.relu(h), self.drop_ratio, training = self.training) 353 | 354 | if self.residual: 355 | h += h_list[layer] 356 | 357 | h_list.append(h) 358 | 359 | self.load_balance_loss /= self.num_layer 360 | 361 | ### Different implementations of Jk-concat 362 | if self.JK == "last": 363 | node_representation = h_list[-1] 364 | elif self.JK == "sum": 365 | node_representation = 0 366 | for layer in range(self.num_layer + 1): 367 | node_representation += h_list[layer] 368 | 369 | return node_representation 370 | 371 | ### Virtual GNN to generate node embedding 372 | class GNN_node_Virtualnode(torch.nn.Module): 373 | """ 374 | Output: 375 | node representations 376 | """ 377 | def __init__(self, num_layer, emb_dim, drop_ratio = 0.5, JK = "last", residual = False, gnn_type = 'gin'): 378 | ''' 379 | emb_dim (int): node embedding dimensionality 380 | ''' 381 | 382 | super(GNN_node_Virtualnode, self).__init__() 383 | self.num_layer = num_layer 384 | self.drop_ratio = drop_ratio 385 | self.JK = JK 386 | ### add residual connection or not 387 | self.residual = residual 388 | 389 | if self.num_layer < 2: 390 | raise ValueError("Number of GNN layers must be greater than 1.") 391 | 392 | self.atom_encoder = AtomEncoder(emb_dim) 393 | 394 | ### set the initial virtual node embedding to 0. 395 | self.virtualnode_embedding = torch.nn.Embedding(1, emb_dim) 396 | torch.nn.init.constant_(self.virtualnode_embedding.weight.data, 0) 397 | 398 | ### List of GNNs 399 | self.convs = torch.nn.ModuleList() 400 | ### batch norms applied to node embeddings 401 | self.batch_norms = torch.nn.ModuleList() 402 | 403 | ### List of MLPs to transform virtual node at every layer 404 | self.mlp_virtualnode_list = torch.nn.ModuleList() 405 | 406 | for layer in range(num_layer): 407 | if gnn_type == 'gin': 408 | self.convs.append(GINConv(emb_dim)) 409 | elif gnn_type == 'gcn': 410 | self.convs.append(GCNConv(emb_dim)) 411 | else: 412 | raise ValueError('Undefined GNN type called {}'.format(gnn_type)) 413 | 414 | self.batch_norms.append(torch.nn.BatchNorm1d(emb_dim)) 415 | 416 | for layer in range(num_layer - 1): 417 | self.mlp_virtualnode_list.append(torch.nn.Sequential(torch.nn.Linear(emb_dim, 2*emb_dim), torch.nn.BatchNorm1d(2*emb_dim), torch.nn.ReLU(), \ 418 | torch.nn.Linear(2*emb_dim, emb_dim), torch.nn.BatchNorm1d(emb_dim), torch.nn.ReLU())) 419 | 420 | 421 | def forward(self, batched_data): 422 | 423 | x, edge_index, edge_attr, batch = batched_data.x, batched_data.edge_index, batched_data.edge_attr, batched_data.batch 424 | 425 | ### virtual node embeddings for graphs 426 | virtualnode_embedding = self.virtualnode_embedding(torch.zeros(batch[-1].item() + 1).to(edge_index.dtype).to(edge_index.device)) 427 | 428 | h_list = [self.atom_encoder(x)] 429 | for layer in range(self.num_layer): 430 | ### add message from virtual nodes to graph nodes 431 | h_list[layer] = h_list[layer] + virtualnode_embedding[batch] 432 | 433 | ### Message passing among graph nodes 434 | h = self.convs[layer](h_list[layer], edge_index, edge_attr) 435 | 436 | h = self.batch_norms[layer](h) 437 | if layer == self.num_layer - 1: 438 | #remove relu for the last layer 439 | h = F.dropout(h, self.drop_ratio, training = self.training) 440 | else: 441 | h = F.dropout(F.relu(h), self.drop_ratio, training = self.training) 442 | 443 | if self.residual: 444 | h = h + h_list[layer] 445 | 446 | h_list.append(h) 447 | 448 | ### update the virtual nodes 449 | if layer < self.num_layer - 1: 450 | ### add message from graph nodes to virtual nodes 451 | virtualnode_embedding_temp = global_add_pool(h_list[layer], batch) + virtualnode_embedding 452 | ### transform virtual nodes using MLP 453 | 454 | if self.residual: 455 | virtualnode_embedding = virtualnode_embedding + F.dropout(self.mlp_virtualnode_list[layer](virtualnode_embedding_temp), self.drop_ratio, training = self.training) 456 | else: 457 | virtualnode_embedding = F.dropout(self.mlp_virtualnode_list[layer](virtualnode_embedding_temp), self.drop_ratio, training = self.training) 458 | 459 | ### Different implementations of Jk-concat 460 | if self.JK == "last": 461 | node_representation = h_list[-1] 462 | elif self.JK == "sum": 463 | node_representation = 0 464 | for layer in range(self.num_layer + 1): 465 | node_representation += h_list[layer] 466 | 467 | return node_representation 468 | 469 | 470 | if __name__ == "__main__": 471 | pass 472 | -------------------------------------------------------------------------------- /graphproppred/gnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_geometric.nn import MessagePassing 3 | from torch_geometric.nn import global_add_pool, global_mean_pool, global_max_pool, GlobalAttention, Set2Set 4 | import torch.nn.functional as F 5 | from torch_geometric.nn.inits import uniform 6 | 7 | from conv import GNN_node, GNN_node_Virtualnode, GNN_MoE_node, GNN_SpMoE_node 8 | 9 | from torch_scatter import scatter_mean 10 | 11 | class GNN(torch.nn.Module): 12 | 13 | def __init__(self, num_tasks, num_layer = 5, emb_dim = 300, 14 | gnn_type = 'gin', virtual_node = True, moe=True, residual = False, drop_ratio = 0.5, JK = "last", graph_pooling = "mean", 15 | num_experts=3, k=1, coef=1e-2, hop=1, num_experts_1hop=None): 16 | ''' 17 | num_tasks (int): number of labels to be predicted 18 | virtual_node (bool): whether to add virtual node or not 19 | ''' 20 | 21 | super(GNN, self).__init__() 22 | 23 | self.num_layer = num_layer 24 | self.drop_ratio = drop_ratio 25 | self.JK = JK 26 | self.emb_dim = emb_dim 27 | self.num_tasks = num_tasks 28 | self.graph_pooling = graph_pooling 29 | 30 | if self.num_layer < 2: 31 | raise ValueError("Number of GNN layers must be greater than 1.") 32 | 33 | ### GNN to generate node embeddings 34 | if virtual_node: 35 | self.gnn_node = GNN_node_Virtualnode(num_layer, emb_dim, JK = JK, drop_ratio = drop_ratio, residual = residual, gnn_type = gnn_type) 36 | elif moe=='dense': 37 | self.gnn_node = GNN_MoE_node(num_layer, emb_dim, JK = JK, drop_ratio = drop_ratio, residual = residual, gnn_type = gnn_type, num_experts=num_experts, num_experts_1hop=num_experts_1hop) 38 | elif moe=='sparse': 39 | self.gnn_node = GNN_SpMoE_node(num_layer, emb_dim, JK = JK, drop_ratio = drop_ratio, residual = residual, gnn_type = gnn_type, num_experts=num_experts, k=k, coef=coef, num_experts_1hop=num_experts_1hop) 40 | else: 41 | self.gnn_node = GNN_node(num_layer, emb_dim, JK = JK, drop_ratio = drop_ratio, residual = residual, gnn_type = gnn_type, hop=hop) 42 | 43 | 44 | ### Pooling function to generate whole-graph embeddings 45 | if self.graph_pooling == "sum": 46 | self.pool = global_add_pool 47 | elif self.graph_pooling == "mean": 48 | self.pool = global_mean_pool 49 | elif self.graph_pooling == "max": 50 | self.pool = global_max_pool 51 | elif self.graph_pooling == "attention": 52 | self.pool = GlobalAttention(gate_nn = torch.nn.Sequential(torch.nn.Linear(emb_dim, 2*emb_dim), torch.nn.BatchNorm1d(2*emb_dim), torch.nn.ReLU(), torch.nn.Linear(2*emb_dim, 1))) 53 | elif self.graph_pooling == "set2set": 54 | self.pool = Set2Set(emb_dim, processing_steps = 2) 55 | else: 56 | raise ValueError("Invalid graph pooling type.") 57 | 58 | if graph_pooling == "set2set": 59 | self.graph_pred_linear = torch.nn.Linear(2*self.emb_dim, self.num_tasks) 60 | else: 61 | self.graph_pred_linear = torch.nn.Linear(self.emb_dim, self.num_tasks) 62 | 63 | def forward(self, batched_data, batched_data_2hop=None): 64 | if batched_data_2hop: 65 | h_node = self.gnn_node(batched_data, batched_data_2hop) 66 | else: 67 | h_node = self.gnn_node(batched_data) 68 | 69 | h_graph = self.pool(h_node, batched_data.batch) 70 | # batched_data.batch: shape=[num_of_nodes]. The graph index of each node. 71 | # h_graph: shape=[batch_size/num_of_graphs_per_batch(default=32)] 72 | 73 | return self.graph_pred_linear(h_graph) 74 | 75 | 76 | if __name__ == '__main__': 77 | GNN(num_tasks = 10) -------------------------------------------------------------------------------- /graphproppred/main_pyg.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_geometric.loader import DataLoader 3 | import torch.optim as optim 4 | import torch.nn.functional as F 5 | from torch.utils.data import RandomSampler, SequentialSampler 6 | from gnn import GNN 7 | from conv import GNN_SpMoE_node 8 | 9 | from tqdm import tqdm 10 | import argparse 11 | import time 12 | import numpy as np 13 | import os 14 | from utils import RandomDropNode 15 | 16 | import pickle 17 | 18 | ### importing OGB 19 | from ogb.graphproppred import PygGraphPropPredDataset, Evaluator 20 | 21 | cls_criterion = torch.nn.BCEWithLogitsLoss() 22 | reg_criterion = torch.nn.MSELoss() 23 | 24 | def train(model, device, loader, optimizer, task_type): 25 | model.train() 26 | 27 | for step, batch in enumerate(tqdm(loader, desc="Iteration")): 28 | batch = batch.to(device) 29 | 30 | if batch.x.shape[0] == 1 or batch.batch[-1] == 0: 31 | pass 32 | else: 33 | pred = model(batch) 34 | optimizer.zero_grad() 35 | ## ignore nan targets (unlabeled) when computing training loss. 36 | is_labeled = batch.y == batch.y 37 | if "classification" in task_type: 38 | utility_loss = cls_criterion(pred.to(torch.float32)[is_labeled], batch.y.to(torch.float32)[is_labeled]) 39 | else: 40 | utility_loss = reg_criterion(pred.to(torch.float32)[is_labeled], batch.y.to(torch.float32)[is_labeled]) 41 | if isinstance(model.gnn_node, GNN_SpMoE_node): 42 | load_balance_loss = model.gnn_node.load_balance_loss 43 | loss = utility_loss + load_balance_loss 44 | else: 45 | loss = utility_loss 46 | loss.backward() 47 | optimizer.step() 48 | 49 | if isinstance(model.gnn_node, GNN_SpMoE_node): 50 | loss_str = 'loss: %.4f (utility: %.4f, load balance: %.4f)' % (loss.item(), utility_loss.item(), load_balance_loss.item()) 51 | print(loss_str) 52 | 53 | def train_mixed(model, device, loader, loader_2hop, optimizer, task_type): 54 | ''' 55 | Deprecated 56 | ''' 57 | model.train() 58 | 59 | for step, (batch, batch_2hop) in enumerate(tqdm(zip(loader, loader_2hop), desc="Iteration")): 60 | batch = batch.to(device) 61 | batch_2hop = batch_2hop.to(device) 62 | 63 | if batch.x.shape[0] == 1 or batch.batch[-1] == 0: 64 | pass 65 | else: 66 | pred = model(batch, batch_2hop) 67 | optimizer.zero_grad() 68 | ## ignore nan targets (unlabeled) when computing training loss. 69 | is_labeled = batch.y == batch.y 70 | if "classification" in task_type: 71 | utility_loss = cls_criterion(pred.to(torch.float32)[is_labeled], batch.y.to(torch.float32)[is_labeled]) 72 | else: 73 | utility_loss = reg_criterion(pred.to(torch.float32)[is_labeled], batch.y.to(torch.float32)[is_labeled]) 74 | if isinstance(model.gnn_node, GNN_SpMoE_node): 75 | load_balance_loss = model.gnn_node.load_balance_loss 76 | loss = utility_loss + load_balance_loss 77 | else: 78 | loss = utility_loss 79 | loss.backward() 80 | optimizer.step() 81 | 82 | if isinstance(model.gnn_node, GNN_SpMoE_node): 83 | loss_str = 'loss: %.4f (utility: %.4f, load balance: %.4f)' % (loss.item(), utility_loss.item(), load_balance_loss.item()) 84 | print(loss_str) 85 | 86 | def eval(model, device, loader, evaluator): 87 | model.eval() 88 | y_true = [] 89 | y_pred = [] 90 | 91 | for step, batch in enumerate(tqdm(loader, desc="Iteration")): 92 | batch = batch.to(device) 93 | 94 | if batch.x.shape[0] == 1: 95 | pass 96 | else: 97 | with torch.no_grad(): 98 | pred = model(batch) 99 | 100 | y_true.append(batch.y.view(pred.shape).detach().cpu()) 101 | y_pred.append(pred.detach().cpu()) 102 | 103 | y_true = torch.cat(y_true, dim = 0).numpy() 104 | y_pred = torch.cat(y_pred, dim = 0).numpy() 105 | 106 | input_dict = {"y_true": y_true, "y_pred": y_pred} 107 | 108 | return evaluator.eval(input_dict) 109 | 110 | def eval_mixed(model, device, loader, loader_2hop, evaluator): 111 | ''' 112 | Deprecated 113 | ''' 114 | model.eval() 115 | y_true = [] 116 | y_pred = [] 117 | 118 | for step, (batch, batch_2hop) in enumerate(tqdm(zip(loader, loader_2hop), desc="Iteration")): 119 | batch = batch.to(device) 120 | batch_2hop = batch_2hop.to(device) 121 | 122 | if batch.x.shape[0] == 1: 123 | pass 124 | else: 125 | with torch.no_grad(): 126 | pred = model(batch, batch_2hop) 127 | 128 | y_true.append(batch.y.view(pred.shape).detach().cpu()) 129 | y_pred.append(pred.detach().cpu()) 130 | 131 | y_true = torch.cat(y_true, dim = 0).numpy() 132 | y_pred = torch.cat(y_pred, dim = 0).numpy() 133 | 134 | input_dict = {"y_true": y_true, "y_pred": y_pred} 135 | 136 | return evaluator.eval(input_dict) 137 | 138 | def main(): 139 | # Training settings 140 | parser = argparse.ArgumentParser(description='GNN baselines on ogbgmol* data with Pytorch Geometrics') 141 | parser.add_argument('--device', type=int, default=0, 142 | help='which gpu to use if any (default: 0)') 143 | parser.add_argument('--gnn', type=str, default='gcn-spmoe', 144 | help='GNN gin, gin-virtual, or gcn, or gcn-virtual (default: gin-virtual)') 145 | parser.add_argument('--drop_ratio', type=float, default=0.5, 146 | help='dropout ratio (default: 0.5)') 147 | parser.add_argument('--lr', type=float, default=0.001, 148 | help='learning rate (default: 0.001)') 149 | parser.add_argument('--weight_decay', '--wd', type=float, default=0, 150 | help='weight decay (default: 0)') 151 | parser.add_argument('--num_layer', type=int, default=5, 152 | help='number of GNN message passing layers (default: 5)') 153 | parser.add_argument('--emb_dim', '-d', type=int, default=150, 154 | help='dimensionality of hidden units in GNNs (default: 300)') 155 | parser.add_argument('--batch_size', type=int, default=32, 156 | help='input batch size for training (default: 32)') 157 | parser.add_argument('--epochs', type=int, default=100, 158 | help='number of epochs to train (default: 100)') 159 | parser.add_argument('--num_workers', type=int, default=8, 160 | help='number of workers (default: 0)') 161 | parser.add_argument('--dataset', type=str, default="ogbg-moltox21", 162 | choices=['ogbg-molhiv', 'ogbg-molmuv', 'ogbg-molpcba', 163 | 'ogbg-molbace', 'ogbg-molbbbp', 'ogbg-molclintox', 164 | 'ogbg-molsider', 'ogbg-moltox21', 'ogbg-moltoxcast', 'ogbg-molesol', 165 | 'ogbg-molfreesolv', 'ogbg-mollipo'], 166 | help='dataset name') 167 | 168 | parser.add_argument('--drop_node_ratio', type=float, default=0.2, 169 | help='randomly drop node with a ratio') 170 | 171 | parser.add_argument('--num_experts', '-n', type=int, default=8, 172 | help='total number of experts in GCN-MoE') 173 | parser.add_argument('-k', type=int, default=4, 174 | help='selected number of experts in GCN-MoE') 175 | parser.add_argument('--hop', type=int, default=1, 176 | help='number of GCN hops') 177 | parser.add_argument('--num_experts_1hop', '--n1', type=int, default=8, 178 | help='number of hop-1 experts in GCN-MoE. Only used when --hop>1.') 179 | parser.add_argument('--coef', type=float, default=1, 180 | help='loss coefficient for load balancing loss in sparse MoE training') 181 | parser.add_argument('--pretrain', default='', 182 | help='pretrained ckpt file name') 183 | args = parser.parse_args() 184 | 185 | exp_str = '%s-%s-dropout%s-lr%s-wd%s' % (args.dataset, args.gnn, args.drop_ratio, args.lr, args.weight_decay) 186 | # exp_str = '%s-%s' % (args.dataset, args.gnn) 187 | if 'spmoe' in args.gnn: 188 | exp_str += '-d%d-n%d-k%d-coef%s' % (args.emb_dim, args.num_experts, args.k, args.coef) 189 | elif 'moe' in args.gnn: 190 | exp_str += '-d%d-n%d' % (args.emb_dim, args.num_experts) 191 | if 'moe' in args.gnn: 192 | args.num_experts_1hop = args.num_experts if args.num_experts_1hop < 0 else args.num_experts_1hop 193 | exp_str += '-split-%d-%d' % (args.num_experts_1hop, args.num_experts-args.num_experts_1hop) # e.g., gcn-spmoe-split-2-2 194 | elif args.hop > 1: 195 | exp_str += '-hop%d' % args.hop # e.g., gcn-hop2 196 | if args.pretrain: 197 | exp_str += '-pretrained' 198 | print('Saving as %s' % exp_str) 199 | 200 | from datetime import datetime 201 | current_date_and_time = datetime.now().strftime("%Y-%m-%d-%H-%M-%S") 202 | exp_str += '-%s' % current_date_and_time 203 | 204 | save_dir = os.path.join('results', args.dataset) 205 | if not os.path.exists(save_dir): 206 | os.makedirs(save_dir) 207 | 208 | device = torch.device("cuda:" + str(args.device)) if torch.cuda.is_available() else torch.device("cpu") 209 | 210 | ### automatic dataloading and splitting 211 | train_transform = RandomDropNode(p=args.drop_node_ratio) 212 | transformed_dataset = PygGraphPropPredDataset(name = args.dataset, transform=train_transform) 213 | dataset = PygGraphPropPredDataset(name = args.dataset) 214 | split_idx = dataset.get_idx_split() 215 | train_set = transformed_dataset[split_idx["train"]] 216 | valid_set, test_set = dataset[split_idx["valid"]], dataset[split_idx["test"]] 217 | # try: 218 | # dataset_2hop = pickle.load(open('/home/haotao/GNN-MoE/mol/two_hop_%s_dataset.pkl' % args.dataset, 'rb')) 219 | # except: 220 | # dataset_2hop = None 221 | 222 | def get_loaders(train_set, valid_set, test_set, train_sampler): 223 | train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=False, sampler=train_sampler, num_workers = args.num_workers) 224 | valid_loader = DataLoader(valid_set, batch_size=args.batch_size, shuffle=False, num_workers = args.num_workers) 225 | test_loader = DataLoader(test_set, batch_size=args.batch_size, shuffle=False, num_workers = args.num_workers) 226 | 227 | return train_loader, valid_loader, test_loader 228 | 229 | # if dataset_2hop: 230 | # train_set_2hop, valid_set_2hop, test_set_2hop = split_dataset(dataset_2hop) 231 | 232 | # seed = np.random.randint(0,1000) 233 | # seed=0 234 | # _train_generator = torch.Generator() 235 | # _train_generator.manual_seed(seed) 236 | # torch.manual_seed(seed) 237 | # train_sampler = RandomSampler(train_set, generator=_train_generator) # use the same random sampler to sync the two datasets. 238 | train_sampler = SequentialSampler(train_set) 239 | 240 | ### automatic evaluator. takes dataset name as input 241 | evaluator = Evaluator(args.dataset) 242 | 243 | if args.gnn == 'gin': 244 | model = GNN(gnn_type = 'gin', num_tasks = dataset.num_tasks, num_layer = args.num_layer, emb_dim = args.emb_dim, drop_ratio = args.drop_ratio, virtual_node = False).to(device) 245 | elif args.gnn == 'gin-spmoe': 246 | model = GNN(gnn_type = 'gin', num_tasks = dataset.num_tasks, num_layer = args.num_layer, emb_dim = args.emb_dim, drop_ratio = args.drop_ratio, virtual_node = False, 247 | moe='sparse', num_experts=args.num_experts, k=args.k, coef=args.coef, num_experts_1hop=args.num_experts_1hop).to(device) 248 | elif args.gnn == 'gin-virtual': 249 | model = GNN(gnn_type = 'gin', num_tasks = dataset.num_tasks, num_layer = args.num_layer, emb_dim = args.emb_dim, drop_ratio = args.drop_ratio, virtual_node = True).to(device) 250 | elif args.gnn == 'gcn': 251 | model = GNN(gnn_type = 'gcn', num_tasks = dataset.num_tasks, num_layer = args.num_layer, emb_dim = args.emb_dim, drop_ratio = args.drop_ratio, virtual_node = False, 252 | hop=args.hop).to(device) 253 | elif args.gnn == 'gcn-moe': 254 | model = GNN(gnn_type = 'gcn', num_tasks = dataset.num_tasks, num_layer = args.num_layer, emb_dim = args.emb_dim, drop_ratio = args.drop_ratio, virtual_node = False, 255 | moe='dense', num_experts=args.num_experts, num_experts_1hop=args.num_experts_1hop).to(device) 256 | elif args.gnn == 'gcn-spmoe': 257 | model = GNN(gnn_type = 'gcn', num_tasks = dataset.num_tasks, num_layer = args.num_layer, emb_dim = args.emb_dim, drop_ratio = args.drop_ratio, virtual_node = False, 258 | moe='sparse', num_experts=args.num_experts, k=args.k, coef=args.coef, num_experts_1hop=args.num_experts_1hop).to(device) 259 | elif args.gnn == 'gcn-virtual': 260 | model = GNN(gnn_type = 'gcn', num_tasks = dataset.num_tasks, num_layer = args.num_layer, emb_dim = args.emb_dim, drop_ratio = args.drop_ratio, virtual_node = True).to(device) 261 | else: 262 | raise ValueError('Invalid GNN type') 263 | 264 | if args.pretrain: 265 | ckpt = torch.load(os.path.join('results/ogbg-molpcba', args.pretrain)) 266 | pretrained_state_dict = ckpt['model'] 267 | pretrained_state_dict = {k: v for k, v in pretrained_state_dict.items() if 'graph_pred_linear' not in k} 268 | model.state_dict().update(pretrained_state_dict) 269 | # model.load_state_dict(ckpt['model']) 270 | 271 | optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) 272 | 273 | valid_curve = [] 274 | test_curve = [] 275 | train_curve = [] 276 | 277 | for epoch in range(1, args.epochs + 1): 278 | print("=====Epoch {}".format(epoch)) 279 | 280 | random_shuffled_idx = np.random.permutation(len(train_set)) 281 | train_set = train_set[random_shuffled_idx] 282 | train_loader, valid_loader, test_loader = \ 283 | get_loaders(train_set, valid_set, test_set, train_sampler) 284 | # if dataset_2hop: 285 | # train_set_2hop = train_set_2hop[random_shuffled_idx] 286 | # train_loader_2hop, valid_loader_2hop, test_loader_2hop = \ 287 | # get_loaders(train_set_2hop, valid_set_2hop, test_set_2hop, train_sampler) 288 | 289 | if 'moe' in args.gnn: 290 | if 'moe' in args.gnn and args.num_experts_1hop == 0: 291 | raise Exception("No longer support two hop datasets after applying random node drop!") 292 | print('Training...') 293 | train(model, device, train_loader_2hop, optimizer, dataset.task_type) 294 | print('Evaluating...') 295 | train_perf = eval(model, device, train_loader_2hop, evaluator) 296 | valid_perf = eval(model, device, valid_loader_2hop, evaluator) 297 | test_perf = eval(model, device, test_loader_2hop, evaluator) 298 | elif 'moe' in args.gnn and args.num_experts_1hop == args.num_experts: 299 | print('Training...') 300 | train(model, device, train_loader, optimizer, dataset.task_type) 301 | print('Evaluating...') 302 | # train_perf = eval(model, device, train_loader, evaluator) 303 | valid_perf = eval(model, device, valid_loader, evaluator) 304 | test_perf = eval(model, device, test_loader, evaluator) 305 | else: 306 | raise Exception("Now using drop node, which doesn't support mixed training/eval yet!") 307 | print('Training mixed...') 308 | train_mixed(model, device, train_loader, train_loader_2hop, optimizer, dataset.task_type) 309 | print('Evaluating mixed...') 310 | train_perf = eval_mixed(model, device, train_loader, train_loader_2hop, evaluator) 311 | valid_perf = eval_mixed(model, device, valid_loader, valid_loader_2hop, evaluator) 312 | test_perf = eval_mixed(model, device, test_loader, test_loader_2hop, evaluator) 313 | else: 314 | if args.hop == 1: 315 | print('Training...') 316 | train(model, device, train_loader, optimizer, dataset.task_type) 317 | print('Evaluating...') 318 | # train_perf = eval(model, device, train_loader, evaluator) 319 | valid_perf = eval(model, device, valid_loader, evaluator) 320 | test_perf = eval(model, device, test_loader, evaluator) 321 | elif args.hop == 2: 322 | raise Exception("No longer support two hop datasets after applying random node drop!") 323 | print('Training...') 324 | train(model, device, train_loader_2hop, optimizer, dataset.task_type) 325 | print('Evaluating...') 326 | train_perf = eval(model, device, train_loader_2hop, evaluator) 327 | valid_perf = eval(model, device, valid_loader_2hop, evaluator) 328 | test_perf = eval(model, device, test_loader_2hop, evaluator) 329 | 330 | print({'Validation': valid_perf, 'Test': test_perf}) 331 | with open(os.path.join(save_dir, '%s.txt' % exp_str), 'a+') as fp: 332 | # fp.write(str({'Train': train_perf, 'Validation': valid_perf, 'Test': test_perf})) 333 | fp.write(str({'Validation': valid_perf, 'Test': test_perf})) 334 | fp.write('\n') 335 | fp.flush() 336 | fp.close() 337 | # eval_str = 'Train: %.4f | Validation: %.4f | Test: %.4f' % (train_perf, valid_perf, test_perf) 338 | # print(eval_str) 339 | # with open('%s.txt' % exp_str) as fp: 340 | # fp.write(eval_str+'\n') 341 | # fp.flush() 342 | # fp.close() 343 | 344 | # train_curve.append(train_perf[dataset.eval_metric]) 345 | valid_curve.append(valid_perf[dataset.eval_metric]) 346 | test_curve.append(test_perf[dataset.eval_metric]) 347 | 348 | if 'classification' in dataset.task_type: 349 | best_val_epoch = np.argmax(np.array(valid_curve)) 350 | # best_train = max(train_curve) 351 | else: 352 | best_val_epoch = np.argmin(np.array(valid_curve)) 353 | # best_train = min(train_curve) 354 | 355 | print('Finished training!') 356 | print('Best validation score: {}'.format(valid_curve[best_val_epoch])) 357 | print('Test score: {}'.format(test_curve[best_val_epoch])) 358 | with open(os.path.join(save_dir, '%s.txt' % exp_str), 'a+') as fp: 359 | fp.write('Best validation score: {}\n'.format(valid_curve[best_val_epoch])) 360 | fp.write('Test score: {}\n'.format(test_curve[best_val_epoch])) 361 | fp.flush() 362 | fp.close() 363 | 364 | filename = os.path.join(save_dir, '%s.pth' % exp_str) 365 | # torch.save({'Val': valid_curve[best_val_epoch], 'Test': test_curve[best_val_epoch], 'Train': train_curve[best_val_epoch], 'BestTrain': best_train}, filename) 366 | if args.dataset == 'ogbg-molpcba': 367 | torch.save({'Val': valid_curve[best_val_epoch], 'Test': test_curve[best_val_epoch], 'model': model.state_dict()}, filename) 368 | else: 369 | torch.save({'Val': valid_curve[best_val_epoch], 'Test': test_curve[best_val_epoch]}, filename) 370 | 371 | 372 | if __name__ == "__main__": 373 | main() 374 | -------------------------------------------------------------------------------- /graphproppred/moe.py: -------------------------------------------------------------------------------- 1 | # Note by Haotao Wang: 2 | # Adapted form https://raw.githubusercontent.com/davidmrau/mixture-of-experts/master/moe.py 3 | 4 | # Sparsely-Gated Mixture-of-Experts Layers. 5 | # See "Outrageously Large Neural Networks" 6 | # https://arxiv.org/abs/1701.06538 7 | # 8 | # Author: David Rau 9 | # 10 | # The code is based on the TensorFlow implementation: 11 | # https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/utils/expert_utils.py 12 | 13 | 14 | import torch 15 | import torch.nn as nn 16 | from torch.distributions.normal import Normal 17 | import numpy as np 18 | 19 | 20 | class SparseDispatcher(object): 21 | """Helper for implementing a mixture of experts. 22 | The purpose of this class is to create input minibatches for the 23 | experts and to combine the results of the experts to form a unified 24 | output tensor. 25 | There are two functions: 26 | dispatch - take an input Tensor and create input Tensors for each expert. 27 | combine - take output Tensors from each expert and form a combined output 28 | Tensor. Outputs from different experts for the same batch element are 29 | summed together, weighted by the provided "gates". 30 | The class is initialized with a "gates" Tensor, which specifies which 31 | batch elements go to which experts, and the weights to use when combining 32 | the outputs. Batch element b is sent to expert e iff gates[b, e] != 0. 33 | The inputs and outputs are all two-dimensional [batch, depth]. 34 | Caller is responsible for collapsing additional dimensions prior to 35 | calling this class and reshaping the output to the original shape. 36 | See common_layers.reshape_like(). 37 | Example use: 38 | gates: a float32 `Tensor` with shape `[batch_size, num_experts]` 39 | inputs: a float32 `Tensor` with shape `[batch_size, input_size]` 40 | experts: a list of length `num_experts` containing sub-networks. 41 | dispatcher = SparseDispatcher(num_experts, gates) 42 | expert_inputs = dispatcher.dispatch(inputs) 43 | expert_outputs = [experts[i](expert_inputs[i]) for i in range(num_experts)] 44 | outputs = dispatcher.combine(expert_outputs) 45 | The preceding code sets the output for a particular example b to: 46 | output[b] = Sum_i(gates[b, i] * experts[i](inputs[b])) 47 | This class takes advantage of sparsity in the gate matrix by including in the 48 | `Tensor`s for expert i only the batch elements for which `gates[b, i] > 0`. 49 | """ 50 | 51 | def __init__(self, num_experts, gates): 52 | """Create a SparseDispatcher.""" 53 | 54 | self._gates = gates 55 | self._num_experts = num_experts 56 | # sort experts 57 | sorted_experts, index_sorted_experts = torch.nonzero(gates).sort(0) 58 | # drop indices 59 | _, self._expert_index = sorted_experts.split(1, dim=1) 60 | # get according batch index for each expert 61 | self._batch_index = torch.nonzero(gates)[index_sorted_experts[:, 1], 0] 62 | # calculate num samples that each expert gets 63 | self._part_sizes = (gates > 0).sum(0).tolist() 64 | # expand gates to match with self._batch_index 65 | gates_exp = gates[self._batch_index.flatten()] 66 | self._nonzero_gates = torch.gather(gates_exp, 1, self._expert_index) 67 | 68 | def dispatch(self, inp, edge_index, edge_attr): 69 | """Create one input Tensor for each expert. 70 | The `Tensor` for a expert `i` contains the slices of `inp` corresponding 71 | to the batch elements `b` where `gates[b, i] > 0`. 72 | Args: 73 | inp: a `Tensor` of shape "[batch_size, ]` 74 | Returns: 75 | a list of `num_experts` `Tensor`s with shapes 76 | `[expert_batch_size_i, ]`. 77 | """ 78 | 79 | # Note by Haotao: 80 | # self._batch_index: shape=(N_batch). The re-order indices from 0 to N_batch-1. 81 | # inp_exp: shape=inp.shape. The input Tensor re-ordered by self._batch_index along the batch dimension. 82 | # self._part_sizes: shape=(N_experts), sum=N_batch. self._part_sizes[i] is the number of samples routed towards expert[i]. 83 | # return value: list [Tensor with shape[0]=self._part_sizes[i] for i in range(N_experts)] 84 | 85 | # assigns samples to experts whose gate is nonzero 86 | 87 | # expand according to batch index so we can just split by _part_sizes 88 | inp_exp = inp[self._batch_index].squeeze(1) 89 | edge_index_exp = edge_index[:,self._batch_index] 90 | edge_attr_exp = edge_attr[self._batch_index] 91 | return torch.split(inp_exp, self._part_sizes, dim=0), torch.split(edge_index_exp, self._part_sizes, dim=1), torch.split(edge_attr_exp, self._part_sizes, dim=0) 92 | 93 | def combine(self, expert_out, multiply_by_gates=True): 94 | """Sum together the expert output, weighted by the gates. 95 | The slice corresponding to a particular batch element `b` is computed 96 | as the sum over all experts `i` of the expert output, weighted by the 97 | corresponding gate values. If `multiply_by_gates` is set to False, the 98 | gate values are ignored. 99 | Args: 100 | expert_out: a list of `num_experts` `Tensor`s, each with shape 101 | `[expert_batch_size_i, ]`. 102 | multiply_by_gates: a boolean 103 | Returns: 104 | a `Tensor` with shape `[batch_size, ]`. 105 | """ 106 | # apply exp to expert outputs, so we are not longer in log space 107 | stitched = torch.cat(expert_out, 0).exp() 108 | 109 | if multiply_by_gates: 110 | stitched = stitched.mul(self._nonzero_gates) 111 | zeros = torch.zeros(self._gates.size(0), expert_out[-1].size(1), requires_grad=True, device=stitched.device) 112 | # combine samples that have been processed by the same k experts 113 | combined = zeros.index_add(0, self._batch_index, stitched.float()) 114 | # add eps to all zero values in order to avoid nans when going back to log space 115 | combined[combined == 0] = np.finfo(float).eps 116 | # back to log space 117 | return combined.log() 118 | 119 | def expert_to_gates(self): 120 | """Gate values corresponding to the examples in the per-expert `Tensor`s. 121 | Returns: 122 | a list of `num_experts` one-dimensional `Tensor`s with type `tf.float32` 123 | and shapes `[expert_batch_size_i]` 124 | """ 125 | # split nonzero gates for each expert 126 | return torch.split(self._nonzero_gates, self._part_sizes, dim=0) 127 | 128 | class MoE(nn.Module): 129 | 130 | """Call a Sparsely gated mixture of experts layer with 1-layer Feed-Forward networks as experts. 131 | Args: 132 | input_size: integer - size of the input 133 | output_size: integer - size of the input 134 | num_experts: an integer - number of experts 135 | hidden_size: an integer - hidden size of the experts 136 | noisy_gating: a boolean 137 | k: an integer - how many experts to use for each batch element 138 | """ 139 | 140 | def __init__(self, input_size, output_size, num_experts, experts_conv, experts_bn, noisy_gating=True, k=4, coef=1e-2, num_experts_1hop=None): 141 | super(MoE, self).__init__() 142 | self.noisy_gating = noisy_gating 143 | self.num_experts = num_experts 144 | self.output_size = output_size 145 | self.input_size = input_size 146 | self.k = k 147 | self.loss_coef = coef 148 | if not num_experts_1hop: 149 | self.num_experts_1hop = num_experts # by default, all experts are hop-1 experts. 150 | else: 151 | assert num_experts_1hop <= num_experts 152 | self.num_experts_1hop = num_experts_1hop 153 | # instantiate experts 154 | # self.experts = nn.ModuleList([MLP(self.input_size, self.output_size, self.hidden_size) for i in range(self.num_experts)]) 155 | self.experts_conv = experts_conv 156 | self.experts_bn = experts_bn 157 | self.w_gate = nn.Parameter(torch.zeros(input_size, num_experts), requires_grad=True) 158 | self.w_noise = nn.Parameter(torch.zeros(input_size, num_experts), requires_grad=True) 159 | 160 | self.softplus = nn.Softplus() 161 | self.softmax = nn.Softmax(1) 162 | self.register_buffer("mean", torch.tensor([0.0])) 163 | self.register_buffer("std", torch.tensor([1.0])) 164 | assert(self.k <= self.num_experts) 165 | 166 | def cv_squared(self, x): 167 | """The squared coefficient of variation of a sample. 168 | Useful as a loss to encourage a positive distribution to be more uniform. 169 | Epsilons added for numerical stability. 170 | Returns 0 for an empty Tensor. 171 | Args: 172 | x: a `Tensor`. 173 | Returns: 174 | a `Scalar`. 175 | """ 176 | eps = 1e-10 177 | # if only num_experts = 1 178 | 179 | if x.shape[0] == 1: 180 | return torch.tensor([0], device=x.device, dtype=x.dtype) 181 | return x.float().var() / (x.float().mean()**2 + eps) 182 | 183 | def _gates_to_load(self, gates): 184 | """Compute the true load per expert, given the gates. 185 | The load is the number of examples for which the corresponding gate is >0. 186 | Args: 187 | gates: a `Tensor` of shape [batch_size, n] 188 | Returns: 189 | a float32 `Tensor` of shape [n] 190 | """ 191 | return (gates > 0).sum(0) 192 | 193 | def _prob_in_top_k(self, clean_values, noisy_values, noise_stddev, noisy_top_values): 194 | """Helper function to NoisyTopKGating. 195 | Computes the probability that value is in top k, given different random noise. 196 | This gives us a way of backpropagating from a loss that balances the number 197 | of times each expert is in the top k experts per example. 198 | In the case of no noise, pass in None for noise_stddev, and the result will 199 | not be differentiable. 200 | Args: 201 | clean_values: a `Tensor` of shape [batch, n]. 202 | noisy_values: a `Tensor` of shape [batch, n]. Equal to clean values plus 203 | normally distributed noise with standard deviation noise_stddev. 204 | noise_stddev: a `Tensor` of shape [batch, n], or None 205 | noisy_top_values: a `Tensor` of shape [batch, m]. 206 | "values" Output of tf.top_k(noisy_top_values, m). m >= k+1 207 | Returns: 208 | a `Tensor` of shape [batch, n]. 209 | """ 210 | batch = clean_values.size(0) 211 | m = noisy_top_values.size(1) 212 | top_values_flat = noisy_top_values.flatten() 213 | 214 | threshold_positions_if_in = torch.arange(batch, device=clean_values.device) * m + self.k 215 | threshold_if_in = torch.unsqueeze(torch.gather(top_values_flat, 0, threshold_positions_if_in), 1) 216 | is_in = torch.gt(noisy_values, threshold_if_in) 217 | threshold_positions_if_out = threshold_positions_if_in - 1 218 | threshold_if_out = torch.unsqueeze(torch.gather(top_values_flat, 0, threshold_positions_if_out), 1) 219 | # is each value currently in the top k. 220 | normal = Normal(self.mean, self.std) 221 | prob_if_in = normal.cdf((clean_values - threshold_if_in)/noise_stddev) 222 | prob_if_out = normal.cdf((clean_values - threshold_if_out)/noise_stddev) 223 | prob = torch.where(is_in, prob_if_in, prob_if_out) 224 | return prob 225 | 226 | def noisy_top_k_gating(self, x, train, noise_epsilon=1e-2): 227 | """Noisy top-k gating. 228 | See paper: https://arxiv.org/abs/1701.06538. 229 | Args: 230 | x: input Tensor with shape [batch_size, input_size] 231 | train: a boolean - we only add noise at training time. 232 | noise_epsilon: a float 233 | Returns: 234 | gates: a Tensor with shape [batch_size, num_experts] 235 | load: a Tensor with shape [num_experts] 236 | """ 237 | clean_logits = x @ self.w_gate 238 | if self.noisy_gating and train: 239 | raw_noise_stddev = x @ self.w_noise 240 | noise_stddev = ((self.softplus(raw_noise_stddev) + noise_epsilon)) 241 | noisy_logits = clean_logits + (torch.randn_like(clean_logits) * noise_stddev) 242 | logits = noisy_logits 243 | else: 244 | logits = clean_logits 245 | 246 | # calculate topk + 1 that will be needed for the noisy gates 247 | top_logits, top_indices = logits.topk(min(self.k + 1, self.num_experts), dim=1) 248 | top_k_logits = top_logits[:, :self.k] 249 | top_k_indices = top_indices[:, :self.k] 250 | top_k_gates = self.softmax(top_k_logits) 251 | 252 | zeros = torch.zeros_like(logits, requires_grad=True) 253 | gates = zeros.scatter(1, top_k_indices, top_k_gates) 254 | 255 | if self.noisy_gating and self.k < self.num_experts and train: 256 | load = (self._prob_in_top_k(clean_logits, noisy_logits, noise_stddev, top_logits)).sum(0) 257 | else: 258 | load = self._gates_to_load(gates) 259 | return gates, load 260 | 261 | def forward(self, x, edge_index, edge_attr, edge_index_2hop=None, edge_attr_2hop=None): 262 | """Args: 263 | x: tensor shape [batch_size, input_size] 264 | train: a boolean scalar. 265 | loss_coef: a scalar - multiplier on load-balancing losses 266 | 267 | Returns: 268 | y: a tensor with shape [batch_size, output_size]. 269 | extra_training_loss: a scalar. This should be added into the overall 270 | training loss of the model. The backpropagation of this loss 271 | encourages all experts to be approximately equally used across a batch. 272 | """ 273 | gates, load = self.noisy_top_k_gating(x, self.training) 274 | # calculate importance loss 275 | importance = gates.sum(0) 276 | # 277 | loss = self.cv_squared(importance) + self.cv_squared(load) 278 | loss *= self.loss_coef 279 | 280 | 281 | 282 | expert_outputs = [] 283 | for i in range(self.num_experts): 284 | if i < self.num_experts_1hop: 285 | expert_i_output = self.experts_conv[i](x, edge_index, edge_attr) 286 | else: 287 | expert_i_output = self.experts_conv[i](x, edge_index_2hop, edge_attr_2hop) 288 | expert_i_output = self.experts_bn[i](expert_i_output) 289 | expert_outputs.append(expert_i_output) 290 | expert_outputs = torch.stack(expert_outputs, dim=1) # shape=[num_nodes, num_experts, d_feature] 291 | 292 | # gates: shape=[num_nodes, num_experts] 293 | y = gates.unsqueeze(dim=-1) * expert_outputs 294 | y = y.mean(dim=1) 295 | 296 | return y, loss 297 | 298 | if __name__ == '__main__': 299 | class MLP(nn.Module): 300 | def __init__(self, input_size, output_size, hidden_size): 301 | super(MLP, self).__init__() 302 | self.fc1 = nn.Linear(input_size, hidden_size) 303 | self.fc2 = nn.Linear(hidden_size, output_size) 304 | self.relu = nn.ReLU() 305 | self.soft = nn.Softmax(1) 306 | 307 | def forward(self, x): 308 | out = self.fc1(x) 309 | out = self.relu(out) 310 | out = self.fc2(out) 311 | out = self.soft(out) 312 | return out 313 | 314 | moe_model = MoE(10, 10, 3, [MLP(10,10,10), MLP(10,10,10), MLP(10,10,10)], k=1) 315 | x = torch.ones((8,10)) 316 | h = moe_model(x) 317 | print(h.shape) -------------------------------------------------------------------------------- /graphproppred/test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_geometric.loader import DataLoader 3 | import torch.optim as optim 4 | import torch.nn.functional as F 5 | from gnn import GNN 6 | 7 | from tqdm import tqdm 8 | import argparse 9 | import time 10 | import numpy as np 11 | import pickle 12 | 13 | ### importing OGB 14 | from ogb.graphproppred import PygGraphPropPredDataset, Evaluator 15 | 16 | cls_criterion = torch.nn.BCEWithLogitsLoss() 17 | reg_criterion = torch.nn.MSELoss() 18 | 19 | def eval(model, device, loader, evaluator): 20 | model.eval() 21 | y_true = [] 22 | y_pred = [] 23 | 24 | for step, batch in enumerate(tqdm(loader, desc="Iteration")): 25 | batch = batch.to(device) 26 | 27 | if batch.x.shape[0] == 1: 28 | pass 29 | else: 30 | with torch.no_grad(): 31 | pred = model(batch) 32 | 33 | y_true.append(batch.y.view(pred.shape).detach().cpu()) 34 | y_pred.append(pred.detach().cpu()) 35 | 36 | y_true = torch.cat(y_true, dim = 0).numpy() 37 | y_pred = torch.cat(y_pred, dim = 0).numpy() 38 | 39 | input_dict = {"y_true": y_true, "y_pred": y_pred} 40 | 41 | return evaluator.eval(input_dict) 42 | 43 | 44 | def main(): 45 | # Training settings 46 | parser = argparse.ArgumentParser(description='GNN baselines on ogbgmol* data with Pytorch Geometrics') 47 | parser.add_argument('--device', type=int, default=0, 48 | help='which gpu to use if any (default: 0)') 49 | parser.add_argument('--gnn', type=str, default='gcn-spmoe', 50 | help='GNN gin, gin-virtual, or gcn, or gcn-virtual (default: gin-virtual)') 51 | parser.add_argument('--drop_ratio', type=float, default=0.5, 52 | help='dropout ratio (default: 0.5)') 53 | parser.add_argument('--num_layer', type=int, default=5, 54 | help='number of GNN message passing layers (default: 5)') 55 | parser.add_argument('--emb_dim', type=int, default=300, 56 | help='dimensionality of hidden units in GNNs (default: 300)') 57 | parser.add_argument('--batch_size', type=int, default=32, 58 | help='input batch size for training (default: 32)') 59 | parser.add_argument('--epochs', type=int, default=100, 60 | help='number of epochs to train (default: 100)') 61 | parser.add_argument('--num_workers', type=int, default=0, 62 | help='number of workers (default: 0)') 63 | parser.add_argument('--dataset', type=str, default="ogbg-molhiv", 64 | help='dataset name (default: ogbg-molhiv)') 65 | 66 | parser.add_argument('--hop', type=int, default=2, 67 | help='number of GCN hops') 68 | 69 | parser.add_argument('--feature', type=str, default="full", 70 | help='full feature or simple feature') 71 | parser.add_argument('--filename', type=str, default="", 72 | help='filename to output result (default: )') 73 | args = parser.parse_args() 74 | 75 | device = torch.device("cuda:" + str(args.device)) if torch.cuda.is_available() else torch.device("cpu") 76 | 77 | ### automatic dataloading and splitting 78 | # dataset = PygGraphPropPredDataset(name = args.dataset) 79 | dataset = pickle.load(open('/home/haotao/GNN-MoE/mol/two_hop_dataset.pkl', 'rb')) 80 | 81 | if args.feature == 'full': 82 | pass 83 | elif args.feature == 'simple': 84 | print('using simple feature') 85 | # only retain the top two node/edge features 86 | dataset.data.x = dataset.data.x[:,:2] 87 | dataset.data.edge_attr = dataset.data.edge_attr[:,:2] 88 | 89 | split_idx = dataset.get_idx_split() 90 | 91 | ### automatic evaluator. takes dataset name as input 92 | evaluator = Evaluator(args.dataset) 93 | 94 | train_loader = DataLoader(dataset[split_idx["train"]], batch_size=args.batch_size, shuffle=True, num_workers = args.num_workers) 95 | valid_loader = DataLoader(dataset[split_idx["valid"]], batch_size=args.batch_size, shuffle=False, num_workers = args.num_workers) 96 | test_loader = DataLoader(dataset[split_idx["test"]], batch_size=args.batch_size, shuffle=False, num_workers = args.num_workers) 97 | 98 | if args.gnn == 'gin': 99 | model = GNN(gnn_type = 'gin', num_tasks = dataset.num_tasks, num_layer = args.num_layer, emb_dim = args.emb_dim, drop_ratio = args.drop_ratio, virtual_node = False).to(device) 100 | elif args.gnn == 'gin-virtual': 101 | model = GNN(gnn_type = 'gin', num_tasks = dataset.num_tasks, num_layer = args.num_layer, emb_dim = args.emb_dim, drop_ratio = args.drop_ratio, virtual_node = True).to(device) 102 | elif args.gnn == 'gcn': 103 | model = GNN(gnn_type = 'gcn', num_tasks = dataset.num_tasks, num_layer = args.num_layer, emb_dim = args.emb_dim, drop_ratio = args.drop_ratio, virtual_node = False, hop=args.hop).to(device) 104 | elif args.gnn == 'gcn-spmoe': 105 | model = GNN(gnn_type = 'gcn', num_tasks = dataset.num_tasks, num_layer = args.num_layer, emb_dim = args.emb_dim, drop_ratio = args.drop_ratio, virtual_node = False, moe='sparse', hop=args.hop).to(device) 106 | elif args.gnn == 'gcn-virtual': 107 | model = GNN(gnn_type = 'gcn', num_tasks = dataset.num_tasks, num_layer = args.num_layer, emb_dim = args.emb_dim, drop_ratio = args.drop_ratio, virtual_node = True).to(device) 108 | else: 109 | raise ValueError('Invalid GNN type') 110 | 111 | test_perf = eval(model, device, test_loader, evaluator) 112 | 113 | 114 | if __name__ == "__main__": 115 | main() 116 | -------------------------------------------------------------------------------- /graphproppred/utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | from typing import Tuple 3 | import torch 4 | 5 | from torch_geometric.data import Data 6 | from torch_geometric.data.datapipes import functional_transform 7 | from torch_geometric.transforms import BaseTransform 8 | from torch_geometric.utils import dropout_node, subgraph 9 | 10 | 11 | @functional_transform('random_node_drop') 12 | class RandomDropNode(BaseTransform): 13 | r"""Randomly drop nodes from a graph with ratio p for graph prediction task only. 14 | 15 | Args: 16 | p (float): randomly drop nodes with probability p. 17 | """ 18 | def __init__(self, p: float): 19 | assert isinstance(p, float) and 0 <= p <=1 20 | self.p = p 21 | 22 | def __call__(self, data: Data) -> Data: 23 | if data.x.size(0) < 5: 24 | if 'num_nodes' not in data.keys: 25 | import numpy as np 26 | print('found bad data') 27 | np.save('bad_data.npy', data) 28 | return Data(x=data.x, edge_index=data.edge_index, edge_attr=data.edge_attr, y=data.y) 29 | # edge_index, edge_mask, node_mask = dropout_node(data.edge_index, p=self.p, num_nodes=data.num_nodes) 30 | node_mask = (torch.empty(data.x.size(0)).uniform_(0, 1) > self.p) 31 | # node_mask = (torch.rand(1).item() > self.p) 32 | new_edge_index, new_edge_attr = subgraph(node_mask, data.edge_index, data.edge_attr, relabel_nodes =True) 33 | new_data = Data(x=data.x[node_mask], edge_index=new_edge_index, edge_attr=new_edge_attr, y=data.y) # TODO: Now for graph prediction task only. 34 | if torch.sum(node_mask)==0: 35 | import numpy as np 36 | print('found bad data') 37 | np.save('bad_data.npy', data) 38 | return Data(x=data.x, edge_index=data.edge_index, edge_attr=data.edge_attr, y=data.y) 39 | else: 40 | return new_data 41 | 42 | def __repr__(self) -> str: 43 | return f'{self.__class__.__name__}({self.p})' -------------------------------------------------------------------------------- /linkproppred/ddi/gnn.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | from torch.utils.data import DataLoader 6 | from torch_geometric.utils import negative_sampling 7 | 8 | import torch_geometric.transforms as T 9 | from torch_geometric.nn import GCNConv, SAGEConv 10 | 11 | from ogb.linkproppred import PygLinkPropPredDataset, Evaluator 12 | 13 | from logger import Logger 14 | 15 | 16 | class GCN(torch.nn.Module): 17 | def __init__(self, in_channels, hidden_channels, out_channels, num_layers, 18 | dropout): 19 | super(GCN, self).__init__() 20 | 21 | self.convs = torch.nn.ModuleList() 22 | self.convs.append(GCNConv(in_channels, hidden_channels, cached=True)) 23 | for _ in range(num_layers - 2): 24 | self.convs.append( 25 | GCNConv(hidden_channels, hidden_channels, cached=True)) 26 | self.convs.append(GCNConv(hidden_channels, out_channels, cached=True)) 27 | 28 | self.dropout = dropout 29 | 30 | def reset_parameters(self): 31 | for conv in self.convs: 32 | conv.reset_parameters() 33 | 34 | def forward(self, x, adj_t): 35 | for conv in self.convs[:-1]: 36 | x = conv(x, adj_t) 37 | x = F.relu(x) 38 | x = F.dropout(x, p=self.dropout, training=self.training) 39 | x = self.convs[-1](x, adj_t) 40 | return x 41 | 42 | 43 | class SAGE(torch.nn.Module): 44 | def __init__(self, in_channels, hidden_channels, out_channels, num_layers, 45 | dropout): 46 | super(SAGE, self).__init__() 47 | 48 | self.convs = torch.nn.ModuleList() 49 | self.convs.append(SAGEConv(in_channels, hidden_channels)) 50 | for _ in range(num_layers - 2): 51 | self.convs.append(SAGEConv(hidden_channels, hidden_channels)) 52 | self.convs.append(SAGEConv(hidden_channels, out_channels)) 53 | 54 | self.dropout = dropout 55 | 56 | def reset_parameters(self): 57 | for conv in self.convs: 58 | conv.reset_parameters() 59 | 60 | def forward(self, x, adj_t): 61 | for conv in self.convs[:-1]: 62 | x = conv(x, adj_t) 63 | x = F.relu(x) 64 | x = F.dropout(x, p=self.dropout, training=self.training) 65 | x = self.convs[-1](x, adj_t) 66 | return x 67 | 68 | 69 | class LinkPredictor(torch.nn.Module): 70 | def __init__(self, in_channels, hidden_channels, out_channels, num_layers, 71 | dropout): 72 | super(LinkPredictor, self).__init__() 73 | 74 | self.lins = torch.nn.ModuleList() 75 | self.lins.append(torch.nn.Linear(in_channels, hidden_channels)) 76 | for _ in range(num_layers - 2): 77 | self.lins.append(torch.nn.Linear(hidden_channels, hidden_channels)) 78 | self.lins.append(torch.nn.Linear(hidden_channels, out_channels)) 79 | 80 | self.dropout = dropout 81 | 82 | def reset_parameters(self): 83 | for lin in self.lins: 84 | lin.reset_parameters() 85 | 86 | def forward(self, x_i, x_j): 87 | x = x_i * x_j 88 | for lin in self.lins[:-1]: 89 | x = lin(x) 90 | x = F.relu(x) 91 | x = F.dropout(x, p=self.dropout, training=self.training) 92 | x = self.lins[-1](x) 93 | return torch.sigmoid(x) 94 | 95 | 96 | def train(model, predictor, x, adj_t, split_edge, optimizer, batch_size): 97 | 98 | row, col, _ = adj_t.coo() 99 | edge_index = torch.stack([col, row], dim=0) 100 | 101 | model.train() 102 | predictor.train() 103 | 104 | pos_train_edge = split_edge['train']['edge'].to(x.device) 105 | 106 | total_loss = total_examples = 0 107 | for perm in DataLoader(range(pos_train_edge.size(0)), batch_size, 108 | shuffle=True): 109 | optimizer.zero_grad() 110 | 111 | h = model(x, adj_t) 112 | 113 | edge = pos_train_edge[perm].t() 114 | 115 | pos_out = predictor(h[edge[0]], h[edge[1]]) 116 | pos_loss = -torch.log(pos_out + 1e-15).mean() 117 | 118 | edge = negative_sampling(edge_index, num_nodes=x.size(0), 119 | num_neg_samples=perm.size(0), method='dense') 120 | 121 | neg_out = predictor(h[edge[0]], h[edge[1]]) 122 | neg_loss = -torch.log(1 - neg_out + 1e-15).mean() 123 | 124 | loss = pos_loss + neg_loss 125 | loss.backward() 126 | 127 | torch.nn.utils.clip_grad_norm_(x, 1.0) 128 | torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) 129 | torch.nn.utils.clip_grad_norm_(predictor.parameters(), 1.0) 130 | 131 | optimizer.step() 132 | 133 | num_examples = pos_out.size(0) 134 | total_loss += loss.item() * num_examples 135 | total_examples += num_examples 136 | 137 | return total_loss / total_examples 138 | 139 | 140 | @torch.no_grad() 141 | def test(model, predictor, x, adj_t, split_edge, evaluator, batch_size): 142 | model.eval() 143 | predictor.eval() 144 | 145 | h = model(x, adj_t) 146 | 147 | pos_train_edge = split_edge['eval_train']['edge'].to(x.device) 148 | pos_valid_edge = split_edge['valid']['edge'].to(x.device) 149 | neg_valid_edge = split_edge['valid']['edge_neg'].to(x.device) 150 | pos_test_edge = split_edge['test']['edge'].to(x.device) 151 | neg_test_edge = split_edge['test']['edge_neg'].to(x.device) 152 | 153 | pos_train_preds = [] 154 | for perm in DataLoader(range(pos_train_edge.size(0)), batch_size): 155 | edge = pos_train_edge[perm].t() 156 | pos_train_preds += [predictor(h[edge[0]], h[edge[1]]).squeeze().cpu()] 157 | pos_train_pred = torch.cat(pos_train_preds, dim=0) 158 | 159 | pos_valid_preds = [] 160 | for perm in DataLoader(range(pos_valid_edge.size(0)), batch_size): 161 | edge = pos_valid_edge[perm].t() 162 | pos_valid_preds += [predictor(h[edge[0]], h[edge[1]]).squeeze().cpu()] 163 | pos_valid_pred = torch.cat(pos_valid_preds, dim=0) 164 | 165 | neg_valid_preds = [] 166 | for perm in DataLoader(range(neg_valid_edge.size(0)), batch_size): 167 | edge = neg_valid_edge[perm].t() 168 | neg_valid_preds += [predictor(h[edge[0]], h[edge[1]]).squeeze().cpu()] 169 | neg_valid_pred = torch.cat(neg_valid_preds, dim=0) 170 | 171 | pos_test_preds = [] 172 | for perm in DataLoader(range(pos_test_edge.size(0)), batch_size): 173 | edge = pos_test_edge[perm].t() 174 | pos_test_preds += [predictor(h[edge[0]], h[edge[1]]).squeeze().cpu()] 175 | pos_test_pred = torch.cat(pos_test_preds, dim=0) 176 | 177 | neg_test_preds = [] 178 | for perm in DataLoader(range(neg_test_edge.size(0)), batch_size): 179 | edge = neg_test_edge[perm].t() 180 | neg_test_preds += [predictor(h[edge[0]], h[edge[1]]).squeeze().cpu()] 181 | neg_test_pred = torch.cat(neg_test_preds, dim=0) 182 | 183 | results = {} 184 | for K in [10, 20, 30]: 185 | evaluator.K = K 186 | train_hits = evaluator.eval({ 187 | 'y_pred_pos': pos_train_pred, 188 | 'y_pred_neg': neg_valid_pred, 189 | })[f'hits@{K}'] 190 | valid_hits = evaluator.eval({ 191 | 'y_pred_pos': pos_valid_pred, 192 | 'y_pred_neg': neg_valid_pred, 193 | })[f'hits@{K}'] 194 | test_hits = evaluator.eval({ 195 | 'y_pred_pos': pos_test_pred, 196 | 'y_pred_neg': neg_test_pred, 197 | })[f'hits@{K}'] 198 | 199 | results[f'Hits@{K}'] = (train_hits, valid_hits, test_hits) 200 | 201 | return results 202 | 203 | 204 | def main(): 205 | parser = argparse.ArgumentParser(description='OGBL-DDI (GNN)') 206 | parser.add_argument('--device', type=int, default=0) 207 | parser.add_argument('--log_steps', type=int, default=1) 208 | parser.add_argument('--use_sage', action='store_true') 209 | parser.add_argument('--num_layers', type=int, default=2) 210 | parser.add_argument('--hidden_channels', type=int, default=256) 211 | parser.add_argument('--dropout', type=float, default=0.5) 212 | parser.add_argument('--batch_size', type=int, default=64 * 1024) 213 | parser.add_argument('--lr', type=float, default=0.005) 214 | parser.add_argument('--epochs', type=int, default=200) 215 | parser.add_argument('--eval_steps', type=int, default=5) 216 | parser.add_argument('--runs', type=int, default=10) 217 | args = parser.parse_args() 218 | print(args) 219 | 220 | device = f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu' 221 | device = torch.device(device) 222 | 223 | dataset = PygLinkPropPredDataset(name='ogbl-ddi', 224 | transform=T.ToSparseTensor()) 225 | data = dataset[0] 226 | adj_t = data.adj_t.to(device) 227 | 228 | split_edge = dataset.get_edge_split() 229 | 230 | # We randomly pick some training samples that we want to evaluate on: 231 | torch.manual_seed(12345) 232 | idx = torch.randperm(split_edge['train']['edge'].size(0)) 233 | idx = idx[:split_edge['valid']['edge'].size(0)] 234 | split_edge['eval_train'] = {'edge': split_edge['train']['edge'][idx]} 235 | 236 | if args.use_sage: 237 | model = SAGE(args.hidden_channels, args.hidden_channels, 238 | args.hidden_channels, args.num_layers, 239 | args.dropout).to(device) 240 | else: 241 | model = GCN(args.hidden_channels, args.hidden_channels, 242 | args.hidden_channels, args.num_layers, 243 | args.dropout).to(device) 244 | 245 | emb = torch.nn.Embedding(data.adj_t.size(0), 246 | args.hidden_channels).to(device) 247 | predictor = LinkPredictor(args.hidden_channels, args.hidden_channels, 1, 248 | args.num_layers, args.dropout).to(device) 249 | 250 | evaluator = Evaluator(name='ogbl-ddi') 251 | loggers = { 252 | 'Hits@10': Logger(args.runs, args), 253 | 'Hits@20': Logger(args.runs, args), 254 | 'Hits@30': Logger(args.runs, args), 255 | } 256 | 257 | for run in range(args.runs): 258 | torch.nn.init.xavier_uniform_(emb.weight) 259 | model.reset_parameters() 260 | predictor.reset_parameters() 261 | optimizer = torch.optim.Adam( 262 | list(model.parameters()) + list(emb.parameters()) + 263 | list(predictor.parameters()), lr=args.lr) 264 | 265 | for epoch in range(1, 1 + args.epochs): 266 | loss = train(model, predictor, emb.weight, adj_t, split_edge, 267 | optimizer, args.batch_size) 268 | 269 | if epoch % args.eval_steps == 0: 270 | results = test(model, predictor, emb.weight, adj_t, split_edge, 271 | evaluator, args.batch_size) 272 | for key, result in results.items(): 273 | loggers[key].add_result(run, result) 274 | 275 | if epoch % args.log_steps == 0: 276 | for key, result in results.items(): 277 | train_hits, valid_hits, test_hits = result 278 | print(key) 279 | print(f'Run: {run + 1:02d}, ' 280 | f'Epoch: {epoch:02d}, ' 281 | f'Loss: {loss:.4f}, ' 282 | f'Train: {100 * train_hits:.2f}%, ' 283 | f'Valid: {100 * valid_hits:.2f}%, ' 284 | f'Test: {100 * test_hits:.2f}%') 285 | print('---') 286 | 287 | for key in loggers.keys(): 288 | print(key) 289 | loggers[key].print_statistics(run) 290 | 291 | for key in loggers.keys(): 292 | print(key) 293 | loggers[key].print_statistics() 294 | 295 | 296 | if __name__ == "__main__": 297 | main() 298 | -------------------------------------------------------------------------------- /linkproppred/ddi/logger.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class Logger(object): 5 | def __init__(self, runs, info=None): 6 | self.info = info 7 | self.results = [[] for _ in range(runs)] 8 | 9 | def add_result(self, run, result): 10 | assert len(result) == 3 11 | assert run >= 0 and run < len(self.results) 12 | self.results[run].append(result) 13 | 14 | def print_statistics(self, run=None): 15 | if run is not None: 16 | result = 100 * torch.tensor(self.results[run]) 17 | argmax = result[:, 1].argmax().item() 18 | print(f'Run {run + 1:02d}:') 19 | print(f'Highest Train: {result[:, 0].max():.2f}') 20 | print(f'Highest Valid: {result[:, 1].max():.2f}') 21 | print(f' Final Train: {result[argmax, 0]:.2f}') 22 | print(f' Final Test: {result[argmax, 2]:.2f}') 23 | else: 24 | result = 100 * torch.tensor(self.results) 25 | 26 | best_results = [] 27 | for r in result: 28 | train1 = r[:, 0].max().item() 29 | valid = r[:, 1].max().item() 30 | train2 = r[r[:, 1].argmax(), 0].item() 31 | test = r[r[:, 1].argmax(), 2].item() 32 | best_results.append((train1, valid, train2, test)) 33 | 34 | best_result = torch.tensor(best_results) 35 | 36 | print(f'All runs:') 37 | r = best_result[:, 0] 38 | print(f'Highest Train: {r.mean():.2f} ± {r.std():.2f}') 39 | r = best_result[:, 1] 40 | print(f'Highest Valid: {r.mean():.2f} ± {r.std():.2f}') 41 | r = best_result[:, 2] 42 | print(f' Final Train: {r.mean():.2f} ± {r.std():.2f}') 43 | r = best_result[:, 3] 44 | print(f' Final Test: {r.mean():.2f} ± {r.std():.2f}') 45 | -------------------------------------------------------------------------------- /linkproppred/ddi/mf.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | from torch.utils.data import DataLoader 6 | from torch_geometric.utils import negative_sampling 7 | 8 | from ogb.linkproppred import PygLinkPropPredDataset, Evaluator 9 | 10 | from logger import Logger 11 | 12 | 13 | class LinkPredictor(torch.nn.Module): 14 | def __init__(self, in_channels, hidden_channels, out_channels, num_layers, 15 | dropout): 16 | super(LinkPredictor, self).__init__() 17 | 18 | self.lins = torch.nn.ModuleList() 19 | self.lins.append(torch.nn.Linear(in_channels, hidden_channels)) 20 | for _ in range(num_layers - 2): 21 | self.lins.append(torch.nn.Linear(hidden_channels, hidden_channels)) 22 | self.lins.append(torch.nn.Linear(hidden_channels, out_channels)) 23 | 24 | self.dropout = dropout 25 | 26 | def reset_parameters(self): 27 | for lin in self.lins: 28 | lin.reset_parameters() 29 | 30 | def forward(self, x_i, x_j): 31 | x = x_i * x_j 32 | for lin in self.lins[:-1]: 33 | x = lin(x) 34 | x = F.relu(x) 35 | x = F.dropout(x, p=self.dropout, training=self.training) 36 | x = self.lins[-1](x) 37 | return torch.sigmoid(x) 38 | 39 | 40 | def train(predictor, x, edge_index, split_edge, optimizer, batch_size): 41 | predictor.train() 42 | 43 | pos_train_edge = split_edge['train']['edge'].to(x.device) 44 | 45 | total_loss = total_examples = 0 46 | for perm in DataLoader(range(pos_train_edge.size(0)), batch_size, 47 | shuffle=True): 48 | optimizer.zero_grad() 49 | 50 | edge = pos_train_edge[perm].t() 51 | 52 | pos_out = predictor(x[edge[0]], x[edge[1]]) 53 | pos_loss = -torch.log(pos_out + 1e-15).mean() 54 | 55 | edge = negative_sampling(edge_index, num_nodes=x.size(0), 56 | num_neg_samples=perm.size(0), method='dense') 57 | 58 | neg_out = predictor(x[edge[0]], x[edge[1]]) 59 | neg_loss = -torch.log(1 - neg_out + 1e-15).mean() 60 | 61 | loss = pos_loss + neg_loss 62 | loss.backward() 63 | optimizer.step() 64 | 65 | num_examples = pos_out.size(0) 66 | total_loss += loss.item() * num_examples 67 | total_examples += num_examples 68 | 69 | return total_loss / total_examples 70 | 71 | 72 | @torch.no_grad() 73 | def test(predictor, x, split_edge, evaluator, batch_size): 74 | predictor.eval() 75 | 76 | pos_train_edge = split_edge['eval_train']['edge'].to(x.device) 77 | pos_valid_edge = split_edge['valid']['edge'].to(x.device) 78 | neg_valid_edge = split_edge['valid']['edge_neg'].to(x.device) 79 | pos_test_edge = split_edge['test']['edge'].to(x.device) 80 | neg_test_edge = split_edge['test']['edge_neg'].to(x.device) 81 | 82 | pos_train_preds = [] 83 | for perm in DataLoader(range(pos_train_edge.size(0)), batch_size): 84 | edge = pos_train_edge[perm].t() 85 | pos_train_preds += [predictor(x[edge[0]], x[edge[1]]).squeeze().cpu()] 86 | pos_train_pred = torch.cat(pos_train_preds, dim=0) 87 | 88 | pos_valid_preds = [] 89 | for perm in DataLoader(range(pos_valid_edge.size(0)), batch_size): 90 | edge = pos_valid_edge[perm].t() 91 | pos_valid_preds += [predictor(x[edge[0]], x[edge[1]]).squeeze().cpu()] 92 | pos_valid_pred = torch.cat(pos_valid_preds, dim=0) 93 | 94 | neg_valid_preds = [] 95 | for perm in DataLoader(range(neg_valid_edge.size(0)), batch_size): 96 | edge = neg_valid_edge[perm].t() 97 | neg_valid_preds += [predictor(x[edge[0]], x[edge[1]]).squeeze().cpu()] 98 | neg_valid_pred = torch.cat(neg_valid_preds, dim=0) 99 | 100 | pos_test_preds = [] 101 | for perm in DataLoader(range(pos_test_edge.size(0)), batch_size): 102 | edge = pos_test_edge[perm].t() 103 | pos_test_preds += [predictor(x[edge[0]], x[edge[1]]).squeeze().cpu()] 104 | pos_test_pred = torch.cat(pos_test_preds, dim=0) 105 | 106 | neg_test_preds = [] 107 | for perm in DataLoader(range(neg_test_edge.size(0)), batch_size): 108 | edge = neg_test_edge[perm].t() 109 | neg_test_preds += [predictor(x[edge[0]], x[edge[1]]).squeeze().cpu()] 110 | neg_test_pred = torch.cat(neg_test_preds, dim=0) 111 | 112 | results = {} 113 | for K in [10, 20, 30]: 114 | evaluator.K = K 115 | train_hits = evaluator.eval({ 116 | 'y_pred_pos': pos_train_pred, 117 | 'y_pred_neg': neg_valid_pred, 118 | })[f'hits@{K}'] 119 | valid_hits = evaluator.eval({ 120 | 'y_pred_pos': pos_valid_pred, 121 | 'y_pred_neg': neg_valid_pred, 122 | })[f'hits@{K}'] 123 | test_hits = evaluator.eval({ 124 | 'y_pred_pos': pos_test_pred, 125 | 'y_pred_neg': neg_test_pred, 126 | })[f'hits@{K}'] 127 | 128 | results[f'Hits@{K}'] = (train_hits, valid_hits, test_hits) 129 | 130 | return results 131 | 132 | 133 | def main(): 134 | parser = argparse.ArgumentParser(description='OGBL-DDI (MF)') 135 | parser.add_argument('--device', type=int, default=0) 136 | parser.add_argument('--log_steps', type=int, default=1) 137 | parser.add_argument('--num_layers', type=int, default=3) 138 | parser.add_argument('--hidden_channels', type=int, default=256) 139 | parser.add_argument('--dropout', type=float, default=0.5) 140 | parser.add_argument('--batch_size', type=int, default=64 * 1024) 141 | parser.add_argument('--lr', type=float, default=0.01) 142 | parser.add_argument('--epochs', type=int, default=200) 143 | parser.add_argument('--eval_steps', type=int, default=5) 144 | parser.add_argument('--runs', type=int, default=10) 145 | args = parser.parse_args() 146 | print(args) 147 | 148 | device = f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu' 149 | device = torch.device(device) 150 | 151 | dataset = PygLinkPropPredDataset(name='ogbl-ddi') 152 | data = dataset[0] 153 | split_edge = dataset.get_edge_split() 154 | 155 | # We randomly pick some training samples that we want to evaluate on: 156 | torch.manual_seed(12345) 157 | idx = torch.randperm(split_edge['train']['edge'].size(0)) 158 | idx = idx[:split_edge['valid']['edge'].size(0)] 159 | split_edge['eval_train'] = {'edge': split_edge['train']['edge'][idx]} 160 | 161 | emb = torch.nn.Embedding(data.num_nodes, args.hidden_channels).to(device) 162 | predictor = LinkPredictor(args.hidden_channels, args.hidden_channels, 1, 163 | args.num_layers, args.dropout).to(device) 164 | 165 | evaluator = Evaluator(name='ogbl-ddi') 166 | loggers = { 167 | 'Hits@10': Logger(args.runs, args), 168 | 'Hits@20': Logger(args.runs, args), 169 | 'Hits@30': Logger(args.runs, args), 170 | } 171 | 172 | for run in range(args.runs): 173 | emb.reset_parameters() 174 | predictor.reset_parameters() 175 | optimizer = torch.optim.Adam( 176 | list(emb.parameters()) + list(predictor.parameters()), lr=args.lr) 177 | 178 | for epoch in range(1, 1 + args.epochs): 179 | loss = train(predictor, emb.weight, data.edge_index, split_edge, 180 | optimizer, args.batch_size) 181 | 182 | if epoch % args.eval_steps == 0: 183 | results = test(predictor, emb.weight, split_edge, evaluator, 184 | args.batch_size) 185 | for key, result in results.items(): 186 | loggers[key].add_result(run, result) 187 | 188 | if epoch % args.log_steps == 0: 189 | for key, result in results.items(): 190 | train_hits, valid_hits, test_hits = result 191 | print(key) 192 | print(f'Run: {run + 1:02d}, ' 193 | f'Epoch: {epoch:02d}, ' 194 | f'Loss: {loss:.4f}, ' 195 | f'Train: {100 * train_hits:.2f}%, ' 196 | f'Valid: {100 * valid_hits:.2f}%, ' 197 | f'Test: {100 * test_hits:.2f}%') 198 | print('---') 199 | 200 | for key in loggers.keys(): 201 | print(key) 202 | loggers[key].print_statistics(run) 203 | 204 | for key in loggers.keys(): 205 | print(key) 206 | loggers[key].print_statistics() 207 | 208 | 209 | if __name__ == "__main__": 210 | main() 211 | -------------------------------------------------------------------------------- /linkproppred/ddi/mlp.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | from torch.utils.data import DataLoader 6 | from torch_geometric.utils import negative_sampling 7 | 8 | from ogb.linkproppred import PygLinkPropPredDataset, Evaluator 9 | 10 | from logger import Logger 11 | 12 | 13 | class LinkPredictor(torch.nn.Module): 14 | def __init__(self, in_channels, hidden_channels, out_channels, num_layers, 15 | dropout): 16 | super(LinkPredictor, self).__init__() 17 | 18 | self.lins = torch.nn.ModuleList() 19 | self.lins.append(torch.nn.Linear(in_channels, hidden_channels)) 20 | for _ in range(num_layers - 2): 21 | self.lins.append(torch.nn.Linear(hidden_channels, hidden_channels)) 22 | self.lins.append(torch.nn.Linear(hidden_channels, out_channels)) 23 | 24 | self.dropout = dropout 25 | 26 | def reset_parameters(self): 27 | for lin in self.lins: 28 | lin.reset_parameters() 29 | 30 | def forward(self, x_i, x_j): 31 | x = x_i * x_j 32 | for lin in self.lins[:-1]: 33 | x = lin(x) 34 | x = F.relu(x) 35 | x = F.dropout(x, p=self.dropout, training=self.training) 36 | x = self.lins[-1](x) 37 | return torch.sigmoid(x) 38 | 39 | 40 | def train(predictor, x, edge_index, split_edge, optimizer, batch_size): 41 | predictor.train() 42 | 43 | pos_train_edge = split_edge['train']['edge'].to(x.device) 44 | 45 | total_loss = total_examples = 0 46 | for perm in DataLoader(range(pos_train_edge.size(0)), batch_size, 47 | shuffle=True): 48 | optimizer.zero_grad() 49 | 50 | edge = pos_train_edge[perm].t() 51 | 52 | pos_out = predictor(x[edge[0]], x[edge[1]]) 53 | pos_loss = -torch.log(pos_out + 1e-15).mean() 54 | 55 | edge = negative_sampling(edge_index, num_nodes=x.size(0), 56 | num_neg_samples=perm.size(0), method='dense') 57 | 58 | neg_out = predictor(x[edge[0]], x[edge[1]]) 59 | neg_loss = -torch.log(1 - neg_out + 1e-15).mean() 60 | 61 | loss = pos_loss + neg_loss 62 | loss.backward() 63 | optimizer.step() 64 | 65 | num_examples = pos_out.size(0) 66 | total_loss += loss.item() * num_examples 67 | total_examples += num_examples 68 | 69 | return total_loss / total_examples 70 | 71 | 72 | @torch.no_grad() 73 | def test(predictor, x, split_edge, evaluator, batch_size): 74 | predictor.eval() 75 | 76 | pos_train_edge = split_edge['eval_train']['edge'].to(x.device) 77 | pos_valid_edge = split_edge['valid']['edge'].to(x.device) 78 | neg_valid_edge = split_edge['valid']['edge_neg'].to(x.device) 79 | pos_test_edge = split_edge['test']['edge'].to(x.device) 80 | neg_test_edge = split_edge['test']['edge_neg'].to(x.device) 81 | 82 | pos_train_preds = [] 83 | for perm in DataLoader(range(pos_train_edge.size(0)), batch_size): 84 | edge = pos_train_edge[perm].t() 85 | pos_train_preds += [predictor(x[edge[0]], x[edge[1]]).squeeze().cpu()] 86 | pos_train_pred = torch.cat(pos_train_preds, dim=0) 87 | 88 | pos_valid_preds = [] 89 | for perm in DataLoader(range(pos_valid_edge.size(0)), batch_size): 90 | edge = pos_valid_edge[perm].t() 91 | pos_valid_preds += [predictor(x[edge[0]], x[edge[1]]).squeeze().cpu()] 92 | pos_valid_pred = torch.cat(pos_valid_preds, dim=0) 93 | 94 | neg_valid_preds = [] 95 | for perm in DataLoader(range(neg_valid_edge.size(0)), batch_size): 96 | edge = neg_valid_edge[perm].t() 97 | neg_valid_preds += [predictor(x[edge[0]], x[edge[1]]).squeeze().cpu()] 98 | neg_valid_pred = torch.cat(neg_valid_preds, dim=0) 99 | 100 | pos_test_preds = [] 101 | for perm in DataLoader(range(pos_test_edge.size(0)), batch_size): 102 | edge = pos_test_edge[perm].t() 103 | pos_test_preds += [predictor(x[edge[0]], x[edge[1]]).squeeze().cpu()] 104 | pos_test_pred = torch.cat(pos_test_preds, dim=0) 105 | 106 | neg_test_preds = [] 107 | for perm in DataLoader(range(neg_test_edge.size(0)), batch_size): 108 | edge = neg_test_edge[perm].t() 109 | neg_test_preds += [predictor(x[edge[0]], x[edge[1]]).squeeze().cpu()] 110 | neg_test_pred = torch.cat(neg_test_preds, dim=0) 111 | 112 | results = {} 113 | for K in [10, 20, 30]: 114 | evaluator.K = K 115 | train_hits = evaluator.eval({ 116 | 'y_pred_pos': pos_train_pred, 117 | 'y_pred_neg': neg_valid_pred, 118 | })[f'hits@{K}'] 119 | valid_hits = evaluator.eval({ 120 | 'y_pred_pos': pos_valid_pred, 121 | 'y_pred_neg': neg_valid_pred, 122 | })[f'hits@{K}'] 123 | test_hits = evaluator.eval({ 124 | 'y_pred_pos': pos_test_pred, 125 | 'y_pred_neg': neg_test_pred, 126 | })[f'hits@{K}'] 127 | 128 | results[f'Hits@{K}'] = (train_hits, valid_hits, test_hits) 129 | 130 | return results 131 | 132 | 133 | def main(): 134 | parser = argparse.ArgumentParser(description='OGBL-DDI (MLP)') 135 | parser.add_argument('--device', type=int, default=0) 136 | parser.add_argument('--log_steps', type=int, default=1) 137 | parser.add_argument('--num_layers', type=int, default=3) 138 | parser.add_argument('--hidden_channels', type=int, default=256) 139 | parser.add_argument('--dropout', type=float, default=0.5) 140 | parser.add_argument('--batch_size', type=int, default=64 * 1024) 141 | parser.add_argument('--lr', type=float, default=0.01) 142 | parser.add_argument('--epochs', type=int, default=200) 143 | parser.add_argument('--eval_steps', type=int, default=5) 144 | parser.add_argument('--runs', type=int, default=10) 145 | args = parser.parse_args() 146 | print(args) 147 | 148 | device = f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu' 149 | device = torch.device(device) 150 | 151 | dataset = PygLinkPropPredDataset(name='ogbl-ddi') 152 | data = dataset[0] 153 | split_edge = dataset.get_edge_split() 154 | 155 | # We randomly pick some training samples that we want to evaluate on: 156 | torch.manual_seed(12345) 157 | idx = torch.randperm(split_edge['train']['edge'].size(0)) 158 | idx = idx[:split_edge['valid']['edge'].size(0)] 159 | split_edge['eval_train'] = {'edge': split_edge['train']['edge'][idx]} 160 | 161 | x = torch.load('embedding.pt', map_location='cpu').to(device) 162 | 163 | predictor = LinkPredictor(x.size(-1), args.hidden_channels, 1, 164 | args.num_layers, args.dropout).to(device) 165 | 166 | evaluator = Evaluator(name='ogbl-ddi') 167 | loggers = { 168 | 'Hits@10': Logger(args.runs, args), 169 | 'Hits@20': Logger(args.runs, args), 170 | 'Hits@30': Logger(args.runs, args), 171 | } 172 | 173 | for run in range(args.runs): 174 | predictor.reset_parameters() 175 | optimizer = torch.optim.Adam(predictor.parameters(), lr=args.lr) 176 | 177 | for epoch in range(1, 1 + args.epochs): 178 | loss = train(predictor, x, data.edge_index, split_edge, optimizer, 179 | args.batch_size) 180 | 181 | if epoch % args.eval_steps == 0: 182 | results = test(predictor, x, split_edge, evaluator, 183 | args.batch_size) 184 | for key, result in results.items(): 185 | loggers[key].add_result(run, result) 186 | 187 | if epoch % args.log_steps == 0: 188 | for key, result in results.items(): 189 | train_hits, valid_hits, test_hits = result 190 | print(key) 191 | print(f'Run: {run + 1:02d}, ' 192 | f'Epoch: {epoch:02d}, ' 193 | f'Loss: {loss:.4f}, ' 194 | f'Train: {100 * train_hits:.2f}%, ' 195 | f'Valid: {100 * valid_hits:.2f}%, ' 196 | f'Test: {100 * test_hits:.2f}%') 197 | print('---') 198 | 199 | for key in loggers.keys(): 200 | print(key) 201 | loggers[key].print_statistics(run) 202 | 203 | for key in loggers.keys(): 204 | print(key) 205 | loggers[key].print_statistics() 206 | 207 | 208 | if __name__ == "__main__": 209 | main() 210 | -------------------------------------------------------------------------------- /linkproppred/ddi/node2vec.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import torch 4 | from torch_geometric.nn import Node2Vec 5 | 6 | from ogb.linkproppred import PygLinkPropPredDataset 7 | 8 | 9 | def save_embedding(model): 10 | torch.save(model.embedding.weight.data.cpu(), 'embedding.pt') 11 | 12 | 13 | def main(): 14 | parser = argparse.ArgumentParser(description='OGBL-DDI (Node2Vec)') 15 | parser.add_argument('--device', type=int, default=0) 16 | parser.add_argument('--embedding_dim', type=int, default=128) 17 | parser.add_argument('--walk_length', type=int, default=40) 18 | parser.add_argument('--context_size', type=int, default=20) 19 | parser.add_argument('--walks_per_node', type=int, default=10) 20 | parser.add_argument('--batch_size', type=int, default=256) 21 | parser.add_argument('--lr', type=float, default=0.01) 22 | parser.add_argument('--epochs', type=int, default=100) 23 | parser.add_argument('--log_steps', type=int, default=1) 24 | args = parser.parse_args() 25 | 26 | device = f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu' 27 | device = torch.device(device) 28 | 29 | dataset = PygLinkPropPredDataset(name='ogbl-ddi') 30 | data = dataset[0] 31 | 32 | model = Node2Vec(data.edge_index, args.embedding_dim, args.walk_length, 33 | args.context_size, args.walks_per_node, 34 | sparse=True).to(device) 35 | 36 | loader = model.loader(batch_size=args.batch_size, shuffle=True, 37 | num_workers=4) 38 | optimizer = torch.optim.SparseAdam(list(model.parameters()), lr=args.lr) 39 | 40 | model.train() 41 | for epoch in range(1, args.epochs + 1): 42 | for i, (pos_rw, neg_rw) in enumerate(loader): 43 | optimizer.zero_grad() 44 | loss = model.loss(pos_rw.to(device), neg_rw.to(device)) 45 | loss.backward() 46 | optimizer.step() 47 | 48 | if (i + 1) % args.log_steps == 0: 49 | print(f'Epoch: {epoch:02d}, Step: {i+1:03d}/{len(loader)}, ' 50 | f'Loss: {loss:.4f}') 51 | 52 | if (i + 1) % 100 == 0: # Save model every 100 steps. 53 | save_embedding(model) 54 | save_embedding(model) 55 | 56 | 57 | if __name__ == "__main__": 58 | main() 59 | -------------------------------------------------------------------------------- /nodeproppred/proteins/gnn.py: -------------------------------------------------------------------------------- 1 | import argparse, os, math 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | import torch_geometric.transforms as T 7 | from torch_geometric.nn import GCNConv, SAGEConv 8 | 9 | from ogb.nodeproppred import PygNodePropPredDataset, Evaluator 10 | 11 | from logger import Logger 12 | from moe import MoE 13 | import numpy as np 14 | 15 | 16 | class GCN(torch.nn.Module): 17 | def __init__(self, in_channels, hidden_channels, out_channels, num_layers, 18 | dropout): 19 | super(GCN, self).__init__() 20 | 21 | self.convs = torch.nn.ModuleList() 22 | self.convs.append( 23 | GCNConv(in_channels, hidden_channels, normalize=False)) 24 | for _ in range(num_layers - 2): 25 | self.convs.append( 26 | GCNConv(hidden_channels, hidden_channels, normalize=False)) 27 | self.convs.append( 28 | GCNConv(hidden_channels, out_channels, normalize=False)) 29 | 30 | self.dropout = dropout 31 | 32 | def reset_parameters(self): 33 | for conv in self.convs: 34 | conv.reset_parameters() 35 | 36 | def forward(self, x, adj_t): 37 | for conv in self.convs[:-1]: 38 | x = conv(x, adj_t) 39 | x = F.relu(x) 40 | x = F.dropout(x, p=self.dropout, training=self.training) 41 | x = self.convs[-1](x, adj_t) 42 | return x 43 | 44 | class GCN_SpMoE(torch.nn.Module): 45 | def __init__(self, in_channels, hidden_channels, out_channels, num_layers, 46 | dropout, num_experts=4, k=1, coef=1e-2): 47 | super(GCN_SpMoE, self).__init__() 48 | 49 | self.num_layers = num_layers 50 | 51 | self.convs = torch.nn.ModuleList() 52 | self.convs.append( 53 | GCNConv(in_channels, hidden_channels, normalize=False)) 54 | for layer_idx in range(num_layers - 2): 55 | if layer_idx % 2 == 0: 56 | ffn = MoE(input_size=hidden_channels, output_size=hidden_channels, num_experts=num_experts, k=k, coef=coef) 57 | self.convs.append(ffn) 58 | else: 59 | self.convs.append( 60 | GCNConv(hidden_channels, hidden_channels, normalize=False)) 61 | self.convs.append( 62 | GCNConv(hidden_channels, out_channels, normalize=False)) 63 | 64 | self.dropout = dropout 65 | 66 | def forward(self, x, adj_t): 67 | self.load_balance_loss = 0 # initialize load_balance_loss to 0 at the beginning of each forward pass. 68 | for conv in self.convs[:-1]: 69 | if isinstance(conv, MoE): 70 | x, _layer_load_balance_loss = conv(x, adj_t) 71 | self.load_balance_loss += _layer_load_balance_loss 72 | else: 73 | x = conv(x, adj_t) 74 | x = F.relu(x) 75 | x = F.dropout(x, p=self.dropout, training=self.training) 76 | x = self.convs[-1](x, adj_t) 77 | self.load_balance_loss /= math.ceil((self.num_layers-2)/2) 78 | return x 79 | 80 | 81 | class SAGE(torch.nn.Module): 82 | def __init__(self, in_channels, hidden_channels, out_channels, num_layers, 83 | dropout): 84 | super(SAGE, self).__init__() 85 | 86 | self.convs = torch.nn.ModuleList() 87 | self.convs.append(SAGEConv(in_channels, hidden_channels)) 88 | for _ in range(num_layers - 2): 89 | self.convs.append(SAGEConv(hidden_channels, hidden_channels)) 90 | self.convs.append(SAGEConv(hidden_channels, out_channels)) 91 | 92 | self.dropout = dropout 93 | 94 | def reset_parameters(self): 95 | for conv in self.convs: 96 | conv.reset_parameters() 97 | 98 | def forward(self, x, adj_t): 99 | for conv in self.convs[:-1]: 100 | x = conv(x, adj_t) 101 | x = F.relu(x) 102 | x = F.dropout(x, p=self.dropout, training=self.training) 103 | x = self.convs[-1](x, adj_t) 104 | return x 105 | 106 | class SAGE_SpMoE(torch.nn.Module): 107 | def __init__(self, in_channels, hidden_channels, out_channels, num_layers, 108 | dropout, num_experts=4, k=1, coef=1e-2): 109 | super(SAGE_SpMoE, self).__init__() 110 | 111 | self.num_layers = num_layers 112 | 113 | self.convs = torch.nn.ModuleList() 114 | self.convs.append(SAGEConv(in_channels, hidden_channels)) 115 | for layer_idx in range(num_layers - 2): 116 | if layer_idx % 2 == 0: 117 | ffn = MoE(input_size=hidden_channels, output_size=hidden_channels, num_experts=num_experts, k=k, coef=coef, sage=True) 118 | self.convs.append(ffn) 119 | else: 120 | self.convs.append(SAGEConv(hidden_channels, hidden_channels)) 121 | self.convs.append(SAGEConv(hidden_channels, out_channels)) 122 | 123 | self.dropout = dropout 124 | 125 | def reset_parameters(self): 126 | for conv in self.convs: 127 | conv.reset_parameters() 128 | 129 | def forward(self, x, adj_t): 130 | self.load_balance_loss = 0 # initialize load_balance_loss to 0 at the beginning of each forward pass. 131 | for conv in self.convs[:-1]: 132 | if isinstance(conv, MoE): 133 | x, _layer_load_balance_loss = conv(x, adj_t) 134 | self.load_balance_loss += _layer_load_balance_loss 135 | else: 136 | x = conv(x, adj_t) 137 | x = F.relu(x) 138 | x = F.dropout(x, p=self.dropout, training=self.training) 139 | x = self.convs[-1](x, adj_t) 140 | self.load_balance_loss /= math.ceil((self.num_layers-2)/2) 141 | return x 142 | 143 | 144 | def train(model, data, train_idx, optimizer): 145 | model.train() 146 | criterion = torch.nn.BCEWithLogitsLoss() 147 | 148 | optimizer.zero_grad() 149 | out = model(data.x, data.adj_t)[train_idx] 150 | loss = criterion(out, data.y[train_idx].to(torch.float)) 151 | if isinstance(model, GCN_SpMoE): 152 | loss += model.load_balance_loss 153 | loss.backward() 154 | optimizer.step() 155 | 156 | return loss.item() 157 | 158 | 159 | @torch.no_grad() 160 | def test(model, data, split_idx, evaluator): 161 | model.eval() 162 | 163 | y_pred = model(data.x, data.adj_t) 164 | 165 | train_rocauc = evaluator.eval({ 166 | 'y_true': data.y[split_idx['train']], 167 | 'y_pred': y_pred[split_idx['train']], 168 | })['rocauc'] 169 | valid_rocauc = evaluator.eval({ 170 | 'y_true': data.y[split_idx['valid']], 171 | 'y_pred': y_pred[split_idx['valid']], 172 | })['rocauc'] 173 | test_rocauc = evaluator.eval({ 174 | 'y_true': data.y[split_idx['test']], 175 | 'y_pred': y_pred[split_idx['test']], 176 | })['rocauc'] 177 | 178 | return train_rocauc, valid_rocauc, test_rocauc 179 | 180 | 181 | def main(): 182 | parser = argparse.ArgumentParser(description='OGBN-Proteins (GNN)') 183 | parser.add_argument('--device', type=int, default=0) 184 | parser.add_argument('--log_steps', type=int, default=1) 185 | parser.add_argument('--gnn', default='gcn-spmoe', choices=['gcn', 'sage', 'gcn-spmoe', 'sage-spmoe']) 186 | parser.add_argument('--num_layers', type=int, default=3) 187 | parser.add_argument('--hidden_channels', '-d', type=int, default=256) 188 | parser.add_argument('--dropout', type=float, default=0.0) 189 | parser.add_argument('--lr', type=float, default=0.01) 190 | parser.add_argument('--epochs', '-e', type=int, default=1000) 191 | parser.add_argument('--eval_steps', type=int, default=5) 192 | parser.add_argument('--num_experts', '-n', type=int, default=8, 193 | help='total number of experts in GCN-MoE') 194 | parser.add_argument('-k', type=int, default=4, 195 | help='selected number of experts in GCN-MoE') 196 | parser.add_argument('--coef', type=float, default=1, 197 | help='loss coefficient for load balancing loss in sparse MoE training') 198 | 199 | args = parser.parse_args() 200 | print(args) 201 | 202 | exp_str = '%s-dropout%s-lr%s-e%d' % (args.gnn, args.dropout, args.lr, args.epochs) 203 | if 'spmoe' in args.gnn: 204 | exp_str += '-d%d-n%d-k%d-coef%s' % (args.hidden_channels, args.num_experts, args.k, args.coef) 205 | 206 | from datetime import datetime 207 | current_date_and_time = datetime.now().strftime("%Y-%m-%d-%H-%M-%S") 208 | exp_str += '-%s' % current_date_and_time 209 | 210 | save_dir = os.path.join('results') 211 | if not os.path.exists(save_dir): 212 | os.makedirs(save_dir) 213 | 214 | device = f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu' 215 | device = torch.device(device) 216 | 217 | dataset = PygNodePropPredDataset( 218 | name='ogbn-proteins', transform=T.ToSparseTensor(attr='edge_attr')) 219 | data = dataset[0] 220 | 221 | # Move edge features to node features. 222 | data.x = data.adj_t.mean(dim=1) 223 | data.adj_t.set_value_(None) 224 | 225 | split_idx = dataset.get_idx_split() 226 | train_idx = split_idx['train'].to(device) 227 | 228 | if args.gnn == 'sage': 229 | model = SAGE(data.num_features, args.hidden_channels, 112, 230 | args.num_layers, args.dropout).to(device) 231 | elif args.gnn == 'gcn': 232 | model = GCN(data.num_features, args.hidden_channels, 112, 233 | args.num_layers, args.dropout).to(device) 234 | elif args.gnn == 'gcn-spmoe': 235 | model = GCN_SpMoE(data.num_features, args.hidden_channels, 112, 236 | args.num_layers, args.dropout, 237 | num_experts=args.num_experts, k=args.k, coef=args.coef).to(device) 238 | elif args.gnn == 'sage-spmoe': 239 | model = SAGE_SpMoE(data.num_features, args.hidden_channels, 112, 240 | args.num_layers, args.dropout, 241 | num_experts=args.num_experts, k=args.k, coef=args.coef).to(device) 242 | 243 | if args.gnn in ['gcn', 'gcn-spmoe']: 244 | # Pre-compute GCN normalization. 245 | adj_t = data.adj_t.set_diag() 246 | deg = adj_t.sum(dim=1).to(torch.float) 247 | deg_inv_sqrt = deg.pow(-0.5) 248 | deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0 249 | adj_t = deg_inv_sqrt.view(-1, 1) * adj_t * deg_inv_sqrt.view(1, -1) 250 | data.adj_t = adj_t 251 | 252 | data = data.to(device) 253 | 254 | evaluator = Evaluator(name='ogbn-proteins') 255 | 256 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) 257 | 258 | valid_curve = [] 259 | test_curve = [] 260 | train_curve = [] 261 | for epoch in range(1, 1 + args.epochs): 262 | loss = train(model, data, train_idx, optimizer) 263 | 264 | if epoch % args.eval_steps == 0: 265 | result = test(model, data, split_idx, evaluator) 266 | train_rocauc, valid_rocauc, test_rocauc = result 267 | valid_curve.append(valid_rocauc) 268 | test_curve.append(test_rocauc) 269 | 270 | if epoch % args.log_steps == 0: 271 | log_str = f'Epoch: {epoch:02d}, '\ 272 | f'Loss: {loss:.4f}, '\ 273 | f'Train: {100 * train_rocauc:.2f}%, '\ 274 | f'Valid: {100 * valid_rocauc:.2f}% '\ 275 | f'Test: {100 * test_rocauc:.2f}%' 276 | print(log_str) 277 | 278 | with open(os.path.join(save_dir, '%s.txt' % exp_str), 'a+') as fp: 279 | fp.write(log_str) 280 | fp.write('\n') 281 | fp.flush() 282 | fp.close() 283 | 284 | best_val_epoch = np.argmax(np.array(valid_curve)) 285 | print('Finished training!') 286 | print('Best validation score: {}'.format(valid_curve[best_val_epoch])) 287 | print('Test score: {}'.format(test_curve[best_val_epoch])) 288 | with open(os.path.join(save_dir, '%s.txt' % exp_str), 'a+') as fp: 289 | fp.write('Best validation score: {}\n'.format(valid_curve[best_val_epoch])) 290 | fp.write('Test score: {}\n'.format(test_curve[best_val_epoch])) 291 | fp.flush() 292 | fp.close() 293 | 294 | filename = os.path.join(save_dir, '%s.pth' % exp_str) 295 | torch.save({'Val': valid_curve[best_val_epoch], 'Test': test_curve[best_val_epoch]}, filename) 296 | 297 | 298 | if __name__ == "__main__": 299 | main() 300 | -------------------------------------------------------------------------------- /nodeproppred/proteins/logger.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class Logger(object): 5 | def __init__(self, runs, info=None): 6 | self.info = info 7 | self.results = [[] for _ in range(runs)] 8 | 9 | def add_result(self, run, result): 10 | assert len(result) == 3 11 | assert run >= 0 and run < len(self.results) 12 | self.results[run].append(result) 13 | 14 | def print_statistics(self, run=None): 15 | if run is not None: 16 | result = 100 * torch.tensor(self.results[run]) 17 | argmax = result[:, 1].argmax().item() 18 | print(f'Run {run + 1:02d}:') 19 | print(f'Highest Train: {result[:, 0].max():.2f}') 20 | print(f'Highest Valid: {result[:, 1].max():.2f}') 21 | print(f' Final Train: {result[argmax, 0]:.2f}') 22 | print(f' Final Test: {result[argmax, 2]:.2f}') 23 | else: 24 | result = 100 * torch.tensor(self.results) 25 | 26 | best_results = [] 27 | for r in result: 28 | train1 = r[:, 0].max().item() 29 | valid = r[:, 1].max().item() 30 | train2 = r[r[:, 1].argmax(), 0].item() 31 | test = r[r[:, 1].argmax(), 2].item() 32 | best_results.append((train1, valid, train2, test)) 33 | 34 | best_result = torch.tensor(best_results) 35 | 36 | print(f'All runs:') 37 | r = best_result[:, 0] 38 | print(f'Highest Train: {r.mean():.2f} ± {r.std():.2f}') 39 | r = best_result[:, 1] 40 | print(f'Highest Valid: {r.mean():.2f} ± {r.std():.2f}') 41 | r = best_result[:, 2] 42 | print(f' Final Train: {r.mean():.2f} ± {r.std():.2f}') 43 | r = best_result[:, 3] 44 | print(f' Final Test: {r.mean():.2f} ± {r.std():.2f}') 45 | -------------------------------------------------------------------------------- /nodeproppred/proteins/mlp.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | from torch_scatter import scatter 6 | 7 | from ogb.nodeproppred import PygNodePropPredDataset, Evaluator 8 | 9 | from logger import Logger 10 | 11 | 12 | class MLP(torch.nn.Module): 13 | def __init__(self, in_channels, hidden_channels, out_channels, num_layers, 14 | dropout): 15 | super(MLP, self).__init__() 16 | 17 | self.lins = torch.nn.ModuleList() 18 | self.lins.append(torch.nn.Linear(in_channels, hidden_channels)) 19 | for _ in range(num_layers - 2): 20 | self.lins.append(torch.nn.Linear(hidden_channels, hidden_channels)) 21 | self.lins.append(torch.nn.Linear(hidden_channels, out_channels)) 22 | 23 | self.dropout = dropout 24 | 25 | def reset_parameters(self): 26 | for lin in self.lins: 27 | lin.reset_parameters() 28 | 29 | def forward(self, x): 30 | for lin in self.lins[:-1]: 31 | x = lin(x) 32 | x = F.relu(x) 33 | x = F.dropout(x, p=self.dropout, training=self.training) 34 | x = self.lins[-1](x) 35 | return x 36 | 37 | 38 | def train(model, x, y_true, train_idx, optimizer): 39 | model.train() 40 | criterion = torch.nn.BCEWithLogitsLoss() 41 | 42 | optimizer.zero_grad() 43 | out = model(x)[train_idx] 44 | loss = criterion(out, y_true[train_idx].to(torch.float)) 45 | loss.backward() 46 | optimizer.step() 47 | 48 | return loss.item() 49 | 50 | 51 | @torch.no_grad() 52 | def test(model, x, y_true, split_idx, evaluator): 53 | model.eval() 54 | 55 | y_pred = model(x) 56 | 57 | train_rocauc = evaluator.eval({ 58 | 'y_true': y_true[split_idx['train']], 59 | 'y_pred': y_pred[split_idx['train']], 60 | })['rocauc'] 61 | valid_rocauc = evaluator.eval({ 62 | 'y_true': y_true[split_idx['valid']], 63 | 'y_pred': y_pred[split_idx['valid']], 64 | })['rocauc'] 65 | test_rocauc = evaluator.eval({ 66 | 'y_true': y_true[split_idx['test']], 67 | 'y_pred': y_pred[split_idx['test']], 68 | })['rocauc'] 69 | 70 | return train_rocauc, valid_rocauc, test_rocauc 71 | 72 | 73 | def main(): 74 | parser = argparse.ArgumentParser(description='OGBN-Proteins (MLP)') 75 | parser.add_argument('--device', type=int, default=0) 76 | parser.add_argument('--log_steps', type=int, default=1) 77 | parser.add_argument('--use_node_embedding', action='store_true') 78 | parser.add_argument('--num_layers', type=int, default=3) 79 | parser.add_argument('--hidden_channels', type=int, default=256) 80 | parser.add_argument('--dropout', type=float, default=0.5) 81 | parser.add_argument('--lr', type=float, default=0.01) 82 | parser.add_argument('--epochs', type=int, default=1000) 83 | parser.add_argument('--eval_steps', type=int, default=5) 84 | parser.add_argument('--runs', type=int, default=10) 85 | args = parser.parse_args() 86 | print(args) 87 | 88 | device = f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu' 89 | device = torch.device(device) 90 | 91 | dataset = PygNodePropPredDataset(name='ogbn-proteins') 92 | split_idx = dataset.get_idx_split() 93 | data = dataset[0] 94 | 95 | x = scatter(data.edge_attr, data.edge_index[0], dim=0, 96 | dim_size=data.num_nodes, reduce='mean').to('cpu') 97 | 98 | if args.use_node_embedding: 99 | embedding = torch.load('embedding.pt', map_location='cpu') 100 | x = torch.cat([x, embedding], dim=-1) 101 | 102 | x = x.to(device) 103 | y_true = data.y.to(device) 104 | train_idx = split_idx['train'].to(device) 105 | 106 | model = MLP(x.size(-1), args.hidden_channels, 112, args.num_layers, 107 | args.dropout).to(device) 108 | 109 | evaluator = Evaluator(name='ogbn-proteins') 110 | logger = Logger(args.runs, args) 111 | 112 | for run in range(args.runs): 113 | model.reset_parameters() 114 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) 115 | for epoch in range(1, 1 + args.epochs): 116 | loss = train(model, x, y_true, train_idx, optimizer) 117 | 118 | if epoch % args.eval_steps == 0: 119 | result = test(model, x, y_true, split_idx, evaluator) 120 | logger.add_result(run, result) 121 | 122 | if epoch % args.log_steps == 0: 123 | train_rocauc, valid_rocauc, test_rocauc = result 124 | print(f'Run: {run + 1:02d}, ' 125 | f'Epoch: {epoch:02d}, ' 126 | f'Loss: {loss:.4f}, ' 127 | f'Train: {100 * train_rocauc:.2f}%, ' 128 | f'Valid: {100 * valid_rocauc:.2f}% ' 129 | f'Test: {100 * test_rocauc:.2f}%') 130 | 131 | logger.print_statistics(run) 132 | logger.print_statistics() 133 | 134 | 135 | if __name__ == "__main__": 136 | main() 137 | -------------------------------------------------------------------------------- /nodeproppred/proteins/moe.py: -------------------------------------------------------------------------------- 1 | # Note by Haotao Wang: 2 | # Adapted form https://raw.githubusercontent.com/davidmrau/mixture-of-experts/master/moe.py 3 | 4 | # Sparsely-Gated Mixture-of-Experts Layers. 5 | # See "Outrageously Large Neural Networks" 6 | # https://arxiv.org/abs/1701.06538 7 | # 8 | # Author: David Rau 9 | # 10 | # The code is based on the TensorFlow implementation: 11 | # https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/utils/expert_utils.py 12 | 13 | 14 | import torch 15 | import torch.nn as nn 16 | from torch.distributions.normal import Normal 17 | import numpy as np 18 | from torch_geometric.nn import GCNConv, SAGEConv 19 | 20 | 21 | class SparseDispatcher(object): 22 | """Helper for implementing a mixture of experts. 23 | The purpose of this class is to create input minibatches for the 24 | experts and to combine the results of the experts to form a unified 25 | output tensor. 26 | There are two functions: 27 | dispatch - take an input Tensor and create input Tensors for each expert. 28 | combine - take output Tensors from each expert and form a combined output 29 | Tensor. Outputs from different experts for the same batch element are 30 | summed together, weighted by the provided "gates". 31 | The class is initialized with a "gates" Tensor, which specifies which 32 | batch elements go to which experts, and the weights to use when combining 33 | the outputs. Batch element b is sent to expert e iff gates[b, e] != 0. 34 | The inputs and outputs are all two-dimensional [batch, depth]. 35 | Caller is responsible for collapsing additional dimensions prior to 36 | calling this class and reshaping the output to the original shape. 37 | See common_layers.reshape_like(). 38 | Example use: 39 | gates: a float32 `Tensor` with shape `[batch_size, num_experts]` 40 | inputs: a float32 `Tensor` with shape `[batch_size, input_size]` 41 | experts: a list of length `num_experts` containing sub-networks. 42 | dispatcher = SparseDispatcher(num_experts, gates) 43 | expert_inputs = dispatcher.dispatch(inputs) 44 | expert_outputs = [experts[i](expert_inputs[i]) for i in range(num_experts)] 45 | outputs = dispatcher.combine(expert_outputs) 46 | The preceding code sets the output for a particular example b to: 47 | output[b] = Sum_i(gates[b, i] * experts[i](inputs[b])) 48 | This class takes advantage of sparsity in the gate matrix by including in the 49 | `Tensor`s for expert i only the batch elements for which `gates[b, i] > 0`. 50 | """ 51 | 52 | def __init__(self, num_experts, gates): 53 | """Create a SparseDispatcher.""" 54 | 55 | self._gates = gates 56 | self._num_experts = num_experts 57 | # sort experts 58 | sorted_experts, index_sorted_experts = torch.nonzero(gates).sort(0) 59 | # drop indices 60 | _, self._expert_index = sorted_experts.split(1, dim=1) 61 | # get according batch index for each expert 62 | self._batch_index = torch.nonzero(gates)[index_sorted_experts[:, 1], 0] 63 | # calculate num samples that each expert gets 64 | self._part_sizes = (gates > 0).sum(0).tolist() 65 | # expand gates to match with self._batch_index 66 | gates_exp = gates[self._batch_index.flatten()] 67 | self._nonzero_gates = torch.gather(gates_exp, 1, self._expert_index) 68 | 69 | def dispatch(self, inp, edge_index, edge_attr): 70 | """Create one input Tensor for each expert. 71 | The `Tensor` for a expert `i` contains the slices of `inp` corresponding 72 | to the batch elements `b` where `gates[b, i] > 0`. 73 | Args: 74 | inp: a `Tensor` of shape "[batch_size, ]` 75 | Returns: 76 | a list of `num_experts` `Tensor`s with shapes 77 | `[expert_batch_size_i, ]`. 78 | """ 79 | 80 | # Note by Haotao: 81 | # self._batch_index: shape=(N_batch). The re-order indices from 0 to N_batch-1. 82 | # inp_exp: shape=inp.shape. The input Tensor re-ordered by self._batch_index along the batch dimension. 83 | # self._part_sizes: shape=(N_experts), sum=N_batch. self._part_sizes[i] is the number of samples routed towards expert[i]. 84 | # return value: list [Tensor with shape[0]=self._part_sizes[i] for i in range(N_experts)] 85 | 86 | # assigns samples to experts whose gate is nonzero 87 | 88 | # expand according to batch index so we can just split by _part_sizes 89 | inp_exp = inp[self._batch_index].squeeze(1) 90 | edge_index_exp = edge_index[:,self._batch_index] 91 | edge_attr_exp = edge_attr[self._batch_index] 92 | return torch.split(inp_exp, self._part_sizes, dim=0), torch.split(edge_index_exp, self._part_sizes, dim=1), torch.split(edge_attr_exp, self._part_sizes, dim=0) 93 | 94 | def combine(self, expert_out, multiply_by_gates=True): 95 | """Sum together the expert output, weighted by the gates. 96 | The slice corresponding to a particular batch element `b` is computed 97 | as the sum over all experts `i` of the expert output, weighted by the 98 | corresponding gate values. If `multiply_by_gates` is set to False, the 99 | gate values are ignored. 100 | Args: 101 | expert_out: a list of `num_experts` `Tensor`s, each with shape 102 | `[expert_batch_size_i, ]`. 103 | multiply_by_gates: a boolean 104 | Returns: 105 | a `Tensor` with shape `[batch_size, ]`. 106 | """ 107 | # apply exp to expert outputs, so we are not longer in log space 108 | stitched = torch.cat(expert_out, 0).exp() 109 | 110 | if multiply_by_gates: 111 | stitched = stitched.mul(self._nonzero_gates) 112 | zeros = torch.zeros(self._gates.size(0), expert_out[-1].size(1), requires_grad=True, device=stitched.device) 113 | # combine samples that have been processed by the same k experts 114 | combined = zeros.index_add(0, self._batch_index, stitched.float()) 115 | # add eps to all zero values in order to avoid nans when going back to log space 116 | combined[combined == 0] = np.finfo(float).eps 117 | # back to log space 118 | return combined.log() 119 | 120 | def expert_to_gates(self): 121 | """Gate values corresponding to the examples in the per-expert `Tensor`s. 122 | Returns: 123 | a list of `num_experts` one-dimensional `Tensor`s with type `tf.float32` 124 | and shapes `[expert_batch_size_i]` 125 | """ 126 | # split nonzero gates for each expert 127 | return torch.split(self._nonzero_gates, self._part_sizes, dim=0) 128 | 129 | class MoE(nn.Module): 130 | 131 | """Call a Sparsely gated mixture of experts layer with 1-layer Feed-Forward networks as experts. 132 | Args: 133 | input_size: integer - size of the input 134 | output_size: integer - size of the input 135 | num_experts: an integer - number of experts 136 | hidden_size: an integer - hidden size of the experts 137 | noisy_gating: a boolean 138 | k: an integer - how many experts to use for each batch element 139 | """ 140 | 141 | def __init__(self, input_size, output_size, num_experts, noisy_gating=True, k=4, coef=1e-2, sage=False): 142 | super(MoE, self).__init__() 143 | self.noisy_gating = noisy_gating 144 | self.num_experts = num_experts 145 | self.k = k 146 | self.loss_coef = coef 147 | # instantiate experts 148 | conv = SAGEConv if sage else GCNConv 149 | self.experts = nn.ModuleList([conv(input_size, output_size, normalize=False) for i in range(self.num_experts)]) 150 | self.w_gate = nn.Parameter(torch.zeros(input_size, num_experts), requires_grad=True) 151 | self.w_noise = nn.Parameter(torch.zeros(input_size, num_experts), requires_grad=True) 152 | 153 | self.softplus = nn.Softplus() 154 | self.softmax = nn.Softmax(1) 155 | self.register_buffer("mean", torch.tensor([0.0])) 156 | self.register_buffer("std", torch.tensor([1.0])) 157 | assert(self.k <= self.num_experts) 158 | 159 | def cv_squared(self, x): 160 | """The squared coefficient of variation of a sample. 161 | Useful as a loss to encourage a positive distribution to be more uniform. 162 | Epsilons added for numerical stability. 163 | Returns 0 for an empty Tensor. 164 | Args: 165 | x: a `Tensor`. 166 | Returns: 167 | a `Scalar`. 168 | """ 169 | eps = 1e-10 170 | # if only num_experts = 1 171 | 172 | if x.shape[0] == 1: 173 | return torch.tensor([0], device=x.device, dtype=x.dtype) 174 | return x.float().var() / (x.float().mean()**2 + eps) 175 | 176 | def _gates_to_load(self, gates): 177 | """Compute the true load per expert, given the gates. 178 | The load is the number of examples for which the corresponding gate is >0. 179 | Args: 180 | gates: a `Tensor` of shape [batch_size, n] 181 | Returns: 182 | a float32 `Tensor` of shape [n] 183 | """ 184 | return (gates > 0).sum(0) 185 | 186 | def _prob_in_top_k(self, clean_values, noisy_values, noise_stddev, noisy_top_values): 187 | """Helper function to NoisyTopKGating. 188 | Computes the probability that value is in top k, given different random noise. 189 | This gives us a way of backpropagating from a loss that balances the number 190 | of times each expert is in the top k experts per example. 191 | In the case of no noise, pass in None for noise_stddev, and the result will 192 | not be differentiable. 193 | Args: 194 | clean_values: a `Tensor` of shape [batch, n]. 195 | noisy_values: a `Tensor` of shape [batch, n]. Equal to clean values plus 196 | normally distributed noise with standard deviation noise_stddev. 197 | noise_stddev: a `Tensor` of shape [batch, n], or None 198 | noisy_top_values: a `Tensor` of shape [batch, m]. 199 | "values" Output of tf.top_k(noisy_top_values, m). m >= k+1 200 | Returns: 201 | a `Tensor` of shape [batch, n]. 202 | """ 203 | batch = clean_values.size(0) 204 | m = noisy_top_values.size(1) 205 | top_values_flat = noisy_top_values.flatten() 206 | 207 | threshold_positions_if_in = torch.arange(batch, device=clean_values.device) * m + self.k 208 | threshold_if_in = torch.unsqueeze(torch.gather(top_values_flat, 0, threshold_positions_if_in), 1) 209 | is_in = torch.gt(noisy_values, threshold_if_in) 210 | threshold_positions_if_out = threshold_positions_if_in - 1 211 | threshold_if_out = torch.unsqueeze(torch.gather(top_values_flat, 0, threshold_positions_if_out), 1) 212 | # is each value currently in the top k. 213 | normal = Normal(self.mean, self.std) 214 | prob_if_in = normal.cdf((clean_values - threshold_if_in)/noise_stddev) 215 | prob_if_out = normal.cdf((clean_values - threshold_if_out)/noise_stddev) 216 | prob = torch.where(is_in, prob_if_in, prob_if_out) 217 | return prob 218 | 219 | def noisy_top_k_gating(self, x, train, noise_epsilon=1e-2): 220 | """Noisy top-k gating. 221 | See paper: https://arxiv.org/abs/1701.06538. 222 | Args: 223 | x: input Tensor with shape [batch_size, input_size] 224 | train: a boolean - we only add noise at training time. 225 | noise_epsilon: a float 226 | Returns: 227 | gates: a Tensor with shape [batch_size, num_experts] 228 | load: a Tensor with shape [num_experts] 229 | """ 230 | clean_logits = x @ self.w_gate 231 | if self.noisy_gating and train: 232 | raw_noise_stddev = x @ self.w_noise 233 | noise_stddev = ((self.softplus(raw_noise_stddev) + noise_epsilon)) 234 | noisy_logits = clean_logits + (torch.randn_like(clean_logits) * noise_stddev) 235 | logits = noisy_logits 236 | else: 237 | logits = clean_logits 238 | 239 | # calculate topk + 1 that will be needed for the noisy gates 240 | top_logits, top_indices = logits.topk(min(self.k + 1, self.num_experts), dim=1) 241 | top_k_logits = top_logits[:, :self.k] 242 | top_k_indices = top_indices[:, :self.k] 243 | top_k_gates = self.softmax(top_k_logits) 244 | 245 | zeros = torch.zeros_like(logits, requires_grad=True) 246 | gates = zeros.scatter(1, top_k_indices, top_k_gates) 247 | 248 | if self.noisy_gating and self.k < self.num_experts and train: 249 | load = (self._prob_in_top_k(clean_logits, noisy_logits, noise_stddev, top_logits)).sum(0) 250 | else: 251 | load = self._gates_to_load(gates) 252 | return gates, load 253 | 254 | def forward(self, x, edge_index, edge_attr=None): 255 | """Args: 256 | x: tensor shape [batch_size, input_size] 257 | train: a boolean scalar. 258 | loss_coef: a scalar - multiplier on load-balancing losses 259 | 260 | Returns: 261 | y: a tensor with shape [batch_size, output_size]. 262 | extra_training_loss: a scalar. This should be added into the overall 263 | training loss of the model. The backpropagation of this loss 264 | encourages all experts to be approximately equally used across a batch. 265 | """ 266 | gates, load = self.noisy_top_k_gating(x, self.training) 267 | # calculate importance loss 268 | importance = gates.sum(0) 269 | # 270 | loss = self.cv_squared(importance) + self.cv_squared(load) 271 | loss *= self.loss_coef 272 | 273 | 274 | 275 | expert_outputs = [] 276 | for i in range(self.num_experts): 277 | expert_i_output = self.experts[i](x, edge_index, edge_attr) 278 | expert_outputs.append(expert_i_output) 279 | expert_outputs = torch.stack(expert_outputs, dim=1) # shape=[num_nodes, num_experts, d_feature] 280 | 281 | # gates: shape=[num_nodes, num_experts] 282 | y = gates.unsqueeze(dim=-1) * expert_outputs 283 | y = y.mean(dim=1) 284 | 285 | return y, loss 286 | 287 | -------------------------------------------------------------------------------- /nodeproppred/proteins/node2vec.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import torch 4 | from torch_geometric.nn import Node2Vec 5 | 6 | from ogb.nodeproppred import PygNodePropPredDataset 7 | 8 | 9 | def save_embedding(model): 10 | torch.save(model.embedding.weight.data.cpu(), 'embedding.pt') 11 | 12 | 13 | def main(): 14 | parser = argparse.ArgumentParser(description='OGBN-Proteins (Node2Vec)') 15 | parser.add_argument('--device', type=int, default=0) 16 | parser.add_argument('--embedding_dim', type=int, default=128) 17 | parser.add_argument('--walk_length', type=int, default=80) 18 | parser.add_argument('--context_size', type=int, default=20) 19 | parser.add_argument('--walks_per_node', type=int, default=10) 20 | parser.add_argument('--batch_size', type=int, default=256) 21 | parser.add_argument('--lr', type=float, default=0.01) 22 | parser.add_argument('--epochs', type=int, default=1) 23 | parser.add_argument('--log_steps', type=int, default=1) 24 | args = parser.parse_args() 25 | 26 | device = f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu' 27 | device = torch.device(device) 28 | 29 | dataset = PygNodePropPredDataset(name='ogbn-proteins') 30 | data = dataset[0] 31 | 32 | model = Node2Vec(data.edge_index, args.embedding_dim, args.walk_length, 33 | args.context_size, args.walks_per_node, 34 | sparse=True).to(device) 35 | 36 | loader = model.loader(batch_size=args.batch_size, shuffle=True, 37 | num_workers=4) 38 | optimizer = torch.optim.SparseAdam(list(model.parameters()), lr=args.lr) 39 | 40 | model.train() 41 | for epoch in range(1, args.epochs + 1): 42 | for i, (pos_rw, neg_rw) in enumerate(loader): 43 | optimizer.zero_grad() 44 | loss = model.loss(pos_rw.to(device), neg_rw.to(device)) 45 | loss.backward() 46 | optimizer.step() 47 | 48 | if (i + 1) % args.log_steps == 0: 49 | print(f'Epoch: {epoch:02d}, Step: {i+1:03d}/{len(loader)}, ' 50 | f'Loss: {loss:.4f}') 51 | 52 | if (i + 1) % 100 == 0: # Save model every 100 steps. 53 | save_embedding(model) 54 | save_embedding(model) 55 | 56 | 57 | if __name__ == "__main__": 58 | main() 59 | --------------------------------------------------------------------------------