├── .gitignore ├── models ├── __init__.py ├── BWGNN.py ├── base.py └── DSGAD.py ├── datasets ├── __init__.py └── datasets.py ├── paper └── AAAI2025_Camera_version.pdf ├── README.md └── t.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | 3 | datasets/* 4 | !datasets/*.py 5 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .BWGNN import BWGNN, BWGNN_hete 2 | from .DSGAD import DSGAD, DSGAD_hete 3 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | ''' 2 | load dataset 3 | ''' 4 | 5 | from .datasets import yelp, tfinance, amazon, tolokers, yelp_hete, amazon_hete 6 | -------------------------------------------------------------------------------- /paper/AAAI2025_Camera_version.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IWantBe/Dynamic-Spectral-Graph-Anomaly-Detection/HEAD/paper/AAAI2025_Camera_version.pdf -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Dynamic-Spectral-Graph-Anomaly-Detection 2 | Dynamic Spectral Graph Anomaly Detection is published by AAAI2025. 3 | 4 | I would like to express my sincere gratitude to Mr. Zhang Tairui for his work on the code implementation, and also extend my heartfelt thanks to the other collaborators. 5 | 6 | If you want to use this code, please 7 | - install python(3.9.18), dgl(2.0.0.cu118, py39_0) and pytorch(2.2.1, py3.9_cuda11.8_cudnn8.7.0_0), and numpy, scipy, sklearn, etc. 8 | - download tfinance and tolokers datasets, and put it into the [`datasets`](datasets/) folder (yelp and amazon are built in the dgl). 9 | - run `python t.py --model DSGAD --run 1 --dataset [yelp, tfinance, amazon, tolokers]` in homogeneous graph, `python t.py --model DSGAD_hete --run 1 --dataset [yelp_hete, amazon_hete]` in heterogeneous graph. 10 | 11 | 12 | Note: Please refer to the original BWGNN code published at https://github.com/squareRoot3/Rethinking-Anomaly-Detection. In this paper, we have adapted it into our framework. 13 | 14 | 15 | 16 | If you find my work useful, please cite it as follows: 17 | 18 | @article{Zheng2025, 19 | title={Dynamic Spectral Graph Anomaly Detection}, 20 | volume={39}, 21 | url={https://ojs.aaai.org/index.php/AAAI/article/view/33464}, 22 | DOI={10.1609/aaai.v39i12.33464}, 23 | number={12}, 24 | journal={Proceedings of the AAAI Conference on Artificial Intelligence}, 25 | author={Zheng, Jianbo and Yang, Chao and Zhang, Tairui and Cao, Longbing and Jiang, Bin and Fan, Xuhui and Wu, Xiao-ming and Zhu, Xianxun}, 26 | year={2025}, 27 | month={Apr.}, 28 | pages={13410-13418} } 29 | 30 | -------------------------------------------------------------------------------- /t.py: -------------------------------------------------------------------------------- 1 | '''The file in which the model is trained''' 2 | 3 | import torch 4 | import argparse 5 | import json 6 | import models 7 | import warnings 8 | import random 9 | import numpy as np 10 | import os 11 | import time 12 | 13 | 14 | def seed_everything(seed, strengthen=False): 15 | random.seed(seed) 16 | np.random.seed(seed) 17 | torch.manual_seed(seed) 18 | torch.cuda.manual_seed(seed) 19 | torch.cuda.manual_seed_all(seed) 20 | if strengthen: 21 | torch.backends.cudnn.deterministic = True 22 | torch.backends.cudnn.benchmark = False 23 | torch.use_deterministic_algorithms(True) 24 | os.environ["CUDA_LAUNCH_BLOCKING"] = "1" 25 | os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8" 26 | os.environ['PYTHONHASHSEED'] = str(seed) 27 | 28 | 29 | if __name__ == '__main__': 30 | 31 | warnings.filterwarnings('ignore') 32 | 33 | parser = argparse.ArgumentParser() 34 | parser.add_argument('--seed', type=int, default=42) # Random number seeds 35 | parser.add_argument('--run', type=int, default=10) # How many times you run during training 36 | 37 | parser.add_argument('--dataset', type=str, default='yelp') # dataset 38 | parser.add_argument('--ratio', type=float, nargs='+', default=[0.4, 0.3, 0.3]) # The ratio of the training/validation/testing set 39 | parser.add_argument('--epoch', type=int, default=100) # epoch 40 | parser.add_argument('--device', type=str, default='cuda') # train device: cuda or cpu 41 | 42 | parser.add_argument('--model', type=str, default='DSGAD') # model 43 | parser.add_argument('--model_config', type=json.loads, default='{}') # The setting of model 44 | parser.add_argument('--save', type=str, default='', help='save weights path') # Model weights save paths 45 | args = parser.parse_args() 46 | 47 | seed_everything(args.seed) 48 | 49 | print(f'dataset: {args.dataset}') 50 | print(f'model: {args.model}') 51 | print(f'model config: {args.model_config}') 52 | print(f'device: {args.device}') 53 | print() 54 | 55 | metrics = ['auc', 'recall', 'precision', 'f1_macro'] 56 | performance = {} 57 | for metric in metrics: 58 | performance[metric] = [] 59 | 60 | start = time.time() 61 | for t in range(args.run): # Train the model and collect metrics 62 | if args.device == 'cuda': torch.cuda.empty_cache() 63 | 64 | print(f'trial: {t+1}/{args.run}') 65 | 66 | m, p = eval(f'models.{args.model}').trainfit(args) 67 | if args.save: torch.save(m.state_dict(), args.save) # Save the weights 68 | for metric in metrics: 69 | performance[metric].append(p[metric]) 70 | 71 | print() 72 | end = time.time() 73 | dt = int(end - start) # Training time 74 | 75 | # Finally, all metrics for each training session are output in a unified manner 76 | print(f'dataset: {args.dataset}') 77 | print(f'model: {args.model}') 78 | print(f'model config: {args.model_config}') 79 | print(f'time: {dt//60}:{dt%60}') 80 | for metric in metrics: 81 | print(f'{metric} ', end='') 82 | print() 83 | for t in range(args.run): 84 | print(f'{t+1:<2}: ', end='') 85 | for metric in metrics: 86 | print(f'{performance[metric][t]:.5f} ', end='') 87 | print() 88 | -------------------------------------------------------------------------------- /models/BWGNN.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import dgl 4 | import dgl.function as dglF 5 | import scipy 6 | import sympy 7 | from .base import BaseModel 8 | 9 | 10 | def calculate_theta2(d: int) -> list[list[float]]: 11 | '''Calculate the polynomial coefficients of the filter''' 12 | thetas = [] 13 | x = sympy.symbols('x') 14 | for i in range(d + 1): 15 | f = sympy.poly((x / 2)**i * (1 - x / 2)**(d - i) / (scipy.special.beta(i + 1, d + 1 - i))) 16 | coeff = f.all_coeffs() 17 | inv_coeff = [] 18 | for i in range(d + 1): 19 | inv_coeff.append(float(coeff[d - i])) 20 | thetas.append(inv_coeff) 21 | return thetas 22 | 23 | 24 | def poly_conv(theta: list[float], g: dgl.DGLGraph, features: torch.Tensor) -> torch.Tensor: 25 | '''Polynomial convolution using filter coefficients''' 26 | 27 | def unnLaplacian(feat, D_invsqrt, graph): 28 | """ Operation Feat * D^-1/2 A D^-1/2 """ 29 | graph.ndata['h'] = feat * D_invsqrt 30 | graph.update_all(dglF.copy_u('h', 'm'), dglF.sum('m', 'h')) 31 | return feat - graph.ndata.pop('h') * D_invsqrt 32 | 33 | with g.local_scope(): 34 | D_invsqrt = torch.pow(g.in_degrees().float().clamp(min=1), -0.5).unsqueeze(-1).to(features.device) 35 | h = theta[0] * features 36 | for k in range(1, len(theta)): 37 | features = unnLaplacian(features, D_invsqrt, g) 38 | h += theta[k] * features 39 | return h 40 | 41 | 42 | class BWGNN(BaseModel): 43 | 44 | def __init__(self, in_feats: int, h_feats: int = 64, num_classes: int = 2, d=2): 45 | super().__init__() 46 | 47 | self.thetas = calculate_theta2(d) 48 | 49 | self.input = nn.Sequential( 50 | nn.Linear(in_feats, h_feats), 51 | nn.ReLU(), 52 | nn.Linear(h_feats, h_feats), 53 | nn.ReLU(), 54 | ) 55 | self.classifier = nn.Sequential( 56 | nn.Linear(h_feats * len(self.thetas), h_feats), 57 | nn.ReLU(), 58 | nn.Linear(h_feats, num_classes), 59 | ) 60 | 61 | def forward(self, g: dgl.DGLGraph, in_feat: torch.Tensor): 62 | h = self.input(in_feat) 63 | 64 | # Direct splicing 65 | h_final = [poly_conv(theta, g, h) for theta in self.thetas] 66 | h_final = torch.cat(h_final, -1) 67 | 68 | h = self.classifier(h_final) 69 | return h 70 | 71 | 72 | class BWGNN_hete(BaseModel): 73 | 74 | def __init__(self, in_feats: int, h_feats: int = 64, num_classes: int = 2, d=2): 75 | super().__init__() 76 | 77 | self.thetas = calculate_theta2(d) 78 | 79 | self.input = nn.Sequential( 80 | nn.Linear(in_feats, h_feats), 81 | nn.ReLU(), 82 | nn.Linear(h_feats, h_feats), 83 | nn.ReLU(), 84 | ) 85 | self.linear3 = nn.Linear(h_feats * len(self.thetas), h_feats) 86 | self.linear4 = nn.Linear(h_feats, num_classes) 87 | self.act = nn.LeakyReLU() 88 | 89 | def forward(self, g: dgl.DGLGraph, in_feat: torch.Tensor): 90 | h = self.input(in_feat) 91 | 92 | h_all = [] 93 | for relation in g.canonical_etypes: 94 | 95 | # Direct splicing 96 | h_final = [poly_conv(theta, g[relation], h) for theta in self.thetas] 97 | h_final = torch.cat(h_final, -1) 98 | h_final = self.linear3(h_final) 99 | h_all.append(h_final) 100 | 101 | h = torch.stack(h_all).sum(0) 102 | h = self.act(h) 103 | h = self.linear4(h) 104 | 105 | return h 106 | -------------------------------------------------------------------------------- /models/base.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import dgl 5 | import numpy as np 6 | import sklearn.metrics as skmetrics 7 | 8 | import datasets 9 | 10 | 11 | def get_thres(labels, probs): 12 | ''' 13 | get threshold according to the best macro f1 14 | 15 | @param labels labels 16 | @param probs The probability of being predicted to be a positive class 17 | ''' 18 | best_f1, best_thres = 0, 0 19 | for thres in np.linspace(0, 1, 101): 20 | preds = np.zeros_like(labels) 21 | preds[probs >= thres] = 1 22 | mf1 = skmetrics.f1_score(labels, preds, average='macro') 23 | if mf1 > best_f1: 24 | best_f1 = mf1 25 | best_thres = thres 26 | return best_thres 27 | 28 | 29 | class BaseModel(nn.Module): 30 | 31 | @classmethod 32 | def trainfit(cls, args): 33 | '''train model''' 34 | 35 | # load graph dataset 36 | dataset = eval(f'datasets.{args.dataset}')() 37 | g: dgl.DGLGraph = dataset.g 38 | print(f'{args.dataset}: {g}') 39 | g = dataset.split(g, args.ratio) 40 | print(f'after split: {g}') 41 | 42 | train_mask = g.ndata['train_mask'] 43 | val_mask = g.ndata['val_mask'] 44 | test_mask = g.ndata['test_mask'] 45 | print(f'train/val/test samples: {train_mask.sum().item()} {val_mask.sum().item()} {test_mask.sum().item()}') 46 | 47 | g = g.to(args.device) 48 | features = g.ndata['feature'] # N x L Node features 49 | labels = g.ndata['label'] # N Node labels 50 | 51 | # Number of Normals / Abnormalities 52 | weight = (1 - labels[train_mask]).sum().item() / labels[train_mask].sum().item() 53 | print(f'cross entropy weight: {weight}') 54 | 55 | # model 56 | model = cls(features.shape[1], **args.model_config).to(args.device) 57 | 58 | optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) 59 | 60 | # train 61 | for e in range(args.epoch): 62 | 63 | model.train() 64 | out = model(g, features) 65 | loss = F.cross_entropy( 66 | out[train_mask], 67 | labels[train_mask], 68 | weight=torch.tensor([1., weight], device=args.device), 69 | ) 70 | optimizer.zero_grad() 71 | loss.backward() 72 | optimizer.step() 73 | 74 | model.eval() 75 | with torch.no_grad(): 76 | out = model(g, features) 77 | probs = out.softmax(1)[:, 1].cpu() # The probability of being predicted to be a positive class 78 | thres = get_thres(labels[val_mask].cpu(), probs[val_mask]) # get threshold in validation set 79 | 80 | probs = probs[test_mask] 81 | y_pred = torch.zeros_like(probs) 82 | y_pred[probs >= thres] = 1 83 | y_true = labels[test_mask].cpu() 84 | 85 | # Test set metrics 86 | auc = skmetrics.roc_auc_score(y_true, probs) 87 | recall = skmetrics.recall_score(y_true, y_pred) 88 | precision = skmetrics.precision_score(y_true, y_pred) 89 | f1_macro = skmetrics.f1_score(y_true, y_pred, average='macro') 90 | 91 | print(f'Epoch {e+1}, loss: {loss.item():.5f}') 92 | print('auc recall precision f1_macro') 93 | print(f'{auc:.5f} {recall:.5f} {precision:.5f} {f1_macro:.5f}') 94 | 95 | return model, { 96 | 'auc': auc, 97 | 'recall': recall, 98 | 'precision': precision, 99 | 'f1_macro': f1_macro, 100 | } 101 | -------------------------------------------------------------------------------- /datasets/datasets.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import dgl 3 | from dgl.data import FraudAmazonDataset, FraudYelpDataset 4 | import sklearn.model_selection as skselection 5 | from pathlib import Path 6 | 7 | datasets_path = Path(__file__).resolve().parent 8 | 9 | 10 | class BaseDataset: 11 | 12 | def __init__(self, *args, **kwargs): 13 | ... 14 | 15 | def split(self, g: dgl.DGLGraph, ratio: list[float]) -> dgl.DGLGraph: 16 | ''' 17 | split the graph by 3 parts (train, val, test) 18 | 19 | the split masks are store in the 3 keys 'train_mask', 'val_mask', 'test_mask' of ndata 20 | ''' 21 | 22 | assert len(ratio) == 3 23 | assert sum(ratio) == 1 24 | for i in ratio: 25 | assert 0 <= i <= 1 26 | 27 | num_nodes = g.num_nodes() 28 | labels = g.ndata['label'] 29 | 30 | s = ['train_mask', 'val_mask', 'test_mask'] 31 | if 1 in ratio: 32 | idx = ratio.index(1) 33 | for i in range(3): 34 | if i == idx: 35 | g.ndata[s[i]] = torch.ones(num_nodes).bool() 36 | else: 37 | g.ndata[s[i]] = torch.zeros(num_nodes).bool() 38 | 39 | elif 0 in ratio: 40 | idx = ratio.index(0) 41 | g.ndata[s[idx]] = torch.zeros(num_nodes).bool() 42 | 43 | idx1, idx2 = 0, 0 44 | for i in range(3): 45 | idx1 = i 46 | if idx1 != idx: break 47 | for i in range(3): 48 | idx2 = i 49 | if idx2 != idx and idx2 != idx1: break 50 | 51 | indeics = list(range(num_nodes)) 52 | x_train, x_test, _, _ = skselection.train_test_split( 53 | indeics, 54 | labels, 55 | stratify=labels, 56 | train_size=ratio[idx1], 57 | random_state=2, 58 | shuffle=True, 59 | ) 60 | mask1 = torch.zeros(num_nodes).bool() 61 | mask2 = torch.zeros(num_nodes).bool() 62 | mask1[x_train] = True 63 | mask2[x_test] = True 64 | 65 | g.ndata[s[idx1]] = mask1 66 | g.ndata[s[idx2]] = mask2 67 | 68 | else: 69 | index = list(range(num_nodes)) 70 | idx_train, idx_rest, _, y_rest = skselection.train_test_split( 71 | index, 72 | labels, 73 | stratify=labels, 74 | train_size=ratio[0], 75 | random_state=2, 76 | shuffle=True, 77 | ) 78 | idx_valid, idx_test, _, _ = skselection.train_test_split( 79 | idx_rest, 80 | y_rest, 81 | stratify=y_rest, 82 | train_size=ratio[1] / (ratio[1] + ratio[2]), 83 | random_state=2, 84 | shuffle=True, 85 | ) 86 | train_mask = torch.zeros(num_nodes).bool() 87 | val_mask = torch.zeros(num_nodes).bool() 88 | test_mask = torch.zeros(num_nodes).bool() 89 | 90 | train_mask[idx_train] = True 91 | val_mask[idx_valid] = True 92 | test_mask[idx_test] = True 93 | 94 | g.ndata[s[0]] = train_mask 95 | g.ndata[s[1]] = val_mask 96 | g.ndata[s[2]] = test_mask 97 | 98 | return g 99 | 100 | 101 | class yelp(BaseDataset): 102 | 103 | def __init__(self, *args, **kwargs): 104 | super().__init__(*args, **kwargs) 105 | 106 | self.g = FraudYelpDataset(raw_dir=str(datasets_path))[0] 107 | # to homogeneous graph 108 | self.g = dgl.to_homogeneous(self.g, ndata=['feature', 'label', 'train_mask', 'val_mask', 'test_mask']) 109 | self.g = dgl.add_self_loop(self.g) 110 | 111 | self.g.ndata['label'] = self.g.ndata['label'].long().squeeze(-1) 112 | self.g.ndata['feature'] = self.g.ndata['feature'].float() 113 | 114 | 115 | class tfinance(BaseDataset): 116 | 117 | def __init__(self, *args, **kwargs): 118 | super().__init__(*args, **kwargs) 119 | 120 | self.g = dgl.load_graphs(str(datasets_path / 'tfinance'))[0][0] 121 | self.g.ndata['label'] = self.g.ndata['label'].argmax(1) 122 | 123 | self.g.ndata['label'] = self.g.ndata['label'].long().squeeze(-1) 124 | self.g.ndata['feature'] = self.g.ndata['feature'].float() 125 | 126 | 127 | class amazon(BaseDataset): 128 | 129 | def __init__(self, *args, **kwargs): 130 | super().__init__(*args, **kwargs) 131 | 132 | self.g = FraudAmazonDataset(raw_dir=str(datasets_path))[0] 133 | # to homogeneous graph 134 | self.g = dgl.to_homogeneous(self.g, ndata=['feature', 'label', 'train_mask', 'val_mask', 'test_mask']) 135 | self.g = dgl.add_self_loop(self.g) 136 | 137 | self.g.ndata['label'] = self.g.ndata['label'].long().squeeze(-1) 138 | self.g.ndata['feature'] = self.g.ndata['feature'].float() 139 | 140 | 141 | class tolokers(BaseDataset): 142 | 143 | def __init__(self, *args, **kwargs): 144 | super().__init__(*args, **kwargs) 145 | 146 | self.g = dgl.load_graphs(str(datasets_path / 'tolokers'))[0][0] 147 | 148 | self.g.ndata['label'] = self.g.ndata['label'].long().squeeze(-1) 149 | self.g.ndata['feature'] = self.g.ndata['feature'].float() 150 | 151 | 152 | class yelp_hete(BaseDataset): 153 | '''Heterogeneous yelp''' 154 | 155 | def __init__(self, *args, **kwargs): 156 | super().__init__(*args, **kwargs) 157 | 158 | self.g = FraudYelpDataset(raw_dir=str(datasets_path))[0] 159 | 160 | self.g.ndata['label'] = self.g.ndata['label'].long().squeeze(-1) 161 | self.g.ndata['feature'] = self.g.ndata['feature'].float() 162 | 163 | 164 | class amazon_hete(BaseDataset): 165 | '''Heterogeneous Amazon''' 166 | 167 | def __init__(self, *args, **kwargs): 168 | super().__init__(*args, **kwargs) 169 | 170 | self.g = FraudAmazonDataset(raw_dir=str(datasets_path))[0] 171 | 172 | self.g.ndata['label'] = self.g.ndata['label'].long().squeeze(-1) 173 | self.g.ndata['feature'] = self.g.ndata['feature'].float() 174 | -------------------------------------------------------------------------------- /models/DSGAD.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import dgl 5 | import dgl.function as dglF 6 | import scipy 7 | import sympy 8 | import sklearn.metrics as skmetrics 9 | from .base import BaseModel, get_thres 10 | import datasets 11 | 12 | 13 | def calculate_theta2(d: int) -> list[list[float]]: 14 | '''Calculate the polynomial coefficients of the filter''' 15 | thetas = [] 16 | x = sympy.symbols('x') 17 | for i in range(d + 1): 18 | f = sympy.poly((x / 2)**i * (1 - x / 2)**(d - i) / (scipy.special.beta(i + 1, d + 1 - i))) 19 | coeff = f.all_coeffs() 20 | inv_coeff = [] 21 | for i in range(d + 1): 22 | inv_coeff.append(float(coeff[d - i])) 23 | thetas.append(inv_coeff) 24 | return thetas 25 | 26 | 27 | def poly_conv(theta: list[float], g: dgl.DGLGraph, features: torch.Tensor) -> torch.Tensor: 28 | '''Polynomial convolution using filter coefficients''' 29 | 30 | def unnLaplacian(feat, D_invsqrt, graph): 31 | """ Operation Feat * D^-1/2 A D^-1/2 """ 32 | graph.ndata['h'] = feat * D_invsqrt 33 | graph.update_all(dglF.copy_u('h', 'm'), dglF.sum('m', 'h')) 34 | return feat - graph.ndata.pop('h') * D_invsqrt 35 | 36 | with g.local_scope(): 37 | D_invsqrt = torch.pow(g.in_degrees().float().clamp(min=1), -0.5).unsqueeze(-1).to(features.device) 38 | h = theta[0] * features 39 | for k in range(1, len(theta)): 40 | features = unnLaplacian(features, D_invsqrt, g) 41 | h += theta[k] * features 42 | return h 43 | 44 | 45 | class DSGAD(BaseModel): 46 | 47 | def __init__( 48 | self, 49 | in_nodes: int, 50 | in_feats: int, 51 | h_feats: int = 64, 52 | num_classes: int = 2, 53 | d=2, 54 | mix_beta: int = 2, # Mixed beta numbers 55 | ): 56 | super().__init__() 57 | 58 | self.thetas = calculate_theta2(d) # Parameters for each filter 59 | self.num_filters = len(self.thetas) # Number of filters 60 | 61 | self.input = nn.Sequential( 62 | nn.Linear(in_feats, h_feats), 63 | nn.ReLU(), 64 | nn.Linear(h_feats, h_feats), 65 | nn.ReLU(), 66 | ) 67 | 68 | self.fcs = nn.ModuleList([nn.Sequential( 69 | nn.Linear(h_feats, h_feats), 70 | nn.ReLU(), 71 | nn.Linear(h_feats, h_feats), 72 | ) for _ in range(self.num_filters + mix_beta)]) 73 | 74 | c = self.num_filters + mix_beta 75 | ks = 3 76 | stride = 1 77 | self.conv = nn.Sequential( 78 | nn.Conv1d(c, c, ks, stride, 'same'), 79 | nn.BatchNorm1d(c), 80 | nn.ReLU(), 81 | nn.Conv1d(c, c, ks, stride, 'same'), 82 | nn.BatchNorm1d(c), 83 | nn.ReLU(), 84 | nn.Flatten(), 85 | ) 86 | 87 | self.classifier = nn.Sequential( 88 | nn.Linear(c * h_feats, h_feats), 89 | nn.ReLU(), 90 | nn.Linear(h_feats, num_classes), 91 | ) 92 | 93 | # Weights used to mix beta filtering 94 | self.weights = nn.Parameter(torch.randn(mix_beta, self.num_filters, in_nodes, h_feats)) 95 | 96 | def forward(self, g: dgl.DGLGraph, in_feat: torch.Tensor): 97 | h = self.input(in_feat) 98 | 99 | # The weight of the hybrid filter is related to the output of the first MLP 100 | if self.weights.shape[0] > 0: # If there is a hybrid filter 101 | X = h.unsqueeze(0).unsqueeze(0) 102 | X = X.repeat(self.weights.shape[0], self.weights.shape[1], 1, 1) 103 | weights = self.weights * X 104 | weights = weights.sum(dim=(2, 3)) 105 | 106 | h = [poly_conv(theta, g, h) for theta in self.thetas] # filter 107 | if self.weights.shape[0] > 0: 108 | mix = [sum([h[i] * w[i] for i in range(self.num_filters)]) for w in weights.softmax(1)] # hybrid filter 109 | h += mix 110 | 111 | h = [self.fcs[i](h[i]) for i in range(len(self.fcs))] # FC 112 | 113 | h = [x.unsqueeze(1) for x in h] # Add channels to the output to facilitate overconvolution 114 | h = torch.cat(h, 1) # Channel merging 115 | h = self.conv(h) # convolution 116 | h = self.classifier(h) # Classifiers 117 | 118 | return h 119 | 120 | @classmethod 121 | def trainfit(cls, args): 122 | 123 | # Load the graph dataset 124 | dataset = eval(f'datasets.{args.dataset}')() 125 | g: dgl.DGLGraph = dataset.g 126 | print(f'{args.dataset}: {g}') 127 | g = dataset.split(g, args.ratio) 128 | print(f'after split: {g}') 129 | 130 | train_mask = g.ndata['train_mask'] 131 | val_mask = g.ndata['val_mask'] 132 | test_mask = g.ndata['test_mask'] 133 | print(f'train/val/test samples: {train_mask.sum().item()} {val_mask.sum().item()} {test_mask.sum().item()}') 134 | 135 | g = g.to(args.device) 136 | features = g.ndata['feature'] # N x L Node features 137 | labels = g.ndata['label'] # N Node labels 138 | 139 | # Number of Normals / Abnormalities 140 | weight = (1 - labels[train_mask]).sum().item() / labels[train_mask].sum().item() 141 | print(f'cross entropy weight: {weight}') 142 | 143 | # model 144 | model = cls(features.shape[0], features.shape[1], **args.model_config).to(args.device) 145 | 146 | params = [] 147 | for name, param in model.named_parameters(): 148 | if name == 'weights': params.append({'params': param, 'lr': 1e-1}) # large learning rates for weight 149 | else: params.append({'params': param, 'lr': 1e-3}) # others normal learning rates 150 | optimizer = torch.optim.Adam(params) 151 | 152 | # train 153 | for e in range(args.epoch): 154 | 155 | model.train() 156 | out = model(g, features) 157 | loss = F.cross_entropy( 158 | out[train_mask], 159 | labels[train_mask], 160 | weight=torch.tensor([1., weight], device=args.device), 161 | ) 162 | optimizer.zero_grad() 163 | loss.backward() 164 | optimizer.step() 165 | 166 | model.eval() 167 | with torch.no_grad(): 168 | out = model(g, features) 169 | probs = out.softmax(1)[:, 1].cpu() # The probability of being predicted to be a positive class 170 | thres = get_thres(labels[val_mask].cpu(), probs[val_mask]) # get threshold in validation set 171 | 172 | probs = probs[test_mask] 173 | y_pred = torch.zeros_like(probs) 174 | y_pred[probs >= thres] = 1 175 | y_true = labels[test_mask].cpu() 176 | 177 | # Test set metrics 178 | auc = skmetrics.roc_auc_score(y_true, probs) 179 | recall = skmetrics.recall_score(y_true, y_pred) 180 | precision = skmetrics.precision_score(y_true, y_pred) 181 | f1_macro = skmetrics.f1_score(y_true, y_pred, average='macro') 182 | 183 | print(f'Epoch {e+1}, loss: {loss.item():.5f}') 184 | print('auc recall precision f1_macro') 185 | print(f'{auc:.5f} {recall:.5f} {precision:.5f} {f1_macro:.5f}') 186 | 187 | return model, { 188 | 'auc': auc, 189 | 'recall': recall, 190 | 'precision': precision, 191 | 'f1_macro': f1_macro, 192 | } 193 | 194 | 195 | class DSGAD_hete(DSGAD): 196 | 197 | def forward(self, g: dgl.DGLGraph, in_feat: torch.Tensor): 198 | h = self.input(in_feat) 199 | 200 | if self.weights.shape[0] > 0: 201 | X = h.unsqueeze(0).unsqueeze(0) 202 | X = X.repeat(self.weights.shape[0], self.weights.shape[1], 1, 1) 203 | weights = self.weights * X 204 | weights = weights.sum(dim=(2, 3)) 205 | 206 | h_all = [] 207 | for relation in g.canonical_etypes: 208 | 209 | hh = [poly_conv(theta, g[relation], h) for theta in self.thetas] 210 | if self.weights.shape[0] > 0: 211 | mix = [sum([hh[i] * w[i] for i in range(self.num_filters)]) for w in weights.softmax(1)] 212 | hh += mix 213 | 214 | hh = [self.fcs[i](hh[i]) for i in range(len(self.fcs))] 215 | 216 | hh = [x.unsqueeze(1) for x in hh] 217 | hh = torch.cat(hh, 1) 218 | hh = self.conv(hh) 219 | hh = self.classifier(hh) 220 | 221 | h_all.append(hh) 222 | 223 | h = sum(h_all) / len(h_all) 224 | 225 | return h 226 | --------------------------------------------------------------------------------