├── BWGNN.py ├── README.md ├── dataset.py ├── figures ├── heatmap.png ├── heterophily.png └── topology.png └── main.py /BWGNN.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import dgl.function as fn 5 | import sympy 6 | import scipy 7 | import numpy as np 8 | from torch import nn 9 | from torch.nn import init 10 | 11 | ''' 12 | BWGNN model from "https://github.com/squareRoot3/Rethinking-Anomaly-Detection" 13 | ''' 14 | class PolyConv(nn.Module): 15 | def __init__(self, 16 | in_feats, 17 | out_feats, 18 | theta, 19 | activation=F.leaky_relu, 20 | lin=False, 21 | bias=False): 22 | super(PolyConv, self).__init__() 23 | self._theta = theta 24 | self._k = len(self._theta) 25 | self._in_feats = in_feats 26 | self._out_feats = out_feats 27 | self.activation = activation 28 | self.linear = nn.Linear(in_feats, out_feats, bias) 29 | self.lin = lin 30 | # self.reset_parameters() 31 | # self.linear2 = nn.Linear(out_feats, out_feats, bias) 32 | 33 | def reset_parameters(self): 34 | if self.linear.weight is not None: 35 | init.xavier_uniform_(self.linear.weight) 36 | if self.linear.bias is not None: 37 | init.zeros_(self.linear.bias) 38 | 39 | def forward(self, graph, feat): 40 | def unnLaplacian(feat, D_invsqrt, graph): 41 | """ Operation Feat * D^-1/2 A D^-1/2 """ 42 | graph.ndata['h'] = feat * D_invsqrt 43 | graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h')) 44 | return feat - graph.ndata.pop('h') * D_invsqrt 45 | 46 | with graph.local_scope(): 47 | D_invsqrt = torch.pow(graph.in_degrees().float().clamp( 48 | min=1), -0.5).unsqueeze(-1).to(feat.device) 49 | h = self._theta[0]*feat 50 | for k in range(1, self._k): 51 | feat = unnLaplacian(feat, D_invsqrt, graph) 52 | h += self._theta[k]*feat 53 | h_copy = h 54 | if self.lin: 55 | h = self.linear(h) 56 | h = self.activation(h) 57 | return h 58 | 59 | class PolyConvBatch(nn.Module): 60 | def __init__(self, 61 | in_feats, 62 | out_feats, 63 | theta, 64 | activation=F.leaky_relu, 65 | lin=False, 66 | bias=False): 67 | super(PolyConvBatch, self).__init__() 68 | self._theta = theta 69 | self._k = len(self._theta) 70 | self._in_feats = in_feats 71 | self._out_feats = out_feats 72 | self.activation = activation 73 | 74 | def reset_parameters(self): 75 | if self.linear.weight is not None: 76 | init.xavier_uniform_(self.linear.weight) 77 | if self.linear.bias is not None: 78 | init.zeros_(self.linear.bias) 79 | 80 | def forward(self, block, feat): 81 | def unnLaplacian(feat, D_invsqrt, block): 82 | """ Operation Feat * D^-1/2 A D^-1/2 """ 83 | block.srcdata['h'] = feat * D_invsqrt 84 | block.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h')) 85 | return feat - block.srcdata.pop('h') * D_invsqrt 86 | 87 | with block.local_scope(): 88 | D_invsqrt = torch.pow(block.out_degrees().float().clamp( 89 | min=1), -0.5).unsqueeze(-1).to(feat.device) 90 | h = self._theta[0]*feat 91 | for k in range(1, self._k): 92 | feat = unnLaplacian(feat, D_invsqrt, block) 93 | h += self._theta[k]*feat 94 | return h 95 | 96 | 97 | def calculate_theta2(d): 98 | thetas = [] 99 | x = sympy.symbols('x') 100 | for i in range(d+1): 101 | f = sympy.poly((x/2) ** i * (1 - x/2) ** (d-i) / (scipy.special.beta(i+1, d+1-i))) 102 | coeff = f.all_coeffs() 103 | inv_coeff = [] 104 | for i in range(d+1): 105 | inv_coeff.append(float(coeff[d-i])) 106 | thetas.append(inv_coeff) 107 | return thetas 108 | 109 | 110 | class BWGNN(nn.Module): 111 | def __init__(self, in_feats, h_feats, num_classes, graph, d=2, batch=False): 112 | super(BWGNN, self).__init__() 113 | self.g = graph 114 | self.thetas = calculate_theta2(d=d) 115 | self.conv = [] 116 | for i in range(len(self.thetas)): 117 | if not batch: 118 | self.conv.append(PolyConv(h_feats, h_feats, self.thetas[i], lin=False)) 119 | else: 120 | self.conv.append(PolyConvBatch(h_feats, h_feats, self.thetas[i], lin=False)) 121 | self.linear = nn.Linear(in_feats, h_feats) 122 | self.linear2 = nn.Linear(h_feats, h_feats) 123 | self.linear3 = nn.Linear(h_feats*len(self.conv), h_feats) 124 | self.linear4 = nn.Linear(h_feats, num_classes) 125 | self.act = nn.ReLU() 126 | self.d = d 127 | 128 | def forward(self, in_feat): 129 | h = self.linear(in_feat) 130 | h = self.act(h) 131 | h = self.linear2(h) 132 | h = self.act(h) 133 | h_final = torch.zeros([len(in_feat), 0]) 134 | # h0_final = [] 135 | for conv in self.conv: 136 | h0 = conv(self.g, h) 137 | h_final = torch.cat([h_final, h0], -1) 138 | # print(h_final.shape) 139 | h = self.linear3(h_final) 140 | h = self.act(h) 141 | h = self.linear4(h) 142 | return h 143 | 144 | def testlarge(self, g, in_feat): 145 | h = self.linear(in_feat) 146 | h = self.act(h) 147 | h = self.linear2(h) 148 | h = self.act(h) 149 | h_final = torch.zeros([len(in_feat), 0]) 150 | for conv in self.conv: 151 | h0 = conv(g, h) 152 | h_final = torch.cat([h_final, h0], -1) 153 | # print(h_final.shape) 154 | h = self.linear3(h_final) 155 | h = self.act(h) 156 | h = self.linear4(h) 157 | return h 158 | 159 | def batch(self, blocks, in_feat): 160 | h = self.linear(in_feat) 161 | h = self.act(h) 162 | h = self.linear2(h) 163 | h = self.act(h) 164 | 165 | h_final = torch.zeros([len(in_feat),0]) 166 | for conv in self.conv: 167 | h0 = conv(blocks[0], h) 168 | h_final = torch.cat([h_final, h0], -1) 169 | # print(h_final.shape) 170 | h = self.linear3(h_final) 171 | h = self.act(h) 172 | h = self.linear4(h) 173 | return h 174 | 175 | 176 | # heterogeneous graph 177 | class BWGNN_Hetero(nn.Module): 178 | def __init__(self, in_feats, h_feats, num_classes, graph, d=2): 179 | super(BWGNN_Hetero, self).__init__() 180 | self.g = graph 181 | self.thetas = calculate_theta2(d=d) 182 | self.h_feats = h_feats 183 | self.conv = [PolyConv(h_feats, h_feats, theta, lin=False) for theta in self.thetas] 184 | self.linear = nn.Linear(in_feats, h_feats) 185 | self.linear2 = nn.Linear(h_feats, h_feats) 186 | self.linear3 = nn.Linear(h_feats*len(self.conv), h_feats) 187 | self.linear4 = nn.Linear(h_feats, num_classes) 188 | self.act = nn.LeakyReLU() 189 | # print(self.thetas) 190 | for param in self.parameters(): 191 | print(type(param), param.size()) 192 | 193 | def forward(self, in_feat): 194 | h = self.linear(in_feat) 195 | h = self.act(h) 196 | h = self.linear2(h) 197 | h = self.act(h) 198 | h_all = [] 199 | 200 | for relation in self.g.canonical_etypes: 201 | # print(relation) 202 | h_final = torch.zeros([len(in_feat), 0]) 203 | for conv in self.conv: 204 | h0 = conv(self.g[relation], h) 205 | h_final = torch.cat([h_final, h0], -1) 206 | # print(h_final.shape) 207 | h = self.linear3(h_final) 208 | h_all.append(h) 209 | 210 | h_all = torch.stack(h_all).sum(0) 211 | h_all = self.act(h_all) 212 | h_all = self.linear4(h_all) 213 | return h_all 214 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # GHRN: Addressing Heterophily in Graph Anomaly Detection: A Perspective of Graph Spectrum 2 | 3 | This is a PyTorch implementation of 4 | 5 | Addressing Heterophily in Graph Anomaly Detection: A Perspective of Graph Spectrum (WWW2023) 6 | 7 | # Overview 8 | In this work, we aim to address the heterophily problem in the spectral domain. We point out that heterophily is positively associated with the frequency of a graph. Towards this end, we could prune inter-class edges by simply emphasizing and delineating the high-frequency components of the graph. We adopt graph Laplacian to measure the extent of 1-hop label changing of the center node and indicate high-frequency components. Our indicator can effectively reduce the heterophily degree of the graph and is less likely to be influenced by the prediction error. 9 | 10 |

11 |
12 |

13 | 14 | # Some questions 15 | 1. What is heterophily and how does it affect the performance of the GNNs? 16 | Heterophily indicates the edges connecting nodes with different labels. Low-pass filters like GCN could undermine the discriminative 17 | information of the anomalies on heterophilous graphs. 18 | 19 |

20 |
21 |

22 | 23 | 2. How does indicator work? 24 | GHRN will calculate the post-aggregation matrix for the graph, and a smaller value means a larger probability of the inter-class edges. 25 | 26 |

27 |
28 |

29 | 30 | # Dataset 31 | YelpChi and Amazon can be downloaded from [here](https://github.com/YingtongDou/CARE-GNN/tree/master/data) or [dgl.data.FraudDataset](https://docs.dgl.ai/api/python/dgl.data.html#fraud-dataset). The T-Finance and T-Social datasets developed in the paper are on [google drive](https://drive.google.com/drive/folders/1PpNwvZx_YRSCDiHaBUmRIS3x1rZR7fMr?usp=sharing). 32 | 33 | # Dependencies 34 | ```sh 35 | - pytorch 1.9.0 36 | - dgl 0.8.1 37 | - sympy 38 | - argparse 39 | - sklearn 40 | - scipy 41 | - pickle 42 | ``` 43 | 44 | # Reproduce 45 | ```sh 46 | python main.py --dataset tfinance 47 | python main.py --dataset tfinance --del_ratio 0.015 48 | ``` 49 | Note that a delete ratio of 0 should be run first to get predictions y. 50 | 51 | Also, [here](https://github.com/squareRoot3/GADBench)'s an awesome implementation. 52 | 53 | # Acknowledgement 54 | Our code references: 55 | - [BWGNN](https://github.com/squareRoot3/Rethinking-Anomaly-Detection) 56 | 57 | # Reference 58 | ``` 59 | @inproceedings{ 60 | gao2023ghrn, 61 | title={Addressing Heterophily in Graph Anomaly Detection: A Perspective of Graph Spectrum}, 62 | author={Yuan Gao and Xiang Wang and Xiangnan He and Zhenguang Liu and Huamin Feng and Yongdong Zhang}, 63 | booktitle={WWW}, 64 | year={2023}, 65 | } 66 | ``` 67 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | from dgl.data import FraudDataset 2 | from dgl.data.utils import load_graphs 3 | import dgl 4 | import torch 5 | import warnings 6 | import pickle as pkl 7 | import torch 8 | import matplotlib.pyplot as plt 9 | import pandas as pd 10 | import numpy as np 11 | plt.rcParams['axes.unicode_minus']=False 12 | import seaborn as sns 13 | from dgl.nn.pytorch.conv import EdgeWeightNorm 14 | import pickle as pkl 15 | import dgl 16 | import dgl.function as fn 17 | 18 | warnings.filterwarnings("ignore") 19 | 20 | class Dataset: 21 | def __init__(self, load_epoch, name='tfinance', del_ratio=0., homo=True, data_path='', adj_type='sym'): 22 | self.name = name 23 | graph = None 24 | prefix = data_path 25 | if name == 'tfinance': 26 | graph, label_dict = load_graphs(f'{prefix}/tfinance') 27 | graph = graph[0] 28 | graph.ndata['label'] = graph.ndata['label'].argmax(1) 29 | if del_ratio != 0.: 30 | graph = graph.add_self_loop() 31 | with open(f'probs_tfinance_BWGNN_{load_epoch}_{homo}.pkl', 'rb') as f: 32 | pred_y = pkl.load(f) 33 | graph.ndata['pred_y'] = pred_y 34 | graph = random_walk_update(graph, del_ratio, adj_type) 35 | graph = dgl.remove_self_loop(graph) 36 | 37 | elif name == 'tsocial': 38 | graph, label_dict = load_graphs(f'{prefix}/tsocial') 39 | graph = graph[0] 40 | if del_ratio != 0.: 41 | graph = graph.add_self_loop() 42 | with open(f'probs_tsocial_BWGNN_{load_epoch}_{homo}.pkl', 'rb') as f: 43 | pred_y = pkl.load(f) 44 | graph.ndata['pred_y'] = pred_y 45 | graph = random_walk_update(graph, del_ratio, adj_type) 46 | graph = dgl.remove_self_loop(graph) 47 | 48 | elif name == 'yelp': 49 | dataset = FraudDataset(name, train_size=0.4, val_size=0.2) 50 | graph = dataset[0] 51 | if homo: 52 | graph = dgl.to_homogeneous(dataset[0], ndata=['feature', 'label', 'train_mask', 'val_mask', 'test_mask']) 53 | graph = dgl.add_self_loop(graph) 54 | if del_ratio != 0.: 55 | with open(f'probs_yelp_BWGNN_{load_epoch}_{homo}.pkl', 'rb') as f: 56 | graph.ndata['pred_y'] = pkl.load(f) 57 | graph = random_walk_update(graph, del_ratio, adj_type) 58 | graph = dgl.add_self_loop(dgl.remove_self_loop(graph)) 59 | else: 60 | if del_ratio != 0.: 61 | with open(f'probs_yelp_BWGNN_{load_epoch}_{homo}.pkl', 'rb') as f: 62 | pred_y = pkl.load(f) 63 | data_dict = {} 64 | flag = 1 65 | for relation in graph.canonical_etypes: 66 | graph_r = dgl.to_homogeneous(graph[relation], ndata=['feature', 'label', 'train_mask', 'val_mask', 'test_mask']) 67 | graph_r = dgl.add_self_loop(graph_r) 68 | graph_r.ndata['pred_y'] = pred_y 69 | graph_r = random_walk_update(graph_r, del_ratio, adj_type) 70 | graph_r = dgl.remove_self_loop(graph_r) 71 | data_dict[('review', str(flag), 'review')] = graph_r.edges() 72 | flag += 1 73 | graph_new = dgl.heterograph(data_dict) 74 | graph_new.ndata['label'] = graph.ndata['label'] 75 | graph_new.ndata['feature'] = graph.ndata['feature'] 76 | graph_new.ndata['train_mask'] = graph.ndata['train_mask'] 77 | graph_new.ndata['val_mask'] = graph.ndata['val_mask'] 78 | graph_new.ndata['test_mask'] = graph.ndata['test_mask'] 79 | graph = graph_new 80 | 81 | 82 | 83 | elif name == 'amazon': 84 | dataset = FraudDataset(name, train_size=0.4, val_size=0.2) 85 | graph = dataset[0] 86 | if homo: 87 | graph = dgl.to_homogeneous(dataset[0], ndata=['feature', 'label', 'train_mask', 'val_mask', 'test_mask']) 88 | graph = dgl.add_self_loop(graph) 89 | if del_ratio != 0.: 90 | with open(f'probs_amazon_BWGNN_{load_epoch}_{homo}.pkl', 'rb') as f: 91 | graph.ndata['pred_y'] = pkl.load(f) 92 | graph = random_walk_update(graph, del_ratio, adj_type) 93 | graph = dgl.add_self_loop(dgl.remove_self_loop(graph)) 94 | else: 95 | if del_ratio != 0.: 96 | with open(f'probs_amazon_BWGNN_{load_epoch}_{homo}.pkl', 'rb') as f: 97 | pred_y = pkl.load(f) 98 | data_dict = {} 99 | flag = 1 100 | for relation in graph.canonical_etypes: 101 | graph[relation].ndata['pred_y'] = pred_y 102 | graph_r = dgl.add_self_loop(graph[relation]) 103 | graph_r = random_walk_update(graph_r, del_ratio, adj_type) 104 | graph_r = dgl.remove_self_loop(graph_r) 105 | data_dict[('review', str(flag), 'review')] = graph_r.edges() 106 | flag += 1 107 | graph_new = dgl.heterograph(data_dict) 108 | graph_new.ndata['label'] = graph.ndata['label'] 109 | graph_new.ndata['feature'] = graph.ndata['feature'] 110 | graph_new.ndata['train_mask'] = graph.ndata['train_mask'] 111 | graph_new.ndata['val_mask'] = graph.ndata['val_mask'] 112 | graph_new.ndata['test_mask'] = graph.ndata['test_mask'] 113 | graph = graph_new 114 | else: 115 | print('no such dataset') 116 | exit(1) 117 | 118 | graph.ndata['label'] = graph.ndata['label'].long().squeeze(-1) 119 | graph.ndata['feature'] = graph.ndata['feature'].float() 120 | print(graph) 121 | 122 | self.graph = graph 123 | 124 | def random_walk_update(graph, delete_ratio, adj_type): 125 | edge_weight = torch.ones(graph.num_edges()) 126 | if adj_type == 'sym': 127 | norm = EdgeWeightNorm(norm='both') 128 | else: 129 | norm = EdgeWeightNorm(norm='left') 130 | graph.edata['w'] = norm(graph, edge_weight) 131 | # functions 132 | aggregate_fn = fn.u_mul_e('h', 'w', 'm') 133 | reduce_fn = fn.sum(msg='m', out='ay') 134 | 135 | graph.ndata['h'] = graph.ndata['pred_y'] 136 | graph.update_all(aggregate_fn, reduce_fn) 137 | graph.ndata['ly'] = graph.ndata['pred_y'] - graph.ndata['ay'] 138 | # graph.ndata['lyyl'] = torch.matmul(graph.ndata['ly'], graph.ndata['ly'].T) 139 | graph.apply_edges(inner_product_black) 140 | # graph.apply_edges(inner_product_white) 141 | black = graph.edata['inner_black'] 142 | # white = graph.edata['inner_white'] 143 | # delete 144 | threshold = int(delete_ratio * graph.num_edges()) 145 | edge_to_move = set(black.sort()[1][:threshold].tolist()) 146 | # edge_to_protect = set(white.sort()[1][-threshold:].tolist()) 147 | edge_to_protect = set() 148 | graph_new = dgl.remove_edges(graph, list(edge_to_move.difference(edge_to_protect))) 149 | return graph_new 150 | 151 | def inner_product_black(edges): 152 | return {'inner_black': (edges.src['ly'] * edges.dst['ly']).sum(axis=1)} 153 | 154 | def inner_product_white(edges): 155 | return {'inner_white': (edges.src['ay'] * edges.dst['ay']).sum(axis=1)} 156 | 157 | def find_inter(edges): 158 | return edges.src['label'] != edges.dst['label'] 159 | 160 | def cal_hetero(edges): 161 | return {'same': edges.src['label'] != edges.dst['label']} 162 | 163 | def cal_hetero_normal(edges): 164 | return {'same_normal': (edges.src['label'] != edges.dst['label']) & (edges.src['label'] == 0)} 165 | 166 | def cal_normal(edges): 167 | return {'normal': edges.src['label'] == 0} 168 | 169 | def cal_hetero_anomal(edges): 170 | return {'same_anomal': (edges.src['label'] != edges.dst['label']) & (edges.src['label'] == 1)} 171 | 172 | def cal_anomal(edges): 173 | return {'anomal': edges.src['label'] == 1} 174 | -------------------------------------------------------------------------------- /figures/heatmap.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/blacksingular/GHRN/d47e047b76df429c0c0a8444f5d9a5dea57f6801/figures/heatmap.png -------------------------------------------------------------------------------- /figures/heterophily.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/blacksingular/GHRN/d47e047b76df429c0c0a8444f5d9a5dea57f6801/figures/heterophily.png -------------------------------------------------------------------------------- /figures/topology.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/blacksingular/GHRN/d47e047b76df429c0c0a8444f5d9a5dea57f6801/figures/topology.png -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import argparse 4 | import time 5 | from dataset import Dataset 6 | from sklearn.metrics import f1_score, recall_score, roc_auc_score, precision_score 7 | from BWGNN import * 8 | from sklearn.model_selection import train_test_split 9 | import pickle as pkl 10 | 11 | 12 | def train(model, g, args): 13 | features = g.ndata['feature'] 14 | labels = g.ndata['label'] 15 | if dataset_name in ['tfinance', 'tsocial']: 16 | index = list(range(len(labels))) 17 | if dataset_name == 'amazon': 18 | index = list(range(3305, len(labels))) 19 | 20 | idx_train, idx_rest, y_train, y_rest = train_test_split(index, labels[index], stratify=labels[index], 21 | train_size=args.train_ratio, 22 | random_state=2, shuffle=True) 23 | idx_valid, idx_test, y_valid, y_test = train_test_split(idx_rest, y_rest, stratify=y_rest, 24 | test_size=0.67, 25 | random_state=2, shuffle=True) 26 | train_mask = torch.zeros([len(labels)]).bool() 27 | val_mask = torch.zeros([len(labels)]).bool() 28 | test_mask = torch.zeros([len(labels)]).bool() 29 | 30 | train_mask[idx_train] = 1 31 | val_mask[idx_valid] = 1 32 | test_mask[idx_test] = 1 33 | else: 34 | train_mask = torch.ByteTensor(g.ndata['train_mask']) 35 | val_mask = torch.ByteTensor(g.ndata['val_mask']) 36 | test_mask = torch.ByteTensor(g.ndata['test_mask']) 37 | print('train/dev/test samples: ', train_mask.sum().item(), val_mask.sum().item(), test_mask.sum().item()) 38 | optimizer = torch.optim.Adam(model.parameters(), lr=0.01) 39 | best_f1, final_tf1, final_trec, final_tpre, final_tmf1, final_tauc = 0., 0., 0., 0., 0., 0. 40 | best_loss = 100 41 | 42 | weight = (1-labels[train_mask]).sum().item() / labels[train_mask].sum().item() 43 | print('cross entropy weight: ', weight) 44 | time_start = time.time() 45 | for e in range(1, args.epoch+1): 46 | model.train() 47 | logits = model(features) 48 | loss = F.cross_entropy(logits[train_mask], labels[train_mask], weight=torch.tensor([1., weight])) 49 | optimizer.zero_grad() 50 | loss.backward() 51 | optimizer.step() 52 | model.eval() 53 | loss = F.cross_entropy(logits[val_mask], labels[val_mask], weight=torch.tensor([1., weight])) 54 | probs = logits.softmax(1) 55 | f1, thres = get_best_f1(labels[val_mask], probs[val_mask]) 56 | preds = torch.zeros_like(labels) 57 | preds[probs[:, 1] > thres] = 1 58 | trec = recall_score(labels[test_mask], preds[test_mask]) 59 | tpre = precision_score(labels[test_mask], preds[test_mask]) 60 | tmf1 = f1_score(labels[test_mask], preds[test_mask], average='macro') 61 | tauc = roc_auc_score(labels[test_mask], probs[test_mask][:, 1].detach().numpy()) 62 | 63 | if loss <= best_loss: 64 | best_loss = loss 65 | final_trec = trec 66 | final_tpre = tpre 67 | final_tmf1 = tmf1 68 | final_tauc = tauc 69 | pred_y = probs 70 | print('Epoch {}, loss: {:.4f}, val mf1: {:.4f}, (best {:.4f})'.format(e, loss, f1, best_f1)) 71 | if args.del_ratio == 0 and e % 20 == 0: 72 | with open(f'probs_{dataset_name}_BWGNN_{e}_{args.homo}.pkl', 'wb') as f: 73 | pkl.dump(pred_y, f) 74 | 75 | time_end = time.time() 76 | print('time cost: ', time_end - time_start, 's') 77 | result = 'REC {:.2f} PRE {:.2f} MF1 {:.2f} AUC {:.2f}'.format(final_trec*100, 78 | final_tpre*100, final_tmf1*100, final_tauc*100) 79 | with open('result.txt', 'a+') as f: 80 | f.write(f'{result}\n') 81 | return final_tmf1, final_tauc 82 | 83 | 84 | # threshold adjusting for best macro f1 85 | def get_best_f1(labels, probs): 86 | best_f1, best_thre = 0, 0 87 | for thres in np.linspace(0.05, 0.95, 19): 88 | preds = np.zeros_like(labels) 89 | preds[probs[:,1] > thres] = 1 90 | mf1 = f1_score(labels, preds, average='macro') 91 | if mf1 > best_f1: 92 | best_f1 = mf1 93 | best_thre = thres 94 | return best_f1, best_thre 95 | 96 | def set_random_seed(seed): 97 | torch.manual_seed(seed) 98 | torch.cuda.manual_seed_all(seed) 99 | np.random.seed(seed) 100 | 101 | if __name__ == '__main__': 102 | 103 | parser = argparse.ArgumentParser(description='BWGNN') 104 | parser.add_argument("--dataset", type=str, default="amazon", 105 | help="Dataset for this model (yelp/amazon/tfinance/tsocial)") 106 | parser.add_argument("--train_ratio", type=float, default=0.4, help="Training ratio") 107 | parser.add_argument("--hid_dim", type=int, default=64, help="Hidden layer dimension") 108 | parser.add_argument("--order", type=int, default=2, help="Order C in Beta Wavelet") 109 | parser.add_argument("--homo", type=int, default=1, help="1 for BWGNN(Homo) and 0 for BWGNN(Hetero)") 110 | parser.add_argument("--epoch", type=int, default=100, help="The max number of epochs") 111 | parser.add_argument("--run", type=int, default=1, help="Running times") 112 | parser.add_argument("--del_ratio", type=float, default=0., help="delete ratios") 113 | parser.add_argument("--adj_type", type=str, default='sym', help="sym or rw") 114 | parser.add_argument("--load_epoch", type=int, default=100, help="load epoch prediction") 115 | parser.add_argument("--data_path", type=str, default='./data', help="data path") 116 | 117 | args = parser.parse_args() 118 | # with open('result.txt', 'a+') as f: 119 | # f.write(f'{args}\n') 120 | print(args) 121 | dataset_name = args.dataset 122 | del_ratio = args.del_ratio 123 | homo = args.homo 124 | order = args.order 125 | h_feats = args.hid_dim 126 | adj_type = args.adj_type 127 | load_epoch = args.load_epoch 128 | data_path = args.data_path 129 | graph = Dataset(load_epoch, dataset_name, del_ratio, homo, data_path, adj_type=adj_type).graph 130 | in_feats = graph.ndata['feature'].shape[1] 131 | num_classes = 2 132 | 133 | # official seed 134 | set_random_seed(717) 135 | 136 | if args.run == 1: 137 | if homo: 138 | model = BWGNN(in_feats, h_feats, num_classes, graph, d=order) 139 | else: 140 | model = BWGNN_Hetero(in_feats, h_feats, num_classes, graph, d=order) 141 | train(model, graph, args) 142 | 143 | else: 144 | final_mf1s, final_aucs = [], [] 145 | for tt in range(args.run): 146 | if homo: 147 | model = BWGNN(in_feats, h_feats, num_classes, graph, d=order) 148 | else: 149 | model = BWGNN_Hetero(in_feats, h_feats, num_classes, graph, d=order) 150 | mf1, auc = train(model, graph, args) 151 | final_mf1s.append(mf1) 152 | final_aucs.append(auc) 153 | final_mf1s = np.array(final_mf1s) 154 | final_aucs = np.array(final_aucs) 155 | result = 'MF1-mean: {:.2f}, MF1-std: {:.2f}, AUC-mean: {:.2f}, AUC-std: {:.2f}'.format(100 * np.mean(final_mf1s), 156 | 100 * np.std(final_mf1s), 157 | 100 * np.mean(final_aucs), 100 * np.std(final_aucs)) 158 | print(result) 159 | 160 | --------------------------------------------------------------------------------