├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── GNN_model ├── GCN.py ├── GNN_normalizations.py ├── __init__.py ├── drop_tricks.py ├── norm_tricks.py └── res_tricks.py ├── LICENSE ├── Label_propagation_model ├── LP_Adj.py ├── diffusion_feature.py ├── norm_spec.jl └── outcome_correlation.py ├── Link_prediction_baseline ├── .DS_Store ├── compute_bound_filepath.py ├── compute_bound_pickle.py ├── heuristics.py ├── models │ ├── dgi.py │ ├── pretrain_contextpred_gin.py │ ├── pretrain_masking_gin.py │ ├── structure_pretrain.py │ ├── subgi.py │ ├── utils.py │ └── vgae.py └── run_airport.py ├── Link_prediction_model ├── edge_LP.py ├── layer.py ├── logger.py ├── loss.py ├── model.py ├── negative_sample.py └── utils.py ├── MLP_model └── __init__.py ├── NOTICE ├── base_options.py ├── figs ├── coldbrew.png ├── gnns.png ├── longtail.png └── motivation.png ├── func_libs-deprecated.py ├── main.py ├── readme.md ├── requirements.txt ├── trainer_link_prediction.py ├── trainer_node_classification.py └── utils.py /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | ## Code of Conduct 2 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). 3 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact 4 | opensource-codeofconduct@amazon.com with any additional questions or comments. 5 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing Guidelines 2 | 3 | Thank you for your interest in contributing to our project. Whether it's a bug report, new feature, correction, or additional 4 | documentation, we greatly value feedback and contributions from our community. 5 | 6 | Please read through this document before submitting any issues or pull requests to ensure we have all the necessary 7 | information to effectively respond to your bug report or contribution. 8 | 9 | 10 | ## Reporting Bugs/Feature Requests 11 | 12 | We welcome you to use the GitHub issue tracker to report bugs or suggest features. 13 | 14 | When filing an issue, please check existing open, or recently closed, issues to make sure somebody else hasn't already 15 | reported the issue. Please try to include as much information as you can. Details like these are incredibly useful: 16 | 17 | * A reproducible test case or series of steps 18 | * The version of our code being used 19 | * Any modifications you've made relevant to the bug 20 | * Anything unusual about your environment or deployment 21 | 22 | 23 | ## Contributing via Pull Requests 24 | Contributions via pull requests are much appreciated. Before sending us a pull request, please ensure that: 25 | 26 | 1. You are working against the latest source on the *main* branch. 27 | 2. You check existing open, and recently merged, pull requests to make sure someone else hasn't addressed the problem already. 28 | 3. You open an issue to discuss any significant work - we would hate for your time to be wasted. 29 | 30 | To send us a pull request, please: 31 | 32 | 1. Fork the repository. 33 | 2. Modify the source; please focus on the specific change you are contributing. If you also reformat all the code, it will be hard for us to focus on your change. 34 | 3. Ensure local tests pass. 35 | 4. Commit to your fork using clear commit messages. 36 | 5. Send us a pull request, answering any default questions in the pull request interface. 37 | 6. Pay attention to any automated CI failures reported in the pull request, and stay involved in the conversation. 38 | 39 | GitHub provides additional document on [forking a repository](https://help.github.com/articles/fork-a-repo/) and 40 | [creating a pull request](https://help.github.com/articles/creating-a-pull-request/). 41 | 42 | 43 | ## Finding contributions to work on 44 | Looking at the existing issues is a great way to find something to contribute on. As our projects, by default, use the default GitHub issue labels (enhancement/bug/duplicate/help wanted/invalid/question/wontfix), looking at any 'help wanted' issues is a great place to start. 45 | 46 | 47 | ## Code of Conduct 48 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). 49 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact 50 | opensource-codeofconduct@amazon.com with any additional questions or comments. 51 | 52 | 53 | ## Security issue notifications 54 | If you discover a potential security issue in this project we ask that you notify AWS/Amazon Security via our [vulnerability reporting page](http://aws.amazon.com/security/vulnerability-reporting/). Please do **not** create a public github issue. 55 | 56 | 57 | ## Licensing 58 | 59 | See the [LICENSE](LICENSE) file for our project's licensing. We will ask you to confirm the licensing of your contribution. 60 | -------------------------------------------------------------------------------- /GNN_model/GCN.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import dgl 5 | import math 6 | import torch as th 7 | import torch.nn.functional as F 8 | 9 | from .drop_tricks import DropoutTrick 10 | from .norm_tricks import * 11 | from .res_tricks import InitialConnection, DenseConnection, ResidualConnection 12 | from dgl import function as fn 13 | from dgl.base import DGLError 14 | from dgl.utils import expand_as_pair 15 | from torch import nn 16 | from torch.nn import init 17 | 18 | class TricksComb(nn.Module): 19 | def __init__(self, args): 20 | super(TricksComb, self).__init__() 21 | self.args = args 22 | self.dglgraph = None 23 | self.alpha = args.res_alpha 24 | self.embedding_dropout = args.dropout 25 | 26 | for k, v in vars(args).items(): 27 | setattr(self, k, v) 28 | # cannot cache graph structure when use graph dropout tricks 29 | self.cached = self.transductive = args.transductive 30 | if AcontainsB(self.type_trick, ['DropEdge', 'DropNode', 'FastGCN', 'LADIES']): 31 | self.cached = False 32 | # set self.has_residual_MLP as True when has residual connection 33 | # to keep same hidden dimension 34 | self.has_residual_MLP = False 35 | if AcontainsB(self.type_trick, ['Jumping', 'Initial', 'Residual', 'Dense']): 36 | self.has_residual_MLP = True 37 | # graph network initialize 38 | self.layers_GCN = nn.ModuleList([]) 39 | self.layers_res = nn.ModuleList([]) 40 | self.layers_norm = nn.ModuleList([]) 41 | self.layers_MLP = nn.ModuleList([]) 42 | # set MLP layer 43 | self.layers_MLP.append(nn.Linear(self.num_feats, self.dim_hidden)) 44 | if not self.has_residual_MLP: 45 | self.layers_GCN.append(GCNConv(self.num_feats, self.dim_hidden, cached=self.cached, args = self.args, whetherHasSE=self.args.TeacherGNN.whetherHasSE[0])) 46 | 47 | for i in range(self.num_layers): 48 | if (not self.has_residual_MLP) and ( 49 | 0 < i < self.num_layers - 1): # if don't want 0_th MLP, then 0-th layer is assigned outside the for loop 50 | self.layers_GCN.append(GCNConv(self.dim_hidden, self.dim_hidden, cached=self.cached,args = self.args, whetherHasSE=self.args.TeacherGNN.whetherHasSE[1])) 51 | elif self.has_residual_MLP: 52 | self.layers_GCN.append(GCNConv(self.dim_hidden, self.dim_hidden, cached=self.cached,args = self.args, whetherHasSE=self.args.TeacherGNN.whetherHasSE[1])) 53 | 54 | appendNormLayer(self, args, self.dim_hidden if i < self.num_layers - 1 else self.num_classes) 55 | 56 | # set residual connection type 57 | if AcontainsB(self.type_trick, ['Residual']): 58 | self.layers_res.append(ResidualConnection(alpha=self.alpha)) 59 | elif AcontainsB(self.type_trick, ['Initial']): 60 | self.layers_res.append(InitialConnection(alpha=self.alpha)) 61 | elif AcontainsB(self.type_trick, ['Dense']): 62 | if self.layer_agg in ['concat', 'maxpool']: 63 | self.layers_res.append( 64 | DenseConnection((i + 2) * self.dim_hidden, self.dim_hidden, self.layer_agg)) 65 | elif self.layer_agg == 'attention': 66 | self.layers_res.append( 67 | DenseConnection(self.dim_hidden, self.dim_hidden, self.layer_agg)) 68 | 69 | self.graph_dropout = DropoutTrick(args) 70 | if not self.has_residual_MLP: 71 | self.layers_GCN.append(GCNConv(self.dim_hidden, self.num_classes, cached=self.cached, args = self.args, whetherHasSE=self.args.TeacherGNN.whetherHasSE[2])) 72 | 73 | if AcontainsB(self.type_trick, ['Jumping']): 74 | if self.layer_agg in ['concat', 'maxpool']: 75 | self.layers_res.append( 76 | DenseConnection((self.num_layers + 1) * self.dim_hidden, self.num_classes, self.layer_agg)) 77 | elif self.layer_agg == 'attention': 78 | self.layers_res.append( 79 | DenseConnection(self.dim_hidden, self.num_classes, self.layer_agg)) 80 | else: 81 | self.layers_MLP.append(nn.Linear(self.dim_hidden, self.num_classes)) 82 | 83 | # set lambda 84 | if AcontainsB(self.type_trick, ['IdentityMapping']): 85 | self.lamda = args.lamda 86 | elif self.type_model == 'SGC': 87 | self.lamda = 0. 88 | elif self.type_model == 'GCN': 89 | self.lamda = 1. 90 | 91 | def forward(self, x, edge_index, want_les=False): 92 | if self.dglgraph is None: 93 | l12 = tonp(edge_index).tolist() 94 | self.dglgraph = dgl.graph((l12[0],l12[1])).to(self.args.device) 95 | graph = self.dglgraph 96 | 97 | x_list = [] 98 | le_collection = [] 99 | se_reg_all = None 100 | 101 | new_adjs = self.graph_dropout(edge_index) 102 | # new_adjs = [edge_index, None] 103 | if self.has_residual_MLP: 104 | x = F.dropout(x, p=self.embedding_dropout, training=self.training) 105 | x = self.layers_MLP[0](x) 106 | x = F.relu(x) 107 | x_list.append(x) 108 | 109 | for i in range(self.num_layers): 110 | x = F.dropout(x, p=self.dropout, training=self.training) 111 | edge_index, _ = new_adjs[i] 112 | beta = math.log(self.lamda / (i + 1) + 1) if AcontainsB(self.type_trick, 113 | ['IdentityMapping']) else self.lamda 114 | # x = self.layers_GCN[i](x, edge_index, beta) 115 | x, se_reg = self.layers_GCN[i](graph, x) 116 | if se_reg is not None: 117 | if se_reg_all is None: 118 | se_reg_all = se_reg 119 | else: 120 | se_reg_all += se_reg 121 | 122 | x = run_norm_if_any(self, x, i) 123 | 124 | if want_les: 125 | le_collection.append(x.clone().detach()) 126 | 127 | if self.has_residual_MLP or i < self.num_layers - 1: 128 | x = F.relu(x) 129 | x_list.append(x) 130 | if AcontainsB(self.type_trick, ['Initial', 'Dense', 'Residual']): 131 | x = self.layers_res[i](x_list) 132 | 133 | x = F.dropout(x, p=self.args.dropout, training=self.training) 134 | if self.has_residual_MLP: 135 | if AcontainsB(self.type_trick, ['Jumping']): 136 | x = self.layers_res[0](x_list) 137 | else: 138 | x = self.layers_MLP[-1](x) 139 | if want_les: 140 | return x, se_reg_all, th.cat(le_collection, dim=-1) 141 | else: 142 | return x, se_reg_all 143 | 144 | def get_se_dim(self, x, edge_index): 145 | _,_,les = self.forward(x, edge_index, want_les=1) 146 | return les.shape[-1] 147 | 148 | def collect_SE(self, x, edge_index): 149 | _,_,les = self.forward(x, edge_index, want_les=1) 150 | return les 151 | 152 | class GCNConv(nn.Module): 153 | # Cold Brew's GCN; modified from DGL official GCN implementation 154 | def __init__(self, 155 | in_feats, 156 | out_feats, 157 | norm='both', 158 | weight=True, 159 | bias=True, 160 | activation=None, 161 | allow_zero_in_degree=False, cached=None, args=None,whetherHasSE=False): 162 | super(GCNConv, self).__init__() 163 | self.args = args 164 | self._in_feats = in_feats 165 | self._out_feats = out_feats 166 | self._norm = norm 167 | self._allow_zero_in_degree = True 168 | self._allow_zero_in_degree = allow_zero_in_degree 169 | if weight: 170 | self.weight = nn.Parameter(th.Tensor(in_feats, out_feats)) 171 | else: 172 | self.register_parameter('weight', None) 173 | if bias: 174 | self.bias = nn.Parameter(th.Tensor(out_feats)) 175 | else: 176 | self.register_parameter('bias', None) 177 | self.reset_parameters() 178 | self._activation = activation 179 | 180 | self.whetherHasSE = whetherHasSE 181 | if whetherHasSE: 182 | self.le = nn.Parameter(th.randn(args.N_nodes, self._out_feats), requires_grad=True) 183 | 184 | def forward(self, graph, feat, weight=None, edge_weight=None): 185 | 186 | with graph.local_scope(): 187 | if not self._allow_zero_in_degree: 188 | if (graph.in_degrees() == 0).any(): 189 | raise DGLError('There are 0-in-degree nodes in the graph, ' 190 | 'output for those nodes will be invalid. ' 191 | 'This is harmful for some applications, ' 192 | 'causing silent performance regression. ' 193 | 'Adding self-loop on the input graph by ' 194 | 'calling `g = dgl.add_self_loop(g)` will resolve ' 195 | 'the issue. Setting ``allow_zero_in_degree`` ' 196 | 'to be `True` when constructing this module will ' 197 | 'suppress the check and let the code run.') 198 | aggregate_fn = fn.copy_src('h', 'm') 199 | if edge_weight is not None: 200 | assert edge_weight.shape[0] == graph.number_of_edges() 201 | graph.edata['_edge_weight'] = edge_weight 202 | aggregate_fn = fn.u_mul_e('h', '_edge_weight', 'm') 203 | 204 | feat_src, feat_dst = expand_as_pair(feat, graph) 205 | if self._norm in ['left', 'both']: 206 | degs = graph.out_degrees().float().clamp(min=1) 207 | if self._norm == 'both': 208 | norm = th.pow(degs, -0.5) 209 | else: 210 | norm = 1.0 / degs 211 | shp = norm.shape + (1,) * (feat_src.dim() - 1) 212 | norm = th.reshape(norm, shp) 213 | feat_src = feat_src * norm 214 | 215 | if weight is not None: 216 | if self.weight is not None: 217 | raise DGLError('External weight is provided while at the same time the' 218 | ' module has defined its own weight parameter. Please' 219 | ' create the module with flag weight=False.') 220 | else: 221 | weight = self.weight 222 | 223 | # mult W first to reduce the feature size for aggregation. 224 | if weight is not None: 225 | feat_src = th.matmul(feat_src, weight) 226 | 227 | # ______________ add Structural Embeddings ______________ 228 | # Math: X^{(l+1)}=\sigma\left(\tilde{\bm{A}}\left(X^{(l)} W^{(l)}+E^{(l)}\right)\right), X^{(l)} \in R^{N\times d_{1}}, W^{(l)} \in R^{d_1\times d_{2}}, E^{(l)} \in R^{N\times d_{2}} 229 | # X^{(L)} and X^{(L+1)} is the input and output of the current convolution layer; E is the structural embedding; \tilde{\bm{A}} is the normalized adjacency matrix. W and E are learnable. 230 | if self.whetherHasSE: 231 | graph.srcdata['h'] = feat_src + self.le 232 | se_reg = th.norm(self.le) 233 | 234 | else: 235 | graph.srcdata['h'] = feat_src 236 | se_reg = None 237 | 238 | graph.update_all(aggregate_fn, fn.sum(msg='m', out='h')) 239 | 240 | rst = graph.dstdata['h'] 241 | 242 | if self._norm in ['right', 'both']: 243 | degs = graph.in_degrees().float().clamp(min=1) 244 | if self._norm == 'both': 245 | norm = th.pow(degs, -0.5) 246 | else: 247 | norm = 1.0 / degs 248 | shp = norm.shape + (1,) * (feat_dst.dim() - 1) 249 | norm = th.reshape(norm, shp) 250 | rst = rst * norm 251 | 252 | if self.bias is not None: 253 | rst = rst + self.bias 254 | 255 | if self._activation is not None: 256 | rst = self._activation(rst) 257 | 258 | return rst, se_reg 259 | 260 | def reset_parameters(self): 261 | if self.weight is not None: 262 | init.xavier_uniform_(self.weight) 263 | if self.bias is not None: 264 | init.zeros_(self.bias) 265 | 266 | def set_allow_zero_in_degree(self, set_value): 267 | self._allow_zero_in_degree = set_value 268 | 269 | def extra_repr(self): 270 | """Set the extra representation of the module, 271 | which will come into effect when printing the model. 272 | """ 273 | summary = 'in={_in_feats}, out={_out_feats}' 274 | summary += ', normalization={_norm}' 275 | if '_activation' in self.__dict__: 276 | summary += ', activation={_activation}' 277 | return summary.format(**self.__dict__) 278 | 279 | def tonp(arr): 280 | if type(arr) is th.Tensor: 281 | return arr.detach().cpu().data.numpy() 282 | else: 283 | return np.asarray(arr) 284 | -------------------------------------------------------------------------------- /GNN_model/GNN_normalizations.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | from .norm_tricks import * 5 | from .GCN import TricksComb 6 | from torch import nn 7 | from utils import D 8 | 9 | class TeacherGNN(nn.Module): 10 | # This class is the teacher GCN model (with structural embedding) for cold brew 11 | def __init__(self, args, proj2class=None): 12 | super().__init__() 13 | proj2class = proj2class or nn.Identity() 14 | args.num_classes_bkup = args.num_classes 15 | args.num_classes = args.dim_commonEmb 16 | self.args = args 17 | 18 | if self.args.dim_learnable_input>0: 19 | embs = torch.randn(args.N_nodes, args.dim_learnable_input)*0.001 20 | self.embs = nn.Parameter(embs, requires_grad=True) 21 | self.args.num_feats_bkup = self.args.num_feats 22 | self.args.num_feats = self.args.dim_learnable_input 23 | 24 | from GNN_model.GNN_normalizations import GNN_norm as GNN_trickComb 25 | self.model = GNN_trickComb(args) 26 | 27 | self.proj2linkp = nn.Identity() 28 | self.proj2class = proj2class 29 | self.dglgraph = None 30 | 31 | def forward(self, x, edge_index): 32 | if self.args.TeacherGNN.change_to_featureless: 33 | x = x*0 34 | if self.args.dim_learnable_input>0: 35 | x = self.embs 36 | commonEmb, self.se_reg_all = self.model(x, edge_index) 37 | self.out = commonEmb 38 | return commonEmb 39 | 40 | def get_3_embs(self, x, edge_index, mask=None, want_heads=True): 41 | commonEmb = self.forward(x, edge_index) 42 | emb4classi_full = self.proj2class(commonEmb) 43 | if want_heads: 44 | if mask is not None: 45 | emb4classi = emb4classi_full[mask] 46 | else: 47 | emb4classi = emb4classi_full 48 | 49 | emb4linkp = self.proj2linkp(commonEmb) 50 | else: 51 | emb4linkp = emb4classi = None 52 | res = D() 53 | res.commonEmb, res.emb4classi, res.emb4classi_full, res.emb4linkp = commonEmb, emb4classi, emb4classi_full, emb4linkp 54 | 55 | return res 56 | 57 | def get_emb4linkp(self, x, edge_index, mask=None): 58 | # return ALL nodes 59 | _, _, emb4linkp = self.get_3_embs(x, edge_index, want_heads=True) 60 | return emb4linkp 61 | 62 | def graph2commonEmb(self, x, edge_index, train_mask): 63 | commonEmb = self.forward(x, edge_index) 64 | commonEmb_train = commonEmb[train_mask] 65 | return commonEmb_train, commonEmb 66 | 67 | class GNN_norm(nn.Module): 68 | def __init__(self, args): 69 | super(GNN_norm, self).__init__() 70 | self.model = TricksComb(args) 71 | 72 | def forward(self, x, edge_index): 73 | return self.model.forward(x, edge_index) 74 | -------------------------------------------------------------------------------- /GNN_model/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | file_dir = os.path.dirname(__file__) 5 | sys.path.append(file_dir) -------------------------------------------------------------------------------- /GNN_model/drop_tricks.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import torch 5 | import torch_scatter 6 | from torch import nn 7 | from torch_geometric.nn.conv.gcn_conv import gcn_norm 8 | from torch_geometric.utils import dropout_adj, subgraph 9 | from torch_geometric.utils.num_nodes import maybe_num_nodes 10 | 11 | # implemented based on GCNModel: https://github.com/DropEdge/DropEdge/blob/master/src/models.py 12 | # Baseblock MultiLayerGCNBlock with nbaseblocklayer=1 13 | class DropEdge(nn.Module): 14 | """ 15 | DropEdge: Dropping edges using a uniform distribution. 16 | """ 17 | def __init__(self, drop_rate): 18 | super(DropEdge, self).__init__() 19 | self.drop_rate = drop_rate 20 | self.undirected = False 21 | 22 | def forward(self, edge_index, edge_attr=None, edge_weight=None, num_nodes=None): 23 | return dropout_adj(edge_index, p=self.drop_rate, edge_attr=edge_attr, 24 | force_undirected=self.undirected, training=self.training) 25 | 26 | class DropNode(nn.Module): 27 | """ 28 | DropNode: Sampling node using a uniform distribution. 29 | """ 30 | 31 | def __init__(self, drop_rate): 32 | super(DropNode, self).__init__() 33 | self.drop_rate = drop_rate 34 | 35 | def forward(self, edge_index, edge_attr=None, edge_weight=None, num_nodes=None): 36 | if not self.training: 37 | return edge_index, edge_attr 38 | 39 | num_nodes = maybe_num_nodes(edge_index, num_nodes) 40 | nodes = torch.arange(num_nodes, dtype=torch.int64) 41 | mask = torch.full_like(nodes, 1 - self.drop_rate, dtype=torch.float32) 42 | mask = torch.bernoulli(mask).to(torch.bool) 43 | subnodes = nodes[mask] 44 | 45 | return subgraph(subnodes, edge_index, edge_attr=edge_attr, num_nodes=num_nodes) 46 | 47 | class FastGCN(nn.Module): 48 | """ 49 | FastGCN: Sampling N nodes using a importance distribution. 50 | """ 51 | def __init__(self, drop_rate): 52 | super(FastGCN, self).__init__() 53 | self.drop_rate = drop_rate 54 | 55 | def forward(self, edge_index, edge_attr=None, edge_weight=None, num_nodes=None): 56 | if not self.training: 57 | return edge_index, edge_attr 58 | 59 | num_nodes = maybe_num_nodes(edge_index, num_nodes) 60 | 61 | if edge_weight is None: 62 | edge_weight = torch.ones((edge_index.shape[1],), device=edge_index.device) 63 | 64 | # Importance sampling: q(u) \propto \sum_{v \in N(u)} w^2(u,v) 65 | row, col = edge_index[0], edge_index[1] 66 | weight = torch_scatter.scatter_add(edge_weight**2, col, dim=0, dim_size=num_nodes) 67 | subnodes = torch.multinomial(weight, int(num_nodes*(1-self.drop_rate)), replacement=False) 68 | 69 | return subgraph(subnodes, edge_index, edge_attr=edge_attr, num_nodes=num_nodes) 70 | 71 | class LADIES(nn.Module): 72 | """ 73 | LADIES: Sampling N nodes dependent on the sampled nodes in the next layer. 74 | """ 75 | def __init__(self, drop_rate, num_layers): 76 | super(LADIES, self).__init__() 77 | self.drop_rate = drop_rate 78 | self.num_layers = num_layers 79 | 80 | def forward(self, edge_index, edge_attr=None, edge_weight=None, num_nodes=None): 81 | if not self.training: 82 | return [(edge_index, edge_attr)] 83 | 84 | num_nodes = maybe_num_nodes(edge_index, num_nodes) 85 | 86 | if edge_weight is None: 87 | edge_weight = torch.ones((edge_index.shape[1],), device=edge_index.device) 88 | 89 | sampled_edges = [] 90 | last_edge_index = edge_index 91 | row_mask = torch.ones(edge_weight.shape[0], dtype=torch.bool) 92 | for i in range(self.num_layers): 93 | # Importance sampling: q(u) \propto \sum_{v \in N(u)} w^2(u,v) 94 | row, col = edge_index[0], edge_index[1] 95 | new_edge_weight = torch.zeros_like(edge_weight) 96 | new_edge_weight[row_mask] = edge_weight[row_mask] 97 | weight = torch_scatter.scatter_add(new_edge_weight**2, col, dim=0, dim_size=num_nodes) 98 | subnodes = torch.multinomial(weight, int(num_nodes*(1-self.drop_rate)), replacement=False) 99 | 100 | # create row mask for next iteration 101 | row_mask = torch.zeros(num_nodes, dtype=torch.bool) 102 | row_mask[subnodes] = True 103 | row_mask = row_mask[row] 104 | assert row_mask.shape[0] == edge_weight.shape[0] 105 | 106 | # record the sampled edges for sampling in the previous layer 107 | new_edge_index, new_edge_attr = subgraph(subnodes, edge_index, edge_attr=edge_attr, num_nodes=num_nodes) 108 | sampled_edges.append((new_edge_index, new_edge_attr)) 109 | # reverse the samples to the layer order 110 | sampled_edges.reverse() 111 | return sampled_edges 112 | 113 | 114 | def AcontainsB(A,listB): 115 | # A: string; listB: list of strings 116 | for s in listB: 117 | if s in A: return True 118 | return False 119 | 120 | class DroppedEdges(list): 121 | def __getitem__(self, i): 122 | if self.__len__() == 1: 123 | return super().__getitem__(0) 124 | else: 125 | return super().__getitem__(i) 126 | 127 | class DropoutTrick(nn.Module): 128 | def __init__(self, args): 129 | super(DropoutTrick, self).__init__() 130 | self.type_trick = args.type_trick 131 | self.num_layers = args.num_layers 132 | self.layerwise_drop = args.layerwise_dropout 133 | 134 | if AcontainsB(self.type_trick, ['DropEdge']): 135 | self.graph_dropout = DropEdge(args.graph_dropout) 136 | elif AcontainsB(self.type_trick, ['DropNode']): 137 | self.graph_dropout = DropNode(args.graph_dropout) 138 | elif AcontainsB(self.type_trick, ['FastGCN']): 139 | self.graph_dropout = FastGCN(args.graph_dropout) 140 | elif AcontainsB(self.type_trick, ['LADIES']): 141 | assert self.layerwise_drop, 'LADIES requires layer-wise dropout flag on' 142 | self.graph_dropout = LADIES(args.graph_dropout, args.num_layers) 143 | else: 144 | self.graph_dropout = None 145 | 146 | 147 | def forward(self, edge_index, edge_weight=None, adj_norm=False, num_nodes=-1): 148 | if self.graph_dropout is not None: 149 | if AcontainsB(self.type_trick, ['LADIES']): 150 | for dp_edges, dp_weights in self.graph_dropout(edge_index, edge_attr=edge_weight, edge_weight=edge_weight): 151 | if adj_norm: 152 | dp_edges, dp_weights = gcn_norm(dp_edges, dp_weights, num_nodes, False) 153 | new_adjs = DroppedEdges([(dp_edges, dp_weights) ]) 154 | else: 155 | new_adjs = DroppedEdges() 156 | if self.layerwise_drop: 157 | for _ in range(self.num_layers): 158 | dp_edges, dp_weights = self.graph_dropout(edge_index, edge_attr=edge_weight, edge_weight=edge_weight) 159 | if adj_norm: 160 | dp_edges, dp_weights = gcn_norm(dp_edges, dp_weights, num_nodes, False) 161 | new_adjs.append((dp_edges, dp_weights)) 162 | else: 163 | dp_edges, dp_weights = self.graph_dropout(edge_index, edge_attr=edge_weight, edge_weight=edge_weight) 164 | if adj_norm: 165 | dp_edges, dp_weights = gcn_norm(dp_edges, dp_weights, num_nodes, False) 166 | new_adjs.append((dp_edges, dp_weights)) 167 | else: 168 | # no dropout 169 | if adj_norm: 170 | edge_index, edge_weight = gcn_norm(edge_index, edge_weight, num_nodes, False) 171 | new_adjs = DroppedEdges([(edge_index, edge_weight)]) 172 | return new_adjs 173 | -------------------------------------------------------------------------------- /GNN_model/norm_tricks.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | import torch.nn as nn 7 | 8 | 9 | class comb_norm(torch.nn.Module): 10 | def __init__(self, norm_list): 11 | super().__init__() 12 | self.norm_list = nn.ModuleList(norm_list) 13 | 14 | def forward(self, x): 15 | for mod in self.norm_list: 16 | x = mod(x) 17 | return x 18 | 19 | 20 | class pair_norm(torch.nn.Module): 21 | def __init__(self): 22 | super(pair_norm, self).__init__() 23 | # print(f'------ ››››››››› {self._get_name()}') 24 | 25 | def forward(self, x): 26 | col_mean = x.mean(dim=0) 27 | x = x - col_mean 28 | rownorm_mean = (1e-6 + x.pow(2).sum(dim=1).mean()).sqrt() 29 | x = x / rownorm_mean 30 | return x 31 | 32 | 33 | class mean_norm(torch.nn.Module): 34 | def __init__(self): 35 | super(mean_norm, self).__init__() 36 | # print(f'------ ››››››››› {self._get_name()}') 37 | 38 | def forward(self, x): 39 | col_mean = x.mean(dim=0) 40 | x = x - col_mean 41 | return x 42 | 43 | 44 | class node_norm(torch.nn.Module): 45 | def __init__(self, node_norm_type="n", unbiased=False, eps=1e-5, power_root=2, **kwargs): 46 | super(node_norm, self).__init__() 47 | self.unbiased = unbiased 48 | self.eps = eps 49 | self.node_norm_type = node_norm_type 50 | self.power = 1 / power_root 51 | # print(f'------ ››››››››› {self._get_name()}') 52 | 53 | def forward(self, x): 54 | # in GCN+Cora, 55 | # n v srv pr 56 | # 16 layer: _19.8_ 15.7 17.4 17.3 57 | # 32 layer: 20.3 _25.5_ 16.2 16.3 58 | 59 | if self.node_norm_type == "n": 60 | mean = torch.mean(x, dim=1, keepdim=True) 61 | std = ( 62 | torch.var(x, unbiased=self.unbiased, dim=1, keepdim=True) + self.eps 63 | ).sqrt() 64 | x = (x - mean) / std 65 | elif self.node_norm_type == "v": 66 | std = ( 67 | torch.var(x, unbiased=self.unbiased, dim=1, keepdim=True) + self.eps 68 | ).sqrt() 69 | x = x / std 70 | 71 | elif self.node_norm_type == "m": 72 | mean = torch.mean(x, dim=1, keepdim=True) 73 | x = x - mean 74 | elif self.node_norm_type == "srv": # squre root of variance 75 | std = ( 76 | torch.var(x, unbiased=self.unbiased, dim=1, keepdim=True) + self.eps 77 | ).sqrt() 78 | x = x / torch.sqrt(std) 79 | elif self.node_norm_type == "pr": 80 | std = ( 81 | torch.var(x, unbiased=self.unbiased, dim=1, keepdim=True) + self.eps 82 | ).sqrt() 83 | x = x / torch.pow(std, self.power) 84 | return x 85 | 86 | def __repr__(self): 87 | original_str = super().__repr__() 88 | components = list(original_str) 89 | node_norm_type_str = f"node_norm_type={self.node_norm_type}" 90 | components.insert(-1, node_norm_type_str) 91 | new_str = "".join(components) 92 | return new_str 93 | 94 | 95 | class group_norm(torch.nn.Module): 96 | def __init__(self, dim_to_norm=None, dim_hidden=16, num_groups=None, skip_weight=None, **w): 97 | super(group_norm, self).__init__() 98 | self.num_groups = num_groups 99 | self.skip_weight = skip_weight 100 | 101 | dim_hidden = dim_hidden if dim_to_norm is None else dim_to_norm 102 | self.dim_hidden = dim_hidden 103 | 104 | # print(f'\n\n{dim_to_norm}\n\n');raise 105 | 106 | self.bn = torch.nn.BatchNorm1d(dim_hidden * self.num_groups, momentum=0.3) 107 | self.group_func = torch.nn.Linear(dim_hidden, self.num_groups, bias=True) 108 | # print(f'------ ››››››››› {self._get_name()}') 109 | 110 | def forward(self, x): 111 | if self.num_groups == 1: 112 | x_temp = self.bn(x) 113 | else: 114 | score_cluster = F.softmax(self.group_func(x), dim=1) 115 | x_temp = torch.cat([score_cluster[:, group].unsqueeze(dim=1) * x for group in range(self.num_groups)], 116 | dim=1) 117 | x_temp = self.bn(x_temp).view(-1, self.num_groups, self.dim_hidden).sum(dim=1) 118 | 119 | x = x + x_temp * self.skip_weight 120 | return x 121 | 122 | 123 | def AcontainsB(A, listB): 124 | # A: string; listB: list of strings 125 | for s in listB: 126 | if s in A: return True 127 | return False 128 | 129 | 130 | def appendNormLayer(net, args, dim_to_norm=None): 131 | if AcontainsB(args.type_trick, ['BatchNorm']): 132 | net.layers_norm.append(torch.nn.BatchNorm1d(net.dim_hidden if dim_to_norm is None else dim_to_norm)) 133 | elif AcontainsB(args.type_trick, ['PairNorm']): 134 | net.layers_norm.append(pair_norm()) 135 | elif AcontainsB(args.type_trick, ['NodeNorm']): 136 | net.layers_norm.append(node_norm(**vars(net.args))) 137 | elif AcontainsB(args.type_trick, ['MeanNorm']): 138 | net.layers_norm.append(mean_norm()) 139 | elif AcontainsB(args.type_trick, ['GroupNorm']): 140 | net.layers_norm.append(group_norm(dim_to_norm, **vars(reset_weight_GroupNorm(args)))) 141 | elif AcontainsB(args.type_trick, ['CombNorm']): 142 | net.layers_norm.append( 143 | comb_norm([group_norm(dim_to_norm, **vars(reset_weight_GroupNorm(args))), node_norm(**vars(net.args))])) 144 | 145 | 146 | def run_norm_if_any(net, x, ilayer): 147 | if net.args.type_trick in ['BatchNorm', 'PairNorm', 'NodeNorm', 'MeanNorm', 'GroupNorm', 'CombNorm']: 148 | return net.layers_norm[ilayer](x) 149 | else: 150 | return x 151 | 152 | 153 | def reset_weight_GroupNorm(args): 154 | if args.num_groups is not None: 155 | return args 156 | 157 | args.miss_rate = 0. 158 | 159 | if (args.dataset=='Citeseer' or 'CV' in args.dataset) and args.miss_rate == 0.: 160 | if args.type_model in ['GAT', 'GCN']: 161 | args.skip_weight = 0.001 if args.num_layers < 6 else 0.005 162 | else: 163 | args.skip_weight = 0.0005 if args.num_layers < 60 else 0.002 164 | 165 | elif args.dataset == 'ogbn-arxiv' and args.miss_rate == 0.: 166 | if args.type_model in ['GAT', 'GCN']: 167 | args.skip_weight = 0.001 if args.num_layers < 6 else 0.005 168 | else: 169 | args.skip_weight = 0.0005 if args.num_layers < 60 else 0.002 170 | 171 | elif args.dataset == 'Pubmed' and args.miss_rate == 0.: 172 | if args.type_model in ['GCN']: 173 | args.skip_weight = 0.001 if args.num_layers < 6 else 0.01 174 | elif args.type_model in ['GAT']: 175 | args.skip_weight = 0.005 if args.num_layers < 6 else 0.01 176 | else: 177 | args.skip_weight = 0.05 178 | 179 | elif args.dataset == 'Cora' and args.miss_rate == 0.: 180 | if args.type_model in ['GCN']: 181 | args.skip_weight = 0.001 if args.num_layers < 6 else 0.03 182 | elif args.type_model in ['GAT']: 183 | args.skip_weight = 0.001 if args.num_layers < 6 else 0.01 184 | else: 185 | args.skip_weight = 0.01 if args.num_layers < 60 else 0.005 186 | 187 | elif args.dataset == 'CoauthorCS' and args.miss_rate == 0.: 188 | if args.type_model in ['GAT', 'GCN']: 189 | args.skip_weight = 0.001 if args.num_layers < 6 else 0.03 190 | else: 191 | args.epochs = 500 192 | args.skip_weight = 0.001 if args.num_layers < 10 else .5 193 | elif args.dataset in ['CoauthorCS', 'CoauthorPhysics', 'AmazonComputers', 'AmazonPhoto', 194 | 'TEXAS', 'WISCONSIN', 'CORNELL']: 195 | args.skip_weight = 0.005 196 | 197 | else: 198 | raise NotImplementedError 199 | 200 | # -wz 201 | if args.dataset == 'Pubmed': 202 | args.num_groups = 5 203 | else: 204 | args.num_groups = 10 205 | 206 | return args 207 | -------------------------------------------------------------------------------- /GNN_model/res_tricks.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import torch 5 | from torch import nn 6 | 7 | class ResidualConnection(nn.Module): 8 | def __init__(self, alpha=0.5): 9 | super(ResidualConnection, self).__init__() 10 | self.alpha = alpha 11 | 12 | def forward(self, Xs: list): 13 | assert len(Xs) >= 1 14 | return Xs[-1] if len(Xs) == 1 else (1 - self.alpha) * Xs[-1] + self.alpha * Xs[-2] 15 | 16 | class InitialConnection(nn.Module): 17 | def __init__(self, alpha=0.5): 18 | super(InitialConnection, self).__init__() 19 | self.alpha = alpha 20 | 21 | def forward(self, Xs: list): 22 | assert len(Xs) >= 1 23 | return Xs[-1] if len(Xs) == 1 else (1 - self.alpha) * Xs[-1] + self.alpha * Xs[0] 24 | 25 | class DenseConnection(nn.Module): 26 | def __init__(self, in_dim, out_dim, aggregation='concat'): 27 | super(DenseConnection, self).__init__() 28 | self.in_dim = in_dim 29 | self.out_dim = out_dim 30 | self.aggregation = aggregation 31 | if aggregation == 'concat': 32 | self.layer_transform = nn.Linear(in_dim, out_dim, bias=True) 33 | elif aggregation == 'attention': 34 | self.layer_att = nn.Linear(in_dim, 1, bias=True) 35 | 36 | def forward(self, Xs: list): 37 | assert len(Xs) >= 1 38 | if self.aggregation == 'concat': 39 | X = torch.cat(Xs, dim=-1) 40 | X = self.layer_transform(X) 41 | return X 42 | elif self.aggregation == 'maxpool': 43 | X = torch.stack(Xs, dim=-1) 44 | X, _ = torch.max(X, dim=-1, keepdim=False) 45 | return X 46 | # implement with the code from https://github.com/mengliu1998/DeeperGNN/blob/master/DeeperGNN/dagnn.py 47 | elif self.aggregation == 'attention': 48 | # pps n x k+1 x c 49 | pps = torch.stack(Xs, dim=1) 50 | retain_score = self.layer_att(pps).squeeze() 51 | retain_score = torch.sigmoid(retain_score).unsqueeze(1) 52 | X = torch.matmul(retain_score, pps).squeeze() 53 | return X 54 | else: 55 | raise Exception("Unknown aggregation") 56 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /Label_propagation_model/LP_Adj.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import copy 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | from .outcome_correlation import * 10 | from .diffusion_feature import * 11 | from torch import Tensor 12 | from torch_geometric.nn.conv import MessagePassing 13 | from torch_geometric.nn.conv.gcn_conv import gcn_norm 14 | from torch_geometric.typing import Adj, OptTensor 15 | from torch_sparse import SparseTensor, matmul 16 | from typing import Callable, Optional 17 | 18 | class LabelPropagation_Adj(nn.Module): 19 | def __init__(self, args, data, train_mask): 20 | super().__init__() 21 | self.train_cnt = 0 22 | self.args = args 23 | self.num_layers = args.num_layers 24 | self.alpha = args.lpStep.alpha 25 | self.num_classes = args.num_classes 26 | self.num_nodes = data.num_nodes 27 | self.edge_index = data.edge_index 28 | self.train_mask = train_mask 29 | self.preStep = PreStep(args) 30 | self.midStep = None 31 | self.lpStep = None 32 | self.embs_step1 = None 33 | self.x_after_step2 = None 34 | self.data_cpu = copy.deepcopy(data).to('cpu') 35 | self.data = data 36 | 37 | def train_net(self, input_dict): 38 | # only complete ONE-TIME backprop/update for all nodes 39 | self.train_cnt += 1 40 | device, split_masks = input_dict['device'], input_dict['split_masks'] 41 | if self.embs_step1 is None: # only preprocess ONCE; has to be on cpu 42 | self.embs_step1 = self.preStep(self.data_cpu).to(device) 43 | 44 | data = self.data_cpu.to(device) 45 | x, y = data.x, data.y 46 | loss_op = input_dict['loss_op'] 47 | train_mask = split_masks['train'] 48 | 49 | if self.midStep is None: 50 | self.midStep = MidStep(self.args, self.embs_step1, self.data).to(device) 51 | self.optimizer = torch.optim.Adam(self.midStep.parameters(), lr=self.args.lr, weight_decay=self.args.weight_decay) 52 | 53 | if self.lpStep is None: 54 | self.lpStep = LPStep(self.args, data, split_masks) 55 | 56 | self.x_after_step2, train_loss = self.midStep.train_forward(self.embs_step1, y, self.optimizer, loss_op, split_masks) # only place that require opt 57 | 58 | if self.train_cnt>20: 59 | print() 60 | acc = cal_acc_logits(self.x_after_step2[split_masks['test']], data.y[split_masks['test']]) 61 | 62 | self.out = self.lpStep(self.x_after_step2, data) 63 | self.out,y,train_mask = to_device([self.out,y,train_mask], 'cpu') 64 | total_correct = int(self.out[train_mask].argmax(dim=-1).eq(y[train_mask]).sum()) 65 | train_acc = total_correct / int(train_mask.sum()) 66 | return train_loss, train_acc 67 | 68 | def inference(self, input_dict): 69 | return self.out 70 | 71 | @torch.no_grad() 72 | def forward_backup( 73 | self, y: Tensor, edge_index: Adj, mask: Optional[Tensor] = None, 74 | edge_weight: OptTensor = None, 75 | post_step: Callable = lambda y: y.clamp_(0., 1.) 76 | ) -> Tensor: 77 | """""" 78 | if y.dtype == torch.long: 79 | y = F.one_hot(y.view(-1)).to(torch.float) 80 | out = y 81 | if mask is not None: 82 | out = torch.zeros_like(y) 83 | out[mask] = y[mask] 84 | if isinstance(edge_index, SparseTensor) and not edge_index.has_value(): 85 | edge_index = gcn_norm(edge_index, add_self_loops=False) 86 | elif isinstance(edge_index, Tensor) and edge_weight is None: 87 | edge_index, edge_weight = gcn_norm(edge_index, num_nodes=y.size(0), 88 | add_self_loops=False) 89 | res = (1 - self.alpha) * out 90 | for _ in range(self.num_layers): 91 | # propagate_type: (y: Tensor, edge_weight: OptTensor) 92 | out = self.propagate(edge_index, x=out, edge_weight=edge_weight, 93 | size=None) 94 | out.mul_(self.alpha).add_(res) 95 | out = post_step(out) 96 | return out 97 | 98 | def message(self, x_j: Tensor, edge_weight: OptTensor) -> Tensor: 99 | return x_j if edge_weight is None else edge_weight.view(-1, 1) * x_j 100 | 101 | def message_and_aggregate(self, adj_t: SparseTensor, x: Tensor) -> Tensor: 102 | return matmul(adj_t, x, reduce=self.aggr) 103 | 104 | def __repr__(self): 105 | return '{}(num_layers={}, alpha={})'.format(self.__class__.__name__, 106 | self.num_layers, 107 | self.alpha) 108 | 109 | class LPStep(nn.Module): 110 | """two papers: 111 | http://mlg.eng.cam.ac.uk/zoubin/papers/CMU-CALD-02-107.pdf 112 | https://github.com/CUAI/CorrectAndSmooth 113 | """ 114 | def __init__(self, args, data, split_masks): 115 | super().__init__() 116 | self.train_cnt = 0 117 | self.args = args 118 | self.train_idx = torch.where(split_masks['train']==True)[0].to(args.device) 119 | self.valid_idx = torch.where(split_masks['valid']==True)[0].to(args.device) 120 | self.test_idx = torch.where(split_masks['test']==True)[0].to(args.device) 121 | self.split_idx = {'train': self.train_idx, 'valid': self.valid_idx, 'test': self.test_idx} 122 | self.no_prep = args.lpStep.no_prep 123 | adj, D_isqrt = process_adj(data) 124 | DAD, DA, AD = gen_normalized_adjs(adj, D_isqrt) 125 | 126 | self.lp_dict = { 127 | 'train_only': True, 128 | 'alpha1': args.lpStep.alpha1, 129 | 'alpha2': args.lpStep.alpha2, 130 | 'A1': eval(args.lpStep.A1), 131 | 'A2': eval(args.lpStep.A2), 132 | 'num_propagations1': args.lpStep.num_propagations1, 133 | 'num_propagations2': args.lpStep.num_propagations2, 134 | 'display': False, 135 | 'device': args.device, 136 | 137 | # below: lp only 138 | 'idxs': ['train'], 139 | 'alpha': args.lpStep.alpha, 140 | 'num_propagations': args.lpStep.num_propagations, 141 | 'A': eval(args.lpStep.A), 142 | } 143 | self.fn = eval(self.args.lpStep.fn) 144 | return 145 | 146 | def forward(self, model_out, data): 147 | # need to pass 'data.y' through 'data' 148 | self.train_cnt += 1 149 | if self.args.lpStep.lp_force_on_cpu: 150 | self.split_idx, data, model_out = to_device([self.split_idx, data, model_out], 'cpu') 151 | else: 152 | self.split_idx, data, model_out = to_device([self.split_idx, data, model_out], self.args.device) 153 | 154 | if self.no_prep: 155 | out = label_propagation(data, self.split_idx, **self.lp_dict) 156 | else: 157 | _, out = self.fn(data, model_out, self.split_idx, **self.lp_dict) 158 | 159 | self.split_idx, data, model_out = to_device([self.split_idx, data, model_out], self.args.device) 160 | return out 161 | 162 | class PreStep(nn.Module): 163 | def __init__(self, args): 164 | super().__init__() 165 | self.args = args 166 | return 167 | 168 | def forward(self, data): 169 | embs = [] 170 | if 'diffusion' in self.args.preStep.pre_methods: 171 | embs.append(preprocess(data, 'diffusion', self.args.preStep.num_propagations, post_fix=self.args.dataset)) 172 | if 'spectral' in self.args.preStep.pre_methods: 173 | embs.append(preprocess(data, 'spectral', self.args.preStep.num_propagations, post_fix=self.args.dataset)) 174 | if 'community' in self.args.preStep.pre_methods: 175 | embs.append(preprocess(data, 'community', self.args.preStep.num_propagations, post_fix=self.args.dataset)) 176 | 177 | embeddings = torch.cat(embs, dim=-1) 178 | return embeddings 179 | 180 | class MidStep(nn.Module): 181 | def __init__(self, args, embs, data): 182 | super().__init__() 183 | self.args = args 184 | self.train_cnt = 0 185 | self.best_valid = 0. 186 | self.data = data 187 | if args.midStep.model == 'mlp': 188 | self.model = MLP(embs.size(-1)+args.num_feats,args.midStep.hidden_channels, args.num_classes, args.midStep.num_layers, 0.5, args.dataset == 'Products').to(args.device) 189 | elif args.midStep.model=='linear': 190 | self.model = MLPLinear(embs.size(-1)+args.num_feats, args.num_classes).to(args.device) 191 | elif args.midStep.model=='plain': 192 | self.model = MLPLinear(embs.size(-1)+args.num_feats, args.num_classes).to(args.device) 193 | return 194 | 195 | def forward(self, x): 196 | return self.model(x) 197 | 198 | def train_forward(self, embs, y, optimizer, loss_op, split_masks): 199 | self.train_cnt += 1 200 | x = torch.cat(to_device([self.data.x, embs], self.args.device), dim=-1) 201 | 202 | y = self.data.y.to(self.args.device) 203 | train_mask = split_masks['train'] 204 | valid_mask = split_masks['valid'] 205 | test_mask = split_masks['test'] 206 | 207 | optimizer.zero_grad() 208 | out = self.model(x) 209 | if isinstance(loss_op, torch.nn.NLLLoss): 210 | out = F.log_softmax(out, dim=-1) 211 | 212 | loss = loss_op(out[train_mask], y[train_mask]) 213 | loss.backward() 214 | optimizer.step() 215 | valid_acc = cal_acc_logits(out[valid_mask], y[valid_mask]) 216 | 217 | print('step2 test_acc = ',cal_acc_logits(out[test_mask], y[test_mask])) 218 | 219 | if valid_acc > self.best_valid: 220 | self.best_valid = valid_acc 221 | self.best_out = out.exp() 222 | print('!!! best val', self.train_cnt, f'={self.best_valid*100:.2}') 223 | loss = float(loss.item()) 224 | return self.best_out, loss 225 | 226 | def cal_acc_logits(output, labels): 227 | # work with model-output-logits, not model-output-indices 228 | assert len(output.shape)==2 and output.shape[1]>1 229 | labels = labels.reshape(-1).to('cpu') 230 | indices = torch.max(output, dim=1)[1].to('cpu') 231 | correct = float(torch.sum(indices == labels)/len(labels)) 232 | return correct 233 | 234 | def cal_acc_indices(output, labels): 235 | assert (len(output.shape)==2 and output.shape[1]==1) or len(output.shape)==1 236 | labels = labels.reshape(-1).to('cpu') 237 | output = output.reshape(-1).to('cpu') 238 | correct = float(torch.sum(output == labels)/len(labels)) 239 | return correct 240 | 241 | def to_device(list1d, device): 242 | newl = [] 243 | for x in list1d: 244 | if type(x) is dict: 245 | for k,v in x.items(): 246 | x[k] = v.to(device) 247 | else: 248 | x = x.to(device) 249 | newl.append(x) 250 | return newl 251 | -------------------------------------------------------------------------------- /Label_propagation_model/diffusion_feature.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import h5py 5 | import numpy as np 6 | import os 7 | import torch 8 | import torch.nn.functional as F 9 | 10 | from copy import deepcopy 11 | from scipy import sparse 12 | from torch_geometric.data import Data 13 | from torch_geometric.utils import to_undirected, dropout_adj 14 | from torch_scatter import scatter 15 | from torch_sparse import SparseTensor 16 | from tqdm import tqdm 17 | 18 | np.random.seed(0) 19 | 20 | class MLP(torch.nn.Module): 21 | def __init__(self, in_channels, hidden_channels, out_channels, num_layers, 22 | dropout, relu_first = True): 23 | super(MLP, self).__init__() 24 | self.lins = torch.nn.ModuleList() 25 | self.lins.append(torch.nn.Linear(in_channels, hidden_channels)) 26 | self.bns = torch.nn.ModuleList() 27 | self.bns.append(torch.nn.BatchNorm1d(hidden_channels)) 28 | for _ in range(num_layers - 2): 29 | self.lins.append(torch.nn.Linear(hidden_channels, hidden_channels)) 30 | self.bns.append(torch.nn.BatchNorm1d(hidden_channels)) 31 | self.lins.append(torch.nn.Linear(hidden_channels, out_channels)) 32 | self.dropout = dropout 33 | self.relu_first = relu_first 34 | 35 | def reset_parameters(self): 36 | for lin in self.lins: 37 | lin.reset_parameters() 38 | for bn in self.bns: 39 | bn.reset_parameters() 40 | 41 | def forward(self, x): 42 | for i, lin in enumerate(self.lins[:-1]): 43 | x = lin(x) 44 | if self.relu_first: 45 | x = F.relu(x, inplace=True) 46 | x = self.bns[i](x) 47 | if not self.relu_first: 48 | x = F.relu(x, inplace=True) 49 | x = F.dropout(x, p=self.dropout, training=self.training) 50 | x = self.lins[-1](x) 51 | return F.log_softmax(x, dim=-1) 52 | 53 | class MLPLinear(torch.nn.Module): 54 | def __init__(self, in_channels, out_channels): 55 | super(MLPLinear, self).__init__() 56 | self.lin = torch.nn.Linear(in_channels, out_channels) 57 | 58 | def reset_parameters(self): 59 | self.lin.reset_parameters() 60 | 61 | def forward(self, x): 62 | return F.log_softmax(self.lin(x), dim=-1) 63 | 64 | def sgc(x, adj, num_propagations): 65 | for _ in tqdm(range(num_propagations)): 66 | x = adj @ x 67 | return torch.from_numpy(x).to(torch.float) 68 | 69 | def lp(adj, train_idx, labels, num_propagations, p, alpha, preprocess): 70 | if p is None: 71 | p = 0.6 72 | if alpha is None: 73 | alpha = 0.4 74 | c = labels.max() + 1 75 | idx = train_idx 76 | y = np.zeros((labels.shape[0], c)) 77 | y[idx] = F.one_hot(labels[idx],c).numpy().squeeze(1) 78 | result = deepcopy(y) 79 | for i in tqdm(range(num_propagations)): 80 | result = y + alpha * adj @ (result**p) 81 | result = np.clip(result,0,1) 82 | return torch.from_numpy(result).to(torch.float) 83 | 84 | def diffusion(x, adj, num_propagations, p, alpha): 85 | if p is None: 86 | p = 1. 87 | if alpha is None: 88 | alpha = 0.5 89 | inital_features = deepcopy(x) 90 | x = x **p 91 | for i in tqdm(range(num_propagations)): 92 | x = x - alpha * (sparse.eye(adj.shape[0]) - adj) @ x 93 | x = x **p 94 | return torch.from_numpy(x).to(torch.float) 95 | 96 | def community(data, post_fix): 97 | print('Setting up community detection feature') 98 | np_edge_index = np.array(data.edge_index) 99 | 100 | G = nx.Graph() 101 | G.add_edges_from(np_edge_index.T) 102 | 103 | partition = community_louvain.best_partition(G) 104 | np_partition = np.zeros(data.num_nodes) 105 | for k, v in partition.items(): 106 | np_partition[k] = v 107 | 108 | np_partition = np_partition.astype(np.int) 109 | n_values = int(np.max(np_partition) + 1) 110 | one_hot = np.eye(n_values)[np_partition] 111 | result = torch.from_numpy(one_hot).float() 112 | torch.save( result, f'LP/embeddings/community{post_fix}.pt') 113 | return result 114 | 115 | def spectral(data, post_fix): 116 | from julia.api import Julia 117 | jl = Julia(compiled_modules=False) 118 | from julia import Main 119 | Main.include("LP/norm_spec.jl") 120 | print('Setting up spectral embedding') 121 | data.edge_index = to_undirected(data.edge_index) 122 | np_edge_index = np.array(data.edge_index.T) 123 | 124 | N = data.num_nodes 125 | row, col = data.edge_index 126 | adj = SparseTensor(row=row, col=col, sparse_sizes=(N, N)) 127 | adj = adj.to_scipy(layout='csr') 128 | result = torch.tensor(Main.main(adj, 128)).float() 129 | torch.save(result, f'LP/embeddings/spectral{post_fix}.pt') 130 | return result 131 | 132 | def preprocess(data, preprocess = "diffusion", num_propagations = 10, p = None, alpha = None, use_cache = True, post_fix = ""): 133 | if use_cache: 134 | try: 135 | x = torch.load(f'LP/embeddings/{preprocess}{post_fix}.pt') 136 | print('Using cache') 137 | return x 138 | except: 139 | print(f'LP/embeddings/{preprocess}{post_fix}.pt not found or not enough iterations! Regenerating it now') 140 | 141 | if preprocess == "community": 142 | return community(data, post_fix) 143 | 144 | if preprocess == "spectral": 145 | return spectral(data, post_fix) 146 | 147 | print('Computing adj...') 148 | N = data.num_nodes 149 | data.edge_index = to_undirected(data.edge_index, data.num_nodes) 150 | 151 | row, col = data.edge_index 152 | adj = SparseTensor(row=row, col=col, sparse_sizes=(N, N)) 153 | adj = adj.set_diag() 154 | deg = adj.sum(dim=1).to(torch.float) 155 | deg_inv_sqrt = deg.pow(-0.5) 156 | deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0 157 | adj = deg_inv_sqrt.view(-1, 1) * adj * deg_inv_sqrt.view(1, -1) 158 | 159 | adj = adj.to_scipy(layout='csr') 160 | sgc_dict = {} 161 | print(f'Start {preprocess} processing') 162 | 163 | if preprocess == "sgc": 164 | result = sgc(data.x.numpy(), adj, num_propagations) 165 | if preprocess == "diffusion": 166 | result = diffusion(data.x.numpy(), adj, num_propagations, p = p, alpha = alpha) 167 | os.makedirs('LP/embeddings', exist_ok=1) 168 | torch.save(result, f'LP/embeddings/{preprocess}{post_fix}.pt') 169 | return result 170 | -------------------------------------------------------------------------------- /Label_propagation_model/norm_spec.jl: -------------------------------------------------------------------------------- 1 | using LinearAlgebra 2 | using LinearMaps 3 | using MAT 4 | using SparseArrays 5 | using Arpack 6 | 7 | using PyCall, SparseArrays 8 | 9 | function scipyCSC_to_julia(A) 10 | m, n = A.shape 11 | colPtr = Int[i+1 for i in PyArray(A."indptr")] 12 | rowVal = Int[i+1 for i in PyArray(A."indices")] 13 | nzVal = Vector{Float64}(PyArray(A."data")) 14 | B = SparseMatrixCSC{Float64,Int}(m, n, colPtr, rowVal, nzVal) 15 | return PyCall.pyjlwrap_new(B) 16 | end 17 | 18 | function read_arxiv(file::String) 19 | I = Int64[] 20 | J = Int64[] 21 | open(file) do f 22 | for line in eachline(f) 23 | if line[1] == '#'; continue; end 24 | data = split(line, ",") 25 | push!(I, parse(Int64, data[1])) 26 | push!(J, parse(Int64, data[2])) 27 | end 28 | end 29 | I .+= 1 30 | J .+= 1 31 | n = max(maximum(I), maximum(J)) 32 | A = sparse(I, J, 1, n, n) 33 | A = max.(A, A') 34 | A = min.(A, 1) 35 | return A 36 | end 37 | 38 | 39 | function main(PyA, k::Int64) 40 | m, n = PyA.shape 41 | colPtr = Int[i+1 for i in PyArray(PyA."indptr")] 42 | rowVal = Int[i+1 for i in PyArray(PyA."indices")] 43 | nzVal = Vector{Float64}(PyArray(PyA."data")) 44 | A = SparseMatrixCSC{Float64,Int}(m, n, colPtr, rowVal, nzVal) 45 | d = vec(sum(A, dims=2)) 46 | τ = sum(d) / length(d) 47 | N = size(A)[1] 48 | 49 | # normalized regularized laplacian 50 | D = Diagonal(1.0 ./ sqrt.(d .+ τ)) 51 | Aop = LinearMap{Float64}(X -> A * X .+ (τ / N) * sum(X), N, N, isposdef=true, issymmetric=true) 52 | NRL = I + D * Aop * D 53 | 54 | (Λ, V) = eigs(NRL, nev=k, tol=1e-6, ncv=2*k+1, which=:LM) 55 | 56 | # axis rotation (not necessary, but could be helpful) 57 | piv = qr(V', Val(true)).jpvt[1:k] 58 | piv_svd = svd(V[piv,:]', full=false) 59 | SCDM_V = V * (piv_svd.U * piv_svd.Vt) 60 | 61 | # save 62 | 63 | return SCDM_V 64 | end 65 | 66 | #A = read_arxiv(ARGS[1]) 67 | #embed = main(A, 128) 68 | #matwrite("$(ARGS[2])_spectral_embedding.mat", Dict("V" => embed), compress=true) 69 | 70 | -------------------------------------------------------------------------------- /Label_propagation_model/outcome_correlation.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import os 5 | import torch 6 | import torch.nn.functional as F 7 | from torch_geometric.utils import to_undirected 8 | from torch_sparse import SparseTensor 9 | from tqdm import tqdm 10 | 11 | def general_outcome_correlation_YAG(Y, adj, G, alpha, num_propagations, post_step=lambda x:torch.clamp(x,1e-9,1-1e-9), device='cuda', display=False, use_sparse_mult = False, **trash): 12 | """ 13 | General outcome correlation. alpha_term=True for outcome correlation, 14 | alpha_term=False for residual correlation 15 | """ 16 | # what happened: 17 | # z = whatever (can be: <0-vector snapped to label - model_out> (LP-step1), (LP-step2), or <0-vector snapped labels> (LP-only) ) 18 | # math: out = a * adj @ out + (1-a) * y (out is initialized as out = y) 19 | orig_device = Y.device 20 | adj = adj.to(device) 21 | G = G.to(device) 22 | Y = Y.to(device) 23 | result = Y.clone() 24 | N = Y.shape[0] 25 | if use_sparse_mult: 26 | for _ in tqdm(range(num_propagations), disable = not display): 27 | result = alpha * torch.sparse.mm(adj, result) 28 | result += (1-alpha)*G 29 | result = result.coalesce() 30 | v = post_step(result.values()) 31 | result = torch.sparse_coo_tensor(result.indices(), v, [N,N]) 32 | else: 33 | for _ in tqdm(range(num_propagations), disable = not display): 34 | result = alpha * (adj @ result) 35 | result += (1-alpha)*G 36 | result = post_step(result) 37 | return (Y*0.998 + result*2e-3).to(orig_device) 38 | 39 | def process_adj(data): 40 | N = data.num_nodes 41 | data.edge_index = to_undirected(data.edge_index, data.num_nodes) 42 | 43 | row, col = data.edge_index 44 | 45 | adj = SparseTensor(row=row, col=col, sparse_sizes=(N, N)) 46 | deg = adj.sum(dim=1).to(torch.float) 47 | deg_inv_sqrt = deg.pow(-0.5) 48 | deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0 49 | return adj, deg_inv_sqrt 50 | 51 | def gen_normalized_adjs(adj, D_isqrt): 52 | DAD = D_isqrt.view(-1,1)*adj*D_isqrt.view(1,-1) 53 | DA = D_isqrt.view(-1,1) * D_isqrt.view(-1,1)*adj 54 | AD = adj*D_isqrt.view(1,-1) * D_isqrt.view(1,-1) 55 | return DAD, DA, AD 56 | 57 | def gen_normalized_adj(adj, pw): # pw = 0 is D^-1A, pw=1 is AD^-1 58 | deg = adj.sum(dim=1).to(torch.float) 59 | front = deg.pow(-(1-pw)) 60 | front[front == float('inf')] = 0 61 | back = deg.pow(-(pw)) 62 | back[back == float('inf')] = 0 63 | return (front.view(-1,1)*adj*back.view(1,-1)) 64 | 65 | def model_load(file, device='cpu'): 66 | result = torch.load(file, map_location='cpu') 67 | run = get_run_from_file(file) 68 | try: 69 | split = torch.load(f'{file}.split', map_location='cpu') 70 | except: 71 | split = None 72 | 73 | mx_diff = (result.sum(dim=-1) - 1).abs().max() 74 | if mx_diff > 1e-1: 75 | print(f'Max difference: {mx_diff}') 76 | print("model output doesn't seem to sum to 1. Did you remember to exp() if your model outputs log_softmax()?") 77 | raise Exception 78 | if split is not None: 79 | return (result, split), run 80 | else: 81 | return result, run 82 | 83 | def get_labels_from_name(labels, split_idx, **trash): 84 | if isinstance(labels, list): 85 | labels = list(labels) 86 | if len(labels) == 0: 87 | return torch.tensor([]) 88 | for idx, i in enumerate(list(labels)): 89 | labels[idx] = split_idx[i] 90 | residual_idx = torch.cat(labels) 91 | else: 92 | residual_idx = split_idx[labels] 93 | return residual_idx 94 | 95 | def pre_residual_correlation(labels, model_out, label_idx, **trash): 96 | """Generates the initial labels used for residual correlation""" 97 | # what happened: 98 | # z = model_out; y = labels 99 | # it take inputs of BOTH model_out and labels 100 | # then, assign (y-z) at indexed locs, and 0 elsewhere. 101 | labels = labels.cpu() 102 | labels[labels.isnan()] = 0 103 | labels = labels.long() 104 | model_out = model_out.cpu() 105 | label_idx = label_idx.cpu() 106 | c = labels.max() + 1 107 | n = labels.shape[0] 108 | y = torch.zeros((n, c)) 109 | y[label_idx] = F.one_hot(labels[label_idx],c).float().squeeze(1) - model_out[label_idx] 110 | return y 111 | 112 | def pre_outcome_correlation(labels, model_out, label_idx, **trash): 113 | """Generates the initial labels used for outcome correlation""" 114 | # what happened: 115 | # z = model_out; y = labels 116 | # it snaps z: z[idx] = y[idx] 117 | labels = labels.cpu() 118 | model_out = model_out.cpu() 119 | label_idx = label_idx.cpu() 120 | c = labels.max() + 1 121 | n = labels.shape[0] 122 | y = model_out.clone() 123 | if len(label_idx) > 0: 124 | y[label_idx] = F.one_hot(labels[label_idx],c).float().squeeze(1) 125 | 126 | return y 127 | 128 | def general_outcome_correlation(adj, y, alpha, num_propagations, post_step, alpha_term, device='cuda', display=True, **trash): 129 | """general outcome correlation. alpha_term = True for outcome correlation, alpha_term = False for residual correlation""" 130 | # what happened: 131 | # z = whatever (can be: <0-vector snapped to label - model_out> (LP-step1), (LP-step2), or <0-vector snapped labels> (LP-only) ) 132 | # it loops: res = a * A@res + (1-a) * z 133 | 134 | adj = adj.to(device) 135 | orig_device = y.device 136 | y = y.to(device) 137 | result = y.clone() 138 | for _ in tqdm(range(num_propagations), disable = not display): 139 | result = alpha * (adj @ result) 140 | if alpha_term: 141 | result += (1-alpha)*y 142 | else: 143 | result += y 144 | result = post_step(result) 145 | return result.to(orig_device) 146 | 147 | def label_propagation(data, split_idx, A, alpha, num_propagations, idxs, **trash): 148 | labels = data.y.data 149 | c = labels.max() + 1 150 | n = labels.shape[0] 151 | y = torch.zeros((n, c),device=data.y.device) 152 | label_idx = get_labels_from_name(idxs, split_idx) 153 | y[label_idx] = F.one_hot(labels[label_idx],c).float().squeeze(1) 154 | 155 | res = general_outcome_correlation(A, y, alpha, num_propagations, post_step=lambda x:torch.clamp(x,0,1), alpha_term=True) 156 | return res 157 | 158 | def double_correlation_autoscale(data, model_out, split_idx, A1, alpha1, num_propagations1, A2, alpha2, num_propagations2, scale=1.0, train_only=False, device='cuda', display=True, **trash): 159 | # train_idx, valid_idx, test_idx = split_idx 160 | if train_only: 161 | label_idx = torch.cat([split_idx['train']]) 162 | residual_idx = split_idx['train'] 163 | else: 164 | label_idx = torch.cat([split_idx['train'], split_idx['valid']]) 165 | residual_idx = label_idx 166 | 167 | y = pre_residual_correlation(labels=data.y.data, model_out=model_out, label_idx=residual_idx) 168 | resid = general_outcome_correlation(adj=A1, y=y, alpha=alpha1, num_propagations=num_propagations1, post_step=lambda x: torch.clamp(x, -1.0, 1.0), alpha_term=True, display=display, device=device) 169 | 170 | orig_diff = y[residual_idx].abs().sum()/residual_idx.shape[0] 171 | resid_scale = (orig_diff/resid.abs().sum(dim=1, keepdim=True)) 172 | resid_scale[resid_scale.isinf()] = 1.0 173 | cur_idxs = (resid_scale > 1000) 174 | resid_scale[cur_idxs] = 1.0 175 | res_result = model_out + resid_scale*resid 176 | res_result[res_result.isnan()] = model_out[res_result.isnan()] 177 | y = pre_outcome_correlation(labels=data.y.data, model_out=res_result, label_idx = label_idx) 178 | result = general_outcome_correlation(adj=A2, y=y, alpha=alpha2, num_propagations=num_propagations2, post_step=lambda x: torch.clamp(x, 0,1), alpha_term=True, display=display, device=device) 179 | 180 | return res_result, result 181 | 182 | def double_correlation_fixed(data, model_out, split_idx, A1, alpha1, num_propagations1, A2, alpha2, num_propagations2, scale=1.0, train_only=False, device='cuda', display=True, **trash): 183 | train_idx, valid_idx, test_idx = split_idx 184 | if train_only: 185 | label_idx = torch.cat([split_idx['train']]) 186 | residual_idx = split_idx['train'] 187 | 188 | else: 189 | label_idx = torch.cat([split_idx['train'], split_idx['valid']]) 190 | residual_idx = label_idx 191 | 192 | y = pre_residual_correlation(labels=data.y.data, model_out=model_out, label_idx=residual_idx) 193 | 194 | fix_y = y[residual_idx].to(device) 195 | def fix_inputs(x): 196 | x[residual_idx] = fix_y 197 | return x 198 | 199 | resid = general_outcome_correlation(adj=A1, y=y, alpha=alpha1, num_propagations=num_propagations1, post_step=lambda x: fix_inputs(x), alpha_term=True, display=display, device=device) 200 | res_result = model_out + scale*resid 201 | 202 | y = pre_outcome_correlation(labels=data.y.data, model_out=res_result, label_idx = label_idx) 203 | 204 | result = general_outcome_correlation(adj=A2, y=y, alpha=alpha2, num_propagations=num_propagations2, post_step=lambda x: x.clamp(0, 1), alpha_term=True, display=display, device=device) 205 | 206 | return res_result, result 207 | 208 | def only_outcome_correlation(data, model_out, split_idx, A, alpha, num_propagations, labels, device='cuda', display=True, **trash): 209 | res_result = model_out.clone() 210 | label_idxs = get_labels_from_name(labels, split_idx) 211 | y = pre_outcome_correlation(labels=data.y.data, model_out=model_out, label_idx=label_idxs) 212 | result = general_outcome_correlation(adj=A, y=y, alpha=alpha, num_propagations=num_propagations, post_step=lambda x: torch.clamp(x, 0, 1), alpha_term=True, display=display, device=device) 213 | return res_result, result 214 | 215 | def get_run_from_file(out): 216 | return int(os.path.splitext(os.path.basename(out))[0]) 217 | 218 | def get_orig_acc(data, eval_test, model_outs, split_idx): 219 | logger_orig = Logger(len(model_outs)) 220 | for out in model_outs: 221 | model_out, run = model_load(out) 222 | if isinstance(model_out, tuple): 223 | model_out, split_idx = model_out 224 | test_acc = eval_test(model_out, split_idx['test']) 225 | logger_orig.add_result(run, (eval_test(model_out, split_idx['train']), eval_test(model_out, split_idx['valid']), test_acc)) 226 | print('Original accuracy') 227 | logger_orig.print_statistics() 228 | 229 | def prepare_folder(name, model): 230 | model_dir = f'models/{name}' 231 | os.makedirs(model_dir, exist_ok=1) 232 | with open(f'{model_dir}/metadata', 'w') as f: 233 | f.write(f'# of params: {sum(p.numel() for p in model.parameters())}\n') 234 | return model_dir 235 | -------------------------------------------------------------------------------- /Link_prediction_baseline/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/gnn-tail-generalization/e6f85eea8721059819d1e74c3d274f3beb70f70c/Link_prediction_baseline/.DS_Store -------------------------------------------------------------------------------- /Link_prediction_baseline/compute_bound_filepath.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import dgl 3 | import networkx as nx 4 | import numpy as np 5 | import pickle 6 | import scipy.sparse as sp 7 | import time 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | 12 | from collections import defaultdict 13 | from dgl import DGLGraph 14 | from dgl.data import register_data_args, load_data 15 | from random import shuffle 16 | from scipy.linalg import eigh, norm 17 | from torch.autograd import Variable 18 | from tqdm import tqdm 19 | 20 | def evaluate(model, features, labels, mask): 21 | model.eval() 22 | with torch.no_grad(): 23 | logits = model(features) 24 | logits = logits[mask] 25 | labels = labels[mask] 26 | _, indices = torch.max(logits, dim=1) 27 | correct = torch.sum(indices == labels) 28 | return correct.item() * 1.0 / len(labels) 29 | 30 | def degree_bucketing(graph, args, degree_emb=None, max_degree = 10): 31 | max_degree = args.n_hidden 32 | features = torch.ones([graph.number_of_nodes(), max_degree]) 33 | return features 34 | 35 | def createTraining(labels, valid_mask = None, train_ratio=0.8): 36 | train_mask = torch.zeros(labels.shape, dtype=torch.bool) 37 | test_mask = torch.ones(labels.shape, dtype=torch.bool) 38 | 39 | num_train = int(labels.shape[0] * train_ratio) 40 | all_node_index = list(range(labels.shape[0])) 41 | np.random.shuffle(all_node_index) 42 | train_mask[all_node_index[:num_train]] = 1 43 | test_mask[all_node_index[:num_train]] = 0 44 | if valid_mask is not None: 45 | train_mask *= valid_mask 46 | test_mask *= valid_mask 47 | return train_mask, test_mask 48 | 49 | def read_struct_net(file_path): 50 | g = nx.Graph() 51 | with open(file_path) as IN: 52 | for line in IN: 53 | tmp = line.strip().split() 54 | # print(tmp[0], tmp[1]) 55 | g.add_edge(int(tmp[0]), int(tmp[1])) 56 | return g 57 | 58 | def constructDGL(graph): 59 | node_mapping = defaultdict(int) 60 | 61 | for node in sorted(list(graph.nodes())): 62 | node_mapping[node] = len(node_mapping) 63 | new_g = DGLGraph() 64 | new_g.add_nodes(len(node_mapping)) 65 | 66 | for edge in graph.edges(): 67 | if not new_g.has_edge_between(node_mapping[edge[0]], node_mapping[edge[1]]): 68 | new_g.add_edge(node_mapping[edge[0]], node_mapping[edge[1]]) 69 | if not new_g.has_edge_between(node_mapping[edge[1]], node_mapping[edge[0]]): 70 | new_g.add_edge(node_mapping[edge[1]], node_mapping[edge[0]]) 71 | 72 | return new_g 73 | 74 | def output_adj(graph): 75 | A = np.zeros([graph.number_of_nodes(), graph.number_of_nodes()]) 76 | a,b = graph.all_edges() 77 | for id_a, id_b in zip(a.numpy().tolist(), b.numpy().tolist()): 78 | A[id_a, id_b] = 1 79 | return A 80 | 81 | def compute_term(l, r): 82 | n = l.shape[0] 83 | eval_ = eigh((l-r).T @ (l-r), eigvals_only=True) 84 | return np.sqrt(max(eval_)) 85 | 86 | def main(args): 87 | 88 | def constructSubG(file_path): 89 | g = read_struct_net(file_path) 90 | if True: 91 | g.remove_edges_from(nx.selfloop_edges(g)) 92 | g = constructDGL(g) 93 | g.readonly() 94 | 95 | node_sampler = dgl.contrib.sampling.NeighborSampler(g, 1, 10, # 0, 96 | neighbor_type='in', num_workers=1, 97 | add_self_loop=False, 98 | num_hops=args.n_layers + 1, shuffle=True) 99 | return g, node_sampler 100 | 101 | def constructHopdic(ego_g): 102 | hop_dic = dict() 103 | 104 | node_set = set([]) 105 | for layer_id in range(args.n_layers+2)[::-1]: 106 | hop_dic[layer_id] = list(set(ego_g.layer_parent_nid(layer_id).numpy()) 107 | - node_set) 108 | node_set |= set(hop_dic[layer_id]) 109 | 110 | return hop_dic 111 | 112 | def constructIdxCoding(hop_dic): 113 | idx_coding = {"00":[]} 114 | 115 | for layer_id in range(args.n_layers+2)[::-1]: 116 | for node_id in hop_dic[layer_id]: 117 | if node_id != "00": 118 | idx_coding[node_id] = len(idx_coding) + len(idx_coding["00"]) - 1 119 | else: 120 | idx_coding["00"] += [len(idx_coding) + len(idx_coding["00"]) - 1] 121 | return idx_coding 122 | 123 | def constructL(g, ego_g, idx_coding, neighbor_type='out'): 124 | dim = len(idx_coding) + len(idx_coding["00"]) - 1 125 | A = np.zeros([dim, dim]) 126 | 127 | for i in range(ego_g.num_blocks): 128 | u,v = g.find_edges(ego_g.block_parent_eid(i)) 129 | for left_id, right_id in zip(u.numpy().tolist(), v.numpy().tolist()): 130 | A[idx_coding[left_id], idx_coding[right_id]] = 1 131 | 132 | # lower part is the out-degree direction 133 | if neighbor_type=='in': 134 | # upper 135 | A = A.T 136 | 137 | # select the non-zero submatrix 138 | selector = list(set(np.arange(dim)) - set(idx_coding['00'])) 139 | A_full = A[np.ix_(selector, selector)] 140 | 141 | # find L 142 | D = np.diag(A_full.sum(1)) 143 | L = D - A_full 144 | D_ = np.diag(1.0 / np.sqrt(A_full.sum(1))) 145 | D_ = np.nan_to_num(D_, posinf=0, neginf=0) #set inf to 0 146 | normailized_L = np.matmul(np.matmul(D_, L), D_) 147 | 148 | # reassign the calculated Laplacian 149 | A[np.ix_(selector, selector)] = normailized_L 150 | 151 | if np.isnan(A.sum()): 152 | embed() 153 | 154 | return A 155 | 156 | def degPermute(ego_g, hop_dic, layer_id): 157 | if layer_id == 0: 158 | return hop_dic[layer_id] 159 | else: 160 | s, arg_degree_sort = torch.sort(-ego_g.layer_in_degree(layer_id)) 161 | return torch.tensor(hop_dic[layer_id])[arg_degree_sort].tolist() 162 | 163 | def pad_nbhd(lg, rg, lego_g, rego_g, perm_type='shuffle', neighbor_type='out'): 164 | # returns two padded Laplacian 165 | lhop_dic = constructHopdic(lego_g) 166 | rhop_dic = constructHopdic(rego_g) 167 | 168 | # make even the size of nhbd 169 | for layer_id in range(args.n_layers+2)[::-1]: 170 | diff = len(lhop_dic[layer_id]) - len(rhop_dic[layer_id]) 171 | 172 | if perm_type == 'shuffle': # including the padded terms 173 | if diff>0: 174 | rhop_dic[layer_id] += ["00"] * abs(diff) 175 | elif diff<0: 176 | lhop_dic[layer_id] += ["00"] * abs(diff) 177 | 178 | shuffle(lhop_dic[layer_id]) 179 | shuffle(rhop_dic[layer_id]) 180 | elif perm_type == 'degree': 181 | lhop_dic[layer_id] = degPermute(lego_g, lhop_dic, layer_id) 182 | rhop_dic[layer_id] = degPermute(rego_g, rhop_dic, layer_id) 183 | 184 | if diff>0: 185 | rhop_dic[layer_id] += ["00"] * abs(diff) 186 | elif diff<0: 187 | lhop_dic[layer_id] += ["00"] * abs(diff) 188 | else: 189 | if diff>0: 190 | rhop_dic[layer_id] += ["00"] * abs(diff) 191 | elif diff<0: 192 | lhop_dic[layer_id] += ["00"] * abs(diff) 193 | 194 | # construct coding dict 195 | lidx_coding = constructIdxCoding(lhop_dic) 196 | ridx_coding = constructIdxCoding(rhop_dic) 197 | 198 | lL = constructL(lg, lego_g, lidx_coding, neighbor_type=neighbor_type) 199 | rL = constructL(rg, rego_g, ridx_coding, neighbor_type=neighbor_type) 200 | 201 | return lL, rL 202 | 203 | print(args.file_path, args.label_path) 204 | Lg, Lego_list = constructSubG(args.file_path) 205 | Rg, Rego_list = constructSubG(args.label_path) 206 | print(Lg) 207 | # embed() 208 | bound = 0 209 | cntl = 0 210 | cntr = 0 211 | for lego_g in tqdm(Lego_list): 212 | cntl += 1 213 | cntr = 0 214 | for rego_g in Rego_list: 215 | cntr += 1 216 | # print(cntr) 217 | lL, rL = pad_nbhd(Lg, Rg, lego_g, rego_g, 218 | perm_type='shuffle', 219 | neighbor_type='in') 220 | bound += compute_term(lL, rL) 221 | 222 | print(bound / (cntl * cntr)) 223 | 224 | if __name__ == '__main__': 225 | parser = argparse.ArgumentParser(description='DGI') 226 | register_data_args(parser) 227 | parser.add_argument("--dropout", type=float, default=0.0, 228 | help="dropout probability") 229 | parser.add_argument("--gpu", type=int, default=-1, 230 | help="gpu") 231 | parser.add_argument("--dgi-lr", type=float, default=1e-2, 232 | help="dgi learning rate") 233 | parser.add_argument("--classifier-lr", type=float, default=1e-2, 234 | help="classifier learning rate") 235 | parser.add_argument("--n-dgi-epochs", type=int, default=300, 236 | help="number of training epochs") 237 | parser.add_argument("--n-classifier-epochs", type=int, default=100, 238 | help="number of training epochs") 239 | parser.add_argument("--n-hidden", type=int, default=32, 240 | help="number of hidden gcn units") 241 | parser.add_argument("--n-layers", type=int, default=1, 242 | help="number of hidden gcn layers") 243 | parser.add_argument("--weight-decay", type=float, default=0., 244 | help="Weight for L2 loss") 245 | parser.add_argument("--patience", type=int, default=20, 246 | help="early stop patience condition") 247 | parser.add_argument("--model", action='store_true', 248 | help="graph self-loop (default=False)") 249 | parser.add_argument("--self-loop", action='store_true', 250 | help="graph self-loop (default=False)") 251 | parser.add_argument("--model-type", type=int, default=2, 252 | help="graph self-loop (default=False)") 253 | parser.add_argument("--graph-type", type=str, default="DD", 254 | help="graph self-loop (default=False)") 255 | parser.add_argument("--data-id", type=str, 256 | help="[usa, europe, brazil]") 257 | parser.add_argument("--data-src", type=str, default='', 258 | help="[usa, europe, brazil]") 259 | parser.add_argument("--file-path", type=str, 260 | help="graph path") 261 | parser.add_argument("--label-path", type=str, 262 | help="label path") 263 | parser.add_argument("--model-id", type=int, default=0, 264 | help="[0, 1, 2, 3]") 265 | 266 | parser.set_defaults(self_loop=False) 267 | args = parser.parse_args() 268 | 269 | main(args) 270 | -------------------------------------------------------------------------------- /Link_prediction_baseline/compute_bound_pickle.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import dgl 3 | import networkx as nx 4 | import numpy as np 5 | import pickle 6 | import scipy.sparse as sp 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | from collections import defaultdict 12 | from dgl import DGLGraph 13 | from dgl.data import register_data_args, load_data 14 | from random import shuffle 15 | from scipy.linalg import eigh, norm 16 | from torch.autograd import Variable 17 | from tqdm import tqdm 18 | 19 | def evaluate(model, features, labels, mask): 20 | model.eval() 21 | with torch.no_grad(): 22 | logits = model(features) 23 | logits = logits[mask] 24 | labels = labels[mask] 25 | _, indices = torch.max(logits, dim=1) 26 | correct = torch.sum(indices == labels) 27 | return correct.item() * 1.0 / len(labels) 28 | 29 | def degree_bucketing(graph, args, degree_emb=None, max_degree = 10): 30 | #G = nx.DiGraph(graph) 31 | #embed() 32 | max_degree = args.n_hidden 33 | features = torch.ones([graph.number_of_nodes(), max_degree]) 34 | return features 35 | 36 | def createTraining(labels, valid_mask = None, train_ratio=0.8): 37 | train_mask = torch.zeros(labels.shape, dtype=torch.bool) 38 | test_mask = torch.ones(labels.shape, dtype=torch.bool) 39 | 40 | num_train = int(labels.shape[0] * train_ratio) 41 | all_node_index = list(range(labels.shape[0])) 42 | np.random.shuffle(all_node_index) 43 | train_mask[all_node_index[:num_train]] = 1 44 | test_mask[all_node_index[:num_train]] = 0 45 | if valid_mask is not None: 46 | train_mask *= valid_mask 47 | test_mask *= valid_mask 48 | return train_mask, test_mask 49 | 50 | def read_struct_net(file_path): 51 | g = nx.Graph() 52 | 53 | with open(file_path) as IN: 54 | for line in IN: 55 | tmp = line.strip().split() 56 | g.add_edge(int(tmp[0]), int(tmp[1])) 57 | return g 58 | 59 | def constructDGL(graph): 60 | node_mapping = defaultdict(int) 61 | 62 | for node in sorted(list(graph.nodes())): 63 | node_mapping[node] = len(node_mapping) 64 | 65 | new_g = DGLGraph() 66 | new_g.add_nodes(len(node_mapping)) 67 | 68 | for edge in graph.edges(): 69 | if not new_g.has_edge_between(node_mapping[edge[0]], node_mapping[edge[1]]): 70 | new_g.add_edge(node_mapping[edge[0]], node_mapping[edge[1]]) 71 | if not new_g.has_edge_between(node_mapping[edge[1]], node_mapping[edge[0]]): 72 | new_g.add_edge(node_mapping[edge[1]], node_mapping[edge[0]]) 73 | 74 | return new_g 75 | 76 | def output_adj(graph): 77 | A = np.zeros([graph.number_of_nodes(), graph.number_of_nodes()]) 78 | a,b = graph.all_edges() 79 | for id_a, id_b in zip(a.numpy().tolist(), b.numpy().tolist()): 80 | A[id_a, id_b] = 1 81 | return A 82 | 83 | # find the max eval 84 | def compute_term(l, r): 85 | n = l.shape[0] 86 | eval = eigh(l-r, eigvals_only=True) 87 | 88 | return max(max(eval), -min(eval)) 89 | 90 | def main(args): 91 | # load and preprocess dataset 92 | def constructSubG(g): 93 | g.readonly() 94 | node_sampler = dgl.contrib.sampling.NeighborSampler(g, 1, 10, # 0, 95 | neighbor_type='in', num_workers=1, 96 | add_self_loop=False, 97 | num_hops=args.n_layers + 1, shuffle=True) 98 | return node_sampler 99 | 100 | def constructHopdic(ego_g): 101 | hop_dic = dict() 102 | 103 | node_set = set([]) 104 | for layer_id in range(args.n_layers+2)[::-1]: 105 | hop_dic[layer_id] = list(set(ego_g.layer_parent_nid(layer_id).numpy()) 106 | - node_set) 107 | node_set |= set(hop_dic[layer_id]) 108 | 109 | return hop_dic 110 | 111 | def constructIdxCoding(hop_dic): 112 | idx_coding = {"00":[]} 113 | 114 | for layer_id in range(args.n_layers+2)[::-1]: 115 | for node_id in hop_dic[layer_id]: 116 | if node_id != "00": 117 | idx_coding[node_id] = len(idx_coding) + len(idx_coding["00"]) - 1 118 | 119 | else: 120 | idx_coding["00"] += [len(idx_coding) + len(idx_coding["00"]) - 1] 121 | return idx_coding 122 | 123 | def constructL(g, ego_g, idx_coding, neighbor_type='out'): 124 | dim = len(idx_coding) + len(idx_coding["00"]) - 1 125 | A = np.zeros([dim, dim]) 126 | 127 | for i in range(ego_g.num_blocks): 128 | u,v = g.find_edges(ego_g.block_parent_eid(i)) 129 | for left_id, right_id in zip(u.numpy().tolist(), v.numpy().tolist()): 130 | A[idx_coding[left_id], idx_coding[right_id]] = 1 131 | 132 | # lower part is the out-degree direction 133 | # A = np.tril(A, -1) 134 | if neighbor_type=='in': 135 | # upper 136 | A = A.T 137 | 138 | # select the non-zero submatrix 139 | selector = list(set(np.arange(dim)) - set(idx_coding['00'])) 140 | A_full = A[np.ix_(selector, selector)] 141 | 142 | # find L 143 | D = np.diag(A_full.sum(1)) 144 | L = D - A_full 145 | D_ = np.diag(1.0 / np.sqrt(A_full.sum(1))) 146 | D_ = np.nan_to_num(D_, posinf=0, neginf=0) #set inf to 0 147 | normailized_L = np.matmul(np.matmul(D_, L), D_) 148 | 149 | # reassign the calculated Laplacian 150 | A[np.ix_(selector, selector)] = normailized_L 151 | 152 | if np.isnan(A.sum()): 153 | embed() 154 | 155 | return A 156 | 157 | def degPermute(ego_g, hop_dic, layer_id): 158 | if layer_id == 0: 159 | return hop_dic[layer_id] 160 | else: 161 | s, arg_degree_sort = torch.sort(-ego_g.layer_in_degree(layer_id)) 162 | return torch.tensor(hop_dic[layer_id])[arg_degree_sort].tolist() 163 | 164 | def pad_nbhd(lg, rg, lego_g, rego_g, perm_type='shuffle', neighbor_type='out'): 165 | # returns two padded Laplacian 166 | lhop_dic = constructHopdic(lego_g) 167 | rhop_dic = constructHopdic(rego_g) 168 | 169 | # make even the size of nhbd 170 | for layer_id in range(args.n_layers+2)[::-1]: 171 | diff = len(lhop_dic[layer_id]) - len(rhop_dic[layer_id]) 172 | 173 | if perm_type == 'shuffle': # including the padded terms 174 | if diff>0: 175 | rhop_dic[layer_id] += ["00"] * abs(diff) 176 | elif diff<0: 177 | lhop_dic[layer_id] += ["00"] * abs(diff) 178 | 179 | shuffle(lhop_dic[layer_id]) 180 | shuffle(rhop_dic[layer_id]) 181 | elif perm_type == 'degree': 182 | lhop_dic[layer_id] = degPermute(lego_g, lhop_dic, layer_id) 183 | rhop_dic[layer_id] = degPermute(rego_g, rhop_dic, layer_id) 184 | 185 | if diff>0: 186 | rhop_dic[layer_id] += ["00"] * abs(diff) 187 | elif diff<0: 188 | lhop_dic[layer_id] += ["00"] * abs(diff) 189 | else: 190 | if diff>0: 191 | rhop_dic[layer_id] += ["00"] * abs(diff) 192 | elif diff<0: 193 | lhop_dic[layer_id] += ["00"] * abs(diff) 194 | 195 | # construct coding dict 196 | lidx_coding = constructIdxCoding(lhop_dic) 197 | ridx_coding = constructIdxCoding(rhop_dic) 198 | 199 | lL = constructL(lg, lego_g, lidx_coding, neighbor_type=neighbor_type) 200 | rL = constructL(rg, rego_g, ridx_coding, neighbor_type=neighbor_type) 201 | 202 | return lL, rL 203 | 204 | Lg = pickle.load(open(args.file_path, 'rb')) 205 | Rg = pickle.load(open(args.label_path, 'rb')) 206 | 207 | print(len(Lg['graphs'])) 208 | bound_ave = [] 209 | 210 | for i in range(1): 211 | Lgi = Lg['graphs'][39] 212 | bound_ave_i = [] 213 | for j in tqdm(range(1, len(Rg['graphs']))): 214 | # for j in range(1,2): 215 | Rgj = Rg['graphs'][j] 216 | Lego_list = constructSubG(Lgi) 217 | Rego_list = constructSubG(Rgj) 218 | 219 | # embed() 220 | bound = 0 221 | cntl = 0 222 | cntr = 0 223 | for lego_g in Lego_list: 224 | cntl += 1 225 | cntr = 0 226 | for rego_g in Rego_list: 227 | cntr += 1 228 | lL, rL = pad_nbhd(Lgi, Rgj, lego_g, rego_g, 229 | perm_type='degree', 230 | neighbor_type='in') 231 | bound += compute_term(lL, rL) 232 | 233 | bound_ave_i += [bound / (cntl * cntr)] 234 | 235 | print(sum(bound_ave_i)/len(bound_ave_i)) 236 | bound_ave += bound_ave_i 237 | 238 | print(sum(bound_ave)/len(bound_ave)) 239 | 240 | if __name__ == '__main__': 241 | parser = argparse.ArgumentParser(description='DGI') 242 | register_data_args(parser) 243 | parser.add_argument("--dropout", type=float, default=0.0, 244 | help="dropout probability") 245 | parser.add_argument("--gpu", type=int, default=-1, 246 | help="gpu") 247 | parser.add_argument("--dgi-lr", type=float, default=1e-2, 248 | help="dgi learning rate") 249 | parser.add_argument("--classifier-lr", type=float, default=1e-2, 250 | help="classifier learning rate") 251 | parser.add_argument("--n-dgi-epochs", type=int, default=300, 252 | help="number of training epochs") 253 | parser.add_argument("--n-classifier-epochs", type=int, default=100, 254 | help="number of training epochs") 255 | parser.add_argument("--n-hidden", type=int, default=32, 256 | help="number of hidden gcn units") 257 | parser.add_argument("--n-layers", type=int, default=1, 258 | help="number of hidden gcn layers") 259 | parser.add_argument("--weight-decay", type=float, default=0., 260 | help="Weight for L2 loss") 261 | parser.add_argument("--patience", type=int, default=20, 262 | help="early stop patience condition") 263 | parser.add_argument("--model", action='store_true', 264 | help="graph self-loop (default=False)") 265 | parser.add_argument("--self-loop", action='store_true', 266 | help="graph self-loop (default=False)") 267 | parser.add_argument("--model-type", type=int, default=2, 268 | help="graph self-loop (default=False)") 269 | parser.add_argument("--graph-type", type=str, default="DD", 270 | help="graph self-loop (default=False)") 271 | parser.add_argument("--data-id", type=str, 272 | help="[usa, europe, brazil]") 273 | parser.add_argument("--data-src", type=str, default='', 274 | help="[usa, europe, brazil]") 275 | parser.add_argument("--file-path", type=str, 276 | help="graph path") 277 | parser.add_argument("--label-path", type=str, 278 | help="label path") 279 | parser.add_argument("--model-id", type=int, default=0, 280 | help="[0, 1, 2, 3]") 281 | 282 | parser.set_defaults(self_loop=False) 283 | args = parser.parse_args() 284 | print(args) 285 | 286 | main(args) 287 | -------------------------------------------------------------------------------- /Link_prediction_baseline/heuristics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from tqdm import tqdm 4 | from torch_geometric.utils import negative_sampling, add_self_loops, train_test_split_edges 5 | from torch_geometric.data import Data, Dataset, InMemoryDataset, DataLoader 6 | 7 | import scipy.sparse as ssp 8 | from ogb.linkproppred import PygLinkPropPredDataset, Evaluator 9 | 10 | def eva_heuristics_v2_dec25(which_heuristic, data, edge_index): 11 | # which_heuristic support: 'CN', 'AA', 'PPR' 12 | # 'edge_index' is the thing that need to be evaluated, predict either true or false, by giving a number. 13 | 14 | if 'edge_attr' in data: 15 | edge_weight = data.edge_attr.view(-1).cpu() 16 | elif 'edge_weight' in data: 17 | edge_weight = data.edge_weight.view(-1).cpu() 18 | else: 19 | edge_weight = torch.ones(data.edge_index.shape[1], dtype=int).cpu() 20 | 21 | if 'A' not in data: 22 | tmp_edge_index = data.edge_index.cpu() 23 | print('check : ', tmp_edge_index.shape, edge_weight.shape) 24 | data.A = ssp.csr_matrix((edge_weight, (tmp_edge_index[0], tmp_edge_index[1])), shape=(data.num_nodes, data.num_nodes)) 25 | 26 | pred_scores, ei = eval(which_heuristic)(data.A, torch.tensor(tonp(edge_index))) 27 | pred_scores = tonp(pred_scores) 28 | 29 | return pred_scores 30 | 31 | def eva_heuristics(args, data, split_edge): 32 | num_nodes = data.num_nodes 33 | if 'edge_weight' in data: 34 | edge_weight = data.edge_weight.view(-1) 35 | else: 36 | edge_weight = torch.ones(data.edge_index.shape[1], dtype=int) 37 | 38 | A = ssp.csr_matrix((edge_weight, (data.edge_index[0], data.edge_index[1])), shape=(num_nodes, num_nodes)) 39 | 40 | pos_val_edge, neg_val_edge = get_pos_neg_edges('valid', split_edge, 41 | data.edge_index, 42 | data.num_nodes) 43 | pos_test_edge, neg_test_edge = get_pos_neg_edges('test', split_edge, 44 | data.edge_index, 45 | data.num_nodes) 46 | pos_val_pred, pos_val_edge = eval(args.use_heuristic)(A, pos_val_edge) 47 | neg_val_pred, neg_val_edge = eval(args.use_heuristic)(A, neg_val_edge) 48 | pos_test_pred, pos_test_edge = eval(args.use_heuristic)(A, pos_test_edge) 49 | neg_test_pred, neg_test_edge = eval(args.use_heuristic)(A, neg_test_edge) 50 | 51 | if args.eval_metric == 'hits': 52 | results = evaluate_hits(pos_val_pred, neg_val_pred, pos_test_pred, neg_test_pred) 53 | elif args.eval_metric == 'mrr': 54 | results = evaluate_mrr(pos_val_pred, neg_val_pred, pos_test_pred, neg_test_pred) 55 | elif args.eval_metric == 'auc': 56 | val_pred = torch.cat([pos_val_pred, neg_val_pred]) 57 | val_true = torch.cat([torch.ones(pos_val_pred.shape[0], dtype=int), 58 | torch.zeros(neg_val_pred.shape[0], dtype=int)]) 59 | test_pred = torch.cat([pos_test_pred, neg_test_pred]) 60 | test_true = torch.cat([torch.ones(pos_test_pred.shape[0], dtype=int), 61 | torch.zeros(neg_test_pred.shape[0], dtype=int)]) 62 | results = evaluate_auc(val_pred, val_true, test_pred, test_true) 63 | 64 | def get_pos_neg_edges(split, split_edge, edge_index, num_nodes, percent=100): 65 | if 'edge' in split_edge['train']: 66 | pos_edge = split_edge[split]['edge'].t() 67 | if split == 'train': 68 | new_edge_index, _ = add_self_loops(edge_index) 69 | neg_edge = negative_sampling( 70 | new_edge_index, num_nodes=num_nodes, 71 | num_neg_samples=pos_edge.shape[1]) 72 | else: 73 | neg_edge = split_edge[split]['edge_neg'].t() 74 | # subsample for pos_edge 75 | np.random.seed(123) 76 | num_pos = pos_edge.shape[1] 77 | perm = np.random.permutation(num_pos) 78 | perm = perm[:int(percent / 100 * num_pos)] 79 | pos_edge = pos_edge[:, perm] 80 | # subsample for neg_edge 81 | np.random.seed(123) 82 | num_neg = neg_edge.shape[1] 83 | perm = np.random.permutation(num_neg) 84 | perm = perm[:int(percent / 100 * num_neg)] 85 | neg_edge = neg_edge[:, perm] 86 | 87 | elif 'source_node' in split_edge['train']: 88 | source = split_edge[split]['source_node'] 89 | target = split_edge[split]['target_node'] 90 | if split == 'train': 91 | target_neg = torch.randint(0, num_nodes, [target.shape[0], 1], 92 | dtype=torch.long) 93 | else: 94 | target_neg = split_edge[split]['target_node_neg'] 95 | # subsample 96 | np.random.seed(123) 97 | num_source = source.shape[0] 98 | perm = np.random.permutation(num_source) 99 | perm = perm[:int(percent / 100 * num_source)] 100 | source, target, target_neg = source[perm], target[perm], target_neg[perm, :] 101 | pos_edge = torch.stack([source, target]) 102 | neg_per_target = target_neg.shape[1] 103 | neg_edge = torch.stack([source.repeat_interleave(neg_per_target), 104 | target_neg.view(-1)]) 105 | return pos_edge, neg_edge 106 | 107 | def CN(A, edge_index, batch_size=100000): 108 | # The Common Neighbor heuristic score. 109 | link_loader = DataLoader(range(edge_index.shape[1]), batch_size) 110 | scores = [] 111 | for ind in tqdm(link_loader): 112 | src, dst = edge_index[0, ind], edge_index[1, ind] 113 | cur_scores = np.array(np.sum(A[src].multiply(A[dst]), 1)).flatten() 114 | scores.append(cur_scores) 115 | return torch.FloatTensor(np.concatenate(scores, 0)), edge_index 116 | 117 | def AA(A, edge_index, batch_size=100000): 118 | # The Adamic-Adar heuristic score. 119 | multiplier = 1 / np.log(A.sum(axis=0)) 120 | multiplier[np.isinf(multiplier)] = 0 121 | A_ = A.multiply(multiplier).tocsr() 122 | link_loader = DataLoader(range(edge_index.shape[1]), batch_size) 123 | scores = [] 124 | for ind in tqdm(link_loader): 125 | src, dst = edge_index[0, ind], edge_index[1, ind] 126 | cur_scores = np.array(np.sum(A[src].multiply(A_[dst]), 1)).flatten() 127 | scores.append(cur_scores) 128 | scores = np.concatenate(scores, 0) 129 | return torch.FloatTensor(scores), edge_index 130 | 131 | def PPR(A, edge_index): 132 | # The Personalized PageRank heuristic score. 133 | # Need install fast_pagerank by "pip install fast-pagerank" 134 | # Too slow for large datasets now. 135 | from fast_pagerank import pagerank_power 136 | num_nodes = A.shape[0] 137 | src_index, sort_indices = torch.sort(edge_index[0]) 138 | dst_index = edge_index[1, sort_indices] 139 | edge_index = torch.stack([src_index, dst_index]) 140 | #edge_index = edge_index[:, :50] 141 | scores = [] 142 | visited = set([]) 143 | j = 0 144 | for i in tqdm(range(edge_index.shape[1])): 145 | if i < j: 146 | continue 147 | src = edge_index[0, i] 148 | personalize = np.zeros(num_nodes) 149 | personalize[src] = 1 150 | ppr = pagerank_power(A, p=0.85, personalize=personalize, tol=1e-7) 151 | j = i 152 | while edge_index[0, j] == src: 153 | j += 1 154 | if j == edge_index.shape[1]: 155 | break 156 | all_dst = edge_index[1, i:j] 157 | cur_scores = ppr[all_dst] 158 | if cur_scores.ndim == 0: 159 | cur_scores = np.expand_dims(cur_scores, 0) 160 | scores.append(np.array(cur_scores)) 161 | 162 | scores = np.concatenate(scores, 0) 163 | return torch.FloatTensor(scores), edge_index 164 | 165 | def tonp(arr): 166 | if type(arr) is torch.Tensor: 167 | return arr.detach().cpu().data.numpy() 168 | else: 169 | return np.asarray(arr) 170 | -------------------------------------------------------------------------------- /Link_prediction_baseline/models/pretrain_contextpred_gin.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from dgl.nn.pytorch import GraphConv, SAGEConv, GINConv 7 | from dgl.nn.pytorch.glob import SumPooling, AvgPooling, MaxPooling 8 | from torch import optim 9 | import numpy as np 10 | dir_path = os.path.dirname(os.path.realpath(__file__)) 11 | parent_path = os.path.abspath(os.path.join(dir_path, os.pardir)) 12 | from dgl import backend as dgl_F 13 | 14 | class ApplyNodeFunc(nn.Module): 15 | """Update the node feature hv with MLP, BN and ReLU.""" 16 | def __init__(self, mlp): 17 | super(ApplyNodeFunc, self).__init__() 18 | self.mlp = mlp 19 | self.bn = nn.BatchNorm1d(self.mlp.output_dim) 20 | 21 | def forward(self, h): 22 | h = self.mlp(h) 23 | h = self.bn(h) 24 | h = F.relu(h) 25 | return h 26 | 27 | 28 | class MLP(nn.Module): 29 | """MLP with linear output""" 30 | def __init__(self, num_layers, input_dim, hidden_dim, output_dim): 31 | """MLP layers construction 32 | Paramters 33 | --------- 34 | num_layers: int 35 | The number of linear layers 36 | input_dim: int 37 | The dimensionality of input features 38 | hidden_dim: int 39 | The dimensionality of hidden units at ALL layers 40 | output_dim: int 41 | The number of classes for prediction 42 | """ 43 | super(MLP, self).__init__() 44 | self.linear_or_not = True # default is linear model 45 | self.num_layers = num_layers 46 | self.output_dim = output_dim 47 | 48 | if num_layers < 1: 49 | raise ValueError("number of layers should be positive!") 50 | elif num_layers == 1: 51 | # Linear model 52 | self.linear = nn.Linear(input_dim, output_dim) 53 | else: 54 | # Multi-layer model 55 | self.linear_or_not = False 56 | self.linears = torch.nn.ModuleList() 57 | self.batch_norms = torch.nn.ModuleList() 58 | 59 | self.linears.append(nn.Linear(input_dim, hidden_dim)) 60 | for layer in range(num_layers - 2): 61 | self.linears.append(nn.Linear(hidden_dim, hidden_dim)) 62 | self.linears.append(nn.Linear(hidden_dim, output_dim)) 63 | 64 | for layer in range(num_layers - 1): 65 | self.batch_norms.append(nn.BatchNorm1d((hidden_dim))) 66 | 67 | def forward(self, x): 68 | if self.linear_or_not: 69 | # If linear model 70 | return self.linear(x) 71 | else: 72 | # If MLP 73 | h = x 74 | for i in range(self.num_layers - 1): 75 | h = F.relu(self.batch_norms[i](self.linears[i](h))) 76 | return self.linears[-1](h) 77 | 78 | 79 | class GIN(nn.Module): 80 | """GIN model""" 81 | def __init__(self, g, num_layers, num_mlp_layers, input_dim, hidden_dim, 82 | output_dim, final_dropout, learn_eps, graph_pooling_type, 83 | neighbor_pooling_type): 84 | """model parameters setting 85 | Paramters 86 | --------- 87 | num_layers: int 88 | The number of linear layers in the neural network 89 | num_mlp_layers: int 90 | The number of linear layers in mlps 91 | input_dim: int 92 | The dimensionality of input features 93 | hidden_dim: int 94 | The dimensionality of hidden units at ALL layers 95 | output_dim: int 96 | The number of classes for prediction 97 | final_dropout: float 98 | dropout ratio on the final linear layer 99 | learn_eps: boolean 100 | If True, learn epsilon to distinguish center nodes from neighbors 101 | If False, aggregate neighbors and center nodes altogether. 102 | neighbor_pooling_type: str 103 | how to aggregate neighbors (sum, mean, or max) 104 | graph_pooling_type: str 105 | how to aggregate entire nodes in a graph (sum, mean or max) 106 | """ 107 | super(GIN, self).__init__() 108 | self.num_layers = num_layers 109 | self.learn_eps = learn_eps 110 | 111 | # List of MLPs 112 | self.ginlayers = torch.nn.ModuleList() 113 | self.batch_norms = torch.nn.ModuleList() 114 | 115 | for layer in range(self.num_layers): 116 | if layer == 0: 117 | mlp = MLP(num_mlp_layers, input_dim, hidden_dim, hidden_dim) 118 | else: 119 | mlp = MLP(num_mlp_layers, hidden_dim, hidden_dim, hidden_dim) 120 | 121 | self.ginlayers.append( 122 | GINConv(ApplyNodeFunc(mlp), neighbor_pooling_type, 0, self.learn_eps)) 123 | self.batch_norms.append(nn.BatchNorm1d(hidden_dim)) 124 | 125 | # Linear function for graph poolings of output of each layer 126 | # which maps the output of different layers into a prediction score 127 | self.linears_prediction = torch.nn.ModuleList() 128 | 129 | for layer in range(num_layers): 130 | if layer == 0: 131 | self.linears_prediction.append( 132 | nn.Linear(input_dim, output_dim)) 133 | else: 134 | self.linears_prediction.append( 135 | nn.Linear(hidden_dim, output_dim)) 136 | 137 | self.drop = nn.Dropout(final_dropout) 138 | 139 | if graph_pooling_type == 'sum': 140 | self.pool = SumPooling() 141 | elif graph_pooling_type == 'mean': 142 | self.pool = AvgPooling() 143 | elif graph_pooling_type == 'max': 144 | self.pool = MaxPooling() 145 | else: 146 | raise NotImplementedError 147 | 148 | def forward(self, g, h): 149 | # list of hidden representation at each layer (including input) 150 | hidden_rep = [h] 151 | 152 | for i in range(self.num_layers): 153 | h = self.ginlayers[i](g, h) 154 | h = self.batch_norms[i](h) 155 | h = F.relu(h) 156 | hidden_rep.append(h) 157 | 158 | # only need node embedding 159 | return h 160 | 161 | 162 | class Encoder(nn.Module): 163 | def __init__(self, g, in_feats, n_hidden, n_layers, dropout): 164 | super(Encoder, self).__init__() 165 | self.conv = GIN(g, n_layers, 1, in_feats, n_hidden, n_hidden, dropout, True, 'sum', 'sum') 166 | 167 | def forward(self, g, features): 168 | features = self.conv(g, features) 169 | return features 170 | 171 | 172 | 173 | class contextpred_GIN(nn.Module): 174 | def __init__(self, args, g, in_feats, n_hidden, n_layers, dropout): 175 | super(contextpred_GIN, self).__init__() 176 | self.args = args 177 | self.g = g 178 | self.num_layer = n_layers 179 | assert n_layers > args.l1 180 | assert args.l2 > args.l1 181 | self.encoder = Encoder(g, in_feats, n_hidden, n_layers, dropout) 182 | self.model_context = Encoder(g, in_feats, n_hidden, int(args.l2 - args.l1), dropout) 183 | self.criterion = nn.BCEWithLogitsLoss() 184 | self.optimizer_substruct = optim.Adam(self.encoder.parameters(), lr=args.central_encoder_lr) 185 | self.optimizer_context = optim.Adam(self.model_context.parameters(), lr=args.context_encoder_lr) 186 | 187 | 188 | def forward(self, features): 189 | substruct_rep = self.encoder(self.g, features) 190 | return substruct_rep 191 | 192 | 193 | def cycle_index(self, num, shift): 194 | arr = torch.arange(num) + shift 195 | arr[-shift:] = torch.arange(shift) 196 | return arr 197 | 198 | 199 | def train_model(self, features): 200 | self.optimizer_substruct.zero_grad() 201 | self.optimizer_context.zero_grad() 202 | 203 | # self.context_Gs[idx], self.overlap_nodes[idx], self.central_id[idx] 204 | # TODO: check cids and the format of bg_overlap_nodes 205 | for bg, bg_features, bg_overlap_nodes, batch_num_overlaped_nodes, cids in self.contextgraph_loader: 206 | substruct_rep = self.forward(features) 207 | overlapped_node_rep = self.model_context(bg, bg_features) 208 | 209 | # TODO: agg by bg_overlap_nodes 210 | context_rep = overlapped_node_rep[bg_overlap_nodes] 211 | n_graphs = bg.batch_size 212 | batch_num_objs = batch_num_overlaped_nodes 213 | seg_id = dgl_F.zerocopy_from_numpy(np.arange(n_graphs, dtype='int64').repeat(batch_num_objs)) 214 | seg_id = dgl_F.copy_to(seg_id, dgl_F.context(context_rep)) 215 | context_rep = dgl_F.unsorted_1d_segment_mean(context_rep, seg_id, n_graphs, 0) 216 | 217 | assert context_rep.shape == substruct_rep.shape 218 | 219 | neg_context_rep = torch.cat( 220 | [context_rep[self.cycle_index(len(context_rep), i + 1)] for i in range(self.args.neg_samples)], dim=0) 221 | 222 | pred_pos = torch.sum(substruct_rep * context_rep, dim=1) 223 | pred_neg = torch.sum(substruct_rep.repeat((self.args.neg_samples, 1)) * neg_context_rep, dim=1) 224 | 225 | loss_pos = self.criterion(pred_pos.double(), torch.ones(len(pred_pos)).to(pred_pos.device).double()) 226 | loss_neg = self.criterion(pred_neg.double(), torch.zeros(len(pred_neg)).to(pred_neg.device).double()) 227 | 228 | loss = loss_pos + self.args.neg_samples * loss_neg 229 | loss.backward() 230 | self.optimizer_substruct.step() 231 | self.optimizer_context.step() 232 | 233 | return loss -------------------------------------------------------------------------------- /Link_prediction_baseline/models/pretrain_masking_gin.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | from dgl.nn.pytorch import GraphConv, SAGEConv, GINConv 5 | from dgl.nn.pytorch.glob import SumPooling, AvgPooling, MaxPooling 6 | 7 | from src.models.MLP import MLP as feat_MLP 8 | 9 | 10 | class ApplyNodeFunc(nn.Module): 11 | """Update the node feature hv with MLP, BN and ReLU.""" 12 | def __init__(self, mlp): 13 | super(ApplyNodeFunc, self).__init__() 14 | self.mlp = mlp 15 | self.bn = nn.BatchNorm1d(self.mlp.output_dim) 16 | 17 | def forward(self, h): 18 | h = self.mlp(h) 19 | h = self.bn(h) 20 | h = F.relu(h) 21 | return h 22 | 23 | class MLP(nn.Module): 24 | """MLP with linear output""" 25 | def __init__(self, num_layers, input_dim, hidden_dim, output_dim, linear_or_not=True): 26 | """MLP layers construction 27 | Paramters 28 | --------- 29 | num_layers: int 30 | The number of linear layers 31 | input_dim: int 32 | The dimensionality of input features 33 | hidden_dim: int 34 | The dimensionality of hidden units at ALL layers 35 | output_dim: int 36 | The number of classes for prediction 37 | """ 38 | super(MLP, self).__init__() 39 | self.linear_or_not = linear_or_not # default is linear model 40 | self.num_layers = num_layers 41 | self.output_dim = output_dim 42 | 43 | if num_layers < 1: 44 | raise ValueError("number of layers should be positive!") 45 | elif num_layers == 1: 46 | # Linear model 47 | self.linear = nn.Linear(input_dim, output_dim) 48 | else: 49 | # Multi-layer model 50 | self.linear_or_not = False 51 | self.linears = torch.nn.ModuleList() 52 | self.batch_norms = torch.nn.ModuleList() 53 | 54 | self.linears.append(nn.Linear(input_dim, hidden_dim)) 55 | for layer in range(num_layers - 2): 56 | self.linears.append(nn.Linear(hidden_dim, hidden_dim)) 57 | self.linears.append(nn.Linear(hidden_dim, output_dim)) 58 | 59 | for layer in range(num_layers - 1): 60 | self.batch_norms.append(nn.BatchNorm1d((hidden_dim))) 61 | 62 | def forward(self, x): 63 | if self.linear_or_not: 64 | # If linear model 65 | return self.linear(x) 66 | else: 67 | # If MLP 68 | h = x 69 | for i in range(self.num_layers - 1): 70 | h = F.relu(self.batch_norms[i](self.linears[i](h))) 71 | return self.linears[-1](h) 72 | 73 | class GIN(nn.Module): 74 | """GIN model""" 75 | def __init__(self, num_layers, num_mlp_layers, input_dim, hidden_dim, 76 | output_dim, final_dropout, learn_eps, graph_pooling_type, 77 | neighbor_pooling_type): 78 | """model parameters setting 79 | Paramters 80 | --------- 81 | num_layers: int 82 | The number of linear layers in the neural network 83 | num_mlp_layers: int 84 | The number of linear layers in mlps 85 | input_dim: int 86 | The dimensionality of input features 87 | hidden_dim: int 88 | The dimensionality of hidden units at ALL layers 89 | output_dim: int 90 | The number of classes for prediction 91 | final_dropout: float 92 | dropout ratio on the final linear layer 93 | learn_eps: boolean 94 | If True, learn epsilon to distinguish center nodes from neighbors 95 | If False, aggregate neighbors and center nodes altogether. 96 | neighbor_pooling_type: str 97 | how to aggregate neighbors (sum, mean, or max) 98 | graph_pooling_type: str 99 | how to aggregate entire nodes in a graph (sum, mean or max) 100 | """ 101 | super(GIN, self).__init__() 102 | self.num_layers = num_layers 103 | self.learn_eps = learn_eps 104 | 105 | # List of MLPs 106 | self.ginlayers = torch.nn.ModuleList() 107 | self.batch_norms = torch.nn.ModuleList() 108 | 109 | for layer in range(self.num_layers): 110 | if layer == 0: 111 | mlp = MLP(num_mlp_layers, input_dim, hidden_dim, hidden_dim) 112 | else: 113 | mlp = MLP(num_mlp_layers, hidden_dim, hidden_dim, hidden_dim) 114 | 115 | self.ginlayers.append( 116 | GINConv(ApplyNodeFunc(mlp), neighbor_pooling_type, 0, self.learn_eps)) 117 | self.batch_norms.append(nn.BatchNorm1d(hidden_dim)) 118 | 119 | # Linear function for graph poolings of output of each layer 120 | # which maps the output of different layers into a prediction score 121 | self.linears_prediction = torch.nn.ModuleList() 122 | 123 | for layer in range(num_layers): 124 | if layer == 0: 125 | self.linears_prediction.append( 126 | nn.Linear(input_dim, output_dim)) 127 | else: 128 | self.linears_prediction.append( 129 | nn.Linear(hidden_dim, output_dim)) 130 | 131 | self.drop = nn.Dropout(final_dropout) 132 | 133 | if graph_pooling_type == 'sum': 134 | self.pool = SumPooling() 135 | elif graph_pooling_type == 'mean': 136 | self.pool = AvgPooling() 137 | elif graph_pooling_type == 'max': 138 | self.pool = MaxPooling() 139 | else: 140 | raise NotImplementedError 141 | 142 | def forward(self, g, h): 143 | # list of hidden representation at each layer (including input) 144 | hidden_rep = [h] 145 | 146 | for i in range(self.num_layers): 147 | h = self.ginlayers[i](g, h) 148 | h = self.batch_norms[i](h) 149 | h = F.relu(h) 150 | hidden_rep.append(h) 151 | 152 | # only need node embedding 153 | return h 154 | 155 | # score_over_layer = 0 156 | # 157 | # # perform pooling over all nodes in each graph in every layer 158 | # for i, h in enumerate(hidden_rep): 159 | # pooled_h = self.pool(g, h) 160 | # score_over_layer += self.drop(self.linears_prediction[i](pooled_h)) 161 | # 162 | # return score_over_layer 163 | 164 | 165 | class GCN(nn.Module): 166 | def __init__(self, n_layers, in_feats, n_hidden, output_dim, activation, dropout): 167 | super(GCN, self).__init__() 168 | self.layers = nn.ModuleList() 169 | self.layers.append(SAGEConv(in_feats, n_hidden, aggregator_type='gcn', activation=activation)) 170 | for i in range(1, n_layers - 1): 171 | self.layers.append(SAGEConv(n_hidden, n_hidden, aggregator_type='gcn', activation=activation)) 172 | self.layers.append(SAGEConv(n_hidden, output_dim, aggregator_type='gcn', activation=None)) 173 | self.dropout = nn.Dropout(p=dropout) 174 | 175 | def forward(self, g, features): 176 | h = features 177 | for i, layer in enumerate(self.layers): 178 | if i != 0: 179 | h = self.dropout(h) 180 | h = layer(g, h) 181 | return h 182 | 183 | 184 | class Encoder(nn.Module): 185 | def __init__(self, in_feats, n_hidden, n_layers, dropout, type): 186 | super(Encoder, self).__init__() 187 | if type == 'gcn': 188 | self.conv = GCN(n_layers+1, in_feats, n_hidden, n_hidden, F.relu, dropout) 189 | elif type == 'gin': 190 | self.conv = GIN(n_layers+1, 1, in_feats, n_hidden, n_hidden, dropout, True, 'sum', 'sum') 191 | 192 | def forward(self, g, features): 193 | features = self.conv(g, features) 194 | return features 195 | 196 | 197 | class masking_GIN(nn.Module): 198 | def __init__(self, args, in_feats, n_hidden, n_layers, n_degree, dropout): 199 | super(masking_GIN, self).__init__() 200 | self.args = args 201 | self.num_layer = n_layers 202 | self.in_feats = in_feats 203 | self.hidden = n_hidden 204 | self.encoder = Encoder(in_feats, n_hidden, n_layers, dropout, args.encoder_type) 205 | if args.pretrain is not None: 206 | self.degree_classifier = torch.nn.Linear(n_hidden, 103) 207 | else: 208 | self.degree_classifier = torch.nn.Linear(n_hidden, n_degree) 209 | self.feat_encoder = None 210 | # self.prepare() 211 | 212 | def prepare(self): 213 | self.feat_encoder = feat_MLP(self.in_feats, self.hidden, self.in_feats) # MLP(1, self.in_feats, self.hidden, self.hidden) 214 | 215 | 216 | def forward(self, g, features, test_mask=None): 217 | if self.feat_encoder is not None: 218 | embedding = F.relu(self.feat_encoder(features)) 219 | else: 220 | embedding = features 221 | embedding = self.encoder(g, embedding) 222 | if test_mask is not None: 223 | pred = torch.log_softmax(self.degree_classifier(embedding[test_mask]), dim=1) 224 | embedding = embedding[test_mask] 225 | else: 226 | pred = torch.log_softmax(self.degree_classifier(embedding), dim=1) 227 | return pred, embedding 228 | 229 | 230 | def train_model(self, g, features, test_mask=None, train_label=None): 231 | self.train() 232 | self.optimizer.zero_grad() 233 | pred, embedding = self.forward(g, features, test_mask) 234 | if train_label is not None: 235 | loss = F.nll_loss(pred, train_label) 236 | else: 237 | loss = F.nll_loss(pred, self.degree) 238 | loss.backward() 239 | self.optimizer.step() 240 | return loss.item() 241 | -------------------------------------------------------------------------------- /Link_prediction_baseline/models/utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | import torch.nn as nn 7 | import torch.optim as optim 8 | import torch.autograd as autograd 9 | import numpy as np 10 | # from cortex_DIM.functions.misc import log_sum_exp 11 | 12 | class Mine(nn.Module): 13 | def __init__(self, input_size=2, hidden_size=100): 14 | super().__init__() 15 | self.fc1 = nn.Linear(input_size, hidden_size) 16 | self.fc2 = nn.Linear(hidden_size, hidden_size) 17 | self.fc3 = nn.Linear(hidden_size, 1) 18 | nn.init.normal_(self.fc1.weight,std=0.02) 19 | nn.init.constant_(self.fc1.bias, 0) 20 | nn.init.normal_(self.fc2.weight,std=0.02) 21 | nn.init.constant_(self.fc2.bias, 0) 22 | nn.init.normal_(self.fc3.weight,std=0.02) 23 | nn.init.constant_(self.fc3.bias, 0) 24 | 25 | def forward(self, input): 26 | output = F.elu(self.fc1(input)) 27 | output = F.elu(self.fc2(output)) 28 | output = self.fc3(output) 29 | return output 30 | 31 | def mutual_information(joint, marginal, mine_net): 32 | t = mine_net(joint) 33 | et = torch.exp(mine_net(marginal)) 34 | mi_lb = torch.mean(t) - torch.log(torch.mean(et)) 35 | return mi_lb, t, et 36 | 37 | def learn_mine(batch, mine_net, mine_net_optim, ma_et, ma_rate=0.01): 38 | # batch is a tuple of (joint, marginal) 39 | joint , marginal = batch 40 | joint = torch.autograd.Variable(torch.FloatTensor(joint)).cuda() 41 | marginal = torch.autograd.Variable(torch.FloatTensor(marginal)).cuda() 42 | mi_lb , t, et = mutual_information(joint, marginal, mine_net) 43 | ma_et = (1-ma_rate)*ma_et + ma_rate*torch.mean(et) 44 | 45 | # unbiasing use moving average 46 | loss = -(torch.mean(t) - (1/ma_et.mean()).detach()*torch.mean(et)) 47 | # use biased estimator 48 | # loss = - mi_lb 49 | 50 | mine_net_optim.zero_grad() 51 | autograd.backward(loss) 52 | mine_net_optim.step() 53 | return mi_lb, ma_et 54 | 55 | def sample_batch(data, batch_size=100, sample_mode='joint'): 56 | if sample_mode == 'joint': 57 | index = np.random.choice(range(data[0].shape[0]), size=batch_size, replace=False) 58 | batch = np.concatenate([data[0][index], data[1][index]], axis=1) 59 | else: 60 | joint_index = np.random.choice(range(data[0].shape[0]), size=batch_size, replace=False) 61 | marginal_index = np.random.choice(range(data[1].shape[0]), size=batch_size, replace=False) 62 | batch = np.concatenate([data[0][joint_index], 63 | data[1][marginal_index]], 64 | axis=1) 65 | return batch 66 | 67 | def train_mine(data, mine_net,mine_net_optim, batch_size=300, iter_num=int(2e+4), log_freq=int(1e+3)): 68 | # data is x or y 69 | result = list() 70 | ma_et = 1. 71 | for i in range(iter_num): 72 | batch = sample_batch(data,batch_size=batch_size)\ 73 | , sample_batch(data,batch_size=batch_size,sample_mode='marginal') 74 | mi_lb, ma_et = learn_mine(batch, mine_net, mine_net_optim, ma_et) 75 | result.append(mi_lb.detach().cpu().numpy()) 76 | if (i+1)%(log_freq)==0: 77 | print(result[-1]) 78 | return result 79 | 80 | def raise_measure_error(measure): 81 | supported_measures = ['GAN', 'JSD', 'X2', 'KL', 'RKL', 'DV', 'H2', 'W1'] 82 | raise NotImplementedError( 83 | 'Measure `{}` not supported. Supported: {}'.format(measure, 84 | supported_measures)) 85 | 86 | 87 | def get_positive_expectation(p_samples, measure, average=True): 88 | """Computes the positive part of a divergence / difference. 89 | Args: 90 | p_samples: Positive samples. 91 | measure: Measure to compute for. 92 | average: Average the result over samples. 93 | Returns: 94 | torch.Tensor 95 | """ 96 | log_2 = math.log(2.) 97 | 98 | if measure == 'GAN': 99 | Ep = - F.softplus(-p_samples) 100 | elif measure == 'JSD': 101 | Ep = log_2 - F.softplus(- p_samples) 102 | elif measure == 'X2': 103 | Ep = p_samples ** 2 104 | elif measure == 'KL': 105 | Ep = p_samples + 1. 106 | elif measure == 'RKL': 107 | Ep = -torch.exp(-p_samples) 108 | elif measure == 'DV': 109 | Ep = p_samples 110 | elif measure == 'H2': 111 | Ep = 1. - torch.exp(-p_samples) 112 | elif measure == 'W1': 113 | Ep = p_samples 114 | else: 115 | raise_measure_error(measure) 116 | 117 | if average: 118 | return Ep.mean() 119 | else: 120 | return Ep 121 | 122 | 123 | def get_negative_expectation(q_samples, measure, average=True): 124 | """Computes the negative part of a divergence / difference. 125 | Args: 126 | q_samples: Negative samples. 127 | measure: Measure to compute for. 128 | average: Average the result over samples. 129 | Returns: 130 | torch.Tensor 131 | """ 132 | log_2 = math.log(2.) 133 | 134 | if measure == 'GAN': 135 | Eq = F.softplus(-q_samples) + q_samples 136 | elif measure == 'JSD': 137 | Eq = F.softplus(-q_samples) + q_samples - log_2 138 | elif measure == 'X2': 139 | Eq = -0.5 * ((torch.sqrt(q_samples ** 2) + 1.) ** 2) 140 | elif measure == 'KL': 141 | Eq = torch.exp(q_samples) 142 | elif measure == 'RKL': 143 | Eq = q_samples - 1. 144 | elif measure == 'DV': 145 | Eq = log_sum_exp(q_samples, 0) - math.log(q_samples.size(0)) 146 | elif measure == 'H2': 147 | Eq = torch.exp(q_samples) - 1. 148 | elif measure == 'W1': 149 | Eq = q_samples 150 | else: 151 | raise_measure_error(measure) 152 | 153 | if average: 154 | return Eq.mean() 155 | else: 156 | return Eq -------------------------------------------------------------------------------- /Link_prediction_baseline/models/vgae.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import dgl 4 | import torch 5 | import torch.nn as nn 6 | from dgl.nn.pytorch import GraphConv, SAGEConv 7 | import dgl.function as fn 8 | import torch.nn.functional as F 9 | import networkx as nx 10 | 11 | from src.models.MLP import MLP 12 | from src.models.inner_product_decoder import InnerProductDecoder 13 | from src.utils import loss_function 14 | import numpy as np 15 | from sklearn.metrics import roc_auc_score, average_precision_score, label_ranking_average_precision_score 16 | import scipy.sparse as sp 17 | from IPython import embed 18 | from torch.nn.utils import clip_grad_norm_ 19 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 20 | 21 | 22 | 23 | def extract_nodeflow(nf): 24 | node_set = set() 25 | for i in range(nf.num_layers): 26 | node_set.update(nf.layer_parent_nid(i).tolist()) 27 | node_idx = list(node_set) 28 | # embed() 29 | return node_idx 30 | #return nf.layer_parent_nid(0), nf.layer_parent_nid(1) 31 | 32 | 33 | def sigmoid(x): 34 | return 1 / (1 + np.exp(-x)) 35 | 36 | 37 | class VGAE(nn.Module): 38 | def __init__(self, g, in_feats, hidden_dim1, hidden_dim2, dropout, pretrain = None): 39 | super(VGAE, self).__init__() 40 | self.g = g 41 | self.in_feats = in_feats 42 | self.pretrain = pretrain 43 | self.hidden_dim1 = hidden_dim1 44 | 45 | self.gc1 = SAGEConv(in_feats, hidden_dim1, aggregator_type='gcn', activation=F.relu) 46 | self.gc2 = SAGEConv(hidden_dim1, hidden_dim2, aggregator_type='gcn', activation=None) 47 | self.gc3 = SAGEConv(hidden_dim1, hidden_dim2, aggregator_type='gcn', activation=None) 48 | #self.gc1 = GraphConv(in_feats, hidden_dim1, bias=False, activation=F.relu) 49 | #self.gc2 = GraphConv(hidden_dim1, hidden_dim2, bias=False, activation=lambda x: x) 50 | #self.gc3 = GraphConv(hidden_dim1, hidden_dim2, bias=False, activation=lambda x: x) 51 | self.dc = InnerProductDecoder(dropout, act=lambda x: x) 52 | 53 | if pretrain is not None: 54 | #pass 55 | self.load_state_dict(torch.load(pretrain)) 56 | print("Loaded pre-train model") 57 | 58 | def prepare(self): 59 | if self.pretrain is not None: 60 | self.feat_encoder = MLP(self.in_feats, self.hidden_dim1, self.hidden_dim1) 61 | #self.feat_encoder = None 62 | else: 63 | #self.feat_encoder = MLP(self.in_feats, self.hidden_dim1, self.hidden_dim1) 64 | self.feat_encoder = None 65 | 66 | def reparameterize(self, mu, logvar): 67 | if self.training: 68 | std = torch.exp(logvar) 69 | eps = torch.randn_like(std) 70 | return eps.mul(std).add_(mu) 71 | else: 72 | return mu 73 | 74 | def encode(self, feature): 75 | if self.feat_encoder is not None: 76 | feature = self.feat_encoder(feature) 77 | hidden1 = self.gc1(self.g, feature) 78 | return self.gc2(self.g, hidden1), self.gc3(self.g, hidden1) 79 | 80 | 81 | 82 | def forward(self, features, relative_node_idx=None): 83 | mu, logvar = self.encode(features) 84 | # embed() 85 | if relative_node_idx is not None: 86 | z = self.reparameterize(mu[relative_node_idx], logvar[relative_node_idx]) 87 | return self.dc(z), mu, logvar 88 | else: 89 | return None, mu, None 90 | 91 | 92 | def train_model(self): 93 | self.train() 94 | cur_loss = [] 95 | for idx, nf in enumerate(self.train_sampler): 96 | t = time.time() 97 | decoder_node_id = extract_nodeflow(nf) 98 | self.optimizer.zero_grad() 99 | recovered, mu, logvar = self.forward(self.features, decoder_node_id) 100 | sub_sampled_adj = self.adj_train[decoder_node_id, :][:, decoder_node_id] 101 | adj_label = sub_sampled_adj + sp.eye(sub_sampled_adj.shape[0]) 102 | adj_label = torch.FloatTensor(adj_label.toarray()).reshape(-1).to(device) 103 | 104 | pos_weight = torch.FloatTensor( [float(sub_sampled_adj.shape[0] * sub_sampled_adj.shape[0] - sub_sampled_adj.sum()) / sub_sampled_adj.sum()]).to(device) 105 | norm = sub_sampled_adj.shape[0] * sub_sampled_adj.shape[0] / float( (sub_sampled_adj.shape[0] * sub_sampled_adj.shape[0] - sub_sampled_adj.sum()) * 2) 106 | 107 | loss = loss_function(recovered, adj_label, mu=mu, logvar=logvar, n_nodes=self.features.shape[0], norm=norm, 108 | pos_weight=pos_weight) 109 | loss.backward() 110 | # clip_grad_norm_(self.parameters(), 1.0) 111 | cur_loss.append(loss.item()) 112 | self.optimizer.step() 113 | # embed() 114 | 115 | return np.mean(cur_loss) 116 | 117 | 118 | def test_model(self, test_edges, test_edges_false, feature_only = False): 119 | with torch.no_grad(): 120 | self.eval() 121 | # feature verify 122 | if feature_only: 123 | output_emb = self.features.cpu() 124 | else: 125 | _, mu, __ = self.forward(self.features) 126 | output_emb = mu.cpu().detach() 127 | # embed() 128 | pos_pred = torch.sigmoid((output_emb[test_edges[:,0]] * output_emb[test_edges[:,1]]).sum(dim=1)).numpy() 129 | neg_pred = torch.sigmoid((output_emb[test_edges_false[:,0]] * output_emb[test_edges_false[:,1]]).sum(dim=1)).numpy() 130 | #offset = 0 131 | pred = np.concatenate( (pos_pred, neg_pred)) 132 | 133 | roc_score = roc_auc_score( 134 | np.concatenate((np.ones(test_edges.shape[0]), np.zeros(test_edges_false.shape[0]))), pred) 135 | ap_score = average_precision_score( 136 | np.concatenate((np.ones(test_edges.shape[0]), np.zeros(test_edges_false.shape[0]))), pred) 137 | mrr_score = label_ranking_average_precision_score( 138 | np.concatenate((np.ones( (test_edges.shape[0],1) ), np.zeros( (test_edges_false.shape[0],1) )), axis = 0), pred[..., np.newaxis]) 139 | return roc_score, ap_score, mrr_score 140 | 141 | def generate_subgraph(self, nf): 142 | nll = 0.0 143 | message_func = fn.v_dot_u('z', 'z', 'm') 144 | for i in range(nf.num_blocks): 145 | nf.layers[i].data['z'] = self.z[nf.layer_parent_nid(i)] 146 | nf.layers[i+1].data['z'] = self.z[nf.layer_parent_nid(i+1)] 147 | #nf.layers[i].data['nz'] = nz[nf.layer_parent_nid(i)] 148 | #if x is not None: 149 | # nf.layers[i].data['x'] = x[nf.layer_parent_nid(i)] 150 | # embed() 151 | nf.block_compute(i, lambda edges: {'m': ((edges.src['z'] - edges.dst['z'])**2).sum(dim=1) }, 152 | lambda node :{'nll': - 1 / (1 + (node.mailbox['m']-1).exp() ).log().sum(dim=1)}) 153 | nll += nf.layers[i+1].data['nll'].sum().item() 154 | # nf.block_compute(i, fn.v_dot_u('z', 'z', 'm'), reduce) 155 | return nll 156 | 157 | 158 | def output_nll(self): 159 | with torch.no_grad(): 160 | self.eval() 161 | nll = 0.0 162 | _, mu, __ = self.forward(self.features) 163 | self.z = mu.detach() 164 | # output_emb = mu.cpu().detach() 165 | for idx, nf in enumerate(self.test_sampler): 166 | nll += self.generate_subgraph(nf) 167 | 168 | return nll 169 | 170 | -------------------------------------------------------------------------------- /Link_prediction_model/edge_LP.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | import torch 3 | import numpy as np 4 | from torch_sparse import SparseTensor 5 | import copy 6 | 7 | from utils import toitem, tonp 8 | from Label_propagation_model.outcome_correlation import general_outcome_correlation_YAG, gen_normalized_adjs 9 | 10 | 11 | def normalize_adj_v2(edge_index, num_nodes): 12 | row, col = edge_index 13 | adj = SparseTensor(row=row, col=col, sparse_sizes=(num_nodes, num_nodes)) 14 | deg = adj.sum(dim=1).to(torch.float) 15 | D_isqrt = deg.pow(-0.5) 16 | D_isqrt[D_isqrt == float('inf')] = 0 17 | DAD, DA, AD = gen_normalized_adjs(adj, D_isqrt) 18 | adj = DAD 19 | return adj 20 | 21 | def normalize_adj_v3(edge_index, num_nodes): 22 | adj = torch.sparse_coo_tensor(edge_index, edge_index[0]*0+1, [num_nodes, num_nodes]).coalesce().float() 23 | def get_degree_vector(adj): 24 | _deg = torch.sparse.sum(adj, dim=1).to(torch.float) 25 | _idx = _deg.indices() 26 | _v = _deg.values() 27 | N = toitem(adj.shape[0]) 28 | deg = torch.zeros(N, device=adj.device) 29 | deg[_idx] = _v 30 | return deg 31 | deg = get_degree_vector(adj) 32 | D_isqrt = deg.pow(-1) 33 | D_isqrt[D_isqrt == float('inf')] = 0 34 | diag_e = torch.tensor([np.arange(len(D_isqrt)),np.arange(len(D_isqrt))], device=edge_index.device) 35 | D_isqrt = torch.sparse_coo_tensor(diag_e, D_isqrt).coalesce().float() 36 | adj = torch.sparse.mm(D_isqrt, adj) 37 | return adj 38 | 39 | def build_edge_adj(edge_index): 40 | # edge_adj is the 'edge_index_of_edge_index' 41 | # this function return the "edge_index of edges". If two edges share a node, then they are connected; other wise they are not. The ID of the edges are dedicated by the input edge_index. 42 | edge_adj = [[i,i] for i in range(edge_index.shape[1])] # add self loop 43 | node2edgeset = get_node2edge(edge_index) 44 | N_nodes = edge_index.max()+1 45 | for ino in range(N_nodes): 46 | edge_set = node2edgeset[ino] 47 | edge_mix = mutual_intermix(edge_set) 48 | edge_adj.extend(edge_mix) 49 | edge_adj = torch.tensor(edge_adj, device=edge_index.device).t() 50 | return edge_adj, node2edgeset 51 | 52 | def run_logitLP(edge_index, LP_device, alpha, num_propagations, 53 | pos_train_pred, pos_valid_pred, pos_test_pred, neg_train_pred, neg_valid_pred, neg_test_pred): 54 | # input: 55 | # edge_logits: logits for all edges (to be computed again using 'h') 56 | # A matrix (original) 57 | 58 | # ---- construct G and Y0 59 | train_val_test_nums = len(pos_train_pred), len(pos_valid_pred), len(pos_test_pred), len(neg_train_pred), len(neg_valid_pred), len(neg_test_pred) 60 | logits = torch.cat([pos_train_pred, pos_valid_pred, pos_test_pred, neg_train_pred, neg_valid_pred, neg_test_pred]) # shape: [N_edges] 61 | Y0 = torch.sigmoid(logits.reshape(-1,1)) 62 | G = torch.zeros(Y0.shape, device=Y0.device) 63 | G[:train_val_test_nums[0]] += 1 64 | G[train_val_test_nums[0]:sum(train_val_test_nums[:3])] += 0.5 65 | 66 | # ---- construct edge-graph-A 67 | edge_adj, node2edgeset = build_edge_adj(edge_index) 68 | adj = normalize_adj_v2(edge_adj, len(logits)) 69 | 70 | # ---- LP step ---- 71 | out = general_outcome_correlation_YAG(Y0, adj, G, alpha, num_propagations, device=LP_device) 72 | 73 | # ---- obtain logits from embedding 74 | edge_logits = invsigmoid(out.reshape(-1)) 75 | pos_train_pred, pos_valid_pred, pos_test_pred, neg_train_pred, neg_valid_pred, neg_test_pred = separate_logits(edge_logits, train_val_test_nums) 76 | return pos_train_pred, pos_valid_pred, pos_test_pred, neg_train_pred, neg_valid_pred, neg_test_pred 77 | 78 | def run_embLP(embs, edge_index, LP_device, alpha, num_propagations, 79 | pos_train_edge, pos_valid_edge, pos_test_edge, neg_train_edge, neg_valid_edge, neg_test_edge): 80 | # ---- construct G and Y0 81 | train_val_test_nums = len(pos_train_edge), len(pos_valid_edge), len(pos_test_edge), len(neg_train_edge), len(neg_valid_edge), len(neg_test_edge) 82 | edges = torch.cat([pos_train_edge, pos_valid_edge, pos_test_edge, neg_train_edge, neg_valid_edge, neg_test_edge]) # shape: [N_edges, 2] 83 | 84 | edge_embs = torch.zeros([ len(edges), embs.shape[1]*2 ], device=edge_index.device).float() 85 | for ie in range(len(edges)): 86 | src, dst = edges[ie] 87 | edge_embs[ie] = torch.cat([embs[src],embs[dst]]) 88 | 89 | # ---- construct edge-graph-A 90 | edge_adj, node2edgeset = build_edge_adj(edge_index) 91 | adj = normalize_adj_v2(edge_adj, len(edges)) 92 | 93 | # ---- LP step ---- 94 | Y0 = edge_embs 95 | G = Y0.clone() 96 | out = general_outcome_correlation_YAG(Y0, adj, G, alpha, num_propagations, device=LP_device) 97 | 98 | # ---- obtain logits from embedding 99 | out = out.view(len(edges), 2, embs.shape[1]) 100 | edge_logits = (out[:,0,:] * out[:,1,:]).sum(axis=1) 101 | 102 | pos_train_pred, pos_valid_pred, pos_test_pred, neg_train_pred, neg_valid_pred, neg_test_pred = separate_logits(edge_logits, train_val_test_nums) 103 | return pos_train_pred, pos_valid_pred, pos_test_pred, neg_train_pred, neg_valid_pred, neg_test_pred 104 | 105 | def run_xmcLP(edge_index, num_nodes, LP_device, alpha, num_propagations, 106 | pos_train_pred, pos_valid_pred, pos_test_pred, neg_train_pred, neg_valid_pred, neg_test_pred, 107 | pos_train_edge, pos_valid_edge, pos_test_edge, neg_train_edge, neg_valid_edge, neg_test_edge): 108 | # input: 109 | # edge_logits: logits for all edges (to be computed again using 'h') 110 | # A matrix (original) 111 | # ---- construct G and Y0 112 | train_val_test_nums_O = len(pos_train_pred), len(pos_valid_pred), len(pos_test_pred), len(neg_train_pred), len(neg_valid_pred), len(neg_test_pred) 113 | edges_O = torch.cat([pos_train_edge, pos_valid_edge, pos_test_edge, neg_train_edge, neg_valid_edge, neg_test_edge]).to(LP_device) # shape: [N_edges, 2] 114 | logits_O = torch.cat([pos_train_pred, pos_valid_pred, pos_test_pred, neg_train_pred, neg_valid_pred, neg_test_pred]).to(LP_device) # shape: [N_edges] 115 | 116 | def remove_duplicate(edges_O, logits_O, train_val_test_nums_O): 117 | # inputs are tensors on cpu/gpu 118 | train_val_test_nums_O = list(copy.copy(train_val_test_nums_O)) 119 | edge2loc = dict() 120 | edges, logits, train_val_test_nums, duplicate_pos = [],[], [], [] 121 | duplicate_mask = np.array([False]*len(logits_O)) 122 | landmark = 0 123 | cnt = -1 124 | for i,e in enumerate([tuple(x) for x in tonp(edges_O)]): 125 | cnt += 1 126 | if e not in edge2loc: 127 | edge2loc[e] = i 128 | edges.append(e) 129 | logits.append(logits_O[i].clone()) 130 | else: 131 | duplicate_mask[i] = True 132 | duplicate_pos.append(edge2loc[e]) 133 | if cnt+1 == train_val_test_nums_O[0]: 134 | train_val_test_nums.append(len(edges) - landmark) 135 | landmark = len(edges) 136 | train_val_test_nums_O.pop(0) 137 | cnt = -1 138 | 139 | edges = torch.tensor(edges).to(edges_O.device) 140 | logits = torch.stack(logits) 141 | return edges, logits, train_val_test_nums, duplicate_mask, duplicate_pos 142 | 143 | edges, logits, train_val_test_nums, duplicate_mask, duplicate_pos = remove_duplicate(edges_O, logits_O, train_val_test_nums_O) 144 | 145 | num_nodes = toitem(num_nodes) 146 | Gvalues = torch.zeros(sum(train_val_test_nums), device=LP_device) 147 | Gvalues[:train_val_test_nums[0]] += 1 148 | Gvalues[train_val_test_nums[0]:sum(train_val_test_nums[:3])] += 1 149 | 150 | Y0 = torch.sparse_coo_tensor(edges.T, torch.sigmoid(logits), [num_nodes, num_nodes]).coalesce().to(LP_device) 151 | G = torch.sparse_coo_tensor(edges.T, Gvalues, [num_nodes, num_nodes]).coalesce().to(LP_device) 152 | 153 | # ---- construct edge-graph-A 154 | adj = normalize_adj_v3(edge_index.to(LP_device), num_nodes) 155 | 156 | # ---- LP step ---- 157 | out = general_outcome_correlation_YAG(Y0, adj, G, alpha, num_propagations, device=LP_device, use_sparse_mult=True) 158 | 159 | # ---- obtain logits from embedding 160 | edges = [tuple(x) for x in tonp(edges)] 161 | edge_logits = invsigmoid(torch.stack([out[e] for e in edges])) 162 | 163 | def add_duplicate(edge_logits, duplicate_mask, duplicate_pos): 164 | edge_logits_O = torch.zeros(duplicate_mask.shape, device=edge_logits.device) 165 | edge_logits_O[~duplicate_mask] = edge_logits 166 | edge_logits_O[duplicate_mask] = edge_logits[duplicate_pos] 167 | return edge_logits_O 168 | 169 | edge_logits_O = add_duplicate(edge_logits, duplicate_mask, duplicate_pos) 170 | pos_train_pred, pos_valid_pred, pos_test_pred, neg_train_pred, neg_valid_pred, neg_test_pred = separate_logits(edge_logits_O, train_val_test_nums_O) 171 | return pos_train_pred, pos_valid_pred, pos_test_pred, neg_train_pred, neg_valid_pred, neg_test_pred 172 | 173 | def separate_logits(edge_logits, train_val_test_nums): 174 | acc = np.array(train_val_test_nums) 175 | acc[1:] += train_val_test_nums[0] 176 | acc[2:] += train_val_test_nums[1] 177 | acc[3:] += train_val_test_nums[2] 178 | acc[4:] += train_val_test_nums[3] 179 | acc[5:] += train_val_test_nums[4] 180 | 181 | pos_train_pred = edge_logits[:acc[0]] 182 | pos_valid_pred = edge_logits[acc[0]:acc[1]] 183 | pos_test_pred = edge_logits[acc[1]:acc[2]] 184 | neg_train_pred = edge_logits[acc[2]:acc[3]] 185 | neg_valid_pred = edge_logits[acc[3]:acc[4]] 186 | neg_test_pred = edge_logits[acc[4]:acc[5]] 187 | return pos_train_pred, pos_valid_pred, pos_test_pred, neg_train_pred, neg_valid_pred, neg_test_pred 188 | 189 | def invsigmoid(y): 190 | eps = 1e-9 191 | return - torch.log(1/(y+eps)-1) 192 | 193 | def mutual_intermix(edge_set): 194 | # input a set of edge_id (that connects to a same node), output the connectivity of these edges 195 | edge_list = list(edge_set) 196 | edge_mix = [] 197 | for ie1 in range(len(edge_list)): 198 | for ie2 in range(ie1+1, len(edge_list)): 199 | edge_mix.append([ie1,ie2]) 200 | edge_mix.append([ie2,ie1]) 201 | return edge_mix 202 | 203 | def get_node2edge(edge_index): 204 | # the returned dict is composed of purely integers, no torch.tensor. 205 | node2edge = defaultdict(set) 206 | for ie in range(edge_index.shape[1]): 207 | e = edge_index[:,ie] 208 | node2edge[int(e[0])].add(ie) 209 | node2edge[int(e[1])].add(ie) 210 | return node2edge 211 | -------------------------------------------------------------------------------- /Link_prediction_model/layer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch_geometric.nn import SAGEConv, GCNConv, GraphConv, TransformerConv 4 | from Link_prediction_baseline.heuristics import eva_heuristics_v2_dec25 5 | 6 | class Heuristics(torch.nn.Module): 7 | def __init__(self, name, data): 8 | super().__init__() 9 | self.name = name 10 | self.data = data 11 | def reset_parameters(self): 12 | return 13 | def forward(self, x, e): 14 | return x 15 | def get_score(self, edge_index): 16 | score = eva_heuristics_v2_dec25(self.name, self.data, edge_index) 17 | return score 18 | 19 | class BaseGNN(torch.nn.Module): 20 | def __init__(self, dropout): 21 | super(BaseGNN, self).__init__() 22 | self.convs = torch.nn.ModuleList() 23 | self.dropout = dropout 24 | 25 | def reset_parameters(self): 26 | for conv in self.convs: 27 | conv.reset_parameters() 28 | 29 | def forward(self, x, adj_t): 30 | for conv in self.convs[:-1]: 31 | x = conv(x, adj_t) 32 | x = F.relu(x) 33 | x = F.dropout(x, p=self.dropout, training=self.training) 34 | x = self.convs[-1](x, adj_t) 35 | return x 36 | 37 | class MLP(BaseGNN): 38 | def __init__(self, in_channels, hidden_channels, out_channels, num_layers, dropout): 39 | super().__init__(dropout) 40 | for i in range(num_layers): 41 | first_channels = in_channels if i == 0 else hidden_channels 42 | second_channels = out_channels if i == num_layers - 1 else hidden_channels 43 | self.convs.append(torch.nn.Linear(first_channels, second_channels)) 44 | 45 | def forward(self, x, adj_t): 46 | for conv in self.convs[:-1]: 47 | x = conv(x) 48 | x = F.relu(x) 49 | x = F.dropout(x, p=self.dropout, training=self.training) 50 | x = self.convs[-1](x) 51 | return x 52 | 53 | class SAGE(BaseGNN): 54 | def __init__(self, in_channels, hidden_channels, out_channels, num_layers, dropout): 55 | super(SAGE, self).__init__(dropout) 56 | for i in range(num_layers): 57 | first_channels = in_channels if i == 0 else hidden_channels 58 | second_channels = out_channels if i == num_layers - 1 else hidden_channels 59 | self.convs.append(SAGEConv(first_channels, second_channels)) 60 | 61 | class GCN(BaseGNN): 62 | def __init__(self, in_channels, hidden_channels, out_channels, num_layers, dropout): 63 | super(GCN, self).__init__(dropout) 64 | for i in range(num_layers): 65 | first_channels = in_channels if i == 0 else hidden_channels 66 | second_channels = out_channels if i == num_layers - 1 else hidden_channels 67 | self.convs.append(GCNConv(first_channels, second_channels, normalize=False)) 68 | 69 | class WSAGE(BaseGNN): 70 | def __init__(self, in_channels, hidden_channels, out_channels, num_layers, dropout): 71 | super(WSAGE, self).__init__(dropout) 72 | for i in range(num_layers): 73 | first_channels = in_channels if i == 0 else hidden_channels 74 | second_channels = out_channels if i == num_layers - 1 else hidden_channels 75 | self.convs.append(GraphConv(first_channels, second_channels)) 76 | 77 | class Transformer(BaseGNN): 78 | def __init__(self, in_channels, hidden_channels, out_channels, num_layers, dropout): 79 | super(Transformer, self).__init__(dropout) 80 | for i in range(num_layers): 81 | first_channels = in_channels if i == 0 else hidden_channels 82 | second_channels = out_channels if i == num_layers - 1 else hidden_channels 83 | self.convs.append(TransformerConv(first_channels, second_channels)) 84 | 85 | class MLPPredictor(torch.nn.Module): 86 | def __init__(self, in_channels, hidden_channels, out_channels, num_layers, dropout): 87 | super(MLPPredictor, self).__init__() 88 | self.lins = torch.nn.ModuleList() 89 | for i in range(num_layers): 90 | first_channels = in_channels if i == 0 else hidden_channels 91 | second_channels = out_channels if i == num_layers - 1 else hidden_channels 92 | self.lins.append(torch.nn.Linear(first_channels, second_channels)) 93 | self.dropout = dropout 94 | 95 | def reset_parameters(self): 96 | for lin in self.lins: 97 | lin.reset_parameters() 98 | 99 | def forward(self, x_i, x_j): 100 | x = x_i * x_j 101 | for lin in self.lins[:-1]: 102 | x = lin(x) 103 | x = F.relu(x) 104 | x = F.dropout(x, p=self.dropout, training=self.training) 105 | x = self.lins[-1](x) 106 | return x 107 | 108 | class MLPCatPredictor(torch.nn.Module): 109 | def __init__(self, in_channels, hidden_channels, out_channels, num_layers, dropout): 110 | super(MLPCatPredictor, self).__init__() 111 | self.lins = torch.nn.ModuleList() 112 | in_channels = 2 * in_channels 113 | for i in range(num_layers): 114 | first_channels = in_channels if i == 0 else hidden_channels 115 | second_channels = out_channels if i == num_layers - 1 else hidden_channels 116 | self.lins.append(torch.nn.Linear(first_channels, second_channels)) 117 | self.dropout = dropout 118 | 119 | def reset_parameters(self): 120 | for lin in self.lins: 121 | lin.reset_parameters() 122 | 123 | def forward(self, x_i, x_j): 124 | x1 = torch.cat([x_i, x_j], dim=-1) 125 | x2 = torch.cat([x_j, x_i], dim=-1) 126 | for lin in self.lins[:-1]: 127 | x1, x2 = lin(x1), lin(x2) 128 | x1, x2 = F.relu(x1), F.relu(x2) 129 | x1 = F.dropout(x1, p=self.dropout, training=self.training) 130 | x2 = F.dropout(x2, p=self.dropout, training=self.training) 131 | x1 = self.lins[-1](x1) 132 | x2 = self.lins[-1](x2) 133 | x = (x1 + x2)/2 134 | return x 135 | 136 | class MLPDotPredictor(torch.nn.Module): 137 | def __init__(self, in_channels, hidden_channels, num_layers, dropout): 138 | super(MLPDotPredictor, self).__init__() 139 | self.lins = torch.nn.ModuleList() 140 | self.lins.append(torch.nn.Linear(in_channels, hidden_channels)) 141 | for _ in range(num_layers - 1): 142 | self.lins.append(torch.nn.Linear(hidden_channels, hidden_channels)) 143 | self.dropout = dropout 144 | 145 | def reset_parameters(self): 146 | for lin in self.lins: 147 | lin.reset_parameters() 148 | 149 | def forward(self, x_i, x_j): 150 | for lin in self.lins: 151 | x_i, x_j = lin(x_i), lin(x_j) 152 | x_i, x_j = F.relu(x_i), F.relu(x_j) 153 | x_i, x_j = F.dropout(x_i, p=self.dropout, training=self.training), \ 154 | F.dropout(x_j, p=self.dropout, training=self.training) 155 | x = torch.sum(x_i * x_j, dim=-1) 156 | return x 157 | 158 | class MLPBilPredictor(torch.nn.Module): 159 | def __init__(self, in_channels, hidden_channels, num_layers, dropout): 160 | super(MLPBilPredictor, self).__init__() 161 | self.lins = torch.nn.ModuleList() 162 | self.lins.append(torch.nn.Linear(in_channels, hidden_channels)) 163 | for _ in range(num_layers - 1): 164 | self.lins.append(torch.nn.Linear(hidden_channels, hidden_channels)) 165 | self.bilin = torch.nn.Linear(hidden_channels, hidden_channels, bias=False) 166 | self.dropout = dropout 167 | 168 | def reset_parameters(self): 169 | for lin in self.lins: 170 | lin.reset_parameters() 171 | self.bilin.reset_parameters() 172 | 173 | def forward(self, x_i, x_j): 174 | for lin in self.lins: 175 | x_i, x_j = lin(x_i), lin(x_j) 176 | x_i, x_j = F.relu(x_i), F.relu(x_j) 177 | x_i, x_j = F.dropout(x_i, p=self.dropout, training=self.training), \ 178 | F.dropout(x_j, p=self.dropout, training=self.training) 179 | x = torch.sum(self.bilin(x_i) * x_j, dim=-1) 180 | return x 181 | 182 | class DotPredictor(torch.nn.Module): 183 | def __init__(self): 184 | super(DotPredictor, self).__init__() 185 | 186 | def reset_parameters(self): 187 | return 188 | 189 | def forward(self, x_i, x_j): 190 | x = torch.sum(x_i * x_j, dim=-1) 191 | return x 192 | 193 | class BilinearPredictor(torch.nn.Module): 194 | def __init__(self, hidden_channels): 195 | super(BilinearPredictor, self).__init__() 196 | self.bilin = torch.nn.Linear(hidden_channels, hidden_channels, bias=False) 197 | 198 | def reset_parameters(self): 199 | self.bilin.reset_parameters() 200 | 201 | def forward(self, x_i, x_j): 202 | x = torch.sum(self.bilin(x_i) * x_j, dim=-1) 203 | return x -------------------------------------------------------------------------------- /Link_prediction_model/logger.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import torch 3 | import sys 4 | 5 | class Logger(object): 6 | def __init__(self, runs, info=None): 7 | self.info = info 8 | self.results = [[] for _ in range(runs)] 9 | 10 | def add_result(self, run, result): 11 | # assert len(result) == 2 12 | assert run >= 0 and run < len(self.results) 13 | self.results[run].append(result) 14 | 15 | def print_statistics(self, run=None, f=sys.stdout, last_best=False): 16 | if run is not None: 17 | result = 100 * torch.tensor(self.results[run]) 18 | if last_best: 19 | # get last max value index by reversing result tensor 20 | argmax = result.size(0) - result[:, 0].flip(dims=[0]).argmax().item() - 1 21 | else: 22 | argmax = result[:, 0].argmax().item() 23 | print(f'Run {run + 1:02d}:', file=f) 24 | print(f'Highest Valid: {result[:, 0].max():.2f}', file=f) 25 | print(f'Highest Eval Point: {argmax + 1}', file=f) 26 | print(f' Final Test: {result[argmax, 1]:.2f}', file=f) 27 | else: 28 | result = 100 * torch.tensor(self.results) 29 | best_results = [] 30 | for r in result: 31 | valid = r[:, 0].max().item() 32 | if last_best: 33 | # get last max value index by reversing result tensor 34 | argmax = r.size(0) - r[:, 0].flip(dims=[0]).argmax().item() - 1 35 | else: 36 | argmax = r[:, 0].argmax().item() 37 | test = r[argmax, 1].item() 38 | best_results.append((valid, test)) 39 | 40 | best_result = torch.tensor(best_results) 41 | 42 | print(f'All runs:', file=f) 43 | r = best_result[:, 0] 44 | print(f'Highest Valid: {r.mean():.2f} {r.std():.2f}', file=f) 45 | r = best_result[:, 1] 46 | print(f' Final Test: {r.mean():.2f} {r.std():.2f}', file=f) 47 | -------------------------------------------------------------------------------- /Link_prediction_model/loss.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import torch 3 | 4 | def auc_loss(pos_out, neg_out, num_neg): 5 | pos_out = torch.reshape(pos_out, (-1, 1)) 6 | neg_out = torch.reshape(neg_out, (-1, num_neg)) 7 | return torch.square(1 - (pos_out - neg_out)).sum() 8 | 9 | def adaptive_auc_loss(pos_out, neg_out, num_neg, weight): 10 | weight = torch.reshape(weight, (-1, 1)) 11 | pos_out = torch.reshape(pos_out, (-1, 1)) 12 | neg_out = torch.reshape(neg_out, (-1, num_neg)) 13 | return (weight*torch.square(1 - (pos_out - neg_out))).sum() 14 | 15 | def log_rank_loss(pos_out, neg_out, num_neg): 16 | pos_out = torch.reshape(pos_out, (-1, 1)) 17 | neg_out = torch.reshape(neg_out, (-1, num_neg)) 18 | return -torch.log(torch.sigmoid(pos_out - neg_out) + 1e-15).mean() 19 | 20 | def ce_loss(pos_out, neg_out): 21 | pos_loss = -torch.log(torch.sigmoid(pos_out) + 1e-15).mean() 22 | neg_loss = -torch.log(1 - torch.sigmoid(neg_out) + 1e-15).mean() 23 | return pos_loss + neg_loss 24 | 25 | def info_nce_loss(pos_out, neg_out, num_neg): 26 | pos_out = torch.reshape(pos_out, (-1, 1)) 27 | neg_out = torch.reshape(neg_out, (-1, num_neg)) 28 | pos_exp = torch.exp(pos_out) 29 | neg_exp = torch.sum(torch.exp(neg_out), 1, keepdim=True) 30 | return -torch.log(pos_exp / (pos_exp + neg_exp) + 1e-15).mean() 31 | -------------------------------------------------------------------------------- /Link_prediction_model/negative_sample.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import torch 3 | from torch_geometric.utils import negative_sampling, add_self_loops 4 | 5 | def global_neg_sample(edge_index, num_nodes, num_samples, 6 | num_neg, method='sparse'): 7 | new_edge_index, _ = add_self_loops(edge_index) 8 | neg_edge = negative_sampling(new_edge_index, num_nodes=num_nodes, 9 | num_neg_samples=num_samples * num_neg, method=method) 10 | 11 | neg_src = neg_edge[0] 12 | neg_dst = neg_edge[1] 13 | if neg_edge.size(1) < num_samples * num_neg: 14 | k = num_samples * num_neg - neg_edge.size(1) 15 | rand_index = torch.randperm(neg_edge.size(1))[:k] 16 | neg_src = torch.cat((neg_src, neg_src[rand_index])) 17 | neg_dst = torch.cat((neg_dst, neg_dst[rand_index])) 18 | return torch.reshape(torch.stack( 19 | (neg_src, neg_dst), dim=-1), (-1, num_neg, 2)) 20 | 21 | def global_perm_neg_sample(edge_index, num_nodes, num_samples, 22 | num_neg, method='sparse'): 23 | new_edge_index, _ = add_self_loops(edge_index) 24 | neg_edge = negative_sampling(new_edge_index, num_nodes=num_nodes, 25 | num_neg_samples=num_samples, method=method) 26 | return sample_perm_copy(neg_edge, num_samples, num_neg) 27 | 28 | def local_neg_sample(pos_edges, num_nodes, num_neg, random_src=False): 29 | if random_src: 30 | neg_src = pos_edges[torch.arange(pos_edges.size(0)), torch.randint( 31 | 0, 2, (pos_edges.size(0),), dtype=torch.long)] 32 | else: 33 | neg_src = pos_edges[:, 0] 34 | neg_src = torch.reshape(neg_src, (-1, 1)).repeat(1, num_neg) 35 | neg_src = torch.reshape(neg_src, (-1,)) 36 | neg_dst = torch.randint( 37 | 0, num_nodes, (num_neg * pos_edges.size(0),), dtype=torch.long) 38 | 39 | return torch.reshape(torch.stack( 40 | (neg_src, neg_dst), dim=-1), (-1, num_neg, 2)) 41 | 42 | def sample_perm_copy(edge_index, target_num_sample, num_perm_copy): 43 | src = edge_index[0] 44 | dst = edge_index[1] 45 | if edge_index.size(1) < target_num_sample: 46 | k = target_num_sample - edge_index.size(1) 47 | rand_index = torch.randperm(edge_index.size(1))[:k] 48 | src = torch.cat((src, src[rand_index])) 49 | dst = torch.cat((dst, dst[rand_index])) 50 | tmp_src = src 51 | tmp_dst = dst 52 | for i in range(num_perm_copy - 1): 53 | rand_index = torch.randperm(target_num_sample) 54 | src = torch.cat((src, tmp_src[rand_index])) 55 | dst = torch.cat((dst, tmp_dst[rand_index])) 56 | return torch.reshape(torch.stack( 57 | (src, dst), dim=-1), (-1, num_perm_copy, 2)) 58 | -------------------------------------------------------------------------------- /Link_prediction_model/utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import torch 3 | import numpy as np 4 | from .negative_sample import global_neg_sample, global_perm_neg_sample, local_neg_sample 5 | from utils import cal_recall 6 | 7 | def get_pos_neg_edges(split, split_edge, edge_index=None, num_nodes=None, neg_sampler_name=None, num_neg=None): 8 | if 'edge' in split_edge['train']: 9 | pos_edge = split_edge[split]['edge'] 10 | elif 'source_node' in split_edge['train']: 11 | source = split_edge[split]['source_node'] 12 | target = split_edge[split]['target_node'] 13 | pos_edge = torch.stack([source, target]).t() 14 | 15 | if split == 'train': 16 | if neg_sampler_name == 'local': 17 | neg_edge = local_neg_sample( 18 | pos_edge, 19 | num_nodes=num_nodes, 20 | num_neg=num_neg) 21 | elif neg_sampler_name == 'global': # pos_edge: [720132, 2] ; neg_edge: [720132, 3, 2] 22 | neg_edge = global_neg_sample( 23 | edge_index, 24 | num_nodes=num_nodes, 25 | num_samples=pos_edge.size(0), 26 | num_neg=num_neg) 27 | else: 28 | neg_edge = global_perm_neg_sample( 29 | edge_index, 30 | num_nodes=num_nodes, 31 | num_samples=pos_edge.size(0), 32 | num_neg=num_neg) 33 | else: 34 | if 'edge' in split_edge['train']: 35 | neg_edge = split_edge[split]['edge_neg'] 36 | elif 'source_node' in split_edge['train']: 37 | target_neg = split_edge[split]['target_node_neg'] 38 | neg_per_target = target_neg.size(1) 39 | neg_edge = torch.stack([source.repeat_interleave(neg_per_target), 40 | target_neg.view(-1)]).t() 41 | return pos_edge, neg_edge 42 | 43 | def evaluate_hits(evaluator, pos_val_pred, neg_val_pred, 44 | pos_test_pred, neg_test_pred): 45 | results = {} 46 | for K in [20, 50, 100]: 47 | evaluator.K = K 48 | valid_hits = evaluator.eval({ 49 | 'y_pred_pos': pos_val_pred, 50 | 'y_pred_neg': neg_val_pred, 51 | })[f'hits@{K}'] 52 | test_hits = evaluator.eval({ 53 | 'y_pred_pos': pos_test_pred, 54 | 'y_pred_neg': neg_test_pred, 55 | })[f'hits@{K}'] 56 | 57 | results[f'Hits@{K}'] = (valid_hits, test_hits) 58 | 59 | return results 60 | 61 | def evaluate_mrr(evaluator, pos_val_pred, neg_val_pred, 62 | pos_test_pred, neg_test_pred): 63 | # neg_val_pred = neg_val_pred.view(pos_val_pred.shape[0], -1) 64 | neg_val_pred = neg_val_pred.view(neg_val_pred.shape[0], -1) 65 | # neg_test_pred = neg_test_pred.view(pos_test_pred.shape[0], -1) 66 | neg_test_pred = neg_test_pred.view(neg_test_pred.shape[0], -1) 67 | results = {} 68 | valid_mrr = evaluator.eval({ 69 | 'y_pred_pos': pos_val_pred, 70 | 'y_pred_neg': neg_val_pred, 71 | })['mrr_list'].mean().item() 72 | 73 | test_mrr = evaluator.eval({ 74 | 'y_pred_pos': pos_test_pred, 75 | 'y_pred_neg': neg_test_pred, 76 | })['mrr_list'].mean().item() 77 | 78 | results['MRR'] = (valid_mrr, test_mrr) 79 | 80 | return results 81 | 82 | def evaluate_recall_my(pos_train_pred, neg_train_pred, 83 | pos_val_pred, neg_val_pred, 84 | pos_test_pred, neg_test_pred, topk=None): 85 | results = {} 86 | recall_train = cal_recall(pos_train_pred, neg_train_pred, topk=topk) 87 | recall_valid = cal_recall(pos_val_pred, neg_val_pred, topk=topk) 88 | recall_test = cal_recall(pos_test_pred, neg_test_pred, topk=topk) 89 | results['recall@100%'] = (recall_train, recall_valid, recall_test) 90 | 91 | return results 92 | 93 | def gcn_normalization(adj_t): 94 | adj_t = adj_t.set_diag() 95 | deg = adj_t.sum(dim=1).to(torch.float) 96 | deg_inv_sqrt = deg.pow(-0.5) 97 | deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0 98 | adj_t = deg_inv_sqrt.view(-1, 1) * adj_t * deg_inv_sqrt.view(1, -1) 99 | return adj_t 100 | 101 | def adj_normalization(adj_t): 102 | deg = adj_t.sum(dim=1).to(torch.float) 103 | deg_inv_sqrt = deg.pow(-1) 104 | deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0 105 | adj_t = deg_inv_sqrt.view(-1, 1) * adj_t 106 | return adj_t 107 | 108 | def generate_neg_dist_table(num_nodes, adj_t, power=0.75, table_size=1e8): 109 | table_size = int(table_size) 110 | adj_t = adj_t.set_diag() 111 | node_degree = adj_t.sum(dim=1).to(torch.float) 112 | node_degree = node_degree.pow(power) 113 | 114 | norm = float((node_degree).sum()) # float is faster than tensor when visited 115 | node_degree = node_degree.tolist() # list has fastest visit speed 116 | sample_table = np.zeros(table_size, dtype=np.int32) 117 | p = 0 118 | i = 0 119 | for j in range(num_nodes): 120 | p += node_degree[j] / norm 121 | while i < table_size and float(i) / float(table_size) < p: 122 | sample_table[i] = j 123 | i += 1 124 | sample_table = torch.from_numpy(sample_table) 125 | return sample_table 126 | -------------------------------------------------------------------------------- /MLP_model/__init__.py: -------------------------------------------------------------------------------- 1 | from utils import * 2 | 3 | class StudentBaseMLP(nn.Module): 4 | def __init__(self, args): 5 | super().__init__() 6 | self.args = args 7 | 8 | if args.StudentBaseMLP.dim_model==-1: 9 | dim_model = None 10 | else: 11 | dim_model = args.StudentBaseMLP.dim_model 12 | self.model = BlockResMLP(dims_in_out = args.StudentBaseMLP.dims_in_out, dim_model=dim_model, skip_conn_period=args.StudentBaseMLP.skip_conn_period, num_blocks=args.StudentBaseMLP.num_blocks) 13 | 14 | def forward(self, x, edge_index=None, mask=None): 15 | if mask is not None: 16 | x = x[mask] 17 | return self.model(x) 18 | def get_emb4linkp(self, x, edge_index, mask=None): 19 | # return ALL nodes 20 | return self.model(x) 21 | 22 | class BlockResMLP(nn.Module): 23 | def __init__(self, dims_in_out, num_blocks, skip_conn_period=2, dim_hidden=None, dim_model=None, activation=nn.GELU, bias=True, dropout=0.1): 24 | # dims_in_out: it is a 2-element list 25 | super().__init__() 26 | 27 | self.dims_in_out = dims_in_out 28 | self.dim_model = dim_model or min(max(dims_in_out), 256) 29 | self.dim_hidden = dim_hidden or int(self.dim_model*1.5)+2 30 | 31 | 32 | self.num_blocks = num_blocks 33 | 34 | self.in_proj = nn.Identity() if self.dim_model==dims_in_out[0] else nn.Linear(dims_in_out[0], self.dim_model) 35 | self.out_proj = nn.Identity() if self.dim_model==dims_in_out[1] else nn.Linear(self.dim_model, dims_in_out[1]) 36 | 37 | neurons = [self.dim_model] + [self.dim_hidden]*(skip_conn_period-1) + [self.dim_model] 38 | self.blocks = nn.ModuleList([getMLP(neurons, activation=activation, bias=bias, dropout=dropout, last_dropout=True) for _ in range(self.num_blocks-1)]) 39 | self.blocks.append(getMLP(neurons, activation=activation, bias=bias, dropout=dropout, last_dropout=False)) 40 | return 41 | 42 | def forward(self, x): 43 | x = self.in_proj(x) 44 | for block in self.blocks: 45 | h = x 46 | x = block(x) 47 | x = h + x 48 | x = self.out_proj(x) 49 | return x 50 | 51 | class SEMLP(nn.Module): 52 | # The implementation of Cold Brew's MLP. 53 | def __init__(self, args, data, teacherGNN): 54 | super().__init__() 55 | self.hidden_dim = 256 56 | self.args = args 57 | self.train_mask = data.train_mask 58 | self.test_mask = data.test_mask 59 | self.train_idx = data.train_idx # torch.where(self.train_mask==True)[0] 60 | self.test_idx = data.test_idx # torch.where(self.test_mask==True)[0] 61 | if self.args.batch_size>len(self.train_idx): 62 | print(f'\n\n Batch size too large...\n Changing batch_size from {self.args.batch_size} to {len(self.train_idx)}!\n\n') 63 | self.args.batch_size = len(self.train_idx) 64 | self.has_NCloss = False 65 | self.adj_pow = None 66 | self.topK_2_replace = args.SEMLP_topK_2_replace 67 | 68 | self.teacherGNN = teacherGNN 69 | self.part1 = None 70 | self.part2 = None 71 | self.alphas = nn.Parameter(torch.tensor([0.0001,0.0001]), requires_grad=True) 72 | if self.has_NCloss: 73 | self.out_proj = nn.Linear(self.hidden_dim, args.num_classes_bkup) 74 | return 75 | 76 | def forward_part1(self, x, edge_index=None, batch_idx=None): 77 | if self.has_NCloss and self.adj_pow is None: 78 | adj = graphUtils.normalize_adj(edge_index) 79 | self.adj_pow = graphUtils.sparse_power(adj, self.args.graphMLP_r) 80 | 81 | if self.part1 is None: 82 | if self.args.StudentMLP__dim_model==-1: 83 | dim_model = None 84 | else: 85 | dim_model = self.args.StudentBaseMLP.dim_model 86 | # ----- build part1 MLP module ----- 87 | neurons_io = [self.args.num_feats, self.teacherGNN.model.model.get_se_dim(x, edge_index)] 88 | if self.args.SEMLP_part1_arch == 'residual': # options: 'residual', '2layer', '3layer', '4layer' 89 | self.part1 = BlockResMLP(dims_in_out = neurons_io, dim_model=dim_model, skip_conn_period=self.args.StudentBaseMLP.skip_conn_period, num_blocks=self.args.StudentBaseMLP.num_blocks).to(self.args.device) 90 | else: 91 | nlayer = int(self.args.SEMLP_part1_arch[0]) 92 | neurons = [neurons_io[0]] + [256]*(nlayer-1) + [neurons_io[1]] 93 | self.part1 = getMLP(neurons, dropout=self.args.dropout_MLP).to(self.args.device) 94 | self.opt = self.optfun(self.parameters(),lr=self.args.lr, weight_decay=self.args.weight_decay) 95 | 96 | if batch_idx is not None: 97 | x = x[batch_idx] 98 | le_guess = self.part1(x) 99 | return le_guess 100 | 101 | def forward_part2(self, x, batch_idx=None, edge_index=None, ): 102 | if batch_idx is not None: 103 | x = x[batch_idx] 104 | if self.args.SEMLP__downgrade_to_MLP: 105 | part2_in = x 106 | else: 107 | part1_out = self.forward_part1(x, batch_idx).detach()*self.alphas[0] 108 | replaced = self.replacement(part1_out)*self.alphas[1] 109 | 110 | if self.args.SEMLP__include_part1out: 111 | part2_in = torch.cat([x, replaced, part1_out], dim=-1) 112 | else: 113 | part2_in = torch.cat([x[batch_idx], replaced], dim=-1) 114 | 115 | if self.part2 is None: 116 | if self.args.StudentMLP__dim_model==-1: 117 | dim_model = None 118 | else: 119 | dim_model = self.args.StudentBaseMLP.dim_model 120 | neurons_io = [part2_in.shape[-1], self.args.num_classes_bkup] 121 | 122 | if self.args.train_which=='GraphMLP': 123 | self.part2 = GraphMLP(self.args, self.train_mask).to(self.args.device) 124 | elif self.args.train_which=='StudentBaseMLP': 125 | self.part2 = BlockResMLP(dims_in_out = [self.args.num_feats, self.args.num_classes_bkup], dim_model=dim_model, skip_conn_period=self.args.StudentBaseMLP.skip_conn_period, num_blocks=self.args.StudentBaseMLP.num_blocks).to(self.args.device) 126 | else: 127 | neurons = [part2_in.shape[1], 256, self.args.num_classes_bkup] 128 | self.part2 = getMLP(neurons, dropout=self.args.dropout_MLP).to(self.args.device) 129 | 130 | self.opt = self.optfun(self.parameters(),lr=self.args.lr, weight_decay=self.args.weight_decay) 131 | 132 | if self.args.train_which=='GraphMLP': 133 | res = self.part2(part2_in, edge_index=edge_index, batch_idx=batch_idx) 134 | y = res.emb 135 | self.loss_NContrastive = res.loss_NContrastive 136 | else: 137 | y = self.part2(part2_in) 138 | return y 139 | 140 | def forward(self, x, edge_index=None): 141 | return 142 | 143 | def replacement(self, le_guess, node_idx=None): 144 | le_guess = le_guess.detach() 145 | res_N_feat = [] 146 | teacherSE_T = self.teacherSE.transpose(0,1) 147 | if node_idx is None: 148 | node_idx = np.arange(len(le_guess)) 149 | for idx in node_idx: 150 | attn_1N = torch.matmul(le_guess[[idx]], teacherSE_T) 151 | sortidx = attn_1N.argsort()[0] 152 | select = sortidx[-self.topK_2_replace:] 153 | attn_1N = F.softmax(attn_1N[:,select], dim=1) 154 | z_1_feat = torch.matmul(attn_1N, self.teacherSE[select]) 155 | res_N_feat.append(z_1_feat) 156 | return torch.cat(res_N_feat, dim=0).detach() 157 | 158 | class GraphMLP(nn.Module): 159 | # pytorch re-implementation of GRAPHMLP: https://arxiv.org/pdf/2106.04051.pdf 160 | def __init__(self, args, train_mask): 161 | super().__init__() 162 | self.dropout = 0.6 # reported in the paper 163 | self.hidden_dim = 256 # reported in the paper 164 | self.args = args 165 | neurons = [args.num_feats, self.hidden_dim, self.hidden_dim] 166 | self.model = getMLP(neurons, dropout=self.dropout).to(args.device) 167 | self.out_proj = nn.Linear(self.hidden_dim, args.num_classes_bkup) 168 | self.train_mask = train_mask 169 | self.train_idx = torch.where(self.train_mask==True)[0] 170 | if self.args.batch_size>len(self.train_idx): 171 | print(f'\n\n Batch size too large...\n Changing batch_size from {self.args.batch_size} to {len(self.train_idx)}!\n\n') 172 | self.args.batch_size = len(self.train_idx) 173 | self.adj_pow = None 174 | 175 | def forward(self, x, edge_index=None, batch_idx=None): 176 | if self.adj_pow is None: 177 | adj = graphUtils.normalize_adj(edge_index) 178 | self.adj_pow = graphUtils.sparse_power(adj, self.args.graphMLP_r) 179 | z = self.model(x) 180 | info = D() 181 | info.loss_NContrastive = get_neighbor_contrastive_loss(z, self.adj_pow, batch_idx, self.args.graphMLP_tau) 182 | info.emb = self.out_proj(z) 183 | return info 184 | 185 | def get_emb4linkp(self, x, edge_index, mask=None): 186 | # return ALL nodes 187 | raise NotImplementedError 188 | return self.model(x) 189 | 190 | def get_neighbor_contrastive_loss(z, adj_pow, batch_idx, tau): 191 | mask = torch.eye(len(z)).to(z.device) 192 | simz = (1 - mask) * torch.exp(cosine_sim(z)/tau) # shape: [B, B] 193 | adj_pow = graphUtils.crop_adj_to_subgraph(adj_pow, batch_idx).to_dense() # shape: [B, B] 194 | numerator = (adj_pow*simz).sum(dim=1, keepdim=False) # 1D tensor 195 | denominator = simz.sum(dim=1, keepdim=False) # 1D tensor 196 | nonzero = torch.where(numerator!=0)[0] 197 | loss_NContrastive = - torch.mean(torch.log(numerator[nonzero]/denominator[nonzero])) 198 | return loss_NContrastive 199 | 200 | def cosine_sim(x): 201 | # This function returns the pair-wise cosine semilarity. 202 | # x.shape = [N_nodes, 256] 203 | # returned shape: [N_nodes, N_nodes] 204 | x_dis = x @ x.T 205 | x_sum = torch.norm(x, p=2, dim=1, keepdim=True) 206 | x_sum = x_sum @ x_sum.T 207 | x_dis = x_dis * (x_sum ** (-1)) 208 | return x_dis 209 | -------------------------------------------------------------------------------- /NOTICE: -------------------------------------------------------------------------------- 1 | Cold-Brew 2 | Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 3 | -------------------------------------------------------------------------------- /figs/coldbrew.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/gnn-tail-generalization/e6f85eea8721059819d1e74c3d274f3beb70f70c/figs/coldbrew.png -------------------------------------------------------------------------------- /figs/gnns.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/gnn-tail-generalization/e6f85eea8721059819d1e74c3d274f3beb70f70c/figs/gnns.png -------------------------------------------------------------------------------- /figs/longtail.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/gnn-tail-generalization/e6f85eea8721059819d1e74c3d274f3beb70f70c/figs/longtail.png -------------------------------------------------------------------------------- /figs/motivation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/gnn-tail-generalization/e6f85eea8721059819d1e74c3d274f3beb70f70c/figs/motivation.png -------------------------------------------------------------------------------- /func_libs-deprecated.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import numpy as np 5 | import numpy.linalg as la 6 | import os 7 | from time import time as timer 8 | import time 9 | import pickle 10 | import copy 11 | import gc 12 | import numpy.linalg as la 13 | from tqdm import tqdm 14 | import torch 15 | import torch.nn as nn 16 | import torch.nn.functional as F 17 | import torch.optim as optim 18 | import matplotlib.pyplot as plt 19 | 20 | 21 | class C: pass 22 | class D: pass 23 | choice = np.random.choice 24 | join = os.path.join 25 | 26 | 27 | 28 | def AcontainsB(A,listB): 29 | # A: string; listB: list of strings 30 | for s in listB: 31 | if s in A: return True 32 | return False 33 | 34 | def getMLP(neurons, activation=nn.GELU, bias=True, dropout=0.1, last_dropout=False, normfun='layernorm'): 35 | # How to access parameters in module: replace printed < model.0.weight > to < model._modules['0'].weight > 36 | # neurons: all n+1 dims from input to output 37 | # len(neurons) = n+1 38 | # num of params layers = n 39 | # num of activations = n-1 40 | if len(neurons) in [0,1]: 41 | return nn.Identity() 42 | if len(neurons) == 2: 43 | return nn.Linear(*neurons) 44 | 45 | nn_list = [] 46 | n = len(neurons)-1 47 | for i in range(n-1): 48 | if normfun=='layernorm': 49 | norm = nn.LayerNorm(neurons[i+1]) 50 | elif normfun=='batchnorm': 51 | norm = nn.BatchNorm1d(neurons[i+1]) 52 | nn_list.extend([nn.Linear(neurons[i], neurons[i+1], bias=bias), norm, activation(), nn.Dropout(dropout)]) 53 | 54 | nn_list.extend([nn.Linear(neurons[n-1], neurons[n], bias=bias)]) 55 | if last_dropout: 56 | nn_list.extend([nn.Dropout(dropout)]) 57 | return nn.Sequential(*nn_list) 58 | 59 | 60 | 61 | 62 | def get_partial_sorted_idx(arr, mode='top25'): 63 | # mode: top25, bottom25, top50, bottom50; top = smaller 64 | arr = tonp(arr).reshape(-1) 65 | if 'top' in mode: 66 | idx = np.where(arr<=np.median(arr))[0] 67 | else: 68 | idx = np.where(arr>=np.median(arr))[0] 69 | 70 | arr1 = arr[idx] 71 | if mode in ['top25','top12','top6','top3']: 72 | idx = np.where(arr<=np.median(arr1))[0] 73 | elif mode in ['bottom25','bottom12','bottom6','bottom3']: 74 | idx = np.where(arr>=np.median(arr1))[0] 75 | 76 | arr1 = arr[idx] 77 | if mode in ['top12','top6','top3']: 78 | idx = np.where(arr<=np.median(arr1))[0] 79 | elif mode in ['bottom12','bottom6','bottom3']: 80 | idx = np.where(arr>=np.median(arr1))[0] 81 | 82 | arr1 = arr[idx] 83 | if mode in ['top6','top3']: 84 | idx = np.where(arr<=np.median(arr1))[0] 85 | elif mode in ['bottom6','bottom3']: 86 | idx = np.where(arr>=np.median(arr1))[0] 87 | 88 | arr1 = arr[idx] 89 | if mode in ['top3']: 90 | idx = np.where(arr<=np.median(arr1))[0] 91 | elif mode in ['bottom3']: 92 | idx = np.where(arr>=np.median(arr1))[0] 93 | 94 | return idx 95 | 96 | def tonp(arr): 97 | if type(arr) is torch.Tensor: 98 | return arr.detach().cpu().data.numpy() 99 | else: 100 | return np.asarray(arr) 101 | 102 | 103 | 104 | def toitem(arr,round=True): 105 | arr1 = tonp(arr) 106 | value = arr1.reshape(-1)[0] 107 | if round: 108 | value = np.round(value,3) 109 | assert arr1.size==1 110 | return value 111 | 112 | 113 | def save_model(net, cwd): # June 2021 114 | torch.save(net.state_dict(), cwd) 115 | print(f"‹‹‹‹‹‹‹--- Saved @ :{cwd}\n\n\n") 116 | 117 | 118 | 119 | def load_model(net, cwd, verbose=True, strict=True, multiGPU=False): 120 | def load_multiGPUModel(network, cwd): 121 | network_dict = torch.load(cwd) 122 | # create new OrderedDict that does not contain `module.` 123 | from collections import OrderedDict 124 | new_state_dict = OrderedDict() 125 | for k, v in network_dict.items(): 126 | namekey = k[7:] # remove `module.` 127 | new_state_dict[namekey] = v 128 | # load params 129 | network.load_state_dict(new_state_dict) 130 | 131 | def load_singleGPUModel(network, cwd): 132 | network_dict = torch.load(cwd, map_location=lambda storage, loc: storage) 133 | network.load_state_dict(network_dict, strict=strict) 134 | 135 | if os.path.exists(cwd): 136 | if not multiGPU: 137 | load_singleGPUModel(net, cwd) 138 | else: 139 | load_multiGPUModel(net, cwd) 140 | 141 | if verbose: print(f"---›››› LOAD success! from {cwd}\n\n\n") 142 | else: 143 | if verbose: print(f"---›››› !!! FileNotFound when load_model: {cwd}\n\n\n") 144 | 145 | 146 | def viz(net, ttl=''): 147 | viz_ = [] 148 | flop_est = 0 149 | for i, (name, p) in enumerate(net.named_parameters()): 150 | print(f'{name:36} {list(p.size())}') 151 | _size = list(p.size()) 152 | viz_.append((name, p.numel(), _size)) 153 | if len(_size)==2: flop_est += _size[0]*_size[1] 154 | 155 | ttl = str(type(net)) if ttl=='' else ttl 156 | print(f'\nAbove is viz for: {ttl}.\n\tDevice is: {p.device}\n\tN_groups = {len(viz_)}\n\tTotal params = {numParamsOf(net)}\n\tMLP FLOP ~= {flop_est}') 157 | 158 | return 159 | 160 | 161 | def wzRec(datas, ttl='', want_save_npy=False, npy_dir='', save_history_fig=True): 162 | # this function save the input datas into two places: ___wIns___.pdf and a record history in wIns folder; 163 | # datas: 1D or 2D of the same meaning, multiple collections 164 | # two options: 165 | # want_save_npy: save data or not 166 | # save_history_fig: save history record fig or not 167 | 168 | if type(datas) is torch.Tensor: 169 | datas = datas.detach().cpu().data.numpy() 170 | 171 | if save_history_fig: 172 | os.makedirs(f'wIns',exist_ok=1) 173 | if want_save_npy: 174 | npy_fname = 'some_arr' if ttl == '' else ttl 175 | recDir = join('wIns/Recs', npy_dir) 176 | fDirName = f'{recDir}/{npy_fname}.npy' 177 | os.makedirs(recDir, exist_ok=1) 178 | np.save(fDirName, datas) 179 | else: 180 | fDirName = 'data not saved' 181 | 182 | datas = np.asarray(datas) 183 | plt.close('all') 184 | plt.figure() 185 | if len(datas.shape)==1: 186 | min_v = min(datas) 187 | plt.plot(datas) 188 | plt.title(ttl+f', min = {min_v:5.4f}\n') 189 | plt.xlabel('step') 190 | elif len(datas.shape)==2: 191 | min_s = np.min(datas, axis=1) 192 | mean_min = np.mean(min_s) 193 | std_min = np.std(min_s) 194 | min_str = f'(avg={mean_min:5.4f},std={std_min:5.4f})' 195 | plot_ci(datas, ttl=ttl+f', min = {min_str}\n', xlb='step') 196 | else: 197 | raise ValueError('dim should be 1D or 2D') 198 | 199 | lt2 = time.strftime("%Y-%m-%d--%H_%M_%S", time.localtime()) 200 | figDirName = 'fig not saved' 201 | if save_history_fig: 202 | figDirName = f'wIns/{ttl}__{lt2}.jpg' 203 | plt.savefig(figDirName) 204 | plt.savefig('___wIns___.pdf',bbox_inches='tight') 205 | plt.show() 206 | 207 | return figDirName, fDirName 208 | 209 | 210 | 211 | 212 | 213 | def figure(): 214 | plt.figure(); plt.pause(0.01) 215 | 216 | 217 | def plot_many(arr_list, legends, ttl='', save_history_fig=True, have_line_yon='y', marker_size_bos='s'): 218 | # arr_list is a list of 1D arrays, lengths can be different 219 | assert len(arr_list)==len(legends) 220 | for i in range(len(arr_list)): 221 | plt.plot(arr_list[i], random_line_marker(have_line_yon=have_line_yon, marker_size_bos=marker_size_bos), label = legends[i],linewidth=1.2, markersize=5) 222 | plt.legend() 223 | plt.title(ttl) 224 | if save_history_fig: 225 | os.makedirs(f'wIns',exist_ok=1) 226 | lt2 = time.strftime("%Y-%m-%d--%H_%M_%S", time.localtime()) 227 | figDirName = f'wIns/{ttl}__{lt2}.jpg' 228 | plt.savefig(figDirName) 229 | plt.savefig('___wIns___.pdf', bbox_inches='tight') 230 | # plt.show() 231 | return 232 | 233 | 234 | def random_line_marker(have_line_yon='y', marker_size_bos = 's'): 235 | # have_line_yon take values from: ['y', 'n', 'o']; 'o' means doesn't care. 236 | # marker_size_bos: ['b',s','o'] for 'big', 'small', 'not care' 237 | 238 | if marker_size_bos=='s': 239 | mk = ['', 'x','.',',','1','2','3','4','*','+','|','_'] 240 | elif marker_size_bos=='o': 241 | mk = ['', 'x','.',',','o','v','^','<','>','1','2','3','4','s','p','*','h','H','+','x','D','d','|','_'] 242 | elif marker_size_bos=='b': 243 | mk = ['o','v','^','<','>','s','p','h','H','x','D','d'] 244 | 245 | 246 | if have_line_yon=='y': 247 | sty = ['-.','-','--',':'] 248 | elif have_line_yon=='o': 249 | sty = ['','-.','-','--',':'] 250 | elif have_line_yon=='n': 251 | sty = [''] 252 | 253 | return np.random.choice(mk)+np.random.choice(sty) 254 | 255 | 256 | class graphUtils: 257 | example_edge_index = torch.tensor([[0,0,1,1,1,2],[0,1,0,1,2,2]]) # 3 nodes, 6 edges; not symmetric. 258 | @staticmethod 259 | def remove_self_loops(edge_index): 260 | mask = edge_index[0] != edge_index[1] 261 | edge_index = edge_index[:, mask] 262 | return edge_index 263 | @staticmethod 264 | def add_self_loops(edge_index, num_nodes=None): 265 | if num_nodes is None: 266 | num_nodes = int(edge_index.max()+1) 267 | loop_index = torch.arange(0, num_nodes, dtype=torch.long, device=edge_index.device) 268 | loop_index = loop_index.unsqueeze(0).repeat(2, 1) 269 | edge_index = torch.cat([edge_index, loop_index], dim=1) 270 | return edge_index 271 | @staticmethod 272 | def edge_index_to_sparse_adj(edge_index, num_nodes=None, edge_weight=None): 273 | if num_nodes is None: 274 | num_nodes = int(edge_index.max()+1) 275 | if edge_weight is None: 276 | v = [1] * edge_index.shape[1] 277 | adj = torch.sparse_coo_tensor(edge_index, v, size=(num_nodes, num_nodes), device=edge_index.device).float() 278 | return adj 279 | @staticmethod 280 | def normalize_adj(edge_index, num_nodes=None): 281 | # math: 282 | # A is without self-loop 283 | # A_hat = A + I 284 | # Dii = A_hat.sum(dim=1) 285 | # A_tilde = D^(-1/2) A_hat D^(-1/2) 286 | if num_nodes is None: 287 | num_nodes = int(edge_index.max()+1) 288 | edge_index = graphUtils.remove_self_loops(edge_index) 289 | edge_index = graphUtils.add_self_loops(edge_index, num_nodes) 290 | Adj = graphUtils.edge_index_to_sparse_adj(edge_index, num_nodes) 291 | D_mtxv = torch.sparse.sum(Adj, dim=1).values() 292 | assert len(D_mtxv)==num_nodes 293 | D_mtxinv = torch.diag(D_mtxv**(-1/2)).to_sparse() 294 | A_tilde = torch.sparse.mm(D_mtxinv, torch.sparse.mm(Adj, D_mtxinv)) 295 | return A_tilde.coalesce() 296 | @staticmethod 297 | def sparse_power(x, N): 298 | x0 = x 299 | assert N>0 300 | for n in range(N-1): 301 | x = torch.sparse.mm(x, x0) 302 | return x.coalesce() 303 | @staticmethod 304 | def subgraph(subset, edge_index, edge_attr=None, relabel_nodes=True, 305 | num_nodes=None): 306 | device = edge_index.device 307 | if isinstance(subset, list) or isinstance(subset, tuple): 308 | subset = torch.tensor(subset, dtype=torch.long) 309 | if num_nodes is None: 310 | num_nodes = int(edge_index.max()+1) 311 | n_mask = torch.zeros(num_nodes, dtype=torch.bool) 312 | n_mask[subset] = 1 # convert idx or mask to mask 313 | if relabel_nodes: 314 | n_idx = torch.zeros(num_nodes, dtype=torch.long, device=device) 315 | n_idx[subset] = torch.arange(subset.shape[0], device=device) 316 | mask = n_mask[edge_index[0]] & n_mask[edge_index[1]] 317 | edge_index = edge_index[:, mask] 318 | edge_attr = edge_attr[mask] if edge_attr is not None else None 319 | if relabel_nodes: 320 | edge_index = n_idx[edge_index] 321 | return edge_index, edge_attr 322 | @staticmethod 323 | def crop_adj_to_subgraph(adj_mtx, subset_idx): 324 | if isinstance(subset_idx, list) or isinstance(subset_idx, tuple): 325 | subset_idx = torch.tensor(subset_idx, dtype=torch.long) 326 | edge_index, edge_attr = adj_mtx.indices(), adj_mtx.values() 327 | edge_index, edge_attr = graphUtils.subgraph(subset_idx, edge_index, edge_attr, relabel_nodes=True, num_nodes=adj_mtx.shape[0]) 328 | n2 = len(subset_idx) 329 | adj2 = torch.sparse_coo_tensor(edge_index, edge_attr, size=(n2, n2), device=adj_mtx.device).float() 330 | return adj2 331 | @staticmethod 332 | def edge_index_to_symmetric(edge_index): 333 | adj = graphUtils.edge_index_to_sparse_adj(edge_index) 334 | adjT = graphUtils.edge_index_to_sparse_adj(edge_index[[1,0]]) 335 | adj = (adj + adjT).coalesce() 336 | edge_index = adj.indices() 337 | return edge_index 338 | @staticmethod 339 | def homo_g_to_data(g): 340 | 341 | # x = g.ndata['x'] 342 | # y = g.ndata['y'] 343 | 344 | edges_two_tuple = g.edges() 345 | data = D() 346 | # data.x = x 347 | # data.y = y 348 | data.edge_index = torch.stack(edges_two_tuple) 349 | 350 | return data 351 | 352 | 353 | 354 | def example_run(): 355 | print('--------- demo normalize ---------') 356 | edge_index = graphUtils.example_edge_index 357 | adj = graphUtils.normalize_adj(edge_index) 358 | 359 | print('--------- demo subgraph ---------') 360 | adj2 = graphUtils.crop_adj_to_subgraph(adj, [0,1]) 361 | 362 | print('--------- demo to-dense ---------') 363 | adj2_dense = adj2.to_dense() 364 | adj2 = adj2_dense.to_sparse() 365 | 366 | print('--------- demo A^K ---------') 367 | adj3 = graphUtils.sparse_power(adj, 3) 368 | print('--------- demo to-symmetric ---------') 369 | edge_index2 = graphUtils.edge_index_to_symmetric(edge_index) 370 | print(edge_index, '\n\n',edge_index2) 371 | return adj 372 | 373 | def demo_bipt_graph(): 374 | import dgl 375 | 376 | edge_in_query = torch.tensor([0, 0, 1]) 377 | edge_in_asin = torch.tensor([1, 2, 4]) 378 | N_nodes_query = 6 379 | N_nodes_asin = 7 380 | dim_feature = 10 381 | 382 | 383 | # edge_in_query = datas.edge_in_query 384 | # edge_in_asin = datas.edge_in_asin 385 | # N_nodes_query = datas.N_nodes_query 386 | # N_nodes_asin = datas.N_nodes_asin 387 | # dim_feature = 768 388 | 389 | x_asin = torch.randn(N_nodes_asin, dim_feature) 390 | x_query = torch.randn(N_nodes_query, dim_feature) 391 | y_asin = torch.randn(N_nodes_asin, 1) 392 | y_query = torch.randn(N_nodes_query, 1) 393 | 394 | graph_data = { 395 | ('query', 'qa', 'asin'): (edge_in_query, edge_in_asin), 396 | ('asin', 'qareverse', 'query'): (edge_in_asin, edge_in_query), 397 | } 398 | num_nodes_dict = {'query': N_nodes_query, 'asin': N_nodes_asin} 399 | hg = dgl.heterograph(graph_data, num_nodes_dict=num_nodes_dict) 400 | 401 | 402 | hg.nodes('asin') # -> full index tensor 403 | 404 | hg.nodes['asin'].data['x'] = x_asin 405 | hg.nodes['query'].data['x'] = x_query 406 | hg.nodes['asin'].data['y'] = y_asin 407 | hg.nodes['query'].data['y'] = y_query 408 | 409 | 410 | 411 | 412 | g = dgl.to_homogeneous(hg, ndata=['x', 'y']) 413 | data = graphUtils.homo_g_to_data(g) 414 | 415 | 416 | print('g', g, 'data.edge_index', data.edge_index.shape) 417 | return hg, g, data 418 | 419 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import gc 5 | import json 6 | import numpy as np 7 | import os 8 | import random 9 | import torch 10 | 11 | from base_options import BaseOptions 12 | from datetime import datetime 13 | 14 | def main(): 15 | args = BaseOptions().get_arguments() 16 | if args.exp_mode == 'coldbrew': 17 | from trainer_node_classification import trainer 18 | elif args.exp_mode == 'I2_GTL': 19 | from trainer_link_prediction import trainer # expected to contain codes for extended work other than cold brew; coming soon. 20 | 21 | if args.prog: tensorRex(None, args.prog, args.rexName) 22 | full_recs_3D = [] 23 | for seed in range(args.N_exp): 24 | print(f'seed (which_run) = <{seed}>') 25 | 26 | args.random_seed = seed 27 | set_seed(args) 28 | torch.cuda.empty_cache() 29 | trnr = trainer(args, seed) 30 | 31 | results_arr2D = trnr.main() 32 | 33 | full_recs_3D.append(results_arr2D) # dimensions: [seeds, record_type, epochs] 34 | del trnr 35 | torch.cuda.empty_cache() 36 | gc.collect() 37 | 38 | if args.prog: tensorRex(full_recs_3D, args.prog, args.rexName) 39 | 40 | def set_seed(args): 41 | torch.backends.cudnn.deterministic = True 42 | torch.backends.cudnn.benchmark = False 43 | if args.cuda and not torch.cuda.is_available(): # cuda is not available 44 | args.cuda = False 45 | if args.cuda: 46 | torch.cuda.manual_seed(args.random_seed) 47 | torch.cuda.manual_seed_all(args.random_seed) 48 | os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda_num) 49 | torch.manual_seed(args.random_seed) 50 | np.random.seed(args.random_seed) 51 | import random 52 | random.seed(args.random_seed) 53 | 54 | def tensorRex(dataND, prog, rexName, support_inequal_shape=True): 55 | # This version is better than the previous 1D version; use this. --Aug.12.2021 56 | # 1. Difference from previous version: previous one only deal with mean/std, which is 1D data; this one deal with N-D data; 57 | # 2. share in common: direct store the input data, and do NOT increase data dimensions. 58 | 59 | # support function for batch running 60 | # how to use: call this function twice, at the begining & the end; set dataND=None at the beginning; 61 | # it will automatically skip experiments that are already completed 62 | 63 | # prog contain 3 elements: tensor_indices, vector_idx, tensor_shape, eg: 64 | # '1_3_2_5__//__143__//__2x4x3x6' 65 | # indicies = [1,3,2,5] 66 | # dataND is N-Dim 67 | # rec[...,0] is flag (=1 means exp is finished) 68 | # rec[...,1:] is dataND 69 | 70 | indicies, idx, shape = prog.split('__//__') 71 | indicies = np.array(indicies.split('_'), dtype=int) 72 | idx = int(idx) 73 | shape = list(np.array(shape.split('*'), dtype=int)) 74 | 75 | if dataND is None: # in the first call: only check if exp has been completed 76 | try: # load or init new rec 77 | rec = np.load(rexName, allow_pickle=1).item() 78 | rec_data = rec['data'] 79 | rec_flag = rec['flag'] 80 | if rec_flag[tuple(indicies)] == 1.: 81 | raise UserWarning('\n\n\nThis exp has completed already\n\n\n') 82 | return 83 | except FileNotFoundError: 84 | assert idx == 0, '\n\n\nFatal Error! previous experiment file deleted!\n\n\n' 85 | return # first run, first exp 86 | 87 | else: # calling at the end: store dataND and exit 88 | dataND = np.asarray(dataND) 89 | try: # load or init new rec 90 | rec = np.load(rexName, allow_pickle=1).item() 91 | rec_data = rec['data'] 92 | rec_flag = rec['flag'] 93 | except FileNotFoundError: 94 | assert idx == 0, '\n\n\nFatal Error! previous experiment file deleted!\n\n\n' 95 | rec_data = np.zeros(shape + list(dataND.shape), dtype=float) 96 | rec_flag = np.zeros(shape, dtype=float) 97 | 98 | if support_inequal_shape and (rec_data[tuple(indicies)].shape != dataND.shape): 99 | to_fill = rec_data[tuple(indicies)] 100 | assert len(dataND.shape)==len(to_fill.shape), f'\n\nFatal Error! new exp has different number of dims ({dataND.shape}) than existing exp ({to_fill.shape})!\n\n' 101 | 102 | def tolerantly_fill_b_in_A(b, A): 103 | # b has dynamic shape; 104 | # A has fixed shape; 105 | # fill b to the upper frontmost corner of A 106 | _shape_str = [] 107 | for _s in range(len(b.shape)): 108 | ms = min(b.shape[_s], A.shape[_s]) 109 | _shape_str.append(f':{ms}') 110 | _shape_str = ','.join(_shape_str) 111 | # evastr = f'A[{_shape_str}]' 112 | exec(f'A[{_shape_str}]=b[{_shape_str}]') 113 | return A 114 | to_fill = tolerantly_fill_b_in_A(dataND, to_fill) 115 | rec_data[tuple(indicies)] = to_fill 116 | print(f'\n\n\n tolerantly fill success!!! \n\n to_fill is:\n{to_fill}\n\n') 117 | 118 | else: 119 | rec_data[tuple(indicies)] = dataND 120 | 121 | rec_flag[tuple(indicies)] = 1. 122 | rec = {'flag':rec_flag,'data':rec_data} 123 | np.save(rexName, rec) 124 | return 125 | 126 | def print_line_by_line(*b, tight=False): 127 | print('\n') 128 | for x in b: 129 | print(x) 130 | if not tight: print() 131 | print('\n') 132 | return 133 | 134 | if __name__ == "__main__": 135 | main() 136 | -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # Cold Brew: Distilling Graph Node Representations with Incomplete or Missing Neighborhoods 2 | 3 | Wenqing Zheng, Edward W Huang, Nikhil Rao, Sumeet Katariya, Zhangyang Wang, Karthik Subbian. 4 | 5 | In [ICLR 2022](https://openreview.net/forum?id=1ugNpm7W6E) and [Arxiv](https://arxiv.org/abs/2111.04840). 6 | 7 | ## Introduction 8 | 9 | Graph Neural Networks (GNNs) have demonstrated superior performance in node classification or regression tasks, and have emerged as the state of the art in several applications. However, (inductive) GNNs require the edge connectivity structure of nodes to be known beforehand to work well. This is often not the case in several practical applications where the node degrees have power-law distributions, and nodes with a few connections might have noisy edges. An extreme case is the strict cold start (SCS) problem, where there is no neighborhood information available, forcing prediction models to rely completely on node features only. To study the viability of using inductive GNNs to solve the SCS problem, we introduce feature-contribution ratio (FCR), a metric to quantify the contribution of a node's features and that of its neighborhood in predicting node labels, and use this new metric as a model selection reward. We then propose Cold Brew, a new method that generalizes GNNs better in the SCS setting compared to pointwise and graph-based models, via a distillation approach. We show experimentally how FCR allows us to disentangle the contributions of various components of graph datasets, and demonstrate the superior performance of Cold Brew on several public benchmarks and proprietary e-commerce datasets. 10 | 11 | ## Motivation 12 | 13 | Long tail distribution is ubiquitously existed in large scale graph mining tasks. In some applications, some cold start nodes have too few or no neighborhood in the graph, which make graph based methods sub-optimal due to insufficient high quality edges to perform message passing. 14 | 15 | ![gnns](figs/motivation.png) 16 | 17 | ![gnns](figs/longtail.png) 18 | 19 | ## Method 20 | 21 | We improve teacher GNN with Structural Embedding, and propose student MLP model with latent neighborhood discovery step. We also propose a metric called FCR to judge the difficulty in cold start generalization. 22 | 23 | ![gnns](figs/gnns.png) 24 | 25 | ![coldbrew](figs/coldbrew.png) 26 | 27 | ## Installation Guide 28 | 29 | The following commands are used for installing key dependencies; other can be directly installed via pip or conda. A full redundant dependency list is in `requirements.txt` 30 | 31 | ``` 32 | pip install dgl 33 | pip3 install torch==1.9.0+cu111 torchvision==0.10.0+cu111 torchaudio==0.9.0 -f https://download.pytorch.org/whl/torch_stable.html 34 | pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-1.9.0+cu111.html 35 | pip install torch-sparse -f https://pytorch-geometric.com/whl/torch-1.9.0+cu111.html 36 | pip install torch-geometric 37 | ``` 38 | 39 | ## Training Guide 40 | 41 | In `base_options.py`, a full list of useable args is present, with default arguments and candidates initialized. 42 | 43 | ### Comparing between traditional GCN (optimized with Initial/Jumping/Dense/PairNorm/NodeNorm/GroupNorm/Dropouts) and Cold Brew's GNN (optimized with Structural Embedding) 44 | 45 | #### Train optimized traditional GNN: 46 | 47 | `python main.py --dataset='Cora' --train_which='TeacherGNN' --whetherHasSE='000' --want_headtail=1 --num_layers=2 --use_special_split=1` Result: `84.15` 48 | 49 | `python main.py --dataset='Citeseer' --train_which='TeacherGNN' --whetherHasSE='000' --want_headtail=1 --num_layers=2 --use_special_split=1` Result: `69.7` 50 | 51 | `python main.py --dataset='Pubmed' --train_which='TeacherGNN' --whetherHasSE='000' --want_headtail=1 --num_layers=2 --use_special_split=1` Result: `78.2` 52 | 53 | #### Training Cold Brew's Teacher GNN: 54 | 55 | `python main.py --dataset='Cora' --train_which='TeacherGNN' --whetherHasSE='100' --se_reg=32 --want_headtail=1 --num_layers=2 --use_special_split=1` Result: `85.10` 56 | 57 | `python main.py --dataset='Citeseer' --train_which='TeacherGNN' --whetherHasSE='100' --se_reg=0.5 --want_headtail=1 --num_layers=2 --use_special_split=1 ` Result: `71.40` 58 | 59 | `python main.py --dataset='Pubmed' --train_which='TeacherGNN' --whetherHasSE='111' --se_reg=0.5 --want_headtail=1 --num_layers=2 --use_special_split=1` Result: `78.2` 60 | 61 | ### Comparing between MLP models: 62 | 63 | #### Training naive MLP: 64 | 65 | `python main.py --dataset='Cora' --train_which='StudentBaseMLP'` Result on isolation split: `61.80` 66 | 67 | #### Training GraphMLP: 68 | 69 | `python main.py --dataset='Cora' --train_which='GraphMLP'` Result on isolation split: `68.63` 70 | 71 | #### Training Cold Brew's MLP: 72 | 73 | `python main.py --dataset='Cora' --train_which="SEMLP" --SEMLP_topK_2_replace=3 --SEMLP_part1_arch="2layer" --dropout_MLP=0.5 --studentMLP__opt_lr='torch.optim.Adam&0.005'` Result on isolation split: `72.50` 74 | 75 | ## Hyperparameter meanings 76 | 77 | `--whetherHasSE`: whether cold brew's TeacherGNN has structural embedding. The first ‘1’ means structural embedding exist in first layer; second ‘1’ means structural embedding exist in every middle layers; third ‘1’ means last layer. 78 | 79 | `--se_reg`: regularization coefficient for cold brew teacher model's structural embedding. 80 | 81 | `--SEMLP_topK_2_replace`: the number of top K best virtual neighbor nodes. 82 | 83 | `--manual_assign_GPU`: set the GPU ID to train on. default=-9999, which means to dynamically choose GPU with most remaining memory. 84 | 85 | ## Adaptation Guide 86 | 87 | #### How to leverage this repo to train on other datasets: 88 | 89 | In `trainer.py`, put any new graph dataset (node classification) under `load_data()` and return it. 90 | 91 | what to load: 92 | return a dataset, which is a namespace, called 'data', 93 | data.x: 2D tensor, on cpu; shape = [N_nodes, dim_feature]. 94 | data.y: 1D tensor, on cpu; shape = [N_nodes]; values are integers, indicating the class of nodes. 95 | data.edge_index: tensor: [2, N_edge], cpu; edges contain self loop. 96 | data.train_mask: bool tensor, shape = [N_nodes], indicating the training node set. 97 | Template class for the 'data': 98 | 99 | ``` 100 | class MyDataset(torch_geometric.data.data.Data): 101 | def __init__(self): 102 | super().__init__() 103 | ``` 104 | 105 | # Citation 106 | 107 | ``` 108 | @article{zheng2021cold, 109 | title={Cold Brew: Distilling Graph Node Representations with Incomplete or Missing Neighborhoods}, 110 | author={Zheng, Wenqing and Huang, Edward W and Rao, Nikhil and Katariya, Sumeet and Wang, Zhangyang and Subbian, Karthik}, 111 | journal={arXiv preprint arXiv:2111.04840}, 112 | year={2021} 113 | } 114 | ``` 115 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==0.13.0 2 | alembic==1.6.5 3 | astor==0.8.1 4 | attrs==21.2.0 5 | cachetools==4.2.2 6 | certifi==2023.7.22 7 | chardet==4.0.0 8 | click==8.0.1 9 | cliff==3.8.0 10 | cloudpickle==1.6.0 11 | cmaes==0.8.2 12 | cmd2==2.1.2 13 | colorama==0.4.4 14 | colorlog==5.0.1 15 | contextlib2==21.6.0 16 | cramjam==2.3.2 17 | cycler==0.10.0 18 | dataclasses==0.8 19 | dgl==0.7.0 20 | dill==0.3.4 21 | fastparquet==0.6.3 22 | filelock==3.0.12 23 | fsspec==2021.6.1 24 | gast==0.5.0 25 | gdown==4.0.1 26 | gensim==3.8.3 27 | google-auth==1.35.0 28 | google-auth-oauthlib==0.4.6 29 | googledrivedownloader==0.4 30 | GPUtil==1.4.0 31 | greenlet==1.1.0 32 | grpcio==1.53.0 33 | huggingface-hub==0.0.12 34 | idna==2.10 35 | imageio==2.9.0 36 | importlib-metadata==4.6.1 37 | isodate==0.6.0 38 | Jinja2==3.0.1 39 | joblib==1.2.0 40 | julia==0.5.6 41 | Keras==2.1.0 42 | Keras-Applications==1.0.8 43 | Keras-Preprocessing==1.1.2 44 | kiwisolver==1.3.1 45 | littleutils==0.2.2 46 | Mako==1.2.2 47 | Markdown==3.3.4 48 | MarkupSafe==2.0.1 49 | matplotlib==3.3.4 50 | mkl-fft==1.3.0 51 | mkl-random==1.1.1 52 | mkl-service==2.3.0 53 | ml-collections==0.1.0 54 | networkx==2.5.1 55 | numpy==1.19.5 56 | oauthlib==3.1.1 57 | ogb==1.3.1 58 | outdated==0.2.1 59 | packaging==21.0 60 | pandas==1.1.5 61 | pbr==5.6.0 62 | pillow>=8.3.2 63 | prettytable==2.1.0 64 | protobuf==3.18.3 65 | pyasn1==0.4.8 66 | pyasn1-modules==0.2.8 67 | pyglet==1.5.21 68 | pyparsing==2.4.7 69 | pyperclip==1.8.2 70 | PySocks==1.7.1 71 | python-dateutil==2.8.1 72 | python-editor==1.0.4 73 | python-louvain==0.15 74 | pytz==2021.1 75 | PyWavelets==1.1.1 76 | PyYAML==5.4.1 77 | rdflib==5.0.0 78 | regex==2021.7.6 79 | requests==2.25.1 80 | requests-oauthlib==1.3.0 81 | rsa==4.7.2 82 | sacremoses==0.0.45 83 | sb3-contrib==1.1.0 84 | scikit-image==0.17.2 85 | scikit-learn==0.24.2 86 | scipy==1.5.4 87 | seaborn==0.11.2 88 | smart-open==5.2.1 89 | SQLAlchemy==1.4.22 90 | stable-baselines3==1.1.0 91 | stevedore==3.3.0 92 | termcolor==1.1.0 93 | texttable==1.6.4 94 | threadpoolctl==2.1.0 95 | thrift==0.13.0 96 | tifffile==2020.9.3 97 | tokenizers==0.10.3 98 | torch==1.13.1 99 | torch-geometric==1.7.2 100 | torch-scatter==2.0.7 101 | torch-sparse==0.6.10 102 | tqdm==4.61.1 103 | typing-extensions==3.10.0.0 104 | urllib3==1.26.6 105 | wcwidth==0.2.5 106 | Werkzeug==2.2.3 107 | zipp==3.5.0 108 | --------------------------------------------------------------------------------