├── gcn ├── __init__.py ├── models.py └── layers.py ├── utils ├── __init__.py ├── logging.py ├── early_stopping.py ├── sensing.py └── load.py ├── README.md ├── main_simple.py ├── attack_stats_all.py ├── main.py ├── mlp_trainer.py ├── attacker.py ├── gcn_trainer.py └── worker.py /gcn/__init__.py: -------------------------------------------------------------------------------- 1 | from .models import * -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .load import * 2 | from .logging import * 3 | from .sensing import * 4 | from .early_stopping import * -------------------------------------------------------------------------------- /utils/logging.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import sys 4 | 5 | def init_logger(log_path, log_file, print_log=True, level=logging.INFO): 6 | if not os.path.isdir(log_path): 7 | os.makedirs(log_path) 8 | 9 | fileHandler = logging.FileHandler("{0}/{1}.log".format(log_path, log_file)) 10 | 11 | handlers = [fileHandler] 12 | 13 | if print_log: 14 | consoleHandler = logging.StreamHandler(sys.stdout) 15 | handlers.append(consoleHandler) 16 | 17 | logging.basicConfig( 18 | level=level, 19 | format= 20 | "%(asctime)s [%(process)d] [%(threadName)-12.12s] [%(levelname)-5.5s] %(message)s", 21 | handlers=handlers) 22 | -------------------------------------------------------------------------------- /utils/early_stopping.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import numpy as np 3 | import torch 4 | 5 | class EarlyStopping: 6 | """Early stops the training if validation loss doesn't improve after a given patience.""" 7 | def __init__(self, patience=50, verbose=False, delta=0): 8 | """ 9 | Args: 10 | patience (int): How long to wait after last time validation loss improved. 11 | Default: 50 12 | verbose (bool): If True, prints a message for each validation loss improvement. 13 | Default: False 14 | delta (float): Minimum change in the monitored quantity to qualify as an improvement. 15 | Default: 0 16 | """ 17 | self.patience = patience 18 | self.verbose = verbose 19 | self.counter = 0 20 | self.best_score = None 21 | self.early_stop = False 22 | self.val_loss_min = np.Inf 23 | self.delta = delta 24 | self.best_model = None 25 | 26 | def __call__(self, val_loss, model): 27 | 28 | score = -val_loss 29 | 30 | if self.best_score is None: 31 | self.best_score = score 32 | self.best_model = copy.deepcopy(model) 33 | # self.save_checkpoint(val_loss, model) 34 | elif score < self.best_score + self.delta: 35 | self.counter += 1 36 | if self.verbose: 37 | print(f'EarlyStopping counter: {self.counter} out of {self.patience}') 38 | if self.counter >= self.patience: 39 | self.early_stop = True 40 | else: 41 | self.best_score = score 42 | self.best_model = copy.deepcopy(model) 43 | # self.save_checkpoint(val_loss, model) 44 | self.counter = 0 45 | 46 | def save_checkpoint(self, val_loss, model): 47 | '''Saves model when validation loss decrease.''' 48 | if self.verbose: 49 | print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ...') 50 | torch.save(model.state_dict(), 'checkpoint.pt') 51 | self.val_loss_min = val_loss 52 | -------------------------------------------------------------------------------- /utils/sensing.py: -------------------------------------------------------------------------------- 1 | from math import sqrt 2 | import numpy as np 3 | import scipy.sparse as sp 4 | 5 | def haar(N, mul1, mul2): 6 | print(N) 7 | if N == 2: 8 | return sp.csr_matrix(np.array([[1, 1], [1, -1]])) 9 | 10 | res = haar(N // 2, mul1, mul2) 11 | upper = sp.kron(res, mul1) 12 | lower = sqrt(N // 2) * sp.kron(sp.eye(N // 2), mul2) 13 | 14 | return sp.vstack([upper, lower]) 15 | 16 | def Haar(n): 17 | mul1 = sp.csr_matrix(np.array([1, 1])) 18 | mul2 = sp.csr_matrix(np.array([1, -1])) 19 | return 1 / sqrt(n) * haar(n, mul1, mul2) 20 | 21 | def H_k(x, k): 22 | ind = np.argpartition(x, -k)[-k:] 23 | ind_other = list(set(range(len(x))) - set(ind)) 24 | x[ind_other] = 0 25 | return x 26 | 27 | def support(x): 28 | return np.where(x>1e-6)[0] 29 | 30 | def threshold(x_1, x_0, phi, c=0.01): 31 | return (1-c) * np.linalg.norm(x_1-x_0, 2) ** 2 \ 32 | / np.linalg.norm(np.dot(phi, x_1-x_0), 2) ** 2 33 | 34 | def cost(x_0, y_star, phi): 35 | return np.linalg.norm(y_star-np.dot(phi, x_0)) 36 | 37 | def compressive_sensing(args, adj): 38 | # initialize and pad the input (symmetric) 39 | n = adj.shape[0] ** 2 40 | n_pad = int(2 ** np.ceil(np.log2(n))) 41 | 42 | D = np.matrix.flatten(adj.A) 43 | n_pad_all = n_pad - n 44 | n_pad_left = n_pad_all // 2 45 | n_pad_right = n_pad_all - n_pad_left 46 | D = np.insert(D, 0, np.zeros(n_pad_left)) 47 | D = np.insert(D, -1, np.zeros(n_pad_right)) 48 | 49 | S = int(args.S) 50 | k = int(S * np.log(n_pad / S)) 51 | 52 | # generate the sensing matrix (~ symmetric Bernoulli distr) 53 | phi = np.random.binomial(1, 1/2, k*n_pad) 54 | phi[np.where(phi==0)[0]] = -1 55 | # phi = phi.astype(np.float32) * coeff / sqrt(k) 56 | # coeff = 1 57 | # phi *= 58 | phi = phi.reshape(k, n_pad) 59 | print('generating Phi done!') 60 | 61 | # generate the representation matrix 62 | haar_N = 1 / sqrt(n_pad) * Haar(n_pad) 63 | print('generating Haar matrix done!') 64 | 65 | # add noise in the latent space 66 | print(phi.shape, D.shape) 67 | y = np.dot(phi, D) 68 | print(y.shape) 69 | noise = np.random.laplace(0, 2*k/args.epsilon, size=k) 70 | y_star = y + noise 71 | print('adding noise done!') 72 | 73 | # recover 74 | x_0 = np.zeros_like(D, dtype=np.float32) 75 | ini = np.dot(phi.T, y) 76 | support_0 = np.argpartition(ini, -S)[-S:] 77 | cost_0 = cost(x_0, y_star, phi) 78 | 79 | for i in range(10000): 80 | g = np.dot(phi.T, y_star - np.dot(phi, x_0)) 81 | 82 | g_tau = g[support_0] 83 | phi_tau = phi[:, support_0] 84 | mu = np.dot(g_tau.T, g_tau) / np.dot(np.dot(np.dot(g_tau.T, phi_tau.T), phi_tau), g_tau) 85 | 86 | x_1 = H_k(x_0 + mu * g, S) 87 | 88 | support_1 = support(x_1) 89 | 90 | if (support_1 == support_0).all(): 91 | x_0 = x_1 92 | else: 93 | lim = threshold(x_1, x_0, phi, args.c) 94 | if mu <= lim: 95 | x_0 = x_1 96 | support_0 = support_1 97 | else: 98 | while mu > lim: 99 | mu /= 2 # k * (1-c) 100 | x_1 = H_k(x_0 + mu * g, S) 101 | lim = threshold(x_1, x_0, phi, args.c) 102 | x_0 = x_1 103 | support_0 = support(x_1) 104 | 105 | cost_1 = cost(x_1, y_star, phi) 106 | if cost_0 - cost_1 < 1e-6: 107 | break 108 | cost_0 = cost_1 109 | print(f'{i}: {cost_0 : .4f}') 110 | 111 | # reconstruct 112 | ret = haar_N.dot(x_0) 113 | A = sp.csr_matrix(ret[n_pad_left : n_pad - n_pad_right].reshape(1000, 1000)) 114 | A = sp.triu(A, k=1) 115 | A += A.T 116 | return A -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # LinkTeller 2 | 3 | This repository contains the official implementation of the S&P 22 paper 4 | 5 | "[LinkTeller: Recovering Private Edges from Graph Neural Networks via Influence Analysis](https://arxiv.org/abs/2108.06504)" 6 | 7 | Fan Wu, Yunhui Long, Ce Zhang, Bo Li 8 | 9 | ## Download and Installation 10 | 11 | 1. The code requires Python >=3.6 and is built on PyTorch. Note that PyTorch may need to be [installed manually](https://pytorch.org/get-started/locally/) depending on different platforms and CUDA drivers. 12 | 2. The graph datasets can be obtained from the [Google Drive link](https://drive.google.com/file/d/1_TV_XNy0ljy_KWli30k_n4Kcj3pgwzPx/view?usp=sharing). Please download the file ``data.zip`` and uncompress it under the root path. 13 | 14 | ## Usage 15 | 16 | We will use the twitch datasets as an example in the following. 17 | 18 | ### Evaluation of LinkTeller 19 | 20 | In this part, we introduce how to use LinkTeller attack to reveal the private edges given access to a blackbox API. 21 | 22 | 1. Train a vanilla GCN 23 | 24 | ```bash 25 | python main.py --mode vanilla-clean --dataset twitch/ES/RU --hidden 256 \ 26 | --display --num-epochs 200 --dropout 0.5 --lr 0.01 --norm FirstOrderGCN 27 | ``` 28 | 29 | The concrete explanations for each **option**, the **choices** of each option, and the **optimal hyper-parameters** or configurations for each dataset can all be referred to in the Appendix F of the paper. 30 | 31 | Additionally, the trained model can be tested via the following script: 32 | 33 | ```bash 34 | python main.py --mode vanilla-clean --dataset twitch/ES/RU --hidden 256 \ 35 | --display --num-epochs 200 --dropout 0.5 --lr 0.01 --norm FirstOrderGCN \ 36 | --test --model-path [model_path] 37 | ``` 38 | 39 | where the ``[model_path]`` can be obtained from the training log in the previous run. 40 | 41 | 2. Attack the trained GCN model 42 | 43 | ```bash 44 | python main.py --mode vanilla-clean --dataset twitch/ES/RU --hidden 256 \ 45 | --display --num-epochs 200 --dropout 0.5 --lr 0.01 --norm FirstOrderGCN \ 46 | --test --model-path [model_path] \ 47 | --attack --approx --sample-type unbalanced --n-test 500 --influence 0.0001 \ 48 | --sample-seed 42 --attack-mode efficient 49 | ``` 50 | 51 | Basically, given a trained model, we configure the following options of the attack 52 | 53 | * **sample-type**: { 'unbalanced', 'unbalanced-lo', 'unbalanced-hi' }, where 'unbalanced' means random node sampling among all nodes in the inference graph, 'unbalanced-lo' means sampling from low degree nodes, and 'unbalanced-hi' means sampling from high degree nodes. 54 | * **n-test**: the size of the node set of interest ($V^{(C)}$ In the paper) 55 | * **sample-seed:** the preset random seed for node sampling 56 | * **attack-mode**: { 'efficient', 'baseline', 'baseline-feat' }, where 'efficient' represents our LinkTeller attack, 'baseline' represents the baseline method LSA2-post, and 'baseline-feat' represents the baseline method LSA2-attr. 57 | 58 | For ease of experimentation, training and attack can be merged into one stage. To do so, simply remove the options ``--test`` and ``--model-path`` in the attack script. 59 | 60 | The attack results will be saved to the local path as indicated in the log. 61 | 62 | ### Evaluation of Differentially Private GCNs 63 | 64 | In this part, we describe the evaluation of Differentially Private (DP) GCNs from two perspectives: utility and privacy. 65 | 66 | 1. Train a DP GCN model 67 | 68 | ```bash 69 | python main.py --mode vanilla --dataset twitch/ES/RU --hidden 256 \ 70 | --num-epochs 200 --dropout 0.5 --lr 0.01 --norm FirstOrderGCN \ 71 | --perturb-type continuous --eps 5 --noise-seed 42 \ 72 | ``` 73 | 74 | We briefly introduce the privacy parameters: 75 | 76 | * **perturb-type**: { 'continuous', 'discrete' }, where 'continuous' refers to the method Lapgraph and 'discrete' refers to the method EdgeRand 77 | * **eps**: the privacy budget in differential privacy 78 | * **noise-seed**: the random seed for noise generation 79 | 80 | 2. Measure the **utility** of the trained DP GCN model (i.e., test the trained model) 81 | 82 | ```bash 83 | python main.py --mode vanilla --dataset twitch/ES/RU --hidden 256 \ 84 | --num-epochs 200 --dropout 0.5 --lr 0.01 --norm FirstOrderGCN \ 85 | --perturb-type continuous --eps 5 --noise-seed 42 \ 86 | --test --model-path [model_path] 87 | ``` 88 | 89 | 3. Measure the **privacy** of the trained DP GCN model via LinkTeller 90 | 91 | ```bash 92 | python main.py --mode vanilla --dataset twitch/ES/RU --hidden 256 \ 93 | --num-epochs 200 --dropout 0.5 --lr 0.01 --norm FirstOrderGCN \ 94 | --perturb-type continuous --eps 5 --noise-seed 42 \ 95 | --test --model-path [model_path] \ 96 | --attack --approx --sample-type unbalanced --n-test 500 --influence 0.0001 \ 97 | --sample-seed 42 --attack-mode efficient 98 | ``` 99 | 100 | The options for attack are the same as previously introduced in the evaluation of LinkTeller. 101 | 102 | ## Citation 103 | 104 | To be updated -------------------------------------------------------------------------------- /main_simple.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from __future__ import print_function 3 | 4 | import argparse 5 | import datetime 6 | import logging 7 | import numpy as np 8 | 9 | import torch 10 | 11 | from worker import Worker 12 | from mlp_trainer import MLPTrainer 13 | from utils import init_logger 14 | 15 | 16 | def get_arguments(): 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument('--seed', type=int, default=42, help='Random seed.') 19 | parser.add_argument('--num-epochs', type=int, default=500, 20 | help='Number of epochs to train.') 21 | parser.add_argument('--lr', type=float, default=0.2, 22 | help='Initial learning rate.') 23 | parser.add_argument('--weight_decay', type=float, default=5e-6, 24 | help='Weight decay (L2 loss on parameters).') 25 | parser.add_argument('--hidden', type=int, default=16, 26 | help='Number of hidden units.') 27 | parser.add_argument('--depth', type=int, default=1, 28 | help='Depth.') 29 | parser.add_argument('--dropout', type=float, default=0.5, 30 | help='Dropout rate (1 - keep probability).') 31 | parser.add_argument('--dataset', type=str, default='cora', 32 | help='Dataset: cora; cora.sample') 33 | parser.add_argument('--model-path', type=str, default='') 34 | parser.add_argument('--mode', type=str, default='sgc', 35 | help='[ vanilla | vanilla-clean | clusteradj | clusteradj-clean ] ') 36 | parser.add_argument('--scale', type=str, default='small', 37 | help='[ large | small ]') 38 | parser.add_argument('--normalization', type=str, default='AugNormAdj', 39 | choices=['AugNormAdj']) 40 | 41 | parser.add_argument('--l2-norm-clip', type=float, default=1., help='upper bound on the l2 norm of gradient updates (default: 0.1)') 42 | parser.add_argument('--noise-multiplier', type=float, default=1.1, help='ratio between clipping bound and std of noise applied to gradients (default: 1.1)') 43 | parser.add_argument('--microbatch-size', type=int, default=1, help='input microbatch size for training (default: 1)') 44 | parser.add_argument('--minibatch-size', type=int, default=256, help='input minibatch size for training (default: 256)') 45 | parser.add_argument('--l2-penalty', type=float, default=0.001, help='l2 penalty on model weights (default: 0.001)') 46 | 47 | parser.add_argument('--epsilon', type=float, default=0.1) 48 | parser.add_argument('--delta', type=float, default=1e-5) 49 | parser.add_argument('--train-ratio', type=float, default=0.5) 50 | parser.add_argument('--patience', type=int, default=50) 51 | parser.add_argument('--batch-size', type=int, default=256) 52 | parser.add_argument('--k', type=float, default=1) 53 | 54 | parser.add_argument('--batch', action='store_true', default=False) 55 | parser.add_argument('--test', action='store_true', default=False) 56 | parser.add_argument('--display', action='store_true', default=False) 57 | parser.add_argument('--trainable', action='store_true', default=False) 58 | parser.add_argument('--early', action='store_true', default=False) 59 | 60 | parser.add_argument('--noise-seed', type=int, default=42) 61 | parser.add_argument('--cluster-seed', type=int, default=42) 62 | parser.add_argument('--noise-type', type=str, default='gaussian') 63 | parser.add_argument('--trainer-type', type=str, default='mlp', 64 | choices=['mlp', 'dp_mlp', 'lr', 'dp_lr', 'multi_mlp', 'dp_multi_mlp']) 65 | parser.add_argument('--feature-size', type=int, default=-1) 66 | parser.add_argument('--iterations', type=int, default=14000, help='number of iterations to train (default: 14000)') 67 | parser.add_argument('--coeff', type=float, default=1) 68 | parser.add_argument('--degree', type=int, default=2) 69 | parser.set_defaults(assign_seed=42) 70 | 71 | return parser.parse_args() 72 | 73 | 74 | def main(): 75 | args = get_arguments() 76 | 77 | np.random.seed(args.seed) 78 | torch.manual_seed(args.seed) 79 | if torch.cuda.is_available(): 80 | torch.cuda.manual_seed(args.seed) 81 | 82 | if args.test: 83 | worker = Worker(args, dataset=args.dataset, mode=args.mode) 84 | 85 | trainer = MLPTrainer(args, worker=worker, train=False) 86 | trainer.init_model(model_path=args.model_path) 87 | trainer.test() 88 | 89 | else: 90 | cur_time = datetime.datetime.now().strftime("%m-%d-%H:%M:%S.%f") 91 | 92 | subdir = 'mode-{}_hidden-{}_lr-{}_decay-{}_dropout-{}_bs-{}_{}'.format(args.mode, args.hidden, args.lr, args.weight_decay, args.dropout, args.batch_size, cur_time) 93 | 94 | print('subdir = {}'.format(subdir)) 95 | init_logger('./logs_{}'.format(args.dataset), subdir, print_log=False) 96 | logging.info(str(args)) 97 | 98 | worker = Worker(args, dataset=args.dataset, mode=args.mode) 99 | 100 | trainer = MLPTrainer(args, subdir=subdir, worker=worker, train=True) 101 | trainer.init_model() 102 | trainer.train() 103 | trainer.test() 104 | 105 | 106 | if __name__ == "__main__": 107 | main() 108 | -------------------------------------------------------------------------------- /gcn/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn.parameter import Parameter 5 | 6 | from .layers import GraphConvolution, ProjectionGraphConvolution, DegreeMLP, BasicMLP 7 | 8 | class GCN(nn.Module): 9 | def __init__(self, nfeat, nhid, nclass, dropout): 10 | super(GCN, self).__init__() 11 | 12 | self.gc1 = GraphConvolution(nfeat, nhid) 13 | self.gc2 = GraphConvolution(nhid, nclass) 14 | # self.gc = GraphConvolution(nfeat, nclass) 15 | self.dropout = dropout 16 | 17 | print('init GCN model done!') 18 | 19 | def forward(self, x, adj): 20 | x = F.relu(self.gc1(x, adj)) 21 | x = F.dropout(x, self.dropout, training=self.training) 22 | x = self.gc2(x, adj) 23 | # x = self.gc(x, adj) 24 | return x 25 | # return F.log_softmax(x, dim=1) 26 | 27 | 28 | class GCN3(nn.Module): 29 | def __init__(self, nfeat, nhid1, nhid2, nclass, dropout): 30 | super(GCN3, self).__init__() 31 | 32 | self.gc1 = GraphConvolution(nfeat, nhid1) 33 | self.gc2 = GraphConvolution(nhid1, nhid2) 34 | self.gc3 = GraphConvolution(nhid2, nclass) 35 | # self.gc = GraphConvolution(nfeat, nclass) 36 | self.dropout = dropout 37 | 38 | print('init GCN model done!') 39 | 40 | def forward(self, x, adj): 41 | x = F.relu(self.gc1(x, adj)) 42 | x = F.dropout(x, self.dropout, training=self.training) 43 | x = F.relu(self.gc2(x, adj)) 44 | x = F.dropout(x, self.dropout, training=self.training) 45 | x = self.gc3(x, adj) 46 | return x 47 | 48 | 49 | class DeGCN(nn.Module): 50 | def __init__(self, nfeat, nhid, nclass, dropout): 51 | super(DeGCN, self).__init__() 52 | 53 | self.gc1_1 = GraphConvolution(nfeat, nhid) 54 | self.gc1_2 = GraphConvolution(nfeat, nhid) 55 | self.gc1_3 = GraphConvolution(nfeat, nhid) 56 | self.gc1 = [self.gc1_1, self.gc1_2, self.gc1_3] 57 | 58 | self.gc2 = GraphConvolution(nhid, nclass) 59 | self.dropout = dropout 60 | 61 | print('init DeGCN model done!') 62 | 63 | def forward(self, x, adj, sub_adj): 64 | x_d = F.dropout(x, self.dropout, training=self.training) 65 | x = self.gc1[0](x_d, sub_adj[0]) 66 | for i in range(1, len(self.gc1)): 67 | x += self.gc1[i](x_d, sub_adj[i]) 68 | x = F.relu(x) 69 | x = F.dropout(x, self.dropout, training=self.training) 70 | x = self.gc2(x, adj) 71 | return F.log_softmax(x, dim=1) 72 | 73 | 74 | class ProjectionGCN(nn.Module): 75 | def __init__(self, nfeat, nhid, nclass, dropout, projection, size, args): 76 | super(ProjectionGCN, self).__init__() 77 | 78 | self.gc1 = ProjectionGraphConvolution(nfeat, nhid, projection, size, trainable=args.trainable) 79 | self.gc2 = ProjectionGraphConvolution(nhid, nclass, projection, size, trainable=args.trainable) 80 | self.dropout = dropout 81 | 82 | print(f'init ProjectionGCN model done!') 83 | 84 | def forward(self, x, adj): 85 | x = F.relu(self.gc1(x, adj)) 86 | x = F.dropout(x, self.dropout, training=self.training) 87 | x = self.gc2(x, adj) 88 | return F.log_softmax(x, dim=1) 89 | 90 | 91 | class SingleHiddenLayerMLP(nn.Module): 92 | def __init__(self, nfeat, nhid, nclass, dropout): 93 | super(SingleHiddenLayerMLP, self).__init__() 94 | 95 | self.W1 = nn.Linear(nfeat, nhid) 96 | self.W2 = nn.Linear(nhid, nclass) 97 | self.dropout = dropout 98 | 99 | self.W = nn.Linear(nfeat, nclass) 100 | 101 | print(f'init SingleHiddeenLayerMLP model done!') 102 | 103 | def forward(self, x): 104 | x = F.relu(self.W1(x)) 105 | x = F.dropout(x, self.dropout, training=self.training) 106 | x = self.W2(x) 107 | return x 108 | 109 | 110 | class OneLayerMLP(nn.Module): 111 | def __init__(self, nfeat, nclass): 112 | super(OneLayerMLP, self).__init__() 113 | 114 | self.W = nn.Linear(nfeat, nclass) 115 | 116 | print(f'init OneLayerMLP model done!') 117 | 118 | def forward(self, x): 119 | x = self.W(x) 120 | return x 121 | 122 | 123 | class MLP(nn.Module): 124 | def __init__(self, nfeat, nhid, nclass, dropout, size, args): 125 | super(MLP, self).__init__() 126 | 127 | layer = { 128 | 'degree_mlp': DegreeMLP, 129 | 'basic_mlp': BasicMLP, 130 | }.get(args.mode) 131 | 132 | self.gc1 = layer(nfeat, nhid, size, args=args) 133 | self.gc2 = layer(nhid, nclass, size, args=args) 134 | self.dropout = dropout 135 | 136 | print(f'init MLP model of layer_type = {layer.__name__} done!') 137 | 138 | def forward(self, x, adj): 139 | x = F.relu(self.gc1(x, adj)) 140 | x = F.dropout(x, self.dropout, training=self.training) 141 | x = self.gc2(x, adj) 142 | return F.log_softmax(x, dim=1) 143 | 144 | 145 | class SGC(nn.Module): 146 | """ 147 | A Simple PyTorch Implementation of Logistic Regression. 148 | Assuming the features have been preprocessed with k-step graph propagation. 149 | """ 150 | def __init__(self, nfeat, nclass): 151 | super(SGC, self).__init__() 152 | 153 | self.W = nn.Linear(nfeat, nclass) 154 | 155 | def forward(self, x): 156 | return self.W(x) 157 | 158 | 159 | class BinaryLR(nn.Module): 160 | """ 161 | A Simple PyTorch Implementation of Logistic Regression. 162 | Assuming the features have been preprocessed with k-step graph propagation. 163 | """ 164 | def __init__(self, nfeat): 165 | super(BinaryLR, self).__init__() 166 | 167 | self.w = Parameter(torch.FloatTensor(nfeat, 1)) 168 | self.b = Parameter(torch.FloatTensor(1)) 169 | self.reset_parameters() 170 | 171 | def reset_parameters(self): 172 | stdv = 1 173 | self.w.data.uniform_(-stdv, stdv) 174 | self.b.data.uniform_(-stdv, stdv) 175 | 176 | def forward(self, x): 177 | return torch.mm(x, self.w) + self.b 178 | -------------------------------------------------------------------------------- /attack_stats_all.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from math import ceil 3 | import numpy as np 4 | import argparse 5 | import os 6 | import os.path as osp 7 | from sklearn import metrics 8 | import pandas as pd 9 | from tqdm import tqdm 10 | 11 | sample_seed = [ 42, 2, 82 ] 12 | 13 | sample_types = ['unbalanced-lo', 'unbalanced', 'unbalanced-hi'] 14 | 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument('--dataset', type=str, default='') 17 | parser.add_argument('--method-type', type=str, default='baseline', 18 | choices=['efficient', 'baseline', 'baseline-feat']) 19 | parser.add_argument('--perturb-type', type=str, default='discrete', 20 | choices=['discrete', 'continuous']) 21 | parser.add_argument('--n-attack', type=int, default=500) 22 | parser.add_argument('--eps', type=float, default=6) 23 | args = parser.parse_args() 24 | 25 | n_edges = np.zeros((3,3), dtype=np.int32) 26 | 27 | dataset = args.dataset 28 | 29 | 30 | for j, sample_type in enumerate(tqdm(sample_types)): 31 | for k, seed in enumerate(sample_seed): 32 | # data = torch.load(f'eval_{args.dataset}/attack-type-{args.method_type}_sample-type-{sample_type}_n-attack-{args.n_attack}_seed-{seed}.pt') # pytorch-GAT 33 | # data = torch.load(f'eval_{args.dataset}/{args.method_type}_{sample_type}_{args.n_attack}_{seed}.pt') # GCN 34 | data = torch.load(f'eval_{args.dataset}/{args.method_type}_{sample_type}_{args.perturb_type}_{args.n_attack}_{seed}_eps-{args.eps}_seed-42.pt') # DP-GCN 35 | # data = torch.load(f'eval_{args.dataset}/type-{sample_type}_n-attack-{args.n_attack}_seed-42.pt') 36 | 37 | result = data['result'] 38 | y = np.asarray(result['y']) 39 | 40 | n_edges[j][k] = y.sum() 41 | 42 | print('pre-collecting stats done!') 43 | 44 | n_nodes = 500 45 | n_total = (n_nodes-1) * n_nodes // 2 46 | 47 | def get_ratio(ne): 48 | return ne / n_total 49 | 50 | baseline = np.zeros((3,3)) 51 | for j in range(3): 52 | for k in range(3): 53 | baseline[j][k] = get_ratio(n_edges[j][k]) 54 | 55 | 56 | def get_closest(value): 57 | cnt = 0 58 | while value < 1: 59 | value *= 10 60 | cnt += 1 61 | base = int(value) 62 | value -= int(value) 63 | if value >= 0.5: 64 | base += 1 65 | return base, cnt 66 | 67 | approx = np.zeros((3,3), dtype=np.int32) 68 | expn = np.zeros((3,3), dtype=np.int32) 69 | 70 | for j in range(3): 71 | for k in range(3): 72 | a, b = get_closest(baseline[j][k]) 73 | approx[j][k] = a 74 | expn[j][k] = b 75 | 76 | def calc_f1(x, y): 77 | if x == 0 or y == 0: return 0 78 | return 2 * x * y / (x + y) 79 | 80 | result_prec = np.zeros((3,3,5)) 81 | result_rec = np.zeros((3,3,5)) 82 | result_cnt = np.zeros((3,3,5), dtype=np.int32) 83 | result_f1 = np.zeros((3,3,5)) 84 | result_auc = np.zeros((3,3)) 85 | 86 | for j, sample_type in enumerate(tqdm(sample_types)): 87 | for k, seed in enumerate(sample_seed): 88 | ratio = approx[j][k] / 10**expn[j][k] 89 | ratio_list = [ratio/4, ratio/2, ratio, ratio*2, ratio*4] 90 | 91 | # data = torch.load(f'eval_{args.dataset}/attack-type-{args.method_type}_sample-type-{sample_type}_n-attack-{args.n_attack}_seed-{seed}.pt') 92 | # data = torch.load(f'eval_{args.dataset}/{args.method_type}_{sample_type}_{args.n_attack}_{seed}.pt') 93 | data = torch.load(f'eval_{args.dataset}/{args.method_type}_{sample_type}_{args.perturb_type}_{args.n_attack}_{seed}_eps-{args.eps}_seed-42.pt') # DP-GCN 94 | # data = torch.load(f'eval_{args.dataset}/type-{sample_type}_n-attack-{args.n_attack}_seed-{seed}.pt') 95 | result = data['result'] 96 | 97 | auc = data['auc'] 98 | fpr = auc['fpr'] 99 | tpr = auc['tpr'] 100 | auc_res = metrics.auc(fpr, tpr) 101 | result_auc[j][k] = auc_res 102 | 103 | pred = np.asarray(result['pred']) 104 | y = np.asarray(result['y']) 105 | 106 | n_total = len(pred) 107 | 108 | for l, ratio in enumerate(ratio_list): 109 | n_pos = ceil(ratio * n_total) 110 | ind = np.argpartition(pred, -n_pos)[-n_pos:] 111 | n_tp = y[ind].sum() 112 | 113 | result_prec[j][k][l] = n_tp / n_pos 114 | result_rec[j][k][l] = n_tp / y.sum() 115 | result_cnt[j][k][l] = n_tp 116 | result_f1[j][k][l] = calc_f1(result_prec[j][k][l], result_rec[j][k][l]) 117 | 118 | 119 | ########################### 120 | ######### save 121 | ########################### 122 | 123 | def get_df(mat, st): 124 | df = pd.DataFrame(mat.T) 125 | df.columns = ['low', 'normal', 'high'] 126 | df.index = [st+':ratio/4', st+':ratio/2', st+':ratio', st+':ratio*2', st+':ratio*4'] 127 | return df 128 | 129 | if dataset.startswith('twitch'): 130 | datadir = dataset[:dataset.find('/')] 131 | cty = dataset[dataset.rfind('/')+1:] 132 | else: 133 | datadir = dataset 134 | cty = '' 135 | 136 | 137 | savedir = f'sheets/dp_attack_{datadir}' 138 | if not osp.exists(savedir): 139 | os.makedirs(savedir) 140 | 141 | # writer = pd.ExcelWriter(f'sheets/attack_{datadir}_3_layer/{cty}_{args.method_type}_{args.n_attack}.xlsx', engine='openpyxl') 142 | # filename1 = f'sheets/attack_{datadir}/{args.method_type}_{args.n_attack}.xlsx' 143 | if cty: 144 | filename1 = osp.join(savedir, f'{cty}_{args.method_type}_{args.perturb_type}_{args.eps}_{args.n_attack}.xlsx') 145 | else: 146 | filename1 = osp.join(savedir, f'{args.method_type}_{args.perturb_type}_{args.eps}_{args.n_attack}.xlsx') 147 | 148 | writer = pd.ExcelWriter(filename1, engine='openpyxl') 149 | 150 | prec_mean = result_prec.mean(axis=1) 151 | prec_std = result_prec.std(axis=1) 152 | 153 | rec_mean = result_rec.mean(axis=1) 154 | rec_std = result_rec.std(axis=1) 155 | 156 | cnt_mean = result_cnt.mean(axis=1) 157 | cnt_std = result_cnt.std(axis=1) 158 | 159 | f1_mean = result_f1.mean(axis=1) 160 | f1_std = result_f1.std(axis=1) 161 | 162 | df = pd.concat([get_df(prec_mean, 'prec_mean'), get_df(prec_std, 'prec_std')]) 163 | df.to_excel(writer) 164 | df = pd.concat([get_df(rec_mean, 'rec_mean'), get_df(rec_std, 'rec_std')]) 165 | df.to_excel(writer, startrow=15) 166 | df = pd.concat([get_df(cnt_mean, 'cnt_mean'), get_df(cnt_std, 'cnt_std')]) 167 | df.to_excel(writer, startrow=30) 168 | df = pd.concat([get_df(f1_mean, 'f1_mean'), get_df(f1_std, 'f1_std')]) 169 | df.to_excel(writer, startrow=45) 170 | 171 | writer.save() 172 | print(f'save prec, rec, cnt, f1 to {filename1}!') 173 | 174 | 175 | auc_mean = result_auc.mean(axis=1) 176 | auc_std = result_auc.std(axis=1) 177 | auc = np.vstack((auc_mean, auc_std)) 178 | 179 | # pd.DataFrame(auc).to_csv(f'sheets/attack_{datadir}_3_layer/{cty}_{args.method_type}_{args.n_attack}_auc.csv') 180 | 181 | # filename2 = f'sheets/attack_{datadir}/{args.method_type}_{args.n_attack}_auc.csv' 182 | if cty: 183 | filename2 = osp.join(savedir, f'{cty}_{args.method_type}_{args.perturb_type}_{args.eps}_{args.n_attack}_auc.csv') 184 | else: 185 | filename2 = osp.join(savedir, f'{args.method_type}_{args.perturb_type}_{args.eps}_{args.n_attack}_auc.csv') 186 | pd.DataFrame(auc).to_csv(filename2) 187 | print(f'save auc to {filename2}!') 188 | -------------------------------------------------------------------------------- /gcn/layers.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | from torch.nn.parameter import Parameter 5 | from torch.nn.modules.module import Module 6 | 7 | 8 | class GraphConvolution(Module): 9 | """ 10 | Simple GCN layer, similar to https://arxiv.org/abs/1609.02907 11 | """ 12 | 13 | def __init__(self, in_features, out_features, bias=True): 14 | super(GraphConvolution, self).__init__() 15 | self.in_features = in_features 16 | self.out_features = out_features 17 | self.weight = Parameter(torch.FloatTensor(in_features, out_features)) 18 | if bias: 19 | self.bias = Parameter(torch.FloatTensor(out_features)) 20 | else: 21 | self.register_parameter('bias', None) 22 | self.reset_parameters() 23 | 24 | def reset_parameters(self): 25 | stdv = 1. / math.sqrt(self.weight.size(1)) 26 | self.weight.data.uniform_(-stdv, stdv) 27 | if self.bias is not None: 28 | self.bias.data.uniform_(-stdv, stdv) 29 | 30 | def forward(self, input, adj): 31 | support = torch.mm(input, self.weight) 32 | output = torch.spmm(adj, support) 33 | if self.bias is not None: 34 | return output + self.bias 35 | else: 36 | return output 37 | 38 | def __repr__(self): 39 | return self.__class__.__name__ + ' (' \ 40 | + str(self.in_features) + ' -> ' \ 41 | + str(self.out_features) + ')' 42 | 43 | 44 | class ProjectionGraphConvolution(Module): 45 | def __init__(self, in_features, out_features, projection, size, bias=True, trainable=False, args=None): 46 | super(ProjectionGraphConvolution, self).__init__() 47 | self.in_features = in_features 48 | self.out_features = out_features 49 | 50 | # self.mask = (projection != 0).type(torch.FloatTensor) 51 | if torch.cuda.is_available(): 52 | # self.mask = self.mask.cuda() 53 | self.projection = projection.cuda() 54 | # self.projection = Parameter(projection) 55 | 56 | self.n_nodes = size 57 | self.weight = Parameter(torch.FloatTensor(in_features, out_features)) 58 | if bias: 59 | self.bias = Parameter(torch.FloatTensor(out_features)) 60 | else: 61 | self.register_parameter('bias', None) 62 | 63 | self.coeff = Parameter(torch.FloatTensor([1]), requires_grad=trainable) 64 | self.reset_parameters() 65 | 66 | def reset_parameters(self): 67 | # torch.manual_seed(100) 68 | stdv = 1. / math.sqrt(self.weight.size(1)) 69 | self.weight.data.uniform_(-stdv, stdv) 70 | if self.bias is not None: 71 | self.bias.data.uniform_(-stdv, stdv) 72 | # self.coeff.data.uniform_(0.5, 5) 73 | 74 | def forward(self, input, adj): 75 | support = torch.mm(input, self.weight) 76 | 77 | output = torch.spmm(self.projection, support) 78 | output = torch.spmm(adj, output) + self.coeff * support 79 | 80 | # output = (1+self.coeff) * support 81 | # output = torch.spmm(adj, output) + support 82 | # adjacency = torch.spmm(adj, self.projection*self.mask) + self.identity 83 | # adjacency = torch.spmm(adj, self.projection) + self.identity 84 | 85 | # adjacency = torch.spmm(adj, self.projection) 86 | # output = torch.spmm(adjacency, support) 87 | 88 | if self.bias is not None: 89 | return output + self.bias 90 | else: 91 | return output 92 | 93 | def __repr__(self): 94 | return self.__class__.__name__ + ' (' \ 95 | + str(self.in_features) + ' -> ' \ 96 | + str(self.out_features) + ')' 97 | 98 | 99 | class BasicMLP(Module): 100 | def __init__(self, in_features, out_features, size, bias=True, args=None, trainable=False): 101 | super(BasicMLP, self).__init__() 102 | self.in_features = in_features 103 | self.out_features = out_features 104 | self.weight = Parameter(torch.FloatTensor(in_features, out_features)) 105 | if bias: 106 | self.bias = Parameter(torch.FloatTensor(out_features)) 107 | else: 108 | self.register_parameter('bias', None) 109 | 110 | self.n_nodes = size 111 | self.coeff = Parameter(torch.FloatTensor([1]), requires_grad=trainable) 112 | self.reset_parameters() 113 | 114 | def reset_parameters(self): 115 | stdv = 1. / math.sqrt(self.weight.size(1)) 116 | self.weight.data.uniform_(-stdv, stdv) 117 | if self.bias is not None: 118 | self.bias.data.uniform_(-stdv, stdv) 119 | 120 | def forward(self, input, d_vec): 121 | support = torch.mm(input, self.weight) 122 | 123 | output = torch.matmul(torch.ones(self.n_nodes,1).cuda(), 124 | torch.matmul(torch.ones(1,self.n_nodes).cuda(), support))# + support 125 | 126 | # output = torch.mm(torch.ones(self.n_nodes, self.n_nodes).cuda()/2708, self.coeff * support) 127 | 128 | if self.bias is not None: 129 | return output + self.bias 130 | else: 131 | return output 132 | 133 | def __repr__(self): 134 | return self.__class__.__name__ + ' (' \ 135 | + str(self.in_features) + ' -> ' \ 136 | + str(self.out_features) + ')' 137 | 138 | 139 | class DegreeMLP(Module): 140 | def __init__(self, in_features, out_features, size, bias=True, args=None, trainable=False): 141 | super(DegreeMLP, self).__init__() 142 | self.in_features = in_features 143 | self.out_features = out_features 144 | self.weight = Parameter(torch.FloatTensor(in_features, out_features)) 145 | if bias: 146 | self.bias = Parameter(torch.FloatTensor(out_features)) 147 | else: 148 | self.register_parameter('bias', None) 149 | 150 | self.projection = torch.ones(1, size).cuda() / size 151 | 152 | self.n_nodes = size 153 | # self.d_vec_add = Parameter(torch.FloatTensor(n_nodes), requires_grad=False) 154 | self.d_vec_add = Parameter(torch.zeros(self.n_nodes), requires_grad=True) 155 | self.coeff = Parameter(torch.FloatTensor([args.coeff]), requires_grad=False) 156 | self.reset_parameters() 157 | 158 | def reset_parameters(self): 159 | stdv = 1. / math.sqrt(self.weight.size(1)) 160 | self.weight.data.uniform_(-stdv, stdv) 161 | if self.bias is not None: 162 | self.bias.data.uniform_(-stdv, stdv) 163 | self.d_vec_add.data.uniform_(-1, 1) 164 | 165 | def forward(self, input, d_vec): 166 | support = torch.mm(input, self.weight) 167 | output = torch.spmm(self.projection, support) 168 | 169 | output = torch.mm( 170 | (torch.ones(self.n_nodes).cuda()+self.d_vec_add).unsqueeze(-1), output)# + self.coeff * support 171 | 172 | if self.bias is not None: 173 | return output + self.bias 174 | else: 175 | return output 176 | 177 | def __repr__(self): 178 | return self.__class__.__name__ + ' (' \ 179 | + str(self.in_features) + ' -> ' \ 180 | + str(self.out_features) + ')' -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from __future__ import print_function 3 | 4 | import argparse 5 | import datetime 6 | import logging 7 | import numpy as np 8 | import random 9 | 10 | import torch 11 | 12 | from worker import Worker 13 | from gcn_trainer import GCNTrainer 14 | from utils import init_logger 15 | 16 | 17 | def get_arguments(): 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument('--no-cuda', action='store_true', default=False, 20 | help='Disables CUDA training.') 21 | parser.add_argument('--fastmode', action='store_true', default=False, 22 | help='Validate during training pass.') 23 | parser.add_argument('--seed', type=int, default=42, help='Random seed.') 24 | parser.add_argument('--num-epochs', type=int, default=500, 25 | help='Number of epochs to train.') 26 | parser.add_argument('--lr', type=float, default=0.01, 27 | help='Initial learning rate.') 28 | parser.add_argument('--weight_decay', type=float, default=5e-4, 29 | help='Weight decay (L2 loss on parameters).') 30 | parser.add_argument('--hidden', type=int, default=16, 31 | help='Number of hidden units.') 32 | parser.add_argument('--hidden1', type=int, default=16, 33 | help='Number of hidden units.') 34 | parser.add_argument('--hidden2', type=int, default=16, 35 | help='Number of hidden units.') 36 | parser.add_argument('--dropout', type=float, default=0.5, 37 | help='Dropout rate (1 - keep probability).') 38 | parser.add_argument('--dataset', type=str, default='cora', 39 | help='Dataset: cora; cora.sample') 40 | parser.add_argument('--model-path', type=str, default='') 41 | parser.add_argument('--mode', type=str, default='vanilla-clean', 42 | help='[ vanilla | vanilla-clean | clusteradj | clusteradj-clean ] ') 43 | parser.add_argument('--init-method', type=str, default='knn', 44 | help='[ naive | voting | knn | gt ]') 45 | parser.add_argument('--cluster-method', type=str, default='hierarchical', 46 | help='[ label | random | kmeans | sskmeans ]') 47 | parser.add_argument('--scale', type=str, default='small', 48 | help='[ large | small ]') 49 | parser.add_argument('--break-method', type=str, default='kmeans', 50 | help='[ kmeans | dp ]') 51 | parser.add_argument('--norm', type=str, default='AugNormAdj', 52 | choices=['AugNormAdj', 'FirstOrderGCN', 'BingGeNormAdj', 'NormAdj', 'RWalk', 'AugRWalk']) 53 | parser.add_argument('--sample-type', type=str, default='balanced', 54 | choices=['balanced', 'unbalanced', 'unbalanced-lo', 'unbalanced-hi', 'bfs', 'balanced-full']) 55 | 56 | parser.add_argument('--epsilon', type=float, default=0.1) 57 | parser.add_argument('--delta', type=float, default=1e-5) 58 | parser.add_argument('--influence', type=float, default=0.0001) 59 | parser.add_argument('--train-ratio', type=float, default=0.5) 60 | parser.add_argument('--patience', type=int, default=10) 61 | parser.add_argument('--n-clusters', type=int, default=10) 62 | parser.add_argument('--n-test', type=int, default=100) 63 | parser.add_argument('--n-layer', type=int, default=2) 64 | parser.add_argument('--break-ratio', type=float, default=1) 65 | parser.add_argument('--feature-size', type=int, default=-1) 66 | parser.add_argument('--k', type=float, default=1) 67 | 68 | parser.add_argument('--approx', action='store_true', default=False) 69 | parser.add_argument('--attack', action='store_true', default=False) 70 | parser.add_argument('--test', action='store_true', default=False) 71 | parser.add_argument('--break-down', action='store_true', default=False) 72 | parser.add_argument('--display', action='store_true', default=False) 73 | parser.add_argument('--same-size', action='store_true', default=False) 74 | parser.add_argument('--eval-degree', action='store_true', default=False) 75 | parser.add_argument('--trainable', action='store_true', default=False) 76 | parser.add_argument('--early', action='store_true', default=False) 77 | parser.add_argument('--fnormalize', action='store_true', default=False) 78 | 79 | parser.add_argument('--noise-seed', type=int, default=42) 80 | parser.add_argument('--sample-seed', type=int, default=42) 81 | parser.add_argument('--cluster-seed', type=int, default=42) 82 | parser.add_argument('--knn', type=int, default=-1) 83 | parser.add_argument('--noise-type', type=str, default='laplace') 84 | parser.add_argument('--perturb-type', type=str, default='discrete', 85 | choices=[ 'discrete', 'continuous' ]) 86 | parser.add_argument('--attack-mode', type=str, default='efficient', 87 | choices=['efficient', 'naive', 'baseline', 'baseline-feat']) 88 | 89 | parser.add_argument('--coeff', type=float, default=1) 90 | parser.add_argument('--degree', type=int, default=2) 91 | parser.set_defaults(assign_seed=42) 92 | 93 | return parser.parse_args() 94 | 95 | 96 | def main(): 97 | args = get_arguments() 98 | print(str(args)) 99 | logging.info(str(args)) 100 | 101 | random.seed(args.seed) 102 | np.random.seed(args.seed) 103 | torch.manual_seed(args.seed) 104 | if torch.cuda.is_available(): 105 | torch.cuda.manual_seed(args.seed) 106 | 107 | if args.test: 108 | worker = Worker(args, dataset=args.dataset, mode=args.mode) 109 | 110 | trainer = GCNTrainer(args, worker=worker) 111 | trainer.init_model(model_path=args.model_path) 112 | trainer.test(args.eval_degree) 113 | 114 | else: 115 | cur_time = datetime.datetime.now().strftime("%m-%d-%H:%M:%S.%f") 116 | 117 | if args.mode in ( 'sgc-clean', 'sgc' ): 118 | subdir = 'mode-{}_lr-{}_{}'.format(args.mode, args.lr, cur_time) 119 | 120 | elif args.mode in ( 'vanilla-clean', 'cs' ): 121 | subdir = 'mode-{}_hidden-{}_lr-{}_decay-{}_dropout-{}_norm-{}_{}'.format(\ 122 | args.mode, args.hidden, args.lr, args.weight_decay, args.dropout, args.norm, cur_time) 123 | 124 | elif args.mode == 'clusteradj-clean': 125 | if not args.scale: 126 | subdir = 'mode-clusteradj-clean_{}_{}'.format(\ 127 | args.cluster_method, cur_time) 128 | else: 129 | subdir = 'mode-clean_small_n-clusters-{}_{}'.format(\ 130 | args.n_clusters, cur_time) 131 | 132 | elif args.mode == 'vanilla': 133 | subdir = 'mode-global_perturb-{}_eps-{}_{}'.format(\ 134 | args.perturb_type, args.epsilon, cur_time) 135 | 136 | elif args.mode == 'clusteradj': 137 | if not args.scale: 138 | subdir = 'mode-clusteradj_ratio-{}_eps-{}_train-{}_{}_{}'.format(\ 139 | args.train_ratio, args.epsilon, args.trainable, args.cluster_method, cur_time) 140 | elif args.scale == 'small': 141 | subdir = 'mode-clusteradj_small_eps-{}_n-clusters-{}_{}'.format(\ 142 | args.epsilon, args.n_clusters, cur_time) 143 | 144 | elif args.mode in ( 'degree_mlp', 'basic_mlp' ): 145 | subdir = 'mode-{}_{}'.format(args.mode, cur_time) 146 | 147 | elif args.mode in ( 'degcn-clean', 'degcn' ): 148 | subdir = 'mode-{}_eps-{}_{}'.format(args.mode, args.epsilon, cur_time) 149 | 150 | else: 151 | print('mode={} not implemented!'.format(args.mode)) 152 | raise NotImplementedError 153 | 154 | print('subdir = {}'.format(subdir)) 155 | init_logger('./logs_{}'.format(args.dataset), subdir, print_log=False) 156 | 157 | worker = Worker(args, dataset=args.dataset, mode=args.mode) 158 | 159 | trainer = GCNTrainer(args, subdir=subdir, worker=worker) 160 | trainer.init_model() 161 | trainer.train() 162 | trainer.test(args.eval_degree) 163 | 164 | 165 | if __name__ == "__main__": 166 | main() 167 | -------------------------------------------------------------------------------- /mlp_trainer.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import logging 3 | import numpy as np 4 | import os 5 | import scipy.sparse as sp 6 | from sklearn.decomposition import PCA 7 | from sklearn.metrics import f1_score, average_precision_score 8 | from tensorboardX import SummaryWriter 9 | import time 10 | from tqdm import trange, tqdm 11 | 12 | import torch 13 | import torch.nn.functional as F 14 | import torch.optim as optim 15 | from torch.utils.data import DataLoader, TensorDataset 16 | 17 | from gcn import GCN, ProjectionGCN, MLP, SGC, SingleHiddenLayerMLP, OneLayerMLP 18 | from utils import EarlyStopping, get_noise 19 | 20 | class MLPTrainer(): 21 | def __init__(self, args, subdir='', worker=None, train=True): 22 | self.args = args 23 | 24 | self.worker = worker 25 | self.loss_func = F.cross_entropy if self.worker.multi_label == 1 \ 26 | else F.binary_cross_entropy_with_logits 27 | self.mode = self.worker.mode 28 | self.dataset = self.worker.dataset 29 | self.subdir = subdir 30 | self.is_train = train 31 | 32 | self.gcnt_train = self.gcnt_valid = 0 33 | if self.args.early: 34 | self.early_stopping = EarlyStopping(patience=self.args.patience) 35 | 36 | if subdir: 37 | self.init_all_logging(subdir) 38 | 39 | self.transfer = (not self.dataset.startswith('twitch-train') and self.dataset.startswith('twitch')) or \ 40 | self.dataset.startswith('wikipedia') or \ 41 | self.dataset.startswith('deezer') 42 | 43 | self.prepare_data() 44 | 45 | 46 | def calc_loss(self, input, target): 47 | if self.loss_func == F.cross_entropy: 48 | return self.loss_func(input, target.squeeze()) 49 | else: 50 | return self.loss_func(input, target.float()) 51 | 52 | 53 | def prepare_data(self): 54 | if self.is_train: 55 | if self.transfer: 56 | self.train_loader = DataLoader(TensorDataset( 57 | self.worker.features_1, 58 | self.worker.labels_1 59 | ), batch_size=self.args.batch_size, shuffle=True) 60 | 61 | else: 62 | self.train_loader = DataLoader(TensorDataset( 63 | self.worker.features[self.worker.idx_train], 64 | self.worker.labels[self.worker.idx_train] 65 | ), batch_size=self.args.batch_size, shuffle=True) 66 | 67 | 68 | def init_all_logging(self, subdir): 69 | tflog_path = os.path.join('tflogs_{}'.format(self.dataset), subdir) 70 | self.model_path = os.path.join('model_{}'.format(self.dataset), subdir) 71 | self.writer = SummaryWriter(log_dir=tflog_path) 72 | if not os.path.exists(self.model_path): os.makedirs(self.model_path) 73 | 74 | 75 | def init_model(self, model_path=''): 76 | # Model and optimizer 77 | if self.mode in ( 'mlp' ): 78 | if self.args.depth == 1: 79 | self.model = OneLayerMLP(nfeat=self.worker.n_features, 80 | nclass=self.worker.n_classes) 81 | elif self.args.depth == 2: 82 | self.model = SingleHiddenLayerMLP(nfeat=self.worker.n_features, 83 | nhid=self.args.hidden, 84 | nclass=self.worker.n_classes, 85 | dropout=self.args.dropout) 86 | 87 | if model_path: 88 | self.model.load_state_dict(torch.load(model_path)) 89 | print('load model from {} done!'.format(model_path)) 90 | else: 91 | self.optimizer = optim.Adam(self.model.parameters(), 92 | lr=self.args.lr, weight_decay=self.args.weight_decay) 93 | if torch.cuda.is_available(): 94 | self.model.cuda() 95 | 96 | 97 | def train_one_epoch(self, epoch): 98 | # training 99 | self.model.train() 100 | 101 | loss_seq = [] 102 | acc_seq = [] 103 | 104 | if self.args.batch: 105 | for features, labels in self.train_loader: 106 | output = self.model(features) 107 | 108 | loss = self.calc_loss(output, labels) 109 | acc = self.f1_score(output, labels) 110 | 111 | self.optimizer.zero_grad() 112 | loss.backward() 113 | self.optimizer.step() 114 | 115 | loss_seq.append(loss.item()) 116 | acc_seq.append(acc[0].item()) 117 | self.gcnt_train += 1 118 | if self.gcnt_train % 10 == 0: 119 | self.writer.add_scalar('train/loss', np.mean(loss_seq), self.gcnt_train) 120 | self.writer.add_scalar('train/acc', np.mean(acc_seq), self.gcnt_train) 121 | loss_seq = [] 122 | acc_seq = [] 123 | 124 | 125 | if not self.transfer: 126 | # validation 127 | output = self.model(self.worker.features[self.worker.idx_val]) 128 | loss_val = self.calc_loss(output, self.worker.labels[self.worker.idx_val]) 129 | acc_val = self.f1_score(output, self.worker.labels[self.worker.idx_val]) 130 | self.gcnt_valid += 1 131 | self.writer.add_scalar('valid/loss', loss_val, self.gcnt_valid) 132 | self.writer.add_scalar('valid/acc', acc_val[0], self.gcnt_valid) 133 | 134 | else: 135 | output = self.model(self.worker.features_1) 136 | loss = self.calc_loss(output, self.worker.labels_1) 137 | acc = self.f1_score(output, self.worker.labels_1) 138 | 139 | self.optimizer.zero_grad() 140 | loss.backward() 141 | self.optimizer.step() 142 | 143 | 144 | def train(self): 145 | # Train model 146 | t_total = time.time() 147 | 148 | for epoch in tqdm(range(self.args.num_epochs)): 149 | logging.info('[epoch {}]'.format(epoch)) 150 | self.train_one_epoch(epoch) 151 | 152 | if self.args.early and self.early_stopping.early_stop: 153 | self.model = self.early_stopping.best_model 154 | logging.info(f'early stop at epoch {epoch}') 155 | break 156 | 157 | torch.save(self.model.state_dict(), os.path.join(self.model_path, 'model.pt')) 158 | 159 | print("Total time elapsed: {:.4f}s".format(time.time() - t_total)) 160 | 161 | 162 | def f1_score(self, output, labels): 163 | if self.worker.multi_label == 1: 164 | preds = F.softmax(output, dim=1) 165 | preds = preds.max(1)[1].type_as(labels) 166 | return f1_score(labels.cpu(), preds.detach().cpu(), average='micro'), \ 167 | f1_score(labels.cpu(), preds.detach().cpu(), average='macro'), \ 168 | f1_score(labels.cpu(), preds.detach().cpu(), average='weighted') 169 | # unique, count = torch.unique(preds, return_counts=True) 170 | # correct = preds.eq(labels).double() 171 | # correct = correct.sum() 172 | # return correct / len(labels) 173 | 174 | else: # multi_label 175 | preds = torch.sigmoid(output) > 0.5 176 | return f1_score(labels.cpu(), preds.detach().cpu(), average='micro'), \ 177 | f1_score(labels.cpu(), preds.detach().cpu(), average='macro'), \ 178 | f1_score(labels.cpu(), preds.detach().cpu(), average='weighted') 179 | 180 | 181 | def rare_class_f1(self, output, labels): 182 | # identify the rare class 183 | ind = [torch.where(labels==0)[0], 184 | torch.where(labels==1)[0]] 185 | rare_class = int(len(ind[0]) > len(ind[1])) 186 | 187 | preds = F.softmax(output, dim=1).max(1) 188 | 189 | ap_score = average_precision_score(labels.cpu() if rare_class==1 else 1-labels.cpu(), preds[0].detach().cpu()) 190 | 191 | preds = preds[1].type_as(labels) 192 | 193 | TP = torch.sum(preds[ind[rare_class]] == rare_class).item() 194 | T = len(ind[rare_class]) 195 | P = torch.sum(preds == rare_class).item() 196 | 197 | if P == 0: return 0 198 | 199 | precision = TP / P 200 | recall = TP / T 201 | F1 = 2 * (precision * recall) / (precision + recall) 202 | return F1, precision, recall, ap_score 203 | 204 | 205 | def eval_output(self): 206 | if not self.transfer: 207 | output = self.model(self.worker.features[self.worker.idx_val]) 208 | loss_val = self.calc_loss(output, self.worker.labels[self.worker.idx_val]) 209 | acc_val = self.f1_score(output, self.worker.labels[self.worker.idx_val]) 210 | 211 | output_info = f'''Valid set results: '''\ 212 | f'''loss = {loss_val.item():.4f} '''\ 213 | f'''f1_score = {acc_val[0].item():.4f} ''' 214 | print(output_info) 215 | logging.info(output_info) 216 | 217 | output = self.model(self.worker.features_2) if self.transfer \ 218 | else self.model(self.worker.features[self.worker.idx_test]) 219 | target = self.worker.labels_2 if self.transfer \ 220 | else self.worker.labels[self.worker.idx_test] 221 | loss_test = self.calc_loss(output, target) 222 | acc_test = self.f1_score(output, target) if not self.worker.transfer \ 223 | else self.rare_class_f1(output, target) 224 | 225 | output_info = f'''Test set results: '''\ 226 | f'''loss = {loss_test.item():.4f} ''' 227 | output_info += f'rare_class_f1 = {acc_test[0]:.4f} prec = {acc_test[1]:.4f} reca = {acc_test[2]:.4f} ap_score = {acc_test[3]:.4f}' if self.worker.transfer else \ 228 | f'''f1_score [micro, macro, weighted] = {acc_test[0].item():.4f} {acc_test[1].item():.4f} {acc_test[2].item():.4f}''' 229 | print(output_info) 230 | logging.info(output_info) 231 | 232 | 233 | def test(self, eval_degree=False): 234 | self.model.eval() 235 | 236 | self.eval_output() 237 | 238 | 239 | def __del__(self): 240 | if hasattr(self, 'writer'): 241 | self.writer.close() -------------------------------------------------------------------------------- /attacker.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | import numpy as np 3 | import os 4 | import os.path as osp 5 | from sklearn import metrics 6 | import time 7 | from tqdm import tqdm 8 | 9 | import torch 10 | import torch.nn.functional as F 11 | 12 | from utils import construct_edge_sets, construct_edge_sets_from_random_subgraph, construct_edge_sets_through_bfs, construct_balanced_edge_sets 13 | 14 | class Attacker: 15 | def __init__(self, args, model, worker): 16 | self.args = args 17 | self.dataset = args.dataset 18 | self.model = model 19 | self.worker = worker 20 | 21 | if args.sample_type == 'balanced-full': 22 | self.args.n_test = self.worker.n_nodes 23 | 24 | if self.dataset.startswith('twitch') or self.dataset.startswith('deezer'): 25 | self.features = self.worker.features_2 26 | self.adj = self.worker.adj_2 27 | 28 | else: 29 | self.features = self.worker.features 30 | self.adj = self.worker.adj_full 31 | 32 | 33 | def prepare_test_data(self): 34 | func = { 35 | 'balanced': construct_edge_sets, 36 | 'balanced-full': construct_balanced_edge_sets, 37 | 'unbalanced': construct_edge_sets_from_random_subgraph, 38 | 'unbalanced-lo': construct_edge_sets_from_random_subgraph, 39 | 'unbalanced-hi': construct_edge_sets_from_random_subgraph, 40 | 'bfs': construct_edge_sets_through_bfs, 41 | }.get(self.args.sample_type) 42 | if not func: 43 | raise NotImplementedError(f'sample_type = {self.args.sample_type} not implemented!') 44 | 45 | np.random.seed(self.args.sample_seed) 46 | (self.exist_edges, self.nonexist_edges), self.test_nodes = func( 47 | self.dataset, self.args.sample_type, self.worker.adj_ori, self.args.n_test) 48 | print(f'generating testing (non-)edge set done!') 49 | 50 | 51 | # \partial_f(x_u) / \partial_x_v 52 | def get_gradient(self, u, v): 53 | h = 0.0001 54 | ret = torch.zeros(self.worker.n_features) 55 | for i in range(self.worker.n_features): 56 | pert = torch.zeros_like(self.worker.features) 57 | pert[v][i] = h 58 | with torch.no_grad(): 59 | grad = (self.model(self.worker.features + pert, self.worker.adj_full).detach() - 60 | self.model(self.worker.features - pert, self.worker.adj_full).detach()) / (2 * h) 61 | ret[i] = grad[u].sum() 62 | 63 | return ret 64 | 65 | 66 | # # \partial_f(x_u) / \partial_epsilon_v 67 | # def get_gradient_eps(self, u, v): 68 | # if self.dataset.startswith('twitch'): 69 | # features = self.worker.features_2 70 | # adj = self.worker.adj_2 71 | 72 | # else: 73 | # features = self.worker.features 74 | # adj = self.worker.adj_full 75 | 76 | # h = 0.00001 77 | # pert_1 = torch.zeros_like(features) 78 | # pert_2 = torch.zeros_like(features) 79 | 80 | # pert_1[v] = features[v] * self.args.influence 81 | # pert_2[v] = features[v] * h 82 | # grad = (self.model(features + pert_1 + pert_2, adj).detach() - 83 | # self.model(features + pert_1 - pert_2, adj).detach()) / (2 * h) 84 | 85 | # return grad[u] 86 | 87 | 88 | # \partial_f(x_u) / \partial_epsilon_v 89 | def get_gradient_eps(self, u, v): 90 | pert_1 = torch.zeros_like(self.features) 91 | 92 | pert_1[v] = self.features[v] * self.args.influence 93 | 94 | grad = (self.model(self.features + pert_1, self.adj).detach() - 95 | self.model(self.features, self.adj).detach()) / self.args.influence 96 | 97 | return grad[u] 98 | 99 | 100 | def get_gradient_eps_mat(self, v): 101 | pert_1 = torch.zeros_like(self.features) 102 | 103 | pert_1[v] = self.features[v] * self.args.influence 104 | 105 | grad = (self.model(self.features + pert_1, self.adj).detach() - 106 | self.model(self.features, self.adj).detach()) / self.args.influence 107 | 108 | return grad 109 | 110 | 111 | def calculate_auc(self, v1, v0): 112 | v1 = sorted(v1) 113 | v0 = sorted(v0) 114 | vall = sorted(v1 + v0) 115 | 116 | TP = self.args.n_test 117 | FP = self.args.n_test 118 | T = F = self.args.n_test # fixed 119 | 120 | p0 = p1 = 0 121 | 122 | TPR = TP / T 123 | FPR = FP / F 124 | 125 | result = [(FPR, TPR)] 126 | auc = 0 127 | for elem in vall: 128 | if p1 < self.args.n_test and abs(elem - v1[p1]) < 1e-6: 129 | p1 += 1 130 | TP -= 1 131 | TPR = TP / T 132 | else: 133 | p0 += 1 134 | FP -= 1 135 | FPR = FP / F 136 | auc += TPR * 1 / F 137 | 138 | result.append((FPR, TPR)) 139 | 140 | return result, auc 141 | 142 | 143 | def link_prediction_attack(self): 144 | norm_exist = [] 145 | norm_nonexist = [] 146 | 147 | t = time.time() 148 | 149 | with torch.no_grad(): 150 | for u, v in tqdm(self.exist_edges): 151 | 152 | grad = self.get_gradient_eps(u, v) # if self.args.approx else self.get_gradient(u, v) 153 | norm_exist.append(grad.norm().item()) 154 | 155 | print(f'time for predicting existing edges: {time.time() - t}') 156 | 157 | t = time.time() 158 | for u, v in tqdm(self.nonexist_edges): 159 | 160 | grad = self.get_gradient_eps(u, v) # if self.args.approx else self.get_gradient(u, v) 161 | norm_nonexist.append(grad.norm().item()) 162 | 163 | print(f'time for predicting non-existing edges: {time.time() - t}') 164 | 165 | # print(sorted(norm_exist)) 166 | # print(sorted(norm_nonexist)) 167 | 168 | 169 | y = [1] * len(norm_exist) + [0] * len(norm_nonexist) 170 | pred = norm_exist + norm_nonexist 171 | 172 | fpr, tpr, thresholds = metrics.roc_curve(y, pred) 173 | print('auc =', metrics.auc(fpr, tpr)) 174 | 175 | precision, recall, thresholds_2 = metrics.precision_recall_curve(y, pred) 176 | print('ap =', metrics.average_precision_score(y, pred)) 177 | 178 | folder_name = f'eval_{self.dataset}' 179 | if not osp.exists(folder_name): os.makedirs(folder_name) 180 | 181 | if self.args.mode == 'vanilla-clean': 182 | filename = osp.join(folder_name, f'{self.args.sample_type}_{self.args.n_test}_{self.args.sample_seed}.pt') 183 | else: 184 | filename = osp.join(folder_name, f'{self.args.sample_type}_{self.args.perturb_type}_{self.args.n_test}_{self.args.sample_seed}_eps-{self.args.epsilon}_seed-{self.args.noise_seed}.pt') 185 | 186 | torch.save({ 187 | 'auc': { 188 | 'fpr': fpr, 189 | 'tpr': tpr, 190 | 'thresholds': thresholds 191 | }, 192 | 'pr': { 193 | 'precision': precision, 194 | 'recall': recall, 195 | 'thresholds': thresholds_2 196 | }, 197 | 'result': { 198 | 'y': y, 199 | 'pred': pred, 200 | } 201 | }, filename) 202 | 203 | # result, auc = self.calculate_auc(norm_exist, norm_nonexist) 204 | 205 | # print('auc =', auc) 206 | # torch.save(result, 'result.pt') 207 | 208 | 209 | def link_prediction_attack_efficient(self): 210 | norm_exist = [] 211 | norm_nonexist = [] 212 | 213 | t = time.time() 214 | 215 | # 2. compute influence value for all pairs of nodes 216 | influence_val = np.zeros((self.args.n_test, self.args.n_test)) 217 | 218 | with torch.no_grad(): 219 | 220 | for i in tqdm(range(self.args.n_test)): 221 | u = self.test_nodes[i] 222 | grad_mat = self.get_gradient_eps_mat(u) 223 | 224 | for j in range(self.args.n_test): 225 | v = self.test_nodes[j] 226 | 227 | grad_vec = grad_mat[v] 228 | 229 | influence_val[i][j] = grad_vec.norm().item() 230 | 231 | print(f'time for predicting edges: {time.time() - t}') 232 | 233 | node2ind = { node : i for i, node in enumerate(self.test_nodes) } 234 | 235 | for u, v in self.exist_edges: 236 | i = node2ind[u] 237 | j = node2ind[v] 238 | 239 | norm_exist.append(influence_val[j][i]) 240 | 241 | for u, v in self.nonexist_edges: 242 | i = node2ind[u] 243 | j = node2ind[v] 244 | 245 | norm_nonexist.append(influence_val[j][i]) 246 | 247 | self.compute_and_save(norm_exist, norm_nonexist) 248 | 249 | 250 | def link_prediction_attack_efficient_balanced(self): 251 | norm_exist = [] 252 | norm_nonexist = [] 253 | 254 | # organize exist_edges and nonexist_edges into dict 255 | edges_dict = defaultdict(list) 256 | nonedges_dict = defaultdict(list) 257 | for u, v in self.exist_edges: 258 | edges_dict[u].append(v) 259 | for u, v in self.nonexist_edges: 260 | nonedges_dict[u].append(v) 261 | 262 | t = time.time() 263 | with torch.no_grad(): 264 | for u in tqdm(range(self.worker.n_nodes)): 265 | if u not in edges_dict and u not in nonedges_dict: 266 | continue 267 | 268 | grad_mat = self.get_gradient_eps_mat(u) 269 | 270 | if u in edges_dict: 271 | v_list = edges_dict[u] 272 | for v in v_list: 273 | grad_vec = grad_mat[v] 274 | norm_exist.append(grad_vec.norm().item()) 275 | 276 | if u in nonedges_dict: 277 | v_list = nonedges_dict[u] 278 | for v in v_list: 279 | grad_vec = grad_mat[v] 280 | norm_nonexist.append(grad_vec.norm().item()) 281 | 282 | print(f'time for predicting edges: {time.time() - t}') 283 | 284 | self.compute_and_save(norm_exist, norm_nonexist) 285 | 286 | 287 | def baseline_attack(self): 288 | norm_exist = [] 289 | norm_nonexist = [] 290 | 291 | t = time.time() 292 | 293 | with torch.no_grad(): 294 | # 0. compute posterior 295 | if self.args.attack_mode == 'baseline': 296 | if self.dataset != 'ppi': 297 | posterior = F.softmax(self.model(self.features, self.adj), dim=1) 298 | else: 299 | posterior = F.sigmoid(self.model(self.features, self.adj)) 300 | elif self.args.attack_mode == 'baseline-feat': 301 | posterior = self.features 302 | else: 303 | raise NotImplementedError(f'attack_mode={self.args.attack_mode} not implemented!') 304 | 305 | # 1. compute the mean posterior of sampled nodes 306 | mean = torch.mean(posterior[self.test_nodes], dim=0) 307 | 308 | # 2. compute correlation value for all pairs of nodes 309 | dist = np.zeros((self.args.n_test, self.args.n_test)) 310 | 311 | for i in tqdm(range(self.args.n_test)): 312 | u = self.test_nodes[i] 313 | for j in range(i+1, self.args.n_test): 314 | v = self.test_nodes[j] 315 | 316 | dist[i][j] = torch.dot(posterior[u] - mean, posterior[v] - mean) / torch.norm(posterior[u] - mean) / torch.norm(posterior[v] - mean) 317 | 318 | print(f'time for computing correlation value: {time.time() - t}') 319 | 320 | node2ind = { node : i for i, node in enumerate(self.test_nodes) } 321 | 322 | for u, v in self.exist_edges: 323 | i = node2ind[u] 324 | j = node2ind[v] 325 | 326 | norm_exist.append(dist[i][j] if i < j else dist[j][i]) 327 | 328 | for u, v in self.nonexist_edges: 329 | i = node2ind[u] 330 | j = node2ind[v] 331 | 332 | norm_nonexist.append(dist[i][j] if i < j else dist[j][i]) 333 | 334 | self.compute_and_save(norm_exist, norm_nonexist) 335 | 336 | 337 | def baseline_attack_balanced(self): 338 | norm_exist = [] 339 | norm_nonexist = [] 340 | 341 | # organize exist_edges and nonexist_edges into dict 342 | edges_dict = defaultdict(list) 343 | nonedges_dict = defaultdict(list) 344 | for u, v in self.exist_edges: 345 | edges_dict[u].append(v) 346 | for u, v in self.nonexist_edges: 347 | nonedges_dict[u].append(v) 348 | 349 | t = time.time() 350 | 351 | with torch.no_grad(): 352 | # 0. compute posterior 353 | if self.args.attack_mode == 'baseline': 354 | if self.dataset != 'ppi': 355 | posterior = F.softmax(self.model(self.features, self.adj), dim=1) 356 | else: 357 | posterior = F.sigmoid(self.model(self.features, self.adj)) 358 | elif self.args.attack_mode == 'baseline-feat': 359 | posterior = self.features 360 | else: 361 | raise NotImplementedError(f'attack_mode={self.args.attack_mode} not implemented!') 362 | 363 | # 1. compute the mean posterior of sampled nodes 364 | mean = torch.mean(posterior, dim=0) 365 | 366 | # 2. compute correlation value for all pairs 367 | for u, v in tqdm(self.exist_edges): 368 | norm_exist.append((torch.dot(posterior[u] - mean, posterior[v] - mean) / torch.norm(posterior[u] - mean) / torch.norm(posterior[v] - mean)).item()) 369 | 370 | for u, v in tqdm(self.nonexist_edges): 371 | norm_nonexist.append((torch.dot(posterior[u] - mean, posterior[v] - mean) / torch.norm(posterior[u] - mean) / torch.norm(posterior[v] - mean)).item()) 372 | 373 | print(f'time for computing correlation value: {time.time() - t}') 374 | 375 | self.compute_and_save(norm_exist, norm_nonexist) 376 | 377 | 378 | def compute_and_save(self, norm_exist, norm_nonexist): 379 | y = [1] * len(norm_exist) + [0] * len(norm_nonexist) 380 | pred = norm_exist + norm_nonexist 381 | 382 | fpr, tpr, thresholds = metrics.roc_curve(y, pred) 383 | print('auc =', metrics.auc(fpr, tpr)) 384 | 385 | precision, recall, thresholds_2 = metrics.precision_recall_curve(y, pred) 386 | print('ap =', metrics.average_precision_score(y, pred)) 387 | 388 | folder_name = f'eval_{self.dataset}' 389 | if not osp.exists(folder_name): os.makedirs(folder_name) 390 | 391 | if self.args.mode == 'vanilla-clean': 392 | filename = osp.join(folder_name, f'{self.args.attack_mode}_{self.args.sample_type}_{self.args.n_test}_{self.args.sample_seed}.pt') 393 | else: 394 | filename = osp.join(folder_name, f'{self.args.attack_mode}_{self.args.sample_type}_{self.args.perturb_type}_{self.args.n_test}_{self.args.sample_seed}_eps-{self.args.epsilon}_seed-{self.args.noise_seed}.pt') 395 | 396 | torch.save({ 397 | 'auc': { 398 | 'fpr': fpr, 399 | 'tpr': tpr, 400 | 'thresholds': thresholds 401 | }, 402 | 'pr': { 403 | 'precision': precision, 404 | 'recall': recall, 405 | 'thresholds': thresholds_2 406 | }, 407 | 'result': { 408 | 'y': y, 409 | 'pred': pred, 410 | } 411 | }, filename) 412 | print(f'attack results saved to: {filename}') 413 | 414 | # result, auc = self.calculate_auc(norm_exist, norm_nonexist) 415 | 416 | # print('auc =', auc) 417 | # torch.save(result, 'result.pt') 418 | -------------------------------------------------------------------------------- /gcn_trainer.py: -------------------------------------------------------------------------------- 1 | from pandas.core.base import DataError 2 | from sklearn.decomposition import PCA 3 | import logging 4 | import os 5 | import os.path as osp 6 | import numpy as np 7 | import random 8 | import scipy.sparse as sp 9 | from sklearn.metrics import f1_score, average_precision_score 10 | from tensorboardX import SummaryWriter 11 | import time 12 | from tqdm import tqdm, trange 13 | 14 | import torch 15 | import torch.nn.functional as F 16 | import torch.optim as optim 17 | 18 | from attacker import Attacker 19 | from gcn import GCN, GCN3, ProjectionGCN, MLP, SGC, DeGCN 20 | from utils import EarlyStopping, get_noise 21 | 22 | class GCNTrainer(): 23 | def __init__(self, args, subdir='', worker=None): 24 | self.args = args 25 | 26 | self.worker = worker 27 | self.loss_func = F.cross_entropy if self.worker.multi_label == 1 \ 28 | else F.binary_cross_entropy_with_logits 29 | self.mode = self.worker.mode 30 | self.dataset = self.worker.dataset 31 | self.subdir = subdir 32 | 33 | self.gcnt_train = self.gcnt_valid = 0 34 | if self.args.early: 35 | self.early_stopping = EarlyStopping(patience=self.args.patience) 36 | 37 | if subdir: 38 | self.init_all_logging(subdir) 39 | 40 | 41 | def calc_loss(self, input, target): 42 | if self.loss_func == F.cross_entropy: 43 | return self.loss_func(input, target.squeeze()) 44 | else: 45 | return self.loss_func(input, target.float()) 46 | 47 | 48 | def init_all_logging(self, subdir): 49 | tflog_path = os.path.join('tflogs_{}'.format(self.dataset), subdir) 50 | self.model_path = os.path.join('model_{}'.format(self.dataset), subdir) 51 | self.writer = SummaryWriter(log_dir=tflog_path) 52 | if not os.path.exists(self.model_path): os.makedirs(self.model_path) 53 | 54 | 55 | def init_model(self, model_path=''): 56 | # Model and optimizer 57 | if self.mode in ( 'sgc-clean', 'sgc' ): 58 | self.model = SGC(nfeat=self.worker.n_features, 59 | nclass=self.worker.n_classes) 60 | 61 | elif self.mode in ( 'degree_mlp', 'basic_mlp' ): 62 | self.model = MLP(nfeat=self.worker.n_features, 63 | nhid=self.args.hidden, 64 | nclass=self.worker.n_classes, 65 | dropout=self.args.dropout, 66 | size=self.worker.n_nodes, 67 | args=self.args) 68 | 69 | elif self.mode in ( 'degcn-clean' ): 70 | self.model = DeGCN(nfeat=self.worker.n_features, 71 | nhid=self.args.hidden, 72 | nclass=self.worker.n_classes, 73 | dropout=self.args.dropout) 74 | 75 | elif self.mode in ( 'vanilla-clean', 'vanilla' ) or not self.args.fnormalize: 76 | if self.args.n_layer == 2: 77 | self.model = GCN(nfeat=self.worker.n_features, 78 | nhid=self.args.hidden, 79 | nclass=self.worker.n_classes, 80 | dropout=self.args.dropout) 81 | elif self.args.n_layer == 3: 82 | self.model = GCN3(nfeat=self.worker.n_features, 83 | nhid1=self.args.hidden1, 84 | nhid2=self.args.hidden2, 85 | nclass=self.worker.n_classes, 86 | dropout=self.args.dropout) 87 | else: 88 | raise NotImplementedError(f'n_layer = {self.args.n_layer} not implemented!') 89 | 90 | elif self.mode in ( 'clusteradj-clean', 'clusteradj' ): 91 | self.model = ProjectionGCN(nfeat=self.worker.n_features, 92 | nhid=self.args.hidden, 93 | nclass=self.worker.n_classes, 94 | dropout=self.args.dropout, 95 | projection=self.worker.prj, 96 | size=self.worker.n_nodes, 97 | args=self.args) 98 | 99 | else: 100 | raise NotImplementedError('mode = {} no corrsponding model!'.format(self.mode)) 101 | 102 | if model_path: 103 | self.model.load_state_dict(torch.load(model_path)) 104 | print('load model from {} done!'.format(model_path)) 105 | self.model_path = model_path 106 | else: 107 | self.optimizer = optim.Adam(self.model.parameters(), 108 | lr=self.args.lr, weight_decay=self.args.weight_decay) 109 | if torch.cuda.is_available(): 110 | self.model.cuda() 111 | 112 | 113 | def forward(self, mode='train'): 114 | if self.mode in ( 'degcn-clean' ): 115 | output = self.model(self.worker.features, self.worker.adj, self.worker.sub_adj) 116 | 117 | elif self.mode in ( 'sgc-clean' ): 118 | if self.dataset in ('reddit', 'flickr', 'ppi', 'ppi-large', 'cora', 'citeseer', 'pubmed') and mode == "train": 119 | output = self.model(self.worker.features_train) 120 | 121 | elif self.worker.transfer: 122 | output = self.model(self.worker.features_1) if mode == 'train' \ 123 | else self.model(self.worker.features_2) 124 | 125 | else: 126 | output = self.model(self.worker.features) 127 | 128 | else: 129 | if self.dataset in ('reddit', 'flickr', 'ppi', 'ppi-large', 'cora', 'citeseer', 'pubmed') \ 130 | or self.dataset.startswith('twitch-train'): 131 | output = self.model(self.worker.features_train, self.worker.adj_train) if mode == 'train' \ 132 | else self.model(self.worker.features, self.worker.adj_full) 133 | 134 | elif self.worker.transfer: 135 | output = self.model(self.worker.features_1, self.worker.adj_1) if mode == 'train' \ 136 | else self.model(self.worker.features_2, self.worker.adj_2) 137 | 138 | else: 139 | output = self.model(self.worker.features, self.worker.adj) 140 | 141 | return output 142 | 143 | 144 | def train_one_epoch(self, epoch): 145 | 146 | t = time.time() 147 | self.model.train() 148 | self.optimizer.zero_grad() 149 | 150 | output = self.forward(mode='train') 151 | output = output[self.worker.idx_train] if self.dataset in ( 'cora', 'citeseer', 'pubmed' ) else output 152 | 153 | target_labels = self.worker.labels_1 if self.worker.transfer \ 154 | else self.worker.labels[self.worker.idx_train] 155 | 156 | loss_train = self.calc_loss(output, target_labels) 157 | acc_train = self.f1_score(output, target_labels) 158 | 159 | loss_train.backward() 160 | 161 | self.optimizer.step() 162 | 163 | self.writer.add_scalar('train/loss', loss_train, self.gcnt_train) 164 | self.writer.add_scalar('train/acc', acc_train[0], self.gcnt_train) 165 | self.gcnt_train += 1 166 | 167 | if self.worker.transfer: # no validation set 168 | self.model.eval() 169 | output = self.forward(mode='valid') 170 | loss_val = self.calc_loss(output, self.worker.labels_2) 171 | acc_val = self.f1_score(output, self.worker.labels_2) 172 | self.writer.add_scalar('valid/loss', loss_val, self.gcnt_valid) 173 | self.writer.add_scalar('valid/acc', acc_val[0], self.gcnt_valid) 174 | self.gcnt_valid += 1 175 | 176 | output_info = 'Epoch: {:04d}'.format(epoch+1),\ 177 | 'loss_train: {:.4f}'.format(loss_train.item()),\ 178 | 'acc_train: {:.4f}'.format(acc_train[0].item()),\ 179 | 'time: {:.4f}s'.format(time.time() - t) 180 | logging.info(output_info) 181 | return loss_train 182 | 183 | if not self.args.fastmode: 184 | # Evaluate validation set performance separately, 185 | # deactivates dropout during validation run. 186 | self.model.eval() 187 | output = self.forward(mode='valid') 188 | 189 | loss_val = self.calc_loss(output[self.worker.idx_val], self.worker.labels[self.worker.idx_val]) 190 | acc_val = self.f1_score(output[self.worker.idx_val], self.worker.labels[self.worker.idx_val]) 191 | self.writer.add_scalar('valid/loss', loss_val, self.gcnt_valid) 192 | self.writer.add_scalar('valid/acc', acc_val[0], self.gcnt_valid) 193 | self.gcnt_valid += 1 194 | 195 | if self.args.early: 196 | self.early_stopping(loss_val, self.model) 197 | 198 | output_info = 'Epoch: {:04d}'.format(epoch+1),\ 199 | 'loss_train: {:.4f}'.format(loss_train.item()),\ 200 | 'acc_train: {:.4f}'.format(acc_train[0].item()),\ 201 | 'loss_val: {:.4f}'.format(loss_val.item()),\ 202 | 'acc_val: {:.4f}'.format(acc_val[0].item()),\ 203 | 'time: {:.4f}s'.format(time.time() - t) 204 | 205 | logging.info(output_info) 206 | return loss_train 207 | 208 | 209 | def train(self): 210 | # Train model 211 | t_total = time.time() 212 | 213 | if self.args.display: 214 | epochs = trange(self.args.num_epochs, desc='Progress') 215 | else: 216 | epochs = range(self.args.num_epochs) 217 | 218 | # if self.mode in ( 'clusteradj-clean', 'clusteradj' ): 219 | # data = { 220 | # # 'adj': self.worker.adj 221 | # 'values': self.worker.adj.coalesce().values(), 222 | # 'indices': self.worker.adj.coalesce().indices(), 223 | # } 224 | # else: 225 | # data = {'adj': self.worker.adj} 226 | 227 | # torch.save(data, 'temp_adj.pt') 228 | 229 | for epoch in epochs: 230 | logging.info('[epoch {}]'.format(epoch)) 231 | output = self.train_one_epoch(epoch) 232 | if self.args.display: 233 | epochs.set_description(f"Train Loss: {output}") 234 | 235 | if self.args.early and self.early_stopping.early_stop: 236 | self.model = self.early_stopping.best_model 237 | logging.info(f'early stop at epoch {epoch}') 238 | break 239 | 240 | torch.save(self.model.state_dict(), os.path.join(self.model_path, 'model.pt')) 241 | 242 | print("Optimization Finished!") 243 | print("Total time elapsed: {:.4f}s".format(time.time() - t_total)) 244 | 245 | 246 | def f1_score(self, output, labels): 247 | if self.worker.multi_label == 1: 248 | preds = F.softmax(output, dim=1) 249 | preds = preds.max(1)[1].type_as(labels) 250 | return f1_score(labels.cpu(), preds.detach().cpu(), average='micro'), \ 251 | f1_score(labels.cpu(), preds.detach().cpu(), average='macro'), \ 252 | f1_score(labels.cpu(), preds.detach().cpu(), average='weighted') 253 | # unique, count = torch.unique(preds, return_counts=True) 254 | # correct = preds.eq(labels).double() 255 | # correct = correct.sum() 256 | # return correct / len(labels) 257 | 258 | else: # multi_label 259 | preds = torch.sigmoid(output) > 0.5 260 | return f1_score(labels.cpu(), preds.detach().cpu(), average='micro'), \ 261 | f1_score(labels.cpu(), preds.detach().cpu(), average='macro'), \ 262 | f1_score(labels.cpu(), preds.detach().cpu(), average='weighted') 263 | 264 | 265 | def rare_class_f1(self, output, labels): 266 | # identify the rare class 267 | ind = [torch.where(labels==0)[0], 268 | torch.where(labels==1)[0]] 269 | rare_class = int(len(ind[0]) > len(ind[1])) 270 | 271 | preds = F.softmax(output, dim=1).max(1) 272 | 273 | ap_score = average_precision_score(labels.cpu() if rare_class==1 else 1-labels.cpu(), preds[0].detach().cpu()) 274 | 275 | preds = preds[1].type_as(labels) 276 | 277 | TP = torch.sum(preds[ind[rare_class]] == rare_class).item() 278 | T = len(ind[rare_class]) 279 | P = torch.sum(preds == rare_class).item() 280 | 281 | if P == 0: return 0 282 | 283 | precision = TP / P 284 | recall = TP / T 285 | F1 = 2 * (precision * recall) / (precision + recall) 286 | return F1, precision, recall, ap_score 287 | 288 | 289 | def eval_degree(self, output): 290 | degrees = self.worker.calculate_degree() 291 | if self.dataset.startswith('twitch'): 292 | path = self.dataset.replace('/', '_') 293 | else: 294 | path = self.dataset 295 | torch.save(degrees, f'{path}_degrees.pt') 296 | 297 | # unique = np.unique(degrees) 298 | # acc_list = np.zeros_like(degrees) 299 | # total_list = np.zeros_like(degrees) 300 | 301 | # idx_list = list(range(self.worker.n_nodes_2)) if self.dataset.startswith( 'twitch' ) else self.worker.idx_test 302 | # labels = self.worker.labels_2 if self.dataset.startswith( 'twitch' ) else self.worker.labels 303 | 304 | 305 | # for i, value in enumerate(unique): 306 | # indice_cur = np.intersect1d(np.where(degrees == value)[0], idx_list, assume_unique=True) 307 | # if indice_cur.size == 0: continue 308 | # acc_cur = self.f1_score(output[indice_cur], labels[indice_cur]) 309 | 310 | # acc_list[i] = acc_cur[0] 311 | # total_list[i] = len(indice_cur) 312 | 313 | # degree_info = 'acc for different node degree: {}'.format(list(zip(unique, acc_list))) 314 | # # torch.save(list(zip(unique, acc_list)), 'degree_{}_{}.pt'.format(mode, self.subdir)) 315 | # # torch.save(list(zip(unique, total_list), 'total_num.pt')) 316 | # print(degree_info) 317 | # logging.info(degree_info) 318 | 319 | 320 | def eval_output(self, output, mode='clean', eval_degree=False): 321 | if self.args.attack: 322 | self.attacker = Attacker(args=self.args, model=self.model, worker=self.worker) 323 | self.attacker.prepare_test_data() 324 | 325 | t = time.time() 326 | if self.args.attack_mode == 'efficient': 327 | if self.args.sample_type == 'balanced-full': 328 | self.attacker.link_prediction_attack_efficient_balanced() 329 | else: 330 | self.attacker.link_prediction_attack_efficient() 331 | elif self.args.attack_mode == 'naive': 332 | self.attacker.link_prediction_attack() 333 | elif self.args.attack_mode in ( 'baseline', 'baseline-feat' ): 334 | if self.args.sample_type == 'balanced-full': 335 | self.attacker.baseline_attack_balanced() 336 | else: 337 | self.attacker.baseline_attack() 338 | # self.attacker.link_prediction_attack() 339 | print(f'attacks done using {time.time() - t} seconds!') 340 | 341 | if not self.worker.transfer: 342 | loss_valid = self.calc_loss(output[self.worker.idx_val], self.worker.labels[self.worker.idx_val]) 343 | acc_valid = self.f1_score(output[self.worker.idx_val], self.worker.labels[self.worker.idx_val]) 344 | 345 | # result on validation set 346 | output_info = f'''[{mode}] Validation set results: '''\ 347 | f'''loss = {loss_valid.item():.4f} '''\ 348 | f'''f1_score = {acc_valid[0].item():.4f}''' 349 | print(output_info) 350 | logging.info(output_info) 351 | 352 | if self.dataset.startswith('twitch-train'): return 353 | 354 | output_labels = output if self.worker.transfer \ 355 | else output[self.worker.idx_test] 356 | target_labels = self.worker.labels_2 if self.worker.transfer \ 357 | else self.worker.labels[self.worker.idx_test] 358 | 359 | loss_test = self.calc_loss(output_labels, target_labels) 360 | acc_test = self.f1_score(output_labels, target_labels) if not self.worker.transfer \ 361 | else self.rare_class_f1(output_labels, target_labels) 362 | 363 | # if 'model.pt' in self.model_path: 364 | # labels_path = self.model_path.replace('model.pt', f'labels.pt') 365 | # else: 366 | # labels_path = osp.join(self.model_path, 'labels.pt') 367 | # torch.save({'output': output_labels.cpu(), 368 | # 'target': target_labels.cpu(), 369 | # }, labels_path) 370 | # print(f'labels saved to {labels_path}!') 371 | # logging.info(f'labels saved to {labels_path}!') 372 | 373 | # a0 = self.worker.adj.cpu().to_dense().numpy()[633,:] 374 | # print(np.where(a0!=0)) 375 | 376 | if eval_degree: 377 | self.eval_degree(output) 378 | 379 | output_info = f'''[{mode}] Test set results: '''\ 380 | f'''loss = {loss_test.item():.4f} ''' 381 | output_info += f'rare_class_f1 = {acc_test[0]:.4f} prec = {acc_test[1]:.4f} reca = {acc_test[2]:.4f} ap_score = {acc_test[3]:.4f}' if self.worker.transfer else \ 382 | f'''f1_score [micro, macro, weighted] = {acc_test[0].item():.4f} {acc_test[1].item():.4f} {acc_test[2].item():.4f}''' 383 | print(output_info) 384 | logging.info(output_info) 385 | 386 | 387 | def test(self, eval_degree=False): 388 | self.model.eval() 389 | 390 | # if self.mode in ( 'vanilla', 'clusteradj', 'degcn' ): 391 | # if self.mode == 'clusteradj' and self.args.fnormalize: 392 | # logging.info(f'eventual coeff: {self.model.gc1.coeff.item()}, {self.model.gc2.coeff.item()}') 393 | 394 | # # test on noisy graph 395 | # output = self.forward(mode='test') 396 | # self.eval_output(output, 'noisy', eval_degree) 397 | 398 | # # test on clean graph 399 | # self.worker.update_adj() 400 | # output = self.forward(mode='test') 401 | # self.eval_output(output, 'clean', eval_degree) 402 | 403 | # else: 404 | # test on clean graph 405 | output = self.forward(mode='test') 406 | self.eval_output(output, 'clean', eval_degree=eval_degree) 407 | 408 | 409 | def __del__(self): 410 | if hasattr(self, 'writer'): 411 | self.writer.close() -------------------------------------------------------------------------------- /utils/load.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict, defaultdict 2 | import itertools 3 | import json 4 | import math 5 | import networkx as nx 6 | import numpy as np 7 | import os 8 | import pandas as pd 9 | import pickle as pkl 10 | from queue import Queue 11 | import random 12 | import scipy.sparse as sp 13 | from sklearn.model_selection import train_test_split 14 | from sklearn.preprocessing import StandardScaler 15 | import sys 16 | import torch 17 | from tqdm import tqdm 18 | 19 | def parse_index_file(filename): 20 | """Parse index file.""" 21 | index = [] 22 | for line in open(filename): 23 | index.append(int(line.strip())) 24 | return index 25 | 26 | 27 | def get_noise(noise_type, size, seed, eps=10, delta=1e-5, sensitivity=2): 28 | np.random.seed(seed) 29 | 30 | if noise_type == 'laplace': 31 | noise = np.random.laplace(0, sensitivity/eps, size) 32 | elif noise_type == 'gaussian': 33 | c = np.sqrt(2*np.log(1.25/delta)) 34 | stddev = c * sensitivity / eps 35 | noise = np.random.normal(0, stddev, size) 36 | else: 37 | raise NotImplementedError('noise {} not implemented!'.format(noise_type)) 38 | 39 | return noise 40 | 41 | 42 | def feature_reader(dataset="cora", scale='large', train_ratio=0.5, feature_size=-1): 43 | if dataset.startswith('twitch/') or dataset.startswith('wikipedia/') \ 44 | or dataset.startswith('facebook'): 45 | if dataset.startswith('twitch') or dataset.startswith('wikipedia/'): 46 | identifier = dataset[dataset.find('/')+1:] 47 | filename = './data/{}/musae_{}_features.json'.format(dataset, identifier) 48 | 49 | else: 50 | filename = './data/facebook/musae_facebook_features.json' 51 | with open(filename) as f: 52 | data = json.load(f) 53 | n_nodes = len(data) 54 | 55 | items = sorted(set(itertools.chain.from_iterable(data.values()))) 56 | n_features = 3170 if dataset.startswith('twitch') else max(items) + 1 57 | 58 | features = np.zeros((n_nodes, n_features)) 59 | for idx, elem in data.items(): 60 | features[int(idx), elem] = 1 61 | 62 | if dataset.startswith('twitch'): 63 | data = pd.read_csv('./data/{}/musae_{}_target.csv'.format(dataset, identifier)) 64 | mature = list(map(int, data['mature'].values)) 65 | new_id = list(map(int, data['new_id'].values)) 66 | idx_map = {elem : i for i, elem in enumerate(new_id)} 67 | labels = [mature[idx_map[idx]] for idx in range(n_nodes)] 68 | elif dataset.startswith('wikipedia/'): 69 | data = pd.read_csv('./data/{}/musae_{}_target.csv'.format(dataset, identifier)) 70 | labels = list(map(int, data['target'].values)) 71 | else: 72 | data = pd.read_csv('./data/facebook/musae_facebook_target.csv') 73 | labels = data['page_type'].values.tolist() 74 | all_labels = sorted(set(labels)) 75 | label_dict = {label: i for i, label in enumerate(all_labels)} 76 | labels = list(map(label_dict.get, labels)) 77 | 78 | labels = torch.LongTensor(labels) 79 | 80 | unique, count = torch.unique(labels, return_counts=True) 81 | 82 | # len_all = len(labels) 83 | # len_valid = len_test = int(0.1 * len_all) 84 | # len_train = int(train_ratio * len_all) 85 | # idx_all = list(range(len_all)) 86 | 87 | # random.seed(42) 88 | # random.shuffle(idx_all) 89 | # idx_test = idx_all[:len_test] 90 | # idx_valid = idx_all[len_test:len_test+len_valid] 91 | # idx_train = idx_all[len_test+len_valid:len_test+len_valid+len_train] 92 | 93 | return features, labels 94 | 95 | elif dataset.startswith('deezer'): 96 | identifier = dataset[dataset.find('/')+1:] 97 | filename = './data/deezer/{}_genres.json'.format(identifier) 98 | with open(filename) as f: 99 | data = json.load(f) 100 | n_nodes = len(data) 101 | 102 | items = sorted(set(itertools.chain.from_iterable(data.values()))) 103 | item_mapping = {item : i for i, item in enumerate(items)} 104 | n_classes = 84 105 | 106 | labels = np.zeros((n_nodes, n_classes)) 107 | for idx, elem in data.items(): 108 | elem_int = np.asarray(list(map(item_mapping.get, elem))) 109 | labels[int(idx), elem_int] = 1 110 | 111 | labels = torch.LongTensor(labels) 112 | return labels 113 | 114 | 115 | elif dataset in ('reddit', 'flickr', 'ppi', 'ppi-large'): 116 | # role 117 | role = json.load(open(f'./data/{dataset}/role.json')) 118 | idx_train = np.asarray(sorted(role['tr'])) 119 | idx_valid = np.asarray(sorted(role['va'])) 120 | idx_test = np.asarray(sorted(role['te'])) 121 | 122 | # features 123 | features = np.load(f'./data/{dataset}/feats.npy') 124 | features_train = features[idx_train] 125 | 126 | scaler = StandardScaler() 127 | scaler.fit(features_train) 128 | features = scaler.transform(features) 129 | features = torch.FloatTensor(features) 130 | features_train = features[idx_train] 131 | 132 | n_nodes = len(features) 133 | 134 | # label 135 | class_map = json.load(open(f'./data/{dataset}/class_map.json')) 136 | 137 | multi_label = 1 138 | for key, value in class_map.items(): 139 | if type(value) == list: 140 | multi_label = len(value) # single-label vs multi-label 141 | break 142 | 143 | labels = np.zeros((n_nodes, multi_label)) 144 | for key, value in class_map.items(): 145 | labels[int(key)] = value 146 | labels = torch.LongTensor(labels) 147 | 148 | return features, features_train, labels, idx_train, idx_valid, idx_test 149 | 150 | else: 151 | names = ['x', 'y', 'tx', 'ty', 'allx', 'ally'] 152 | objects = [] 153 | for i in range(len(names)): 154 | with open("data/ind.{}.{}".format(dataset, names[i]), 'rb') as f: 155 | if sys.version_info > (3, 0): 156 | objects.append(pkl.load(f, encoding='latin1')) 157 | else: 158 | objects.append(pkl.load(f)) 159 | x, y, tx, ty, allx, ally = tuple(objects) 160 | 161 | test_idx_reorder = parse_index_file("data/ind.{}.test.index".format(dataset)) 162 | test_idx_range = np.sort(test_idx_reorder) 163 | 164 | if dataset == 'citeseer': 165 | # Fix citeseer dataset (there are some isolated nodes in the graph) 166 | # Find isolated nodes, add them as zero-vecs into the right position 167 | test_idx_range_full = range(min(test_idx_reorder), max(test_idx_reorder)+1) 168 | zero_ind = list(set(test_idx_range_full) - set(test_idx_reorder)) 169 | tx_extended = sp.lil_matrix((len(test_idx_range_full), x.shape[1])) 170 | tx_extended[test_idx_range-min(test_idx_range), :] = tx 171 | tx = tx_extended 172 | ty_extended = np.zeros((len(test_idx_range_full), y.shape[1])) 173 | ty_extended[test_idx_range-min(test_idx_range), :] = ty 174 | ty_extended[zero_ind-min(test_idx_range), np.random.randint(0, y.shape[1], len(zero_ind))] = 1 175 | ty = ty_extended 176 | 177 | features = sp.vstack((allx, tx)).tolil() 178 | features[test_idx_reorder, :] = features[test_idx_range, :] 179 | 180 | labels = np.vstack((ally, ty)) 181 | labels[test_idx_reorder, :] = labels[test_idx_range, :] 182 | 183 | idx_test = test_idx_range.tolist() 184 | idx_train = list(range(len(y))) 185 | idx_valid = list(range(len(y), len(y)+500)) 186 | 187 | print('#idx_train', len(idx_train)) 188 | print('#idx_valid', len(idx_valid)) 189 | print('#idx_test', len(idx_test)) 190 | 191 | features = preprocess_features(features) 192 | features = torch.FloatTensor(np.array(features.todense())) 193 | features_train = torch.clone(features) 194 | labels = torch.LongTensor(np.where(labels)[1]).unsqueeze(-1) 195 | 196 | return features, features_train, labels, idx_train, idx_valid, idx_test 197 | 198 | 199 | def preprocess_features(features): 200 | """Row-normalize feature matrix and convert to tuple representation""" 201 | rowsum = np.array(features.sum(1)) 202 | r_inv = np.power(rowsum, -1).flatten() 203 | r_inv[np.isinf(r_inv)] = 0. 204 | r_mat_inv = sp.diags(r_inv) 205 | features = r_mat_inv.dot(features) 206 | return features 207 | 208 | 209 | def load_csr_mat_from_npz(filename): 210 | adj = np.load(filename) 211 | data = adj['data'] 212 | indices = adj['indices'] 213 | indptr = adj['indptr'] 214 | N, M = adj['shape'] 215 | 216 | return sp.csr_matrix((data, indices, indptr), (N, M)) 217 | 218 | 219 | def construct_balanced_edge_sets(dataset, sample_type, adj, n_samples): 220 | indices = adj.indices 221 | indptr = adj.indptr 222 | n_nodes = adj.shape[0] 223 | 224 | dic = defaultdict(list) 225 | for u in range(n_nodes): 226 | begg, endd = indptr[u: u+2] 227 | dic[u] = indices[begg: endd] 228 | 229 | edge_set = [] 230 | nonedge_set = [] 231 | 232 | # construct edge set 233 | for u in range(n_nodes): 234 | for v in dic[u]: 235 | if v > u: 236 | edge_set.append((u, v)) 237 | n_samples = len(edge_set) 238 | 239 | # random sample equal number of pairs to compose a nonoedge set 240 | while 1: 241 | u = np.random.choice(n_nodes) 242 | v = np.random.choice(n_nodes) 243 | if v not in dic[u] and u not in dic[v]: 244 | nonedge_set.append((u, v)) 245 | if len(nonedge_set) == n_samples: break 246 | 247 | print(f'sampling done! len(edge_set) = {len(edge_set)}, len(nonedge_set) = {len(nonedge_set)}') 248 | 249 | return (edge_set, nonedge_set), list(range(n_nodes)) 250 | 251 | 252 | def construct_edge_sets(dataset, sample_type, adj, n_samples): 253 | indices = adj.indices 254 | indptr = adj.indptr 255 | n_nodes = adj.shape[0] 256 | 257 | # construct edge set 258 | edge_set = [] 259 | while 1: 260 | u = np.random.choice(n_nodes) 261 | begg, endd = indptr[u: u+2] 262 | v_range = indices[begg: endd] 263 | if len(v_range): 264 | v = np.random.choice(v_range) 265 | edge_set.append((u, v)) 266 | if len(edge_set) == n_samples: break 267 | 268 | # construct non-edge set 269 | nonedge_set = [] 270 | 271 | # # randomly select non-neighbors 272 | # for _ in tqdm(range(n_samples)): 273 | # u = np.random.choice(n_nodes) 274 | # begg, endd = indptr[u: u+2] 275 | # v_range = indices[begg: endd] 276 | # while 1: 277 | # v = np.random.choice(n_nodes) 278 | # if v not in v_range: 279 | # nonedge_set.append((u, v)) 280 | # break 281 | 282 | # randomly select nodes with two-hop distance 283 | while 1: 284 | u = np.random.choice(n_nodes) 285 | begg, endd = indptr[u: u+2] 286 | v_range = indices[begg: endd] 287 | 288 | vv_range_all = [] 289 | for v in v_range: 290 | begg, endd = indptr[v: v+2] 291 | vv_range = set(indices[begg: endd]) - set(v_range) 292 | if vv_range: 293 | vv_range_all.append(vv_range) 294 | 295 | if vv_range_all: 296 | vv_range = np.random.choice(vv_range_all) 297 | vv = np.random.choice(list(vv_range)) 298 | nonedge_set.append((u, vv)) 299 | if len(nonedge_set) == n_samples: break 300 | 301 | return edge_set, nonedge_set 302 | 303 | 304 | def _get_edge_sets_among_nodes(indices, indptr, nodes): 305 | # construct edge list for each node 306 | dic = defaultdict(list) 307 | 308 | for u in nodes: 309 | begg, endd = indptr[u: u+2] 310 | dic[u] = indices[begg: endd] 311 | 312 | n_nodes = len(nodes) 313 | edge_set = [] 314 | nonedge_set = [] 315 | for i in range(n_nodes): 316 | for j in range(i+1, n_nodes): 317 | u, v = nodes[i], nodes[j] 318 | if v in dic[u]: 319 | edge_set.append((u, v)) 320 | else: 321 | nonedge_set.append((u, v)) 322 | 323 | print('#nodes =', len(nodes)) 324 | print('#edges_set =', len(edge_set)) 325 | print('#nonedge_set =', len(nonedge_set)) 326 | return edge_set, nonedge_set 327 | 328 | 329 | def _get_degree(n_nodes, indptr): 330 | deg = np.zeros(n_nodes, dtype=np.int32) 331 | for i in range(n_nodes): 332 | deg[i] = indptr[i+1] - indptr[i] 333 | 334 | ind = np.argsort(deg) 335 | return deg, ind 336 | 337 | 338 | def construct_edge_sets_from_random_subgraph(dataset, sample_type, adj, n_samples): 339 | indices = adj.indices 340 | indptr = adj.indptr 341 | n_nodes = adj.shape[0] 342 | 343 | if sample_type == 'unbalanced': 344 | indice_all = range(n_nodes) 345 | 346 | else: 347 | deg, ind = _get_degree(n_nodes, indptr) 348 | # unique, count = np.unique(deg, return_counts=True) 349 | # l = list(zip(unique, count)) 350 | # print(l) 351 | # print(len(np.where(deg <= 10)[0])) 352 | # print(len(np.where(deg >= 10)[0])) 353 | 354 | if dataset.startswith('twitch'): 355 | lo = 5 if 'PTBR' not in dataset else 10 356 | hi = 10 357 | elif dataset in ( 'flickr', 'ppi' ) or dataset.startswith('deezer'): 358 | lo = 15 359 | hi = 30 360 | elif dataset in ( 'cora' ): 361 | lo = 3 362 | hi = 4 363 | elif dataset in ( 'citeseer' ): 364 | lo = 3 365 | hi = 3 366 | elif dataset in ( 'pubmed' ): 367 | lo = 10 368 | hi = 10 369 | else: 370 | raise NotImplementedError(f'lo and hi for dataset = {dataset} not set!') 371 | 372 | if sample_type == 'unbalanced-lo': 373 | indice_all = np.where(deg <= lo)[0] 374 | else: 375 | indice_all = np.where(deg >= hi)[0] 376 | 377 | print('#indice =', len(indice_all)) 378 | 379 | nodes = np.random.choice(indice_all, n_samples, replace=False) # choose from low degree nodes 380 | 381 | return _get_edge_sets_among_nodes(indices, indptr, nodes), nodes 382 | 383 | 384 | 385 | def construct_edge_sets_through_bfs(sample_type, adj, n_hop): 386 | indices = adj.indices 387 | indptr = adj.indptr 388 | n_nodes = adj.shape[0] 389 | 390 | deg, ind = _get_degree(n_nodes, indptr) 391 | 392 | sorted_deg = deg[ind] 393 | 394 | unique, count = np.unique(deg, return_counts=True) 395 | l = list(zip(unique, count)) 396 | print(l) 397 | 398 | deg_lo = sorted_deg[np.where(sorted_deg)[0][0]] 399 | deg_hi = sorted_deg[n_nodes - 1] 400 | 401 | print(deg_lo, deg_hi) 402 | 403 | # indice_lo = np.where(deg == deg_lo)[0] 404 | # indice_hi = np.where(deg == deg_hi)[0] 405 | 406 | indice_lo = np.where(deg <= 5)[0] 407 | # indice_hi = np.where(deg >= 100)[0] 408 | # print(len(indice_lo), len(indice_hi)) 409 | 410 | # randomly sample a starting node 411 | # may replace with choosing the node with the highest/lowest degree later 412 | # src = np.random.choice(n_nodes) 413 | src = np.random.choice(indice_lo) 414 | 415 | que = Queue() 416 | vis = np.zeros(n_nodes, dtype=np.int8) 417 | 418 | que.put((src, 0)) 419 | vis[src] = 1 420 | 421 | while 1: 422 | if que.empty(): break 423 | 424 | head = que.get() 425 | u, dep = head 426 | 427 | if dep == n_hop: continue 428 | 429 | begg, endd = indptr[u: u+2] 430 | v_range = indices[begg: endd] 431 | for v in v_range: 432 | if vis[v]: continue 433 | que.put((v, dep+1)) 434 | vis[v] = 1 435 | 436 | nodes = np.where(vis)[0] 437 | return _get_edge_sets_among_nodes(indices, indptr, nodes) 438 | 439 | 440 | def load_edges_from_npz(filename): 441 | adj = np.load(filename) 442 | indices = adj['indices'] 443 | indptr = adj['indptr'] 444 | n_nodes = adj['shape'][0] 445 | edges = [] 446 | for i in tqdm(range(n_nodes)): 447 | begg = indptr[i] 448 | endd = indptr[i+1] 449 | edges += [(i, elem) for elem in indices[begg:endd] if elem > i] 450 | return np.asarray(edges) 451 | 452 | def graph_reader(args, dataset="cora", n_nodes=-1): 453 | if dataset.startswith('twitch') or dataset.startswith('wikipedia/'): 454 | identifier = dataset[dataset.find('/')+1:] 455 | data = pd.read_csv('./data/{}/musae_{}_edges.csv'.format(dataset, identifier)) 456 | edges = data.values 457 | adj = sp.csr_matrix((np.ones(edges.shape[0]), (edges[:, 0], edges[:, 1])), 458 | shape=(n_nodes, n_nodes), 459 | dtype=np.float32) 460 | return adj + adj.T 461 | 462 | elif dataset.startswith('deezer'): 463 | identifier = dataset[dataset.find('/')+1:] 464 | data = pd.read_csv('./data/deezer/{}_edges.csv'.format(identifier)) 465 | edges = data.values 466 | adj = sp.csr_matrix((np.ones(edges.shape[0]), (edges[:, 0], edges[:, 1])), 467 | shape=(n_nodes, n_nodes), 468 | dtype=np.float32) 469 | adj += adj.T 470 | 471 | degree = np.zeros(n_nodes, dtype=np.int64) 472 | for u, v in edges: 473 | degree[u] += 1 474 | degree[v] += 1 475 | 476 | features = np.zeros((n_nodes, 500)) 477 | features[np.arange(n_nodes), degree] = 1 478 | 479 | return adj, torch.FloatTensor(features) 480 | 481 | elif dataset == 'facebook': 482 | data = pd.read_csv('./data/facebook/musae_facebook_edges.csv') 483 | return data.values 484 | 485 | elif dataset in ('reddit', 'flickr', 'ppi', 'ppi-large'): 486 | adj_full = load_csr_mat_from_npz(f'./data/{dataset}/adj_full.npz') 487 | print(f'loading {dataset} graph done!') 488 | return adj_full 489 | 490 | else: 491 | with open("data/ind.{}.{}".format(dataset, 'graph'), 'rb') as f: 492 | if sys.version_info > (3, 0): 493 | graph = pkl.load(f, encoding='latin1') 494 | else: 495 | graph = pkl.load(f) 496 | 497 | adj = nx.adjacency_matrix(nx.from_dict_of_lists(graph)) 498 | adj_normalized = normalize_adj(adj + sp.eye(adj.shape[0])) 499 | return sp.csr_matrix(adj_normalized) 500 | 501 | 502 | def sparse_to_tuple(sparse_mx): 503 | """Convert sparse matrix to tuple representation.""" 504 | def to_tuple(mx): 505 | if not sp.isspmatrix_coo(mx): 506 | mx = mx.tocoo() 507 | coords = np.vstack((mx.row, mx.col)).transpose() 508 | values = mx.data 509 | shape = mx.shape 510 | return coords, values, shape 511 | 512 | if isinstance(sparse_mx, list): 513 | for i in range(len(sparse_mx)): 514 | sparse_mx[i] = to_tuple(sparse_mx[i]) 515 | else: 516 | sparse_mx = to_tuple(sparse_mx) 517 | 518 | return sparse_mx 519 | 520 | 521 | def normalize_adj(adj): 522 | """Symmetrically normalize adjacency matrix.""" 523 | adj = sp.coo_matrix(adj) 524 | rowsum = np.array(adj.sum(1)) 525 | d_inv_sqrt = np.power(rowsum, -0.5).flatten() 526 | d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0. 527 | d_mat_inv_sqrt = sp.diags(d_inv_sqrt) 528 | return adj.dot(d_mat_inv_sqrt).transpose().dot(d_mat_inv_sqrt).tocoo() 529 | 530 | def normalize(mx): 531 | """Row-normalize sparse matrix""" 532 | rowsum = np.array(mx.sum(1))*1.0 533 | r_inv = np.power(rowsum, -1).flatten() 534 | r_inv[np.isinf(r_inv)] = 0. 535 | r_mat_inv = sp.diags(r_inv) 536 | mx = r_mat_inv.dot(mx) 537 | 538 | return mx 539 | 540 | def t_normalize(tensor): 541 | """Row-normalize sparse matrix""" 542 | 543 | rowsum = tensor.sum(1) 544 | r_inv = rowsum ** -1 545 | r_inv[torch.isinf(r_inv)] = 0 546 | r_mat_inv = torch.diag(r_inv) 547 | ret = torch.mm(r_mat_inv, tensor) 548 | 549 | return ret 550 | 551 | 552 | def sparse_mx_to_torch_sparse_tensor(sparse_mx): 553 | """Convert a scipy sparse matrix to a torch sparse tensor.""" 554 | sparse_mx = sparse_mx.tocoo().astype(np.float32) 555 | indices = torch.from_numpy( 556 | np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64)) 557 | values = torch.from_numpy(sparse_mx.data) 558 | shape = torch.Size(sparse_mx.shape) 559 | return torch.sparse.FloatTensor(indices, values, shape) 560 | 561 | 562 | def aug_normalized_adjacency(adj): 563 | adj = adj + sp.eye(adj.shape[0]) 564 | adj = sp.coo_matrix(adj) 565 | row_sum = np.array(adj.sum(1)) 566 | d_inv_sqrt = np.power(row_sum, -0.5).flatten() 567 | d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0. 568 | d_mat_inv_sqrt = sp.diags(d_inv_sqrt) 569 | return d_mat_inv_sqrt.dot(adj).dot(d_mat_inv_sqrt).tocoo() 570 | 571 | 572 | def gcn(adj): 573 | adj = sp.coo_matrix(adj) 574 | row_sum = np.array(adj.sum(1)) 575 | d_inv_sqrt = np.power(row_sum, -0.5).flatten() 576 | d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0. 577 | d_mat_inv_sqrt = sp.diags(d_inv_sqrt) 578 | return (sp.eye(adj.shape[0]) + d_mat_inv_sqrt.dot(adj).dot(d_mat_inv_sqrt)).tocoo() 579 | 580 | 581 | def bingge_norm_adjacency(adj): 582 | adj = adj + sp.eye(adj.shape[0]) 583 | adj = sp.coo_matrix(adj) 584 | row_sum = np.array(adj.sum(1)) 585 | d_inv_sqrt = np.power(row_sum, -0.5).flatten() 586 | d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0. 587 | d_mat_inv_sqrt = sp.diags(d_inv_sqrt) 588 | return (d_mat_inv_sqrt.dot(adj).dot(d_mat_inv_sqrt) + sp.eye(adj.shape[0])).tocoo() 589 | 590 | 591 | def normalized_adjacency(adj): 592 | adj = sp.coo_matrix(adj) 593 | row_sum = np.array(adj.sum(1)) 594 | d_inv_sqrt = np.power(row_sum, -0.5).flatten() 595 | d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0. 596 | d_mat_inv_sqrt = sp.diags(d_inv_sqrt) 597 | return (d_mat_inv_sqrt.dot(adj).dot(d_mat_inv_sqrt)).tocoo() 598 | 599 | 600 | def random_walk(adj): 601 | adj = sp.coo_matrix(adj) 602 | row_sum = np.array(adj.sum(1)) 603 | d_inv = np.power(row_sum, -1.0).flatten() 604 | d_mat = sp.diags(d_inv) 605 | return d_mat.dot(adj).tocoo() 606 | 607 | 608 | def aug_random_walk(adj): 609 | adj = adj + sp.eye(adj.shape[0]) 610 | adj = sp.coo_matrix(adj) 611 | row_sum = np.array(adj.sum(1)) 612 | d_inv = np.power(row_sum, -1.0).flatten() 613 | d_mat = sp.diags(d_inv) 614 | return (d_mat.dot(adj)).tocoo() 615 | 616 | 617 | def fetch_normalization(type): 618 | switcher = { 619 | 'FirstOrderGCN': gcn, # A' = I + D^-1/2 * A * D^-1/2 620 | 'BingGeNormAdj': bingge_norm_adjacency, # A' = I + (D + I)^-1/2 * (A + I) * (D + I)^-1/2 621 | 'NormAdj': normalized_adjacency, # D^-1/2 * A * D^-1/2 622 | 'AugRWalk': aug_random_walk, # A' = (D + I)^-1*(A + I) 623 | 'RWalk': random_walk, # A' = D^-1*A 624 | 'AugNormAdj': aug_normalized_adjacency, # A' = (D + I)^-1/2 * ( A + I ) * (D + I)^-1/2 625 | } 626 | func = switcher.get(type, lambda: "Invalid normalization technique.") 627 | return func 628 | -------------------------------------------------------------------------------- /worker.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import math 3 | import numpy as np 4 | import random 5 | import scipy.sparse as sp 6 | from sklearn.preprocessing import StandardScaler 7 | import time 8 | from tqdm import tqdm 9 | 10 | import torch 11 | from torch.nn.parameter import Parameter 12 | 13 | from utils import feature_reader, get_noise, graph_reader, \ 14 | normalize, sparse_mx_to_torch_sparse_tensor, t_normalize, fetch_normalization, \ 15 | compressive_sensing 16 | 17 | class Worker(): 18 | def __init__(self, args, dataset='', mode=''): 19 | self.args = args 20 | self.dataset = dataset 21 | self.mode = mode 22 | self.transfer = (not self.dataset.startswith('twitch-train') and self.dataset.startswith('twitch')) or \ 23 | self.dataset.startswith('wikipedia') or \ 24 | self.dataset.startswith('deezer') 25 | self.load_data() 26 | 27 | 28 | def build_cluster_adj(self, clean=False, fnormalize=True): 29 | adj = np.zeros((self.n_nodes, self.n_clusters), dtype=np.float64) 30 | 31 | for dst, src in self.edges.tolist(): 32 | adj[src, self.fake_labels[dst]] += 1 33 | adj[dst, self.fake_labels[src]] += 1 34 | 35 | if self.mode in ( 'clusteradj' ) and not clean: 36 | adj += get_noise( 37 | self.args.noise_type, size=self.n_nodes*self.n_clusters, seed=self.args.noise_seed, 38 | eps=self.args.epsilon, delta=self.args.delta).reshape(self.n_nodes, self.n_clusters) 39 | 40 | adj = np.clip(adj, a_min=0, a_max=None) 41 | 42 | if fnormalize: 43 | adj = normalize(adj) 44 | else: 45 | adj = normalize(np.dot(adj, self.prj.cpu().numpy()) + np.eye(self.n_nodes)) 46 | 47 | return torch.FloatTensor(adj) 48 | 49 | # adj = sp.coo_matrix(adj) 50 | if fnormalize: 51 | adj = normalize(adj) 52 | elif self.mode != 'degree_nlp': 53 | adj = normalize(np.dot(adj, self.prj.cpu().numpy()) + np.eye(self.n_nodes)) 54 | 55 | # adj = sparse_mx_to_torch_sparse_tensor(adj) 56 | 57 | # if not fnormalize and not self.mode == 'degree_mlp': 58 | # adj = t_normalize(torch.mm(adj, self.prj) + torch.eye(self.n_nodes)) 59 | 60 | # return adj 61 | return torch.FloatTensor(adj) 62 | 63 | 64 | def build_cluster_prj(self): 65 | unique, count = np.unique(self.fake_labels.cpu(), return_counts=True) 66 | 67 | count_dict = {k:v for k, v in zip(unique, count)} 68 | 69 | prj = np.zeros((self.n_clusters, self.n_nodes)) 70 | 71 | for i, label in enumerate(self.fake_labels): 72 | label = label.item() 73 | prj[label, i] = 1 / count_dict[label] 74 | 75 | # prj = sp.coo_matrix(prj) 76 | # return sparse_mx_to_torch_sparse_tensor(prj) 77 | return torch.FloatTensor(prj) 78 | 79 | 80 | def build_degree_vec(self): 81 | vec = np.zeros((self.n_nodes, 1)) 82 | for u, v in self.edges: 83 | vec[u][0] += 1 84 | vec[v][0] += 1 85 | 86 | return torch.FloatTensor(vec) 87 | 88 | 89 | def break_down(self): 90 | unique, count = torch.unique(self.fake_labels, return_counts=True) 91 | indice = [(self.fake_labels == x).nonzero().squeeze() for x in unique] 92 | 93 | if self.args.break_method == 'kmeans': 94 | 95 | min_size = int(torch.min(count).item() * self.args.break_ratio + 0.5) 96 | if min_size in (0, 1): 97 | print('skip the step of generating broken labels!') 98 | 99 | else: 100 | 101 | # print('min_size', min_size) 102 | split = [self.labeler.get_equal_size(val, min_size) for val in count] 103 | 104 | # print('split', split) 105 | 106 | t0 = time.time() 107 | start = 0 108 | 109 | for idx, (n_clusters, quota) in zip(indice, split): 110 | self.fake_labels[idx] = self.labeler.get_cluster_labels( 111 | self.features[idx].cpu(), n_clusters, quota=quota, start=start, same_size=True).cuda() 112 | start += n_clusters 113 | 114 | self.n_clusters = start 115 | 116 | print('generating broken down fake labels done using {} secs!'.format(time.time()-t0)) 117 | # torch.save(self.fake_labels, 'flabels_{}.pt'.format(self.n_clusters)) 118 | 119 | else: 120 | print(f'break_method = {self.args.break_method} not implemented!') 121 | exit(-1) 122 | 123 | logging.info(f'num_cluster = {self.n_clusters}') 124 | 125 | 126 | def build_adj_vanilla(self): 127 | adj = np.zeros((self.n_nodes, self.n_nodes), dtype=np.float64) 128 | for dst, src in self.edges: 129 | adj[src][dst] = adj[dst][src] = 1 130 | 131 | t0 = time.time() 132 | # adj += get_noise(self.args.noise_type, size=self.n_nodes*self.n_nodes, seed=self.args.noise_seed, 133 | # eps=self.args.epsilon, delta=self.args.delta).reshape(self.n_nodes, self.n_nodes) 134 | # adj = np.clip(adj, a_min=0, a_max=None) 135 | 136 | s = 2 / (np.exp(self.args.epsilon) + 1) 137 | print(f's={s:.4f}') 138 | bernoulli = np.random.binomial(1, s, self.n_nodes * (self.n_nodes-1)) 139 | entry = np.where(bernoulli) 140 | for u, v in zip(*entry): 141 | if u >= v: continue 142 | x = np.random.binomial(1, 0.5) 143 | adj[u][v] = adj[v][u] = x 144 | 145 | print('adding noise done using {} secs!'.format(time.time() - t0)) 146 | return adj 147 | 148 | 149 | # def calc_index(self, k): 150 | # lo = (-1 + math.sqrt(8 * k + 9)) / 2 151 | # # hi = (1 + math.sqrt(8 * k + 1)) / 2 152 | # if lo - int(lo) < 1e-6: 153 | # i = int(lo) 154 | # else: 155 | # i = int(lo) + 1 156 | # j = k - (i-1) * i // 2 157 | # assert(j <= i-1) 158 | # assert(i < self.n_nodes and j < self.n_nodes), f'{k}, {i}, {j} invalid!' 159 | 160 | # return i, j 161 | 162 | 163 | def calc_index(self, N, k): 164 | lo = (2*N-3 - math.sqrt((2*N-3)**2 - 4*(2*k-2*N+4))) / 2 165 | # hi = (1 + math.sqrt(8 * k + 1)) / 2 166 | if lo - int(lo) < 1e-6: 167 | i = int(lo) 168 | else: 169 | i = int(lo) + 1 170 | 171 | j = k - (2*N-1-i)*i//2 + i + 1 172 | 173 | assert(j > i and j < N and i < N) 174 | 175 | return i, j 176 | 177 | 178 | def construct_sparse_mat(self, indice, N): 179 | cur_row = -1 180 | new_indices = [] 181 | new_indptr = [] 182 | 183 | for i, j in tqdm(indice): 184 | if i >= j: 185 | continue 186 | 187 | while i > cur_row: 188 | new_indptr.append(len(new_indices)) 189 | cur_row += 1 190 | 191 | new_indices.append(j) 192 | 193 | while N > cur_row: 194 | new_indptr.append(len(new_indices)) 195 | cur_row += 1 196 | 197 | data = np.ones(len(new_indices), dtype=np.int64) 198 | indices = np.asarray(new_indices, dtype=np.int64) 199 | indptr = np.asarray(new_indptr, dtype=np.int64) 200 | 201 | mat = sp.csr_matrix((data, indices, indptr), (N, N)) 202 | 203 | return mat + mat.T 204 | 205 | 206 | def perturb_adj(self, adj, perturb_type): 207 | if perturb_type == 'discrete': 208 | return self.perturb_adj_discrete(adj) 209 | else: 210 | return self.perturb_adj_continuous(adj) 211 | 212 | 213 | def perturb_adj_discrete(self, adj): 214 | s = 2 / (np.exp(self.args.epsilon) + 1) 215 | print(f's = {s:.4f}') 216 | N = adj.shape[0] 217 | 218 | t = time.time() 219 | # bernoulli = np.random.binomial(1, s, N * (N-1) // 2) 220 | # entry = np.where(bernoulli)[0] 221 | 222 | np.random.seed(self.args.noise_seed) 223 | bernoulli = np.random.binomial(1, s, (N, N)) 224 | print(f'generating perturbing vector done using {time.time() - t} secs!') 225 | logging.info(f'generating perturbing vector done using {time.time() - t} secs!') 226 | entry = np.asarray(list(zip(*np.where(bernoulli)))) 227 | 228 | dig_1 = np.random.binomial(1, 1/2, len(entry)) 229 | indice_1 = entry[np.where(dig_1 == 1)[0]] 230 | indice_0 = entry[np.where(dig_1 == 0)[0]] 231 | 232 | add_mat = self.construct_sparse_mat(indice_1, N) 233 | minus_mat = self.construct_sparse_mat(indice_0, N) 234 | 235 | # # add_mat = np.zeros_like(adj.A) 236 | # add_row = [] 237 | # add_col = [] 238 | # # minus_mat = np.zeros_like(adj.A) 239 | # minus_row = [] 240 | # minus_col = [] 241 | 242 | # for i in tqdm(range(N)): 243 | # for j in range(i+1, N): 244 | # x = np.random.binomial(1, s, 1) 245 | # if x == 1: 246 | # x = np.random.binomial(1, 1/2, 1) 247 | # if x == 1: 248 | # # add_mat[i, j] = x 249 | # add_row.append(i) 250 | # add_col.append(j) 251 | # else: 252 | # # minus_mat[i, j] = x 253 | # minus_row.append(i) 254 | # minus_col.append(j) 255 | # add_data = np.ones(len(add_row), dtype=np.int32) 256 | # minus_data = np.ones(len(minus_row), dtype=np.int32) 257 | # # add_mat = sp.csr_matrix(add_mat) 258 | # # minus_mat = sp.csr_matrix(minus_mat) 259 | # add_mat = sp.csr_matrix(add_data, (add_row, add_col)) 260 | # minus_mat = sp.csr_matrix(minus_data, (minus_row, minus_col)) 261 | # add_mat += add_mat.T 262 | # minus_mat += minus_mat.T 263 | 264 | adj_noisy = adj + add_mat - minus_mat 265 | 266 | adj_noisy.data[np.where(adj_noisy.data == -1)[0]] = 0 267 | adj_noisy.data[np.where(adj_noisy.data == 2)[0]] = 1 268 | 269 | # adj = sp.lil_matrix(adj) 270 | # for k in tqdm(indice_1): 271 | # i, j = self.calc_index(k) 272 | # adj[i, j] = adj[j, i] = 1 273 | 274 | # for k in tqdm(indice_0): 275 | # i, j = self.calc_index(k) 276 | # adj[i, j] = adj[j, i] = 0 277 | 278 | return adj_noisy 279 | 280 | 281 | def perturb_adj_continuous(self, adj): 282 | self.n_nodes = adj.shape[0] 283 | n_edges = len(adj.data) // 2 284 | 285 | N = self.n_nodes 286 | t = time.time() 287 | 288 | A = sp.tril(adj, k=-1) 289 | print('getting the lower triangle of adj matrix done!') 290 | 291 | eps_1 = self.args.epsilon * 0.01 292 | eps_2 = self.args.epsilon - eps_1 293 | noise = get_noise(noise_type=self.args.noise_type, size=(N, N), seed=self.args.noise_seed, 294 | eps=eps_2, delta=self.args.delta, sensitivity=1) 295 | noise *= np.tri(*noise.shape, k=-1, dtype=np.bool) 296 | print(f'generating noise done using {time.time() - t} secs!') 297 | 298 | A += noise 299 | print(f'adding noise to the adj matrix done!') 300 | 301 | t = time.time() 302 | n_edges_keep = n_edges + int( 303 | get_noise(noise_type=self.args.noise_type, size=1, seed=self.args.noise_seed, 304 | eps=eps_1, delta=self.args.delta, sensitivity=1)[0]) 305 | print(f'edge number from {n_edges} to {n_edges_keep}') 306 | 307 | t = time.time() 308 | a_r = A.A.ravel() 309 | 310 | n_splits = 50 311 | len_h = len(a_r) // n_splits 312 | ind_list = [] 313 | for i in tqdm(range(n_splits - 1)): 314 | ind = np.argpartition(a_r[len_h*i:len_h*(i+1)], -n_edges_keep)[-n_edges_keep:] 315 | ind_list.append(ind + len_h * i) 316 | 317 | ind = np.argpartition(a_r[len_h*(n_splits-1):], -n_edges_keep)[-n_edges_keep:] 318 | ind_list.append(ind + len_h * (n_splits - 1)) 319 | 320 | ind_subset = np.hstack(ind_list) 321 | a_subset = a_r[ind_subset] 322 | ind = np.argpartition(a_subset, -n_edges_keep)[-n_edges_keep:] 323 | 324 | row_idx = [] 325 | col_idx = [] 326 | for idx in ind: 327 | idx = ind_subset[idx] 328 | row_idx.append(idx // N) 329 | col_idx.append(idx % N) 330 | assert(col_idx < row_idx) 331 | data_idx = np.ones(n_edges_keep, dtype=np.int32) 332 | print(f'data preparation done using {time.time() - t} secs!') 333 | 334 | mat = sp.csr_matrix((data_idx, (row_idx, col_idx)), shape=(N, N)) 335 | return mat + mat.T 336 | 337 | 338 | def build_adj_original(self, edges): 339 | adj = sp.coo_matrix((np.ones(edges.shape[0]), (edges[:, 0], edges[:, 1])), 340 | shape=(self.n_nodes, self.n_nodes), 341 | dtype=np.float32) 342 | 343 | # build symmetric adjacency matrix 344 | adj = adj + adj.T.multiply(adj.T > adj) - adj.multiply(adj.T > adj) 345 | return adj 346 | 347 | 348 | def build_adj_mat(self, edges, mode='vanilla-clean'): 349 | if mode in ( 'vanilla-clean', 'degcn-clean' ): 350 | adj = self.build_adj_original(edges) 351 | 352 | elif mode in ( 'vanilla', 'degcn' ): 353 | adj = self.build_adj_vanilla() 354 | if mode == 'degcn': 355 | # temp = np.zeros((self.n_nodes, self.n_nodes)) 356 | # temp[adj > 0.5] = 1 357 | # adj = temp 358 | 359 | # print(len(self.edges)) 360 | self.edges = [] 361 | for u, v in zip(*np.where(adj)): 362 | if u > v: continue 363 | self.edges.append((u, v)) 364 | 365 | print(len(self.edges)) 366 | 367 | adj = normalize(adj + sp.eye(adj.shape[0])) 368 | adj = sparse_mx_to_torch_sparse_tensor(adj) if mode in ( 'vanilla-clean', 'degcn-clean' ) else torch.FloatTensor(adj) 369 | return adj 370 | 371 | 372 | def sgc_precompute(self, adj, features, mode='sgc-clean'): 373 | # # if mode == 'sgc-clean': 374 | # adj = self.build_adj_original() 375 | # # else: 376 | # # adj = self.build_adj_vanilla() 377 | 378 | normalizer = fetch_normalization(self.args.norm) 379 | adj = sparse_mx_to_torch_sparse_tensor(normalizer(adj)).float().cuda() 380 | 381 | # adj_normalizer = fetch_normalization(self.args.normalization) 382 | # adj = adj_normalizer(adj) 383 | # adj = sparse_mx_to_torch_sparse_tensor(adj).float().cuda() 384 | 385 | # for _ in range(self.args.degree): 386 | features = torch.spmm(adj, features) 387 | 388 | return features 389 | 390 | 391 | def decompose_graph(self): 392 | self.sub_adj = [] 393 | num = len(self.edges) 394 | for i in range(3): 395 | sub_edges = [self.edges[i] for i in range(i, num, 3)] 396 | self.sub_adj.append(self.build_adj_mat(np.asarray(sub_edges), mode='degcn-clean')) 397 | 398 | 399 | def construct_hop_dict(self): 400 | self.edge_dict = {u : set() for u in range(self.n_nodes)} 401 | for u, v in self.edges: 402 | self.edge_dict[u].add(v) 403 | self.edge_dict[v].add(u) 404 | 405 | self.two_hop_dict = {u : set() for u in range(self.n_nodes)} 406 | 407 | self.one_hop_edges = [] 408 | for u in self.edge_dict: 409 | for v in self.edge_dict[u]: 410 | for p in self.edge_dict[v]: 411 | if p > u and u not in self.edge_dict[p]: 412 | self.one_hop_edges.append((u, p)) 413 | self.two_hop_dict[u].add(p) 414 | self.two_hop_dict[p].add(u) 415 | 416 | self.two_hop_edges = [] 417 | for u in self.edge_dict: 418 | for v in self.edge_dict[u]: 419 | for p in self.two_hop_dict[v]: 420 | if p > u and u not in self.edge_dict[p] and u not in self.two_hop_dict[p]: 421 | self.two_hop_edges.append((u, p)) 422 | 423 | 424 | def load_data(self): 425 | if self.dataset in ( 'reddit', 'flickr', 'ppi', 'ppi-large', 'cora', 'citeseer', 'pubmed' ): 426 | self.features, self.features_train, self.labels, self.idx_train, self.idx_val, self.idx_test \ 427 | = feature_reader(dataset=self.dataset, scale=self.args.scale, 428 | train_ratio=self.args.train_ratio, feature_size=self.args.feature_size) 429 | 430 | if torch.cuda.is_available(): 431 | self.features = self.features.cuda() 432 | self.features_train = self.features_train.cuda() 433 | self.labels = self.labels.cuda() 434 | 435 | self.n_nodes = len(self.labels) 436 | self.n_features = self.features.shape[1] 437 | self.multi_label = self.labels.shape[1] 438 | if self.multi_label == 1: 439 | self.n_classes = self.labels.max().item() + 1 440 | else: 441 | self.n_classes = self.multi_label 442 | 443 | elif self.dataset.startswith( 'twitch-train' ): 444 | p = self.dataset.find('/') 445 | self.features, self.labels = feature_reader(dataset=f'twitch/{self.dataset[p+1:]}') 446 | self.n_nodes = len(self.labels) 447 | self.n_nodes_1 = int(0.8 * self.n_nodes) 448 | self.n_nodes_2 = self.n_nodes - self.n_nodes_1 449 | self.idx_train = np.random.choice(self.n_nodes, self.n_nodes_1, replace=False) 450 | self.idx_val = np.asarray( list( set(range(self.n_nodes)) - set(range(self.n_nodes_1)) ) ) 451 | 452 | self.features_train = self.features[self.idx_train] 453 | 454 | scaler = StandardScaler() 455 | scaler.fit(self.features_train) 456 | self.features = scaler.transform(self.features) 457 | self.features = torch.FloatTensor(self.features) 458 | self.features_train = self.features[self.idx_train] 459 | 460 | if torch.cuda.is_available(): 461 | self.features = self.features.cuda() 462 | self.features_train = self.features_train.cuda() 463 | self.labels = self.labels.cuda() 464 | 465 | self.n_features = 3170 466 | self.multi_label = 1 467 | self.n_classes = 2 468 | 469 | 470 | elif self.dataset.startswith( 'twitch' ): 471 | p_0 = self.dataset.find('/') 472 | data_folder = self.dataset[:p_0] 473 | 474 | p = self.dataset.rfind('/')+1 475 | self.dataset1 = self.dataset[:p-1] 476 | self.dataset2 = f'{data_folder}/{self.dataset[p:]}' 477 | 478 | self.features_1, self.labels_1 = feature_reader(dataset=self.dataset1) 479 | self.features_2, self.labels_2 = feature_reader(dataset=self.dataset2) 480 | 481 | scaler = StandardScaler() 482 | scaler.fit(self.features_1) 483 | self.features_1 = torch.FloatTensor(scaler.transform(self.features_1)) 484 | self.features_2 = torch.FloatTensor(scaler.transform(self.features_2)) 485 | 486 | if torch.cuda.is_available(): 487 | self.features_1 = self.features_1.cuda() 488 | self.features_2 = self.features_2.cuda() 489 | self.labels_1 = self.labels_1.cuda() 490 | self.labels_2 = self.labels_2.cuda() 491 | 492 | self.n_nodes_1 = len(self.labels_1) 493 | self.n_nodes_2 = len(self.labels_2) 494 | self.n_features = 3170 495 | self.multi_label = 1 496 | self.n_classes = 2 497 | 498 | elif self.dataset.startswith( 'deezer' ): 499 | p_0 = self.dataset.find('/') 500 | data_folder = self.dataset[:p_0] 501 | 502 | p = self.dataset.rfind('/')+1 503 | self.dataset1 = self.dataset[:p-1] 504 | self.dataset2 = f'{data_folder}/{self.dataset[p:]}' 505 | 506 | self.labels_1 = feature_reader(dataset=self.dataset1) 507 | self.labels_2 = feature_reader(dataset=self.dataset2) 508 | 509 | if torch.cuda.is_available(): 510 | self.labels_1 = self.labels_1.cuda() 511 | self.labels_2 = self.labels_2.cuda() 512 | 513 | self.n_nodes_1 = len(self.labels_1) 514 | self.n_nodes_2 = len(self.labels_2) 515 | self.n_classes = self.multi_label = 84 516 | 517 | else: 518 | raise NotImplementedError(f'dataset = {self.dataset} not implemented!') 519 | 520 | print(f'loading {self.dataset} features done!') 521 | 522 | # print('feature_size', self.features.shape) 523 | 524 | # print('====================================') 525 | # print('|| n_nodes =', self.n_nodes) 526 | # print('|| n_features =', self.n_features) 527 | # print('|| n_classes =', self.n_classes, '(', self.multi_label, ')') 528 | # print('====================================') 529 | 530 | if self.args.mode in ( 'mlp', 'lr' ): return 531 | 532 | if self.dataset in ( 'reddit', 'flickr', 'ppi', 'ppi-large', 'cora', 'citeseer', 'pubmed' ): 533 | self.adj_full = graph_reader(args=self.args, dataset=self.dataset, n_nodes=self.n_nodes) 534 | 535 | # construct training data 536 | if self.dataset in ( 'cora', 'citeseer', 'pubmed' ): 537 | self.adj_train = sp.csr_matrix.copy(self.adj_full) 538 | self.adj_ori = sp.csr_matrix.copy(self.adj_full) 539 | else: 540 | self.adj_train = self.adj_full[self.idx_train, :][:, self.idx_train] 541 | self.adj_ori = sp.csr_matrix.copy(self.adj_full) 542 | 543 | elif self.dataset.startswith( 'twitch-train' ): 544 | p = self.dataset.find('/') 545 | self.adj_full = graph_reader(args=self.args, dataset=f'twitch/{self.dataset[p+1:]}', n_nodes=self.n_nodes) 546 | self.adj_train = self.adj_full[self.idx_train, :][:, self.idx_train] 547 | self.adj_ori = sp.csr_matrix.copy(self.adj_full) 548 | 549 | elif self.dataset.startswith( 'twitch' ): 550 | self.adj_1 = graph_reader(args=self.args, dataset=self.dataset1, n_nodes=self.n_nodes_1) 551 | self.adj_2 = graph_reader(args=self.args, dataset=self.dataset2, n_nodes=self.n_nodes_2) 552 | self.adj_ori = sp.csr_matrix.copy(self.adj_2) 553 | 554 | elif self.dataset.startswith( 'deezer' ): 555 | self.adj_1, self.features_1 = graph_reader(args=self.args, dataset=self.dataset1, n_nodes=self.n_nodes_1) 556 | self.adj_2, self.features_2 = graph_reader(args=self.args, dataset=self.dataset2, n_nodes=self.n_nodes_2) 557 | self.adj_ori = sp.csr_matrix.copy(self.adj_2) 558 | self.n_features = self.features_1.shape[-1] 559 | 560 | if torch.cuda.is_available(): 561 | self.features_1 = self.features_1.cuda() 562 | self.features_2 = self.features_2.cuda() 563 | 564 | else: 565 | self.edges = graph_reader(args=self.args, dataset=self.dataset) 566 | 567 | # self.construct_hop_dict() 568 | 569 | # self.exist_edges = random.sample(self.edges.tolist(), self.n_test) 570 | # self.nonexist_edges = random.sample(self.one_hop_edges, self.n_test) 571 | 572 | # self.nonexist_edges = random.sample(self.two_hop_edges, self.n_test) 573 | # self.nonexist_edges = random.sample(self.two_hop_edges+self.one_hop_edges, self.n_test) 574 | # self.nonexist_edges = [] 575 | # cnt_nonexist = 0 576 | # while 1: 577 | # u = np.random.choice(self.n_nodes) 578 | # v = np.random.choice(self.n_nodes) 579 | # if u != v and v not in self.edge_dict[u]: 580 | # self.nonexist_edges.append((u, v)) 581 | # cnt_nonexist += 1 582 | # if cnt_nonexist == self.n_test: break 583 | 584 | # self.labeler = Labeler(self.features, self.labels, self.n_classes, 585 | # self.idx_train, self.idx_val, self.idx_test) 586 | 587 | self.prepare_data() 588 | 589 | def prepare_data(self): 590 | if self.mode in ( 'sgc-clean', 'sgc' ): 591 | if self.dataset in ( 'reddit', 'flickr', 'ppi', 'ppi-large', 'cora', 'citeseer', 'pubmed' ): 592 | self.features_train = self.sgc_precompute(self.adj_train, self.features_train, mode=self.mode) 593 | self.features = self.sgc_precompute(self.adj_full, self.features, mode=self.mode) 594 | self.adj = self.adj_train = None 595 | 596 | elif self.transfer: 597 | self.features_1 = self.sgc_precompute(self.adj_1, self.features_1, mode=self.mode) 598 | self.features_2 = self.sgc_precompute(self.adj_2, self.features_2, mode=self.mode) 599 | self.adj_1 = self.adj_2 = None 600 | 601 | else: 602 | raise NotImplementedError(f'dataset = {self.dataset} not implemented!') 603 | 604 | print('SGC Precomputing done!') 605 | 606 | elif self.mode in ( 'clusteradj', 'clusteradj-clean' ): 607 | self.generate_fake_labels() 608 | if self.args.break_down: 609 | self.break_down() 610 | 611 | self.prj = self.build_cluster_prj() 612 | self.adj = self.build_cluster_adj(fnormalize=self.args.fnormalize) 613 | 614 | elif self.mode in ( 'vanilla', 'vanilla-clean', 'cs' ): 615 | if self.dataset in ( 'reddit', 'flickr', 'ppi', 'ppi-large', 'cora', 'citeseer', 'pubmed' ) \ 616 | or self.dataset.startswith('twitch-train'): 617 | if self.mode == 'vanilla': 618 | self.adj_full = self.perturb_adj(self.adj_full, self.args.perturb_type) 619 | self.adj_train = self.perturb_adj(self.adj_train, self.args.perturb_type) 620 | print('perturbing done!') 621 | 622 | # normalize adjacency matrix 623 | if self.dataset not in ( 'cora', 'citeseer', 'pubmed' ): 624 | normalizer = fetch_normalization(self.args.norm) 625 | self.adj_train = normalizer(self.adj_train) 626 | self.adj_full = normalizer(self.adj_full) 627 | 628 | self.adj_train = sparse_mx_to_torch_sparse_tensor(self.adj_train) 629 | self.adj_full = sparse_mx_to_torch_sparse_tensor(self.adj_full) 630 | 631 | elif self.transfer: 632 | if self.mode == 'vanilla': 633 | self.adj_1 = self.perturb_adj(self.adj_1, self.args.perturb_type) 634 | self.adj_2 = self.perturb_adj(self.adj_2, self.args.perturb_type) 635 | print('perturbing done!') 636 | 637 | elif self.mode == 'cs': 638 | self.adj_1 = compressive_sensing(self.args, self.adj_1) 639 | self.adj_2 = compressive_sensing(self.args, self.adj_2) 640 | print('compressive sensing done!') 641 | 642 | # normalize adjacency matrix 643 | normalizer = fetch_normalization(self.args.norm) 644 | self.adj_1 = sparse_mx_to_torch_sparse_tensor(normalizer(self.adj_1)) 645 | self.adj_2 = sparse_mx_to_torch_sparse_tensor(normalizer(self.adj_2)) 646 | 647 | else: 648 | # self.adj = self.build_adj_mat(self.edges, mode=self.mode) 649 | raise NotImplementedError(f'dataset = {self.dataset} not implemented!') 650 | 651 | print('Normalizing Adj done!') 652 | 653 | elif self.mode in ( 'degree_mlp', 'basic_mlp' ): 654 | self.adj = None 655 | 656 | elif self.mode in ( 'degcn', 'degcn-clean' ): 657 | self.adj = self.build_adj_mat(self.edges, mode=self.mode) 658 | self.decompose_graph() 659 | 660 | else: 661 | raise NotImplementedError('mode = {} not implemented!'.format(self.mode)) 662 | 663 | # self.calculate_connectivity() 664 | 665 | if torch.cuda.is_available(): 666 | if hasattr(self, 'adj') and self.adj is not None: 667 | self.adj = self.adj.cuda() 668 | if hasattr(self, 'adj_train') and self.adj_train is not None: 669 | self.adj_train = self.adj_train.cuda() 670 | self.adj_full = self.adj_full.cuda() 671 | if hasattr(self, 'adj_1') and self.adj_1 is not None: 672 | self.adj_1 = self.adj_1.cuda() 673 | self.adj_2 = self.adj_2.cuda() 674 | if hasattr(self, 'prj'): 675 | self.prj = self.prj.cuda() 676 | if hasattr(self, 'sub_adj'): 677 | for i in range(len(self.sub_adj)): 678 | self.sub_adj[i] = self.sub_adj[i].cuda() 679 | 680 | 681 | def generate_fake_labels(self): 682 | cluster_method = self.args.cluster_method 683 | t0 = time.time() 684 | 685 | if cluster_method == 'random': 686 | self.n_clusters = self.args.n_clusters 687 | self.fake_labels = self.labeler.get_random_labels(self.n_clusters, self.args.cluster_seed) 688 | 689 | elif cluster_method == 'hierarchical': 690 | init_method = self.args.init_method 691 | self.n_clusters = self.n_classes 692 | 693 | if init_method == 'naive': 694 | self.fake_labels = self.labeler.get_naive_labels(self.args.assign_seed) 695 | 696 | elif init_method == 'voting': 697 | self.fake_labels = self.labeler.get_majority_labels(self.edges, self.args.assign_seed) 698 | 699 | elif init_method == 'knn': 700 | self.fake_labels = self.labeler.get_knn_labels(self.args.knn) 701 | 702 | elif init_method == 'gt': 703 | self.fake_labels = self.labels.clone() 704 | 705 | else: 706 | raise NotImplementedError('init_method={} in cluster_method=label not implemented!'.format(init_method)) 707 | 708 | elif cluster_method in ( 'kmeans', 'sskmeans' ): 709 | self.n_clusters = self.args.n_clusters 710 | self.fake_labels = self.labeler.get_kmeans_labels( 711 | self.n_clusters, self.args.knn, cluster_method, same_size=self.args.same_size) 712 | 713 | else: 714 | raise NotImplementedError('cluster_method={} not implemented!'.format(cluster_method)) 715 | 716 | print('generating fake labels done using {} secs!'.format(time.time()-t0)) 717 | # torch.save(self.fake_labels, 'flabels_{}.pt'.format(self.n_clusters)) 718 | 719 | 720 | def calculate_connectivity(self): 721 | n_edges = len(self.edges) 722 | kappa = n_edges / (0.5*self.n_nodes*(self.n_nodes-1)) 723 | labels = self.fake_labels 724 | 725 | edge_adj = np.zeros((self.n_clusters, self.n_clusters)) 726 | for edge in self.edges: 727 | u, v = labels[edge[0]], labels[edge[1]] 728 | edge_adj[u][v] += 1 729 | edge_adj[v][u] += 1 730 | 731 | unique, count = np.unique(labels, return_counts=True) 732 | 733 | kappa_intra = 0 734 | for i in range(self.n_clusters): 735 | kappa_intra += edge_adj[i][i] / (0.5 * count[i] * (count[i]-1)) 736 | kappa_intra /= self.n_clusters 737 | 738 | kappa_inter = 0 739 | for i in range(self.n_clusters): 740 | for j in range(i+1, self.n_clusters): 741 | kappa_inter += edge_adj[i][j] / (count[i] * count[j]) 742 | kappa_inter /= (0.5 * self.n_clusters * (self.n_clusters - 1)) 743 | 744 | print('k_inter = {:4f}, k = {:4f}, k_intra = {:4f}'.format(kappa_inter, kappa, kappa_intra)) 745 | logging.info('k_inter = {:4f}, k = {:4f}, k_intra = {:4f}'.format(kappa_inter, kappa, kappa_intra)) 746 | 747 | 748 | def calculate_degree(self): 749 | if self.dataset.startswith( 'twitch' ): 750 | degrees = np.zeros(self.n_nodes_2) 751 | adj = self.adj_2 752 | else: 753 | degrees = np.zeros(self.n_nodes) 754 | adj = self.adj_train 755 | 756 | self.edges = [] 757 | for u, v in zip(*np.where(adj.cpu().to_dense())): 758 | if u > v: continue 759 | degrees[u] += 1 760 | degrees[v] += 1 761 | return degrees 762 | 763 | 764 | def update_adj(self): 765 | if self.mode == 'clusteradj': 766 | self.adj = self.build_cluster_adj(clean=True, fnormalize=self.args.fnormalize) 767 | 768 | elif self.mode == 'vanilla': 769 | self.adj = self.build_adj_mat(self.edges) 770 | 771 | elif self.mode == 'sgc': 772 | self.features = self.sgc_precompute() 773 | 774 | elif self.mode == 'degcn': 775 | pass 776 | 777 | if torch.cuda.is_available(): 778 | self.adj = self.adj.cuda() --------------------------------------------------------------------------------