├── .gitignore ├── LICENSE ├── README.md ├── configs ├── file_paths.json └── log_config.json └── src ├── dataLoader.py ├── main.py ├── models.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # ignore data folder and logs folder 2 | data/ 3 | logs/ 4 | models/ 5 | results/ 6 | .vscode/ 7 | 8 | # Byte-compiled / optimized / DLL files 9 | src/__pycache__/ 10 | *.py[cod] 11 | *$py.class 12 | .DS_Store 13 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Tong Zhao 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DeepFD-pyTorch 2 | This is a PyTorch implementation of DeepFD ([Deep Structure Learning for Fraud Detection](https://ieeexplore.ieee.org/abstract/document/8594881)). 3 | Other than the unsupervised DBSCAN classifier used in the original paper, I also added a supervised 3-layer MLP as a classifier option. The whole embedding part is still always unsupervised. 4 | 5 | #### Authors of this code package: 6 | [Tong Zhao](https://github.com/zhao-tong) (tzhao2@nd.edu), 7 | [Kaifeng Yu](https://github.com/kaifeng16) (ykf16@mails.tsinghua.edu.cn), 8 | [Chuchen Deng](https://github.com/ChuchenD) (cdeng@nd.edu). 9 | 10 | ## Environment settings 11 | - python==3.6.8 12 | - pytorch==1.0.1.post2 13 | 14 | 15 | ## Basic Usage 16 | Before running the model, first you need to create two folders: `results/` and `data/`. 17 | 18 | **Data Inputs** 19 | 20 | Required input data files are `graph_u2p` and `labels`, the paths need to be modified in `configs/file_paths.json`. 21 | 22 | `graph_u2p` is the pickled adjacency matrix in `scipy.sparse.csr_matrix` format, where each none-zero entry stands for a edge. 23 | 24 | `labels` is the pickled binary labels in `numpy.ndarray` format, where 1 stands for fraudulent user and 0 stands for benign user. For dataset with limited labels, the unlabeled user should be labeled as -1 in the labels vector. 25 | 26 | **Example Usage** 27 | 28 | To run the unsupervised model on Cuda with the default GPU card: 29 | ``` 30 | python -m src.main --cuda 9 --dataSet [YourDataSet] --cls_method [dbscan or mlp] 31 | ``` 32 | 33 | **Main Parameters:** 34 | 35 | ``` 36 | --dataSet The input graph dataset. (default: weibo_s) 37 | --name The name of this run. (default: debug) 38 | --cls_method The classification method to be used. Choose between dbscan and mlp. (default: dbscan) 39 | --epochs Number of epochs. (default: 10) 40 | --b_sz Batch size. (default: 100) 41 | --seed Random seed. (default: 1234) 42 | --hidden_size The size of hidden layer in encoder and decoder. (default: 128) 43 | --emb_size The size of the embeddings for each user. (default: 2) 44 | --cuda Which GPU card to use. -1 for CPU, 9 for default GPU, 0~3 for specific GPU. (default: -1) 45 | ``` 46 | 47 | 48 | 49 | 50 | 51 | -------------------------------------------------------------------------------- /configs/file_paths.json: -------------------------------------------------------------------------------- 1 | { 2 | "weibo": { 3 | "graph_u2p": "./data/weibo_s/weibo_s_graph_u2p.pkl", 4 | "labels": "./data/weibo_s/weibo_s_labels_u.pkl", 5 | "graph_u2u_simi": "./data/weibo_s/weibo_s_graph_u2u_simi.pkl" 6 | }, 7 | "alpha": { 8 | "graph_u2p": "./data/alpha/alpha_graph_u2p.pickle", 9 | "labels": "./data/alpha/alpha_labels.pickle", 10 | "graph_u2u_simi": "./data/alpha/alpha_graph_u2u_simi.pkl" 11 | } 12 | } -------------------------------------------------------------------------------- /configs/log_config.json: -------------------------------------------------------------------------------- 1 | 2 | { 3 | "version": 1, 4 | "disable_existing_loggers": false, 5 | "formatters": { 6 | "simple": { 7 | "format": "%(asctime)s - [%(levelname)s] - %(message)s" 8 | } 9 | }, 10 | 11 | "handlers": { 12 | "file_handler": { 13 | "class": "logging.FileHandler", 14 | "level": "DEBUG", 15 | "formatter": "simple", 16 | "filename": "python_logging.log", 17 | "encoding": "utf8" 18 | } 19 | }, 20 | 21 | "root": { 22 | "level": "DEBUG", 23 | "handlers": ["file_handler"] 24 | } 25 | } -------------------------------------------------------------------------------- /src/dataLoader.py: -------------------------------------------------------------------------------- 1 | __author__ = 'Tong Zhao' 2 | __email__ = 'tzhao2@nd.edu' 3 | 4 | import os 5 | import sys 6 | import copy 7 | import pickle 8 | import pathlib 9 | import numpy as np 10 | from multiprocessing import Pool 11 | from scipy.sparse import csr_matrix 12 | 13 | from src.utils import * 14 | from collections import defaultdict, Counter 15 | 16 | class DataLoader(): 17 | """ data loader """ 18 | def __init__(self, args, logger): 19 | self.ds = args.dataSet 20 | self.args = args 21 | self.logger = logger 22 | self.file_paths = json.load(open(f'{args.config_dir}/{args.file_paths}')) 23 | self.load_dataSet(args.dataSet) 24 | 25 | def load_dataSet(self, dataSet): 26 | ds = dataSet 27 | graph_u2p_file = self.file_paths[ds]['graph_u2p'] 28 | graph_simi_file = self.file_paths[ds]['graph_u2u_simi'] 29 | labels_file = self.file_paths[ds]['labels'] 30 | 31 | graph_u2p = pickle.load(open(graph_u2p_file, 'rb')) 32 | labels = pickle.load(open(labels_file, 'rb')) 33 | graph_u2p[graph_u2p > 0] = 1 34 | graph_u2u = graph_u2p @ graph_u2p.T 35 | if os.path.isfile(graph_simi_file): 36 | graph_simi = pickle.load(open(graph_simi_file, 'rb')) 37 | self.logger.info('loaded similarity graph from cache') 38 | else: 39 | graph_simi = np.zeros(np.shape(graph_u2u)) 40 | nz_entries = [] 41 | for i in range(np.shape(graph_u2u)[0]): 42 | for j in range(i+1, np.shape(graph_u2u)[0]): 43 | nz_entries.append([i, j]) 44 | self.logger.info(f'Calculating user-user similarity graph, {len(nz_entries)} edges to go...') 45 | sz = 1000 46 | n_batch = math.ceil(len(nz_entries) / sz) 47 | batches = np.array_split(nz_entries, n_batch) 48 | pool = Pool() 49 | results = pool.map(get_simi_single_iter, [(entries_batch, graph_u2p) for entries_batch in batches]) 50 | results = list(zip(*results)) 51 | row = np.concatenate(results[0]) 52 | col = np.concatenate(results[1]) 53 | dat = np.concatenate(results[2]) 54 | for x in range(len(row)): 55 | graph_simi[row[x], col[x]] = dat[x] 56 | graph_simi[col[x], row[x]] = dat[x] 57 | pickle.dump(graph_simi, open(graph_simi_file, "wb")) 58 | self.logger.info('Calculated user-user similarity and saved it for catch.') 59 | 60 | assert len(labels) == np.shape(graph_u2p)[0] == np.shape(graph_u2u)[0] 61 | if ds == 'alpha': 62 | labeled_nodes = np.where(labels>=0)[0] 63 | test_indexs_cls, val_indexs_cls, train_indexs_cls = self._split_data_cls_limited(labeled_nodes) 64 | else: 65 | test_indexs_cls, val_indexs_cls, train_indexs_cls = self._split_data_cls(len(labels)) 66 | 67 | setattr(self, dataSet+'_train', np.arange(np.shape(graph_u2p)[0])) 68 | setattr(self, dataSet+'_cls_test', test_indexs_cls) 69 | setattr(self, dataSet+'_cls_val', val_indexs_cls) 70 | setattr(self, dataSet+'_cls_train', train_indexs_cls) 71 | setattr(self, dataSet+'_u2p', graph_u2p) 72 | setattr(self, dataSet+'_u2u', graph_u2u) 73 | setattr(self, dataSet+'_simi', graph_simi) 74 | setattr(self, dataSet+'_labels', labels) 75 | 76 | def get_train(self): 77 | training_cps = defaultdict(list) 78 | g_u2u = getattr(self, self.ds+'_u2u') 79 | n = np.shape(g_u2u)[0] 80 | for i in range(n): 81 | line = g_u2u[i].toarray().squeeze() 82 | pos_pool = np.where(line != 0)[0] 83 | neg_pool = np.where(line == 0)[0] 84 | if len(pos_pool) <= 10: 85 | pos_nodes = pos_pool 86 | else: 87 | pos_nodes = np.random.choice(pos_pool, 10, replace=False) 88 | if len(neg_pool) <= 10: 89 | neg_nodes = neg_pool 90 | else: 91 | neg_nodes = np.random.choice(neg_pool, 10, replace=False) 92 | for pos_n in pos_nodes: 93 | training_cps[i].append((i, pos_n)) 94 | for neg_n in neg_nodes: 95 | training_cps[i].append((i, neg_n)) 96 | return training_cps 97 | 98 | def _split_data_cls(self, num_nodes, test_split=3, val_split=6): 99 | rand_indices = np.random.permutation(num_nodes) 100 | 101 | test_size = num_nodes // test_split 102 | val_size = num_nodes // val_split 103 | train_size = num_nodes - (test_size + val_size) 104 | 105 | test_indexs = rand_indices[:test_size] 106 | val_indexs = rand_indices[test_size:(test_size+val_size)] 107 | train_indexs = rand_indices[(test_size+val_size):] 108 | 109 | return test_indexs, val_indexs, train_indexs 110 | 111 | def _split_data_cls_limited(self, nodes, test_split = 4, val_split = 4): 112 | # used when only limited nodes are labeled 113 | np.random.shuffle(nodes) 114 | 115 | test_size = len(nodes) // test_split 116 | val_size = len(nodes) // val_split 117 | train_size = len(nodes) - (test_size + val_size) 118 | 119 | val_indexs = nodes[:test_size] 120 | test_indexs = nodes[test_size:(test_size+val_size)] 121 | train_indexs = nodes[(test_size+val_size):] 122 | 123 | return test_indexs, val_indexs, train_indexs -------------------------------------------------------------------------------- /src/main.py: -------------------------------------------------------------------------------- 1 | __author__ = 'Tong Zhao' 2 | __email__ = 'tzhao2@nd.edu' 3 | 4 | import os 5 | import sys 6 | import time 7 | import torch 8 | import random 9 | import argparse 10 | import numpy as np 11 | 12 | from src.utils import * 13 | from src.models import * 14 | from src.dataLoader import * 15 | 16 | parser = argparse.ArgumentParser(description='DeepFD') 17 | parser.add_argument('--cuda', type=int, default=-1, help='Which GPU to run on (-1 for using CPU, 9 for not specifying which GPU to use.)') 18 | parser.add_argument('--dataSet', type=str, default='weibo') 19 | parser.add_argument('--file_paths', type=str, default='file_paths.json') 20 | parser.add_argument('--config_dir', type=str, default='./configs') 21 | parser.add_argument('--logs_dir', type=str, default='./logs') 22 | parser.add_argument('--out_dir', default='./results') 23 | parser.add_argument('--name', type=str, default='debug') 24 | parser.add_argument('--cls_method', type=str, default='dbscan') 25 | 26 | parser.add_argument('--seed', type=int, default=1234) 27 | parser.add_argument('--epochs', type=int, default=10) 28 | parser.add_argument('--b_sz', type=int, default=100) 29 | parser.add_argument('--hidden_size', type=int, default=128) 30 | parser.add_argument('--emb_size', type=int, default=2) 31 | parser.add_argument('--max_vali_f1', type=float, default=0) 32 | # Hyper parameters 33 | parser.add_argument('--alpha', type=float, default=10) 34 | parser.add_argument('--beta', type=float, default=20) 35 | parser.add_argument('--gamma', type=float, default=0.001) 36 | parser.add_argument('--lr', type=float, default=0.025) 37 | args = parser.parse_args() 38 | args.argv = sys.argv 39 | 40 | # check if cuda is available, warn is available but not used. 41 | if torch.cuda.is_available(): 42 | if args.cuda == -1: 43 | print("WARNING: You have a CUDA device, so you should probably run with --cuda") 44 | else: 45 | device_id = torch.cuda.current_device() 46 | print('using device', device_id, torch.cuda.get_device_name(device_id)) 47 | args.device = torch.device(f"cuda:{args.cuda}" if args.cuda>=0 else "cpu") 48 | if args.cuda == 9: 49 | args.device = torch.device('cuda') 50 | 51 | random.seed(args.seed) 52 | np.random.seed(args.seed) 53 | torch.manual_seed(args.seed) 54 | torch.cuda.manual_seed_all(args.seed) 55 | 56 | def main(): 57 | args.name = f'{args.name}_{args.dataSet}_{args.cls_method}_{time.strftime("%m-%d_%H-%M")}' 58 | args.out_path = args.out_dir + '/' + args.name 59 | if not os.path.isdir(args.out_path): os.mkdir(args.out_path) 60 | 61 | logger = getLogger(args.name, args.out_path, args.config_dir) 62 | logger.info(f'Implementation of DeepFD, all results, embeddings and loggings will be saved in {args.out_path}/') 63 | Dl = DataLoader(args, logger) 64 | device = args.device 65 | features = torch.FloatTensor(getattr(Dl, Dl.ds+'_u2p').toarray()).to(device) 66 | 67 | deepFD = DeepFD(features, features.size(1), args.hidden_size, args.emb_size) 68 | deepFD.to(args.device) 69 | model_loss = Loss_DeepFD(features, getattr(Dl, Dl.ds+'_simi'), args.device, args.alpha, args.beta, args.gamma) 70 | if args.cls_method == 'mlp': 71 | cls_model = Classification(args.emb_size) 72 | cls_model.to(args.device) 73 | 74 | for epoch in range(args.epochs): 75 | logger.info(f'----------------------EPOCH {epoch}-----------------------') 76 | deepFD = train_model(Dl, args, logger, deepFD, model_loss, device, epoch) 77 | if args.cls_method == 'dbscan': 78 | test_dbscan(Dl, args, logger, deepFD, epoch) 79 | elif args.cls_method == 'mlp': 80 | args.max_vali_f1 = train_classification(Dl, args, logger, deepFD, cls_model, device, args.max_vali_f1, epoch) 81 | 82 | if __name__ == '__main__': 83 | main() -------------------------------------------------------------------------------- /src/models.py: -------------------------------------------------------------------------------- 1 | __author__ = 'Tong Zhao' 2 | __email__ = 'tzhao2@nd.edu' 3 | 4 | import os 5 | import sys 6 | import copy 7 | import torch 8 | import random 9 | import numpy as np 10 | from scipy.sparse import csr_matrix 11 | 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | 15 | class Classification(nn.Module): 16 | def __init__(self, emb_size): 17 | super(Classification, self).__init__() 18 | 19 | self.fc1 = nn.Linear(emb_size, 64) 20 | self.fc2 = nn.Linear(64, 2) 21 | 22 | def init_params(self): 23 | for param in self.parameters(): 24 | if len(param.size()) == 2: 25 | nn.init.xavier_uniform_(param) 26 | else: 27 | # initialize all bias as zeros 28 | nn.init.constant_(param, 0.0) 29 | 30 | def forward(self, embeds): 31 | x = F.elu_(self.fc1(embeds)) 32 | x = F.elu_(self.fc2(x)) 33 | logists = torch.log_softmax(x, 1) 34 | return logists 35 | 36 | class DeepFD(nn.Module): 37 | def __init__(self, features, feat_size, hidden_size, emb_size): 38 | super(DeepFD, self).__init__() 39 | self.features = features 40 | 41 | self.fc1 = nn.Linear(feat_size, hidden_size) 42 | self.fc2 = nn.Linear(hidden_size, emb_size) 43 | self.fc3 = nn.Linear(emb_size, hidden_size) 44 | self.fc4 = nn.Linear(hidden_size, feat_size) 45 | 46 | def init_params(self): 47 | for param in self.parameters(): 48 | if len(param.size()) == 2: 49 | nn.init.xavier_uniform_(param) 50 | else: 51 | # initialize all bias as zeros 52 | nn.init.constant_(param, 0.0) 53 | 54 | def forward(self, nodes_batch): 55 | feats = self.features[nodes_batch] 56 | x_en = F.relu_(self.fc1(feats)) 57 | embs = F.relu_(self.fc2(x_en)) 58 | x_de = F.relu_(self.fc3(embs)) 59 | recon = F.relu_(self.fc4(x_de)) 60 | return embs, recon 61 | 62 | class Loss_DeepFD(): 63 | def __init__(self, features, graph_simi, device, alpha, beta, gamma): 64 | self.features = features 65 | self.graph_simi = graph_simi 66 | self.device = device 67 | self.alpha = alpha 68 | self.beta = beta 69 | self.gamma = gamma 70 | self.node_pairs = {} 71 | self.original_nodes_batch = None 72 | self.extended_nodes_batch = None 73 | 74 | def extend_nodes(self, nodes_batch, training_cps): 75 | self.original_nodes_batch = copy.deepcopy(nodes_batch) 76 | self.node_pairs = {} 77 | self.extended_nodes_batch = set(nodes_batch) 78 | 79 | for node in nodes_batch: 80 | cps = training_cps[node] 81 | self.node_pairs[node] = cps 82 | for cp in cps: 83 | self.extended_nodes_batch.add(cp[1]) 84 | self.extended_nodes_batch = list(self.extended_nodes_batch) 85 | return self.extended_nodes_batch 86 | 87 | def get_loss(self, nodes_batch, embs_batch, recon_batch): 88 | # calculate loss_simi and loss+recon, 89 | # loss_reg is included in SGD optimizer as weight_decay 90 | loss_recon = self.get_loss_recon(nodes_batch, recon_batch) 91 | loss_simi = self.get_loss_simi(embs_batch) 92 | loss = loss_recon + self.alpha * loss_simi 93 | return loss 94 | 95 | def get_loss_simi(self, embs_batch): 96 | node2index = {n:i for i,n in enumerate(self.extended_nodes_batch)} 97 | simi_feat = [] 98 | simi_embs = [] 99 | for node, cps in self.node_pairs.items(): 100 | for i, j in cps: 101 | simi_feat.append(torch.FloatTensor([self.graph_simi[i, j]])) 102 | dis_ij = (embs_batch[node2index[i]] - embs_batch[node2index[j]]) ** 2 103 | dis_ij = torch.exp(-dis_ij.sum()) 104 | simi_embs.append(dis_ij.view(1)) 105 | simi_feat = torch.cat(simi_feat, 0).to(self.device) 106 | simi_embs = torch.cat(simi_embs, 0) 107 | L = simi_feat * ((simi_embs - simi_feat) ** 2) 108 | return L.mean() 109 | 110 | def get_loss_recon(self, nodes_batch, recon_batch): 111 | feats_batch = self.features[nodes_batch] 112 | H_batch = (feats_batch * (self.beta - 1)) + 1 113 | assert feats_batch.size() == recon_batch.size() == H_batch.size() 114 | L = ((recon_batch - feats_batch) * H_batch) ** 2 115 | return L.mean() -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | __author__ = 'Tong Zhao' 2 | __email__ = 'tzhao2@nd.edu' 3 | 4 | import os 5 | import sys 6 | import json 7 | import math 8 | import torch 9 | import pickle 10 | import random 11 | import logging 12 | import logging.config 13 | 14 | import numpy as np 15 | import torch.nn as nn 16 | 17 | from collections import Counter, defaultdict 18 | from scipy.sparse import csr_matrix 19 | from sklearn import metrics 20 | from sklearn.metrics import * 21 | from sklearn.cluster import OPTICS, DBSCAN, cluster_optics_dbscan 22 | 23 | def getLogger(name, out_path, config_dir): 24 | config_dict = json.load(open(config_dir + '/log_config.json')) 25 | 26 | config_dict['handlers']['file_handler']['filename'] = f'{out_path}/log-{name}.txt' 27 | logging.config.dictConfig(config_dict) 28 | logger = logging.getLogger(name) 29 | 30 | std_out_format = '%(asctime)s - [%(levelname)s] - %(message)s' 31 | consoleHandler = logging.StreamHandler(sys.stdout) 32 | consoleHandler.setFormatter(logging.Formatter(std_out_format)) 33 | logger.addHandler(consoleHandler) 34 | 35 | return logger 36 | 37 | def get_simi_single_iter(params): 38 | entries_batch, feats = params 39 | ii, jj = entries_batch.T 40 | simi = [] 41 | for x in range(len(ii)): 42 | simi.append(get_simi(feats[ii[x]].toarray(), feats[jj[x]].toarray())) 43 | simi = np.asarray(simi) 44 | assert np.shape(ii) == np.shape(jj) == np.shape(simi) 45 | return ii, jj, simi 46 | 47 | def get_simi(u1, u2): 48 | nz_u1 = u1.nonzero()[1] 49 | nz_u2 = u2.nonzero()[1] 50 | nz_inter = np.array(list(set(nz_u1) & set(nz_u2))) 51 | nz_union = np.array(list(set(nz_u1) | set(nz_u2))) 52 | if len(nz_inter) == 0: 53 | simi_score = 1 / (len(nz_union) + len(u1)) 54 | elif len(nz_inter) == len(nz_union): 55 | simi_score = (len(nz_union) + len(u1) - 1) / (len(nz_union) + len(u1)) 56 | else: 57 | simi_score = len(nz_inter) / len(nz_union) 58 | return float(simi_score) 59 | 60 | def cls_evaluate(Dl, logger, cls_model, features, device, out_path, max_vali_f1, epoch): 61 | test_nodes = getattr(Dl, Dl.ds+'_cls_test') 62 | val_nodes = getattr(Dl, Dl.ds+'_cls_val') 63 | labels = getattr(Dl, Dl.ds+'_labels') 64 | 65 | embs = features[val_nodes] 66 | with torch.no_grad(): 67 | logists = cls_model(embs) 68 | _, predicts = torch.max(logists, 1) 69 | labels_val = labels[val_nodes] 70 | assert len(labels_val) == len(predicts) 71 | logists = logists.cpu().numpy().T[1] 72 | logists = np.exp(logists) 73 | vali_results = _eval(labels_val, logists, predicts.cpu().numpy()) 74 | logger.info('Epoch [{}], Validation F1: {:.6f}'.format(epoch, vali_results['f1'])) 75 | if vali_results['f1'] > max_vali_f1: 76 | max_vali_f1 = vali_results['f1'] 77 | embs = features[test_nodes] 78 | with torch.no_grad(): 79 | logists = cls_model(embs) 80 | _, predicts = torch.max(logists, 1) 81 | labels_test = labels[test_nodes] 82 | assert len(labels_test) == len(predicts) 83 | logists = logists.cpu().numpy().T[1] 84 | logists = np.exp(logists) 85 | test_results = _eval(labels_test, logists, predicts.cpu().numpy()) 86 | 87 | logger.info('Epoch [{}], Current best test F1: {:.6f}'.format(epoch, test_results['f1'])) 88 | 89 | resultfile = f'{out_path}/result.txt' 90 | with open(resultfile, 'w') as fr: 91 | fr.write(f'Epoch {epoch}\n') 92 | fr.write(' \t pre \t rec \t f1 \t ap \tpr_auc\troc_auc\t h_pre\t h_rec\t h_f1 \n') 93 | fr.write('vali:\t{:.4f}\t{:.4f}\t{:.4f}\t{:.4f}\t{:.4f}\t{:.4f}\t{:.4f}\t{:.4f}\t{:.4f}\n'.format(vali_results['pre'],vali_results['rec'],vali_results['f1'],vali_results['ap'],vali_results['pr_auc'],vali_results['roc_auc'],vali_results['h_pre'],vali_results['h_rec'],vali_results['h_f1'])) 94 | fr.write('test:\t{:.4f}\t{:.4f}\t{:.4f}\t{:.4f}\t{:.4f}\t{:.4f}\t{:.4f}\t{:.4f}\t{:.4f}\n'.format(test_results['pre'],test_results['rec'],test_results['f1'],test_results['ap'],test_results['pr_auc'],test_results['roc_auc'],test_results['h_pre'],test_results['h_rec'],test_results['h_f1'])) 95 | return max_vali_f1 96 | 97 | def _eval(labels, logists, predicts): 98 | if np.sum(labels<0) > 0: 99 | labeled_nodes = np.where(labels>=0)[0] 100 | labels = labels[labeled_nodes] 101 | logists = logists[labeled_nodes] 102 | predicts = predicts[labeled_nodes] 103 | 104 | pre, rec, f1, _ = precision_recall_fscore_support(labels, predicts, average='binary') 105 | fpr, tpr, _ = roc_curve(labels, logists, pos_label=1) 106 | roc_auc = metrics.auc(fpr, tpr) 107 | precisions, recalls, thresholds = precision_recall_curve(labels, logists, pos_label=1) 108 | pr_auc = metrics.auc(recalls, precisions) 109 | ap = average_precision_score(labels, logists) 110 | f1s = np.nan_to_num(2*precisions*recalls/(precisions+recalls)) 111 | best_comb = np.argmax(f1s) 112 | best_f1 = f1s[best_comb] 113 | best_pre = precisions[best_comb] 114 | best_rec = recalls[best_comb] 115 | best_threshold = thresholds[best_comb] 116 | results = { 117 | 'h_pre': pre, 118 | 'h_rec': rec, 119 | 'h_f1': f1, 120 | 'roc_auc': roc_auc, 121 | 'pr_auc': pr_auc, 122 | 'ap': ap, 123 | 'pre': best_pre, 124 | 'rec': best_rec, 125 | 'f1': best_f1, 126 | } 127 | return results 128 | 129 | def get_embeddings(deepFD, Dl): 130 | nodes = getattr(Dl, Dl.ds+'_train') 131 | b_sz = 500 132 | batches = math.ceil(len(nodes) / b_sz) 133 | embs = [] 134 | for index in range(batches): 135 | nodes_batch = nodes[index*b_sz:(index+1)*b_sz] 136 | with torch.no_grad(): 137 | embs_batch, _ = deepFD(nodes_batch) 138 | # print(embs_batch.size(), np.shape(nodes_batch)) 139 | assert len(embs_batch) == len(nodes_batch) 140 | embs.append(embs_batch) 141 | assert len(embs) == batches 142 | embs = torch.cat(embs, 0) 143 | assert len(embs) == len(nodes) 144 | return embs.detach() 145 | 146 | def save_embeddings(embs, out_path, outer_epoch): 147 | pickle.dump(embs, open(f'{out_path}/embs_ep{outer_epoch}.pkl', 'wb')) 148 | 149 | def train_classification(Dl, args, logger, deepFD, cls_model, device, max_vali_f1, outer_epoch, epochs=500): 150 | logger.info('Testing with MLP') 151 | cls_model.zero_grad() 152 | c_optimizer = torch.optim.SGD(cls_model.parameters(), lr=0.5) 153 | c_optimizer.zero_grad() 154 | b_sz = 100 155 | train_nodes = getattr(Dl, Dl.ds+'_cls_train') 156 | labels = getattr(Dl, Dl.ds+'_labels') 157 | features = get_embeddings(deepFD, Dl) 158 | save_embeddings(features.cpu().numpy(), args.out_path, outer_epoch) 159 | 160 | for epoch in range(epochs): 161 | # train_nodes = shuffle(train_nodes) 162 | np.random.shuffle(train_nodes) 163 | batches = math.ceil(len(train_nodes) / b_sz) 164 | 165 | for index in range(batches): 166 | nodes_batch = train_nodes[index*b_sz:(index+1)*b_sz] 167 | labels_batch = labels[nodes_batch] 168 | embs_batch = features[nodes_batch] 169 | logists = cls_model(embs_batch) 170 | loss = -torch.sum(logists[range(logists.size(0)), labels_batch], 0) 171 | loss /= len(nodes_batch) 172 | loss.backward() 173 | 174 | nn.utils.clip_grad_norm_(cls_model.parameters(), 5) 175 | c_optimizer.step() 176 | c_optimizer.zero_grad() 177 | cls_model.zero_grad() 178 | 179 | max_vali_f1 = cls_evaluate(Dl, logger, cls_model, features, device, args.out_path, max_vali_f1, 1000*outer_epoch+epoch) 180 | return max_vali_f1 181 | 182 | def test_dbscan(Dl, args, logger, deepFD, epoch): 183 | logger.info('Testing with DBSCAN...') 184 | labels = getattr(Dl, Dl.ds+'_labels') 185 | features = get_embeddings(deepFD, Dl).cpu().numpy() 186 | save_embeddings(features, args.out_path, epoch) 187 | 188 | resultfile = f'{args.out_path}/results.txt' 189 | fa = open(resultfile, 'a') 190 | fa.write(f'====== Epoch {epoch} ======\n') 191 | # optics 192 | optics = OPTICS() 193 | optics.fit(features) 194 | logists = optics.labels_ 195 | logists[logists >= 0] = 0 196 | logists[logists < 0] = 1 197 | logger.info('evaluating with optics') 198 | results = _eval(labels, logists, logists) 199 | logger.info(' pre \t rec \t f1 \t ap \tpr_auc\troc_auc\t h_pre\t h_rec\t h_f1') 200 | logger.info('{:.4f}\t{:.4f}\t{:.4f}\t{:.4f}\t{:.4f}\t{:.4f}\t{:.4f}\t{:.4f}\t{:.4f}'.format(results['pre'],results['rec'],results['f1'],results['ap'],results['pr_auc'],results['roc_auc'],results['h_pre'],results['h_rec'],results['h_f1'])) 201 | fa.write('OPTICS\n') 202 | fa.write(' pre \t rec \t f1 \t ap \tpr_auc\troc_auc\t h_pre\t h_rec\t h_f1 \n') 203 | fa.write('{:.4f}\t{:.4f}\t{:.4f}\t{:.4f}\t{:.4f}\t{:.4f}\t{:.4f}\t{:.4f}\t{:.4f}\n'.format(results['pre'],results['rec'],results['f1'],results['ap'],results['pr_auc'],results['roc_auc'],results['h_pre'],results['h_rec'],results['h_f1'])) 204 | 205 | # dbscan with different epsilon 206 | epsilons = [0.5, 2, 5, 10] 207 | for ep in epsilons: 208 | logists = cluster_optics_dbscan(reachability=optics.reachability_, core_distances=optics.core_distances_, ordering=optics.ordering_, eps=ep) 209 | logists[logists >= 0] = 0 210 | logists[logists < 0] = 1 211 | logger.info(f'evaluating with dbscan at {ep}') 212 | results = _eval(labels, logists, logists) 213 | logger.info(' pre \t rec \t f1 \t ap \tpr_auc\troc_auc\t h_pre\t h_rec\t h_f1') 214 | logger.info('{:.4f}\t{:.4f}\t{:.4f}\t{:.4f}\t{:.4f}\t{:.4f}\t{:.4f}\t{:.4f}\t{:.4f}'.format(results['pre'],results['rec'],results['f1'],results['ap'],results['pr_auc'],results['roc_auc'],results['h_pre'],results['h_rec'],results['h_f1'])) 215 | fa.write(f'DBSCAN at {ep}\n') 216 | fa.write(' pre \t rec \t f1 \t ap \tpr_auc\troc_auc\t h_pre\t h_rec\t h_f1 \n') 217 | fa.write('{:.4f}\t{:.4f}\t{:.4f}\t{:.4f}\t{:.4f}\t{:.4f}\t{:.4f}\t{:.4f}\t{:.4f}\n'.format(results['pre'],results['rec'],results['f1'],results['ap'],results['pr_auc'],results['roc_auc'],results['h_pre'],results['h_rec'],results['h_f1'])) 218 | fa.close() 219 | 220 | def train_model(Dl, args, logger, deepFD, model_loss, device, epoch): 221 | train_nodes = getattr(Dl, Dl.ds+'_train') 222 | np.random.shuffle(train_nodes) 223 | 224 | params = [] 225 | for param in deepFD.parameters(): 226 | if param.requires_grad: 227 | params.append(param) 228 | optimizer = torch.optim.SGD(params, lr=args.lr, weight_decay=args.gamma) 229 | optimizer.zero_grad() 230 | deepFD.zero_grad() 231 | 232 | batches = math.ceil(len(train_nodes) / args.b_sz) 233 | visited_nodes = set() 234 | training_cps = Dl.get_train() 235 | logger.info('sampled pos and neg nodes for each node in this epoch.') 236 | for index in range(batches): 237 | nodes_batch = train_nodes[index*args.b_sz:(index+1)*args.b_sz] 238 | nodes_batch = np.asarray(model_loss.extend_nodes(nodes_batch, training_cps)) 239 | visited_nodes |= set(nodes_batch) 240 | 241 | embs_batch, recon_batch = deepFD(nodes_batch) 242 | loss = model_loss.get_loss(nodes_batch, embs_batch, recon_batch) 243 | 244 | logger.info(f'EP[{epoch}], Batch [{index+1}/{batches}], Loss: {loss.item():.4f}, Dealed Nodes [{len(visited_nodes)}/{len(train_nodes)}]') 245 | loss.backward() 246 | 247 | nn.utils.clip_grad_norm_(deepFD.parameters(), 5) 248 | optimizer.step() 249 | 250 | optimizer.zero_grad() 251 | deepFD.zero_grad() 252 | 253 | # stop when all nodes are trained 254 | if len(visited_nodes) == len(train_nodes): 255 | break 256 | 257 | return deepFD --------------------------------------------------------------------------------