├── requirements.txt ├── Meta-GDN.png ├── graphs └── pubmed │ ├── Pubmed_0.mat │ ├── Pubmed_1.mat │ ├── Pubmed_2.mat │ ├── Pubmed_3.mat │ └── Pubmed_4.mat ├── getConfig.py ├── models.py ├── README.md ├── utils.py ├── learner.py ├── data.py ├── run.py └── meta.py /requirements.txt: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /Meta-GDN.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kaize0409/Meta-GDN_AnomalyDetection/HEAD/Meta-GDN.png -------------------------------------------------------------------------------- /graphs/pubmed/Pubmed_0.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kaize0409/Meta-GDN_AnomalyDetection/HEAD/graphs/pubmed/Pubmed_0.mat -------------------------------------------------------------------------------- /graphs/pubmed/Pubmed_1.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kaize0409/Meta-GDN_AnomalyDetection/HEAD/graphs/pubmed/Pubmed_1.mat -------------------------------------------------------------------------------- /graphs/pubmed/Pubmed_2.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kaize0409/Meta-GDN_AnomalyDetection/HEAD/graphs/pubmed/Pubmed_2.mat -------------------------------------------------------------------------------- /graphs/pubmed/Pubmed_3.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kaize0409/Meta-GDN_AnomalyDetection/HEAD/graphs/pubmed/Pubmed_3.mat -------------------------------------------------------------------------------- /graphs/pubmed/Pubmed_4.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kaize0409/Meta-GDN_AnomalyDetection/HEAD/graphs/pubmed/Pubmed_4.mat -------------------------------------------------------------------------------- /getConfig.py: -------------------------------------------------------------------------------- 1 | def modelArch(in_feature, out_feature): 2 | 3 | config = [ 4 | ('linear', [512, in_feature]), 5 | ('linear', [out_feature, 512]) 6 | ] 7 | 8 | return config -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class SGC(nn.Module): 8 | 9 | def __init__(self, in_feature, out_feature): 10 | super(SGC, self).__init__() 11 | self.fc1 = nn.Linear(in_feature, 512) 12 | self.out = nn.Linear(512, out_feature) 13 | 14 | def forward(self, x): 15 | x = self.fc1(x) 16 | x = self.out(x) 17 | return x 18 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Meta-GDN 2 | 3 | This is the implementation for TheWebConf'21 paper: ["Few-shot Network Anomaly Detection via Cross-network Meta-learning"](https://arxiv.org/pdf/2102.11165.pdf). 4 | ![The proposed framework](Meta-GDN.png) 5 | 6 | 7 | # Requirements 8 | -Python: 3.6 9 | -Pytorch: 1.1.0 10 | -numpy: 1.19.2 11 | -scikit-learn: 0.20.3 12 | -scipy: 1.2.1 13 | 14 | # Evaluation 15 | python run.py 16 | 17 | # Others 18 | Please cite our paper if you use this code in your own work: 19 | 20 | ``` 21 | @inproceedings{ding2021few, 22 | title={Few-shot Network Anomaly Detection via Cross-network Meta-learning}, 23 | author={Ding, Kaize and Zhou, Qinghai and Tong, Hanghang and Liu, Huan}, 24 | booktitle={Proceedings of the Web Conference 2021}, 25 | pages={2448--2456}, 26 | year={2021} 27 | } 28 | ``` 29 | 30 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | import torch 4 | from time import perf_counter 5 | import networkx as nx 6 | import random 7 | from sklearn.metrics import roc_auc_score, average_precision_score, auc, precision_recall_curve 8 | import scipy.sparse as sp 9 | 10 | 11 | def aucPerformance(y_true, y_pred): 12 | y_true = y_true.flatten() 13 | y_pred = y_pred.flatten() 14 | roc_auc = roc_auc_score(y_true, y_pred) 15 | precision, recall, _ = precision_recall_curve(y_true, y_pred) 16 | auc_pr = auc(recall, precision) 17 | ap = average_precision_score(y_true, y_pred) 18 | return roc_auc, auc_pr, ap 19 | 20 | 21 | def sgc_precompute(features, adj, degree): 22 | # compute S^K 23 | for i in range(degree): 24 | features = torch.spmm(adj, features) 25 | return features 26 | 27 | 28 | def normalize_adjacency(adj): 29 | adj = adj + sp.eye(adj.shape[0]) 30 | adj = sp.coo_matrix(adj) 31 | row_sum = np.array(adj.sum(1)) 32 | d_inv_sqrt = np.power(row_sum, -0.5).flatten() 33 | d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0. 34 | d_mat_inv_sqrt = sp.diags(d_inv_sqrt) 35 | return d_mat_inv_sqrt.dot(adj).dot(d_mat_inv_sqrt).tocoo() 36 | 37 | 38 | def normalize_feature(feature): 39 | # Row-wise normalization of sparse feature matrix 40 | rowsum = np.array(feature.sum(1)) 41 | r_inv = np.power(rowsum, -1).flatten() 42 | r_inv[np.isinf(r_inv)] = 0. 43 | r_mat_inv = sp.diags(r_inv) 44 | mx = r_mat_inv.dot(feature) 45 | return mx 46 | 47 | -------------------------------------------------------------------------------- /learner.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | import numpy as np 5 | 6 | 7 | class Learner(nn.Module): 8 | 9 | def __init__(self, config): 10 | super(Learner, self).__init__() 11 | 12 | self.config = config 13 | self.vars = nn.ParameterList() 14 | self.vars_bn = nn.ParameterList() 15 | for i, (name, param) in enumerate(self.config): 16 | if name is 'linear': 17 | w = nn.Parameter(torch.ones(*param)) 18 | torch.nn.init.kaiming_normal_(w) 19 | self.vars.append(w) 20 | # [ch_out] 21 | self.vars.append(nn.Parameter(torch.zeros(param[0]))) 22 | elif name is 'bn': 23 | w = nn.Parameter(torch.ones(param[0])) 24 | self.vars.append(w) 25 | # [ch_out] 26 | self.vars.append(nn.Parameter(torch.zeros(param[0]))) 27 | # must set requires_grad=False 28 | running_mean = nn.Parameter(torch.zeros(param[0]), requires_grad=False) 29 | running_var = nn.Parameter(torch.ones(param[0]), requires_grad=False) 30 | self.vars_bn.extend([running_mean, running_var]) 31 | elif name in ['tanh', 'relu', 'upsample', 'flatten', 'reshape', 'leakyrelu', 'sigmoid']: 32 | continue 33 | else: 34 | raise NotImplementedError 35 | def extra_repr(self): 36 | info = '' 37 | for name, param in self.config: 38 | if name is 'linear': 39 | tmp = 'linear:(in:%d, out:%d)'%(param[1], param[0]) 40 | info += tmp + '\n' 41 | elif name is 'leakyrelu': 42 | tmp = 'leakyrelu:(slope:%f)'%(param[0]) 43 | info += tmp + '\n' 44 | elif name in ['flatten', 'tanh', 'relu', 'upsample', 'reshape', 'sigmoid', 'use_logits', 'bn']: 45 | tmp = name + ':' + str(tuple(param)) 46 | info += tmp + '\n' 47 | else: 48 | raise NotImplementedError 49 | 50 | return info 51 | 52 | def forward(self, x, vars=None, bn_training=True): 53 | ''' 54 | :param x: 55 | :param vars: 56 | :param bn_training: 57 | :return: 58 | ''' 59 | if vars is None: 60 | vars = self.vars 61 | idx = 0 62 | idx_bn = 0 63 | for name, param in self.config: 64 | if name is 'linear': 65 | w, b = vars[idx], vars[idx + 1] 66 | x = F.linear(x, w, b) 67 | idx += 2 68 | elif name is 'bn': 69 | w, b = vars[idx], vars[idx + 1] 70 | running_mean, running_var = self.vars_bn[idx_bn], self.vars_bn[idx_bn+1] 71 | x = F.batch_norm(x, running_mean, running_var, weight=w, bias=b, training=bn_training) 72 | idx += 2 73 | idx_bn += 2 74 | elif name is 'relu': 75 | x = F.relu(x, inplace=param[0]) 76 | else: 77 | raise NotImplementedError 78 | assert idx == len(vars) 79 | assert idx_bn == len(self.vars_bn) 80 | 81 | return x 82 | 83 | def zero_grad(self, vars=None): 84 | """ 85 | :param vars: 86 | :return: 87 | """ 88 | with torch.no_grad(): 89 | if vars is None: 90 | for p in self.vars: 91 | if p.grad is not None: 92 | p.grad.zero_() 93 | else: 94 | for p in vars: 95 | if p.grad is not None: 96 | p.grad.zero_() 97 | 98 | def parameters(self): 99 | """ 100 | override this function since initial parameters will return with a generator. 101 | :return: 102 | """ 103 | return self.vars 104 | -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | import scipy.io as sio 2 | import glob 3 | from utils import * 4 | 5 | 6 | def task_generator(feature, l_list, ul_list, bs, device): 7 | feature_l = [] 8 | label_l = [] 9 | feature_l_qry = [] 10 | label_l_qry = [] 11 | 12 | for i in range(len(feature)): 13 | perm_l = list(range(len(l_list[i]))) 14 | random.shuffle(perm_l) 15 | 16 | perm_ul = list(range(len(ul_list[i]))) 17 | random.shuffle(perm_ul) 18 | 19 | # generate support set 20 | support_idx = np.array(l_list[i])[perm_l[:int(bs / 2)]].tolist() 21 | support_idx += np.array(ul_list[i])[perm_ul[:int(bs / 2)]].tolist() 22 | label_t = np.concatenate((np.ones(int(bs / 2)), np.zeros(int(bs / 2)))) 23 | 24 | feature_l.append(feature[i][support_idx].to(device)) 25 | label_l.append(torch.FloatTensor(label_t).to(device)) 26 | 27 | # generate query set 28 | bs_qry = 2 * (len(l_list[i]) - int(bs / 2)) 29 | qry_idx = np.array(l_list[i])[perm_l[-int(bs_qry / 2):]].tolist() 30 | qry_idx += np.array(ul_list[i])[perm_ul[-int(bs_qry / 2):]].tolist() 31 | label_t_qry = np.concatenate((np.ones(int(bs_qry / 2)), np.zeros(int(bs_qry / 2)))) 32 | 33 | feature_l_qry.append(feature[i][label_t_qry].to(device)) 34 | label_l_qry.append(torch.FloatTensor(label_t_qry).to(device)) 35 | 36 | return feature_l, label_l, feature_l_qry, label_l_qry 37 | 38 | 39 | def test_task_generator(feature, l_list, ul_list, bs, label, test_idx, device): 40 | feature_l = [] 41 | label_l = [] 42 | 43 | for q in range(3): 44 | perm_l = list(range(len(l_list))) 45 | random.shuffle(perm_l) 46 | 47 | perm_ul = list(range(len(ul_list))) 48 | random.shuffle(perm_ul) 49 | 50 | # generate support set 51 | support_idx = np.array(l_list)[perm_l[:int(bs / 2)]].tolist() 52 | support_idx += np.array(ul_list)[perm_ul[:int(bs / 2)]].tolist() 53 | label_t = np.concatenate((np.ones(int(bs / 2)), np.zeros(int(bs / 2)))) 54 | 55 | feature_l.append(feature[support_idx].to(device)) 56 | label_l.append(torch.FloatTensor(label_t).to(device)) 57 | 58 | return feature_l, label_l, feature[test_idx].to( 59 | device), torch.FloatTensor(label[test_idx]).to(device) 60 | 61 | 62 | def test_task_generator_backup(feature, l_list, ul_list, bs, label, test_idx, device): 63 | perm_l = list(range(len(l_list))) 64 | random.shuffle(perm_l) 65 | 66 | perm_ul = list(range(len(ul_list))) 67 | random.shuffle(perm_ul) 68 | 69 | # generate support set 70 | support_idx = np.array(l_list)[perm_l[:int(bs / 2)]].tolist() 71 | support_idx += np.array(ul_list)[perm_ul[:int(bs / 2)]].tolist() 72 | label_t = np.concatenate((np.ones(int(bs / 2)), np.zeros(int(bs / 2)))) 73 | 74 | return feature[support_idx].to(device), torch.FloatTensor(label_t).to(device), feature[test_idx].to( 75 | device), torch.FloatTensor(label[test_idx]).to(device) 76 | 77 | 78 | def load_yelp(file): 79 | data = sio.loadmat(file) 80 | network = data['Network'].astype(np.float) 81 | labels = data['Label'].flatten() 82 | attributes = data['Attributes'].astype(np.float) 83 | 84 | return network, attributes, labels 85 | 86 | 87 | def sp_matrix_to_torch_sparse_tensor(sparse_mx): 88 | """Convert a scipy sparse matrix to a torch sparse tensor.""" 89 | sparse_mx = sparse_mx.tocoo().astype(np.float32) 90 | indices = torch.from_numpy( 91 | np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64)) 92 | values = torch.from_numpy(sparse_mx.data) 93 | shape = torch.Size(sparse_mx.shape) 94 | return torch.sparse.FloatTensor(indices, values, shape) 95 | 96 | 97 | class DataProcessor: 98 | 99 | def __init__(self, num_graph, degree, data_name): 100 | self.num_graph = num_graph # number of auxiliary graph + 1 (target graph) 101 | self.degree = degree 102 | self.data_name = data_name 103 | self.feature_l, self.label_l, self.adj_l = [], [], [] 104 | self.target, self.target_adj, self.target_feature, self.target_label= None, None, None, None 105 | self.target_idx_train_ano_all, self.target_idx_train_normal_all, self.target_idx_val, self.target_idx_test = None, None, None, None 106 | 107 | self.labeled_idx_l, self.unlabeled_idx_l = [], [] 108 | self.target_labeled_idx, self.target_unlabeled_idx = [], [] 109 | 110 | def data_loader(self): 111 | 112 | l = glob.glob("graphs/{}/*.mat".format(self.data_name)) 113 | 114 | f_l = random.sample(l, self.num_graph) 115 | random.shuffle(f_l) 116 | for f in f_l[:-1]: 117 | adj, feature, label = load_yelp(f) 118 | adj = normalize_adjacency(adj) 119 | adj = sp_matrix_to_torch_sparse_tensor(adj).float() 120 | feature = torch.FloatTensor(feature.toarray()) 121 | feature = sgc_precompute(feature, adj, self.degree) 122 | 123 | self.feature_l.append(feature) 124 | self.label_l.append(label) 125 | self.adj_l.append(adj) 126 | 127 | # load the target graph 128 | self.target = f_l[-1] 129 | adj, feature, label = load_yelp(self.target) 130 | adj = normalize_adjacency(adj) 131 | self.target_adj = sp_matrix_to_torch_sparse_tensor(adj).float() 132 | self.target_feature = torch.FloatTensor(feature.toarray()) 133 | self.target_feature = sgc_precompute(self.target_feature, self.target_adj, self.degree) 134 | self.target_label = label 135 | 136 | # split the target graph into train/valid/test with 4/2/4 137 | idx_anomaly = np.random.permutation(np.nonzero(self.target_label == 1)[0]) 138 | idx_normal = np.random.permutation(np.nonzero(self.target_label == 0)[0]) 139 | split_ano = int(0.4 * len(idx_anomaly)) 140 | split_normal = int(0.4 * len(idx_normal)) 141 | 142 | self.target_idx_train_ano_all = idx_anomaly[:split_ano] 143 | self.target_idx_train_normal_all = idx_normal[:split_normal] 144 | self.target_idx_val = np.concatenate((idx_anomaly[split_ano:-split_ano], idx_normal[split_normal:-split_normal])).tolist() 145 | self.target_idx_test = np.concatenate((idx_anomaly[-split_ano:], idx_normal[-split_normal:])).tolist() 146 | 147 | print("data loading finished.") 148 | 149 | def sample_anomaly(self, num_labeled_ano): 150 | 151 | for i in range(self.num_graph - 1): 152 | # sampling anomalies from auxiliary graphs 153 | label_tmp = self.label_l[i] 154 | idx_anomaly = np.random.permutation(np.nonzero(label_tmp == 1)[0]) 155 | idx_normal = np.random.permutation(np.nonzero(label_tmp == 0)[0]) 156 | self.labeled_idx_l.append(idx_anomaly[:num_labeled_ano].tolist()) 157 | self.unlabeled_idx_l.append(np.concatenate((idx_normal, idx_anomaly[num_labeled_ano:])).tolist()) 158 | 159 | self.target_idx_train_ano_all = np.random.permutation(self.target_idx_train_ano_all) 160 | self.target_idx_train_normal_all = np.random.permutation(self.target_idx_train_normal_all) 161 | 162 | if num_labeled_ano <= len(self.target_idx_train_ano_all): 163 | self.target_labeled_idx = self.target_idx_train_ano_all[:num_labeled_ano].tolist() 164 | self.target_unlabeled_idx = np.concatenate((self.target_idx_train_normal_all, self.target_idx_train_ano_all[num_labeled_ano:])).tolist() 165 | 166 | return [self.feature_l, self.labeled_idx_l, self.unlabeled_idx_l], \ 167 | [self.target_feature, self.target_labeled_idx, self.target_unlabeled_idx] 168 | 169 | 170 | -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import scipy.stats 4 | # from torch.utils.data import DataLoader 5 | from torch.optim import lr_scheduler 6 | import random, sys, pickle 7 | import argparse 8 | from meta import * 9 | from getConfig import modelArch 10 | from data import DataProcessor, task_generator, test_task_generator, test_task_generator_backup 11 | from models import SGC 12 | from sklearn.metrics import auc, roc_curve 13 | from utils import aucPerformance 14 | 15 | 16 | 17 | def main(): 18 | 19 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 20 | 21 | random.seed(args.seed) 22 | torch.manual_seed(args.seed) 23 | np.random.seed(args.seed) 24 | if torch.cuda.is_available(): 25 | torch.cuda.manual_seed(args.seed) 26 | 27 | num_labeled_ano = 10 # each graph (auxiliary or target) has 10 sampled anomaly nodes 28 | 29 | results_meta_gdn = [] 30 | results_gdn = [] 31 | for t in range(args.num_run): 32 | dataset = DataProcessor(num_graph=args.num_graph, degree=2, data_name=args.data_name) 33 | dataset.data_loader() 34 | 35 | # training meta-gdn 36 | print("Meta-GDN training...") 37 | print("In %d-th run..." % (t + 1)) 38 | [feature_list, l_list, ul_list], [target_feature, target_l_idx, target_ul_idx] = dataset.sample_anomaly(num_labeled_ano) 39 | 40 | config = modelArch(feature_list[0].shape[1], 1) 41 | 42 | maml = Meta(args, config).to(device) 43 | best_val_auc = 0 44 | for e in range(1, args.num_epochs + 1): 45 | 46 | # training 47 | maml.train() 48 | x_train, y_train, x_qry, y_qry = task_generator(feature_list, l_list, ul_list, bs=args.bs, device=device) 49 | loss = maml(x_train, y_train, x_qry, y_qry) 50 | torch.save(maml.state_dict(), 'temp.pkl') 51 | # validation 52 | model_meta_eval = Meta(args, config).to(device) 53 | model_meta_eval.load_state_dict(torch.load('temp.pkl')) 54 | model_meta_eval.eval() 55 | x_train, y_train, x_val, y_val = test_task_generator(target_feature, target_l_idx, 56 | target_ul_idx, args.bs, 57 | dataset.target_label, 58 | dataset.target_idx_val, device) 59 | auc_roc, auc_pr, ap = model_meta_eval.evaluate(x_train, y_train, x_val, y_val) 60 | print("%dth Epoch: Training Loss %4f, Validation, AUC-ROC %.4f, AUC-PR %.4f, AP %.4f" % (e, loss.item(), auc_roc, auc_pr, ap)) 61 | 62 | if auc_roc > best_val_auc: # store the best model 63 | best_val_auc = auc_roc 64 | torch.save(maml.state_dict(), 'best_meta_GDN.pkl') 65 | 66 | print("End of training.") 67 | # testing 68 | print("Load the best performing Meta-GDN model and Evaluate") 69 | maml = Meta(args, config).to(device) 70 | maml.load_state_dict(torch.load('best_meta_GDN.pkl')) 71 | maml.eval() 72 | x_train, y_train, x_test, y_test = test_task_generator(target_feature, target_l_idx, 73 | target_ul_idx, args.bs, 74 | dataset.target_label, 75 | dataset.target_idx_test, device) 76 | auc_roc, auc_pr, ap = maml.evaluate(x_train, y_train, x_test, y_test) 77 | print("Testing performance of Meta-GDN: AUC-ROC %.4f, AUC-PR %.4f, AP %.4f" % (auc_roc, auc_pr, ap)) 78 | print("End of evaluating.") 79 | results_meta_gdn.append(auc_roc) 80 | 81 | # GDN training 82 | print('GDN training...') 83 | model = SGC(target_feature.shape[1], 1).to(device) 84 | optim = torch.optim.Adam(model.parameters(), lr=args.gdn_lr, weight_decay=0) 85 | best_val_auc = 0 86 | for e in range(1, args.num_epochs_GDN + 1): 87 | 88 | x_train, y_train, x_test, y_test = test_task_generator_backup(target_feature, target_l_idx, 89 | target_ul_idx, num_labeled_ano * 2, 90 | dataset.target_label, 91 | dataset.target_idx_test, device) 92 | x_train, y_train = x_train.to(device), y_train.to(device) 93 | model.train() 94 | optim.zero_grad() 95 | y_pred = model(x_train) 96 | loss = dev_loss(y_train, y_pred) 97 | loss.backward() 98 | optim.step() 99 | 100 | # validation 101 | _, _, x_val, y_val = test_task_generator_backup(target_feature, target_l_idx, 102 | target_ul_idx, num_labeled_ano * 2, 103 | dataset.target_label, 104 | dataset.target_idx_val, device) 105 | model.eval() 106 | y_pred = model(x_val).detach().cpu().numpy() 107 | y_val = y_val.detach().cpu().numpy() 108 | auc_roc, auc_pr, ap = aucPerformance(y_val, y_pred) 109 | print("%dth Epoch: Training Loss %4f, Validation, AUC-ROC %.4f, AUC-PR %.4f, AP %.4f" % (e, loss.item(), auc_roc, auc_pr, ap)) 110 | 111 | if auc_roc > best_val_auc: # store the best model 112 | best_val_auc = auc_roc 113 | torch.save(model.state_dict(), 'best_GDN.pkl') 114 | 115 | # testing 116 | model = SGC(target_feature.shape[1], 1).to(device) 117 | model.load_state_dict(torch.load('best_GDN.pkl')) 118 | model.eval() 119 | _, _, x_test, y_test = test_task_generator_backup(target_feature, target_l_idx, 120 | target_ul_idx, num_labeled_ano * 2, 121 | dataset.target_label, 122 | dataset.target_idx_test, device) 123 | y_pred = model(x_test).detach().cpu().numpy() 124 | y_test = y_test.detach().cpu().numpy() 125 | auc_roc, auc_pr, auc_pr = aucPerformance(y_test, y_pred) 126 | print("Testing performance of GDN: AUC-ROC: %.4f, AUC-PR: %.4f, AP: %.4f" % (auc_roc, auc_pr, ap)) 127 | results_gdn.append(auc_roc) 128 | 129 | print(results_gdn) 130 | print(results_meta_gdn) 131 | print("Average Testing performance of GDN: AUC-ROC: %.4f" % (sum(results_gdn)*1.0/len(results_gdn))) 132 | print("Average Testing performance of meta-GDN: AUC-ROC: %.4f" % (sum(results_meta_gdn) * 1.0 / len(results_meta_gdn))) 133 | 134 | 135 | if __name__ == '__main__': 136 | argparser = argparse.ArgumentParser() 137 | argparser.add_argument('--data_name', help='pubmed/yelp', default='pubmed') 138 | argparser.add_argument('--num_epochs', type=int, help='epoch number', default=100) 139 | argparser.add_argument('--num_epochs_GDN', type=int, help='epoch number for GDN', default=100) 140 | argparser.add_argument('--gdn_lr', type=float, help='learning rate for GDN', default=0.01) 141 | argparser.add_argument('--bs', type=int, help='batch size', default=16) 142 | argparser.add_argument('--num_graph', type=int, help='meta batch size, namely task num', default=5) 143 | argparser.add_argument('--meta_lr', type=float, help='meta-level outer learning rate', default=0.003) 144 | argparser.add_argument('--update_lr', type=float, help='task-level inner update learning rate', default=0.5) 145 | argparser.add_argument('--update_step', type=int, help='task-level inner update steps', default=3) 146 | argparser.add_argument('--update_step_test', type=int, help='update steps for finetunning', default=3) 147 | argparser.add_argument('--seed', type=int, default=1234, help='Random seed.') 148 | argparser.add_argument('--num_run', type=int, help='run the experiments multiple times', default=100) 149 | 150 | args = argparser.parse_args() 151 | 152 | main() -------------------------------------------------------------------------------- /meta.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch import optim 4 | from torch.nn import functional as F 5 | from torch.utils.data import TensorDataset, DataLoader 6 | from torch import optim 7 | import numpy as np 8 | from learner import Learner 9 | from copy import deepcopy 10 | from sklearn.metrics import auc, roc_curve 11 | from utils import aucPerformance 12 | 13 | 14 | def dev_loss(y_true, y_prediction): 15 | ''' 16 | z-score based deviation loss 17 | :param y_true: true anomaly labels 18 | :param y_prediction: predicted anomaly label 19 | :return: loss in training 20 | ''' 21 | confidence_margin = 5.0 22 | ref = torch.tensor(np.random.normal(loc=0.0, scale=1.0, size=5000), dtype=torch.float32) 23 | dev = (y_prediction - torch.mean(ref)) / torch.std(ref) 24 | inlier_loss = torch.abs(dev) 25 | outlier_loss = confidence_margin - dev 26 | outlier_loss[outlier_loss < 0.] = 0 27 | return torch.mean((1 - y_true) * inlier_loss.flatten() + y_true * outlier_loss.flatten()) 28 | 29 | 30 | class Meta(nn.Module): 31 | 32 | def __init__(self, args, config): 33 | 34 | super(Meta, self).__init__() 35 | 36 | self.update_lr = args.update_lr 37 | self.meta_lr = args.meta_lr 38 | self.task_num = args.num_graph - 1 39 | self.update_step = args.update_step 40 | self.update_step_test = args.update_step_test 41 | 42 | self.net = Learner(config) 43 | self.meta_optim = optim.Adam(self.net.parameters(), lr=self.meta_lr) 44 | 45 | 46 | def forward(self, x_train, y_train, x_qry, y_qry): 47 | ''' 48 | :param x_train: [nb_task, batch_size, attr_dimension] 49 | :param y_train: [nb_task, batch_size] 50 | :param x_qry: [nb_task, qry_batch_size, attr_dimension] 51 | :param y_qry: [nb_task, qry_batch_size] 52 | :return: 53 | ''' 54 | num_task = len(x_train) 55 | losses = [0 for _ in range(self.update_step + 1)] 56 | results = [] 57 | 58 | for t in range(num_task): 59 | prediction = self.net(x_train[t], vars=None, bn_training=True) 60 | loss = dev_loss(y_train[t], prediction) 61 | grad = torch.autograd.grad(loss, self.net.parameters()) 62 | # update the parameters 63 | adapt_weights = list(map(lambda p: p[1] - self.update_lr * p[0], zip(grad, self.net.parameters()))) 64 | 65 | #before the first update 66 | with torch.no_grad(): 67 | prediction_q = self.net(x_qry[t], self.net.parameters(), bn_training=True) 68 | loss_q = dev_loss(y_qry[t], prediction_q) 69 | losses[0] += loss_q 70 | 71 | # after the first update 72 | with torch.no_grad(): 73 | prediction_q = self.net(x_qry[t], adapt_weights, bn_training=True) 74 | loss_q = dev_loss(y_qry[t], prediction_q) 75 | losses[1] += loss_q 76 | 77 | # for multiple step update 78 | for k in range(1, self.update_step): 79 | # evaluate the i-th task 80 | prediction = self.net(x_train[t], adapt_weights, bn_training=True) 81 | loss = dev_loss(y_train[t], prediction) 82 | # compute gradients on theta' 83 | grad = torch.autograd.grad(loss, adapt_weights) 84 | # perform one-step update step i + 1 85 | adapt_weights = list(map(lambda p: p[1] - self.update_lr * p[0], zip(grad, adapt_weights))) 86 | 87 | prediction_q = self.net(x_qry[t], adapt_weights, bn_training=True) 88 | loss_q = dev_loss(y_qry[t], prediction_q) 89 | losses[k+1] += loss_q 90 | 91 | # evaluation can be done here 92 | 93 | # finish all tasks 94 | loss_f = losses[-1] / num_task 95 | # update parameters 96 | self.meta_optim.zero_grad() 97 | loss_f.backward() 98 | self.meta_optim.step() 99 | 100 | # evaluate 101 | return loss_f 102 | 103 | 104 | def evaluate(self, x_train, y_train, x_test, y_test): 105 | 106 | prediction = self.net(x_train[0], vars=None, bn_training=True) 107 | loss = dev_loss(y_train[0], prediction) 108 | grad = torch.autograd.grad(loss, self.net.parameters()) 109 | # update the parameters 110 | adapt_weights = list(map(lambda p: p[1] - self.update_lr * p[0], zip(grad, self.net.parameters()))) 111 | 112 | # for multiple step update 113 | for k in range(1, self.update_step_test): 114 | # evaluate the i-th task 115 | prediction = self.net(x_train[0], adapt_weights, bn_training=True) 116 | loss = dev_loss(y_train[0], prediction) 117 | # compute gradients on theta' 118 | grad = torch.autograd.grad(loss, adapt_weights) 119 | # perform one-step update step i + 1 120 | adapt_weights = list(map(lambda p: p[1] - self.update_lr * p[0], zip(grad, adapt_weights))) 121 | 122 | for i in range(1, len(x_train)): 123 | # for multiple step update 124 | for k in range(self.update_step_test): 125 | # evaluate the i-th task 126 | prediction = self.net(x_train[i], adapt_weights, bn_training=True) 127 | loss = dev_loss(y_train[i], prediction) 128 | # compute gradients on theta' 129 | grad = torch.autograd.grad(loss, adapt_weights) 130 | # perform one-step update step i + 1 131 | adapt_weights = list(map(lambda p: p[1] - self.update_lr * p[0], zip(grad, adapt_weights))) 132 | 133 | y_pred = self.net(x_test, adapt_weights, bn_training=True) 134 | y_test = y_test.detach().cpu().numpy() 135 | y_pred = y_pred.detach().cpu().numpy() 136 | 137 | auc_roc, auc_pr, ap = aucPerformance(y_test, y_pred) 138 | return auc_roc, auc_pr, ap 139 | 140 | 141 | def evaluate2(self, x_train, y_train, x_test, y_test): 142 | 143 | for i in range(len(x_train)): 144 | prediction = self.net(x_train[i], vars=None, bn_training=True) 145 | loss = dev_loss(y_train[i], prediction) 146 | grad = torch.autograd.grad(loss, self.net.parameters()) 147 | # update the parameters 148 | adapt_weights = list(map(lambda p: p[1] - self.update_lr * p[0], zip(grad, self.net.parameters()))) 149 | 150 | 151 | # for multiple step update 152 | for k in range(1, self.update_step_test): 153 | # evaluate the i-th task 154 | prediction = self.net(x_train[i], adapt_weights, bn_training=True) 155 | loss = dev_loss(y_train[i], prediction) 156 | # compute gradients on theta' 157 | grad = torch.autograd.grad(loss, adapt_weights) 158 | # perform one-step update step i + 1 159 | adapt_weights = list(map(lambda p: p[1] - self.update_lr * p[0], zip(grad, adapt_weights))) 160 | 161 | y_pred = self.net(x_test, adapt_weights, bn_training=True) 162 | y_test = y_test.detach().cpu().numpy() 163 | y_pred = y_pred.detach().cpu().numpy() 164 | 165 | auc_roc, auc_pr, ap = aucPerformance(y_test, y_pred) 166 | return auc_roc, auc_pr, ap 167 | 168 | def evaluate_backup(self, x_train, y_train, x_test, y_test): 169 | 170 | prediction = self.net(x_train, vars=None, bn_training=True) 171 | loss = dev_loss(y_train, prediction) 172 | grad = torch.autograd.grad(loss, self.net.parameters()) 173 | # update the parameters 174 | adapt_weights = list(map(lambda p: p[1] - self.update_lr * p[0], zip(grad, self.net.parameters()))) 175 | 176 | 177 | # for multiple step update 178 | for k in range(1, self.update_step_test): 179 | # evaluate the i-th task 180 | prediction = self.net(x_train, adapt_weights, bn_training=True) 181 | loss = dev_loss(y_train, prediction) 182 | # compute gradients on theta' 183 | grad = torch.autograd.grad(loss, adapt_weights) 184 | # perform one-step update step i + 1 185 | adapt_weights = list(map(lambda p: p[1] - self.update_lr * p[0], zip(grad, adapt_weights))) 186 | 187 | 188 | y_pred = self.net(x_test, adapt_weights, bn_training=True) 189 | y_test = y_test.detach().cpu().numpy() 190 | y_pred = y_pred.detach().cpu().numpy() 191 | 192 | auc_roc, auc_pr, ap = aucPerformance(y_test, y_pred) 193 | return auc_roc, auc_pr, ap 194 | 195 | 196 | 197 | def main(): 198 | pass 199 | 200 | if __name__ == '__main__': 201 | main() 202 | --------------------------------------------------------------------------------