├── images └── fig1.png ├── framework ├── models │ ├── __init__.py │ ├── gat.py │ ├── gcn.py │ ├── gin.py │ ├── rgcn.py │ ├── graph_classification │ │ ├── gcn.py │ │ └── gcn_delete.py │ ├── deletion.py │ └── rgat.py ├── utils.py ├── __init__.py ├── data_loader.py ├── trainer │ ├── gradient_ascent_with_mp.py │ ├── descent_to_delete.py │ ├── approx_retrain.py │ ├── member_infer.py │ ├── gradient_ascent.py │ ├── gnndelete_embdis.py │ ├── retrain.py │ └── graph_eraser.py ├── evaluation.py └── training_args.py ├── run_original.sh ├── run_delete.sh ├── LICENSE ├── graph_stat.py ├── train_node.py ├── .gitignore ├── train_gnn.py ├── README.md ├── delete_gnn.py ├── delete_node.py ├── delete_node_feature.py └── prepare_dataset.py /images/fig1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mims-harvard/GNNDelete/HEAD/images/fig1.png -------------------------------------------------------------------------------- /framework/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .gcn import GCN 2 | from .gat import GAT 3 | from .gin import GIN 4 | from .rgcn import RGCN 5 | from .rgat import RGAT 6 | from .deletion import GCNDelete, GATDelete, GINDelete, RGCNDelete, RGATDelete -------------------------------------------------------------------------------- /run_original.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | source activate pyg 4 | 5 | DATA=$1 6 | MODEL=$2 7 | SEED=$3 8 | UN=original 9 | 10 | export WANDB_MODE=offline 11 | export WANDB_PROJECT=zitniklab-gnn-unlearning 12 | 13 | export WANDB_NAME="$UN"_"$DATA"_"$MODEL"_"$SEED" 14 | export WANDB_RUN_NAME="$UN"_"$DATA"_"$MODEL"_"$SEED" 15 | export WANDB_RUN_ID="$UN"_"$DATA"_"$MODEL"_"$SEED" 16 | 17 | python train_gnn.py --lr 1e-3 \ 18 | --epochs 1500 \ 19 | --dataset "$DATA" \ 20 | --random_seed "$SEED" \ 21 | --unlearning_model "$UN" \ 22 | --gnn "$MODEL" -------------------------------------------------------------------------------- /run_delete.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | source activate pyg 4 | 5 | DATA=$1 6 | MODEL=$2 7 | UN=$3 8 | DF=$4 9 | DF_SIZE=$5 10 | SEED=$6 11 | 12 | export WANDB_MODE=offline 13 | export WANDB_PROJECT=zitniklab-gnn-unlearning 14 | 15 | 16 | export WANDB_NAME="$UN"_"$DATA"_"$MODEL"_"$DF"_"$DF_SIZE"_"$SEED" 17 | export WANDB_RUN_NAME="$UN"_"$DATA"_"$MODEL"_"$DF"_"$DF_SIZE"_"$SEED" 18 | export WANDB_RUN_ID="$UN"_"$DATA"_"$MODEL"_"$DF"_"$DF_SIZE"_"$SEED" 19 | 20 | python delete_gnn.py --lr 1e-3 \ 21 | --epochs 1500 \ 22 | --dataset "$DATA" \ 23 | --random_seed "$SEED" \ 24 | --unlearning_model "$UN" \ 25 | --gnn "$MODEL" \ 26 | --df "$DF" \ 27 | --df_size "$DF_SIZE" -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Machine Learning for Medicine and Science @ Harvard 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /framework/models/gat.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch_geometric.nn import GATConv 5 | 6 | 7 | class GAT(nn.Module): 8 | def __init__(self, args, **kwargs): 9 | super().__init__() 10 | 11 | self.conv1 = GATConv(args.in_dim, args.hidden_dim) 12 | self.conv2 = GATConv(args.hidden_dim, args.out_dim) 13 | # self.dropout = nn.Dropout(args.dropout) 14 | 15 | def forward(self, x, edge_index, return_all_emb=False): 16 | x1 = self.conv1(x, edge_index) 17 | x = F.relu(x1) 18 | # x = self.dropout(x) 19 | x2 = self.conv2(x, edge_index) 20 | 21 | if return_all_emb: 22 | return x1, x2 23 | 24 | return x2 25 | 26 | def decode(self, z, pos_edge_index, neg_edge_index=None): 27 | if neg_edge_index is not None: 28 | edge_index = torch.cat([pos_edge_index, neg_edge_index], dim=-1) 29 | logits = (z[edge_index[0]] * z[edge_index[1]]).sum(dim=-1) 30 | 31 | else: 32 | edge_index = pos_edge_index 33 | logits = (z[edge_index[0]] * z[edge_index[1]]).sum(dim=-1) 34 | 35 | return logits 36 | -------------------------------------------------------------------------------- /framework/models/gcn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch_geometric.nn import GCNConv 5 | 6 | 7 | class GCN(nn.Module): 8 | def __init__(self, args, **kwargs): 9 | super().__init__() 10 | 11 | self.conv1 = GCNConv(args.in_dim, args.hidden_dim) 12 | self.conv2 = GCNConv(args.hidden_dim, args.out_dim) 13 | # self.dropout = nn.Dropout(args.dropout) 14 | 15 | def forward(self, x, edge_index, return_all_emb=False): 16 | x1 = self.conv1(x, edge_index) 17 | x = F.relu(x1) 18 | # x = self.dropout(x) 19 | x2 = self.conv2(x, edge_index) 20 | 21 | if return_all_emb: 22 | return x1, x2 23 | 24 | return x2 25 | 26 | def decode(self, z, pos_edge_index, neg_edge_index=None): 27 | if neg_edge_index is not None: 28 | edge_index = torch.cat([pos_edge_index, neg_edge_index], dim=-1) 29 | logits = (z[edge_index[0]] * z[edge_index[1]]).sum(dim=-1) 30 | 31 | else: 32 | edge_index = pos_edge_index 33 | logits = (z[edge_index[0]] * z[edge_index[1]]).sum(dim=-1) 34 | 35 | return logits 36 | -------------------------------------------------------------------------------- /graph_stat.py: -------------------------------------------------------------------------------- 1 | import os 2 | from torch_geometric.data import Data 3 | import torch_geometric.transforms as T 4 | from torch_geometric.datasets import CitationFull, Coauthor, Flickr, RelLinkPredDataset, WordNet18RR 5 | from ogb.linkproppred import PygLinkPropPredDataset 6 | 7 | 8 | data_dir = './data' 9 | datasets = ['Cora', 'PubMed', 'DBLP', 'CS', 'Physics', 'ogbl-citation2', 'ogbl-collab', 'FB15k-237', 'WordNet18RR', 'ogbl-biokg', 'ogbl-wikikg2'][-2:] 10 | 11 | def get_stat(d): 12 | if d in ['Cora', 'PubMed', 'DBLP']: 13 | dataset = CitationFull(os.path.join(data_dir, d), d, transform=T.NormalizeFeatures()) 14 | if d in ['CS', 'Physics']: 15 | dataset = Coauthor(os.path.join(data_dir, d), d, transform=T.NormalizeFeatures()) 16 | if d in ['Flickr']: 17 | dataset = Flickr(os.path.join(data_dir, d), transform=T.NormalizeFeatures()) 18 | if 'ogbl' in d: 19 | dataset = PygLinkPropPredDataset(root=os.path.join(data_dir, d), name=d) 20 | 21 | data = dataset[0] 22 | print(d) 23 | print('Number of nodes:', data.num_nodes) 24 | print('Number of edges:', data.num_edges) 25 | print('Number of max deleted edges:', int(0.05 * data.num_edges)) 26 | if hasattr(data, 'edge_type'): 27 | print('Number of nodes:', data.edge_type.unique().shape) 28 | 29 | def main(): 30 | for d in datasets: 31 | get_stat(d) 32 | 33 | if __name__ == "__main__": 34 | main() 35 | -------------------------------------------------------------------------------- /framework/models/gin.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch_geometric.nn import GINConv 5 | 6 | 7 | class GIN(nn.Module): 8 | def __init__(self, args, **kwargs): 9 | super().__init__() 10 | 11 | self.conv1 = GINConv(nn.Linear(args.in_dim, args.hidden_dim)) 12 | self.conv2= GINConv(nn.Linear(args.hidden_dim, args.out_dim)) 13 | # self.transition = nn.Sequential( 14 | # nn.ReLU(), 15 | # # nn.Dropout(p=args.dropout) 16 | # ) 17 | # self.mlp1 = nn.Sequential( 18 | # nn.Linear(args.in_dim, args.hidden_dim), 19 | # nn.ReLU(), 20 | # ) 21 | # self.mlp2 = nn.Sequential( 22 | # nn.Linear(args.hidden_dim, args.out_dim), 23 | # nn.ReLU(), 24 | # ) 25 | 26 | def forward(self, x, edge_index, return_all_emb=False): 27 | x1 = self.conv1(x, edge_index) 28 | x = F.relu(x1) 29 | x2 = self.conv2(x, edge_index) 30 | 31 | if return_all_emb: 32 | return x1, x2 33 | 34 | return x2 35 | 36 | def decode(self, z, pos_edge_index, neg_edge_index=None): 37 | if neg_edge_index is not None: 38 | edge_index = torch.cat([pos_edge_index, neg_edge_index], dim=-1) 39 | logits = (z[edge_index[0]] * z[edge_index[1]]).sum(dim=-1) 40 | 41 | else: 42 | edge_index = pos_edge_index 43 | logits = (z[edge_index[0]] * z[edge_index[1]]).sum(dim=-1) 44 | 45 | return logits 46 | -------------------------------------------------------------------------------- /framework/models/rgcn.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch_geometric.nn import RGCNConv, FastRGCNConv 6 | from sklearn.metrics import roc_auc_score, average_precision_score 7 | 8 | 9 | class RGCN(nn.Module): 10 | def __init__(self, args, num_nodes, num_edge_type, **kwargs): 11 | super().__init__() 12 | self.args = args 13 | self.num_edge_type = num_edge_type 14 | 15 | # Encoder: RGCN 16 | self.node_emb = nn.Embedding(num_nodes, args.in_dim) 17 | if num_edge_type > 20: 18 | self.conv1 = RGCNConv(args.in_dim, args.hidden_dim, num_edge_type * 2, num_blocks=4) 19 | self.conv2 = RGCNConv(args.hidden_dim, args.out_dim, num_edge_type * 2, num_blocks=4) 20 | else: 21 | self.conv1 = RGCNConv(args.in_dim, args.hidden_dim, num_edge_type * 2) 22 | self.conv2 = RGCNConv(args.hidden_dim, args.out_dim, num_edge_type * 2) 23 | self.relu = nn.ReLU() 24 | 25 | # Decoder: DistMult 26 | self.W = nn.Parameter(torch.Tensor(num_edge_type, args.out_dim)) 27 | nn.init.xavier_uniform_(self.W, gain=nn.init.calculate_gain('relu')) 28 | 29 | def forward(self, x, edge, edge_type, return_all_emb=False): 30 | x = self.node_emb(x) 31 | x1 = self.conv1(x, edge, edge_type) 32 | x = self.relu(x1) 33 | x2 = self.conv2(x, edge, edge_type) 34 | 35 | if return_all_emb: 36 | return x1, x2 37 | 38 | return x2 39 | 40 | def decode(self, z, edge_index, edge_type): 41 | h = z[edge_index[0]] 42 | t = z[edge_index[1]] 43 | r = self.W[edge_type] 44 | 45 | logits = torch.sum(h * r * t, dim=1) 46 | 47 | return logits 48 | 49 | class RGCNDelete(RGCN): 50 | def __init__(self): 51 | pass 52 | -------------------------------------------------------------------------------- /framework/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import networkx as nx 4 | 5 | 6 | def get_node_edge(graph): 7 | degree_sorted_ascend = sorted(graph.degree, key=lambda x: x[1]) 8 | 9 | return degree_sorted_ascend[-1][0] 10 | 11 | def h_hop_neighbor(G, node, h): 12 | path_lengths = nx.single_source_dijkstra_path_length(G, node) 13 | return [node for node, length in path_lengths.items() if length == h] 14 | 15 | def get_enclosing_subgraph(graph, edge_to_delete): 16 | subgraph = {0: [edge_to_delete]} 17 | s, t = edge_to_delete 18 | 19 | neighbor_s = [] 20 | neighbor_t = [] 21 | for h in range(1, 2+1): 22 | neighbor_s += h_hop_neighbor(graph, s, h) 23 | neighbor_t += h_hop_neighbor(graph, t, h) 24 | 25 | nodes = neighbor_s + neighbor_t + [s, t] 26 | 27 | subgraph[h] = list(graph.subgraph(nodes).edges()) 28 | 29 | return subgraph 30 | 31 | @torch.no_grad() 32 | def get_link_labels(pos_edge_index, neg_edge_index): 33 | E = pos_edge_index.size(1) + neg_edge_index.size(1) 34 | link_labels = torch.zeros(E, dtype=torch.float, device=pos_edge_index.device) 35 | link_labels[:pos_edge_index.size(1)] = 1. 36 | return link_labels 37 | 38 | @torch.no_grad() 39 | def get_link_labels_kg(pos_edge_index, neg_edge_index): 40 | E = pos_edge_index.size(1) + neg_edge_index.size(1) 41 | link_labels = torch.zeros(E, dtype=torch.float, device=pos_edge_index.device) 42 | link_labels[:pos_edge_index.size(1)] = 1. 43 | 44 | return link_labels 45 | 46 | @torch.no_grad() 47 | def negative_sampling_kg(edge_index, edge_type): 48 | '''Generate negative samples but keep the node type the same''' 49 | 50 | edge_index_copy = edge_index.clone() 51 | for et in edge_type.unique(): 52 | mask = (edge_type == et) 53 | old_source = edge_index_copy[0, mask] 54 | new_index = torch.randperm(old_source.shape[0]) 55 | new_source = old_source[new_index] 56 | edge_index_copy[0, mask] = new_source 57 | 58 | return edge_index_copy 59 | -------------------------------------------------------------------------------- /train_node.py: -------------------------------------------------------------------------------- 1 | import os 2 | import wandb 3 | import pickle 4 | import torch 5 | from torch_geometric.seed import seed_everything 6 | from torch_geometric.utils import to_undirected, is_undirected 7 | import torch_geometric.transforms as T 8 | from torch_geometric.datasets import CitationFull, Coauthor, Flickr, RelLinkPredDataset, WordNet18, WordNet18RR 9 | from torch_geometric.seed import seed_everything 10 | 11 | from framework import get_model, get_trainer 12 | from framework.training_args import parse_args 13 | from framework.trainer.base import NodeClassificationTrainer 14 | from framework.utils import negative_sampling_kg 15 | 16 | 17 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 18 | 19 | def main(): 20 | args = parse_args() 21 | args.checkpoint_dir = 'checkpoint_node' 22 | args.dataset = 'DBLP' 23 | args.unlearning_model = 'original' 24 | args.checkpoint_dir = os.path.join(args.checkpoint_dir, args.dataset, args.gnn, args.unlearning_model, str(args.random_seed)) 25 | os.makedirs(args.checkpoint_dir, exist_ok=True) 26 | seed_everything(args.random_seed) 27 | 28 | # Dataset 29 | dataset = CitationFull(os.path.join(args.data_dir, args.dataset), args.dataset, transform=T.NormalizeFeatures()) 30 | data = dataset[0] 31 | print('Original data', data) 32 | 33 | split = T.RandomNodeSplit() 34 | data = split(data) 35 | assert is_undirected(data.edge_index) 36 | 37 | print('Split data', data) 38 | args.in_dim = data.x.shape[1] 39 | args.out_dim = dataset.num_classes 40 | 41 | wandb.init(config=args) 42 | 43 | # Model 44 | model = get_model(args, num_nodes=data.num_nodes, num_edge_type=args.num_edge_type).to(device) 45 | wandb.watch(model, log_freq=100) 46 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)#, weight_decay=args.weight_decay) 47 | 48 | # Train 49 | trainer = NodeClassificationTrainer(args) 50 | trainer.train(model, data, optimizer, args) 51 | 52 | # Test 53 | trainer.test(model, data) 54 | trainer.save_log() 55 | 56 | 57 | if __name__ == "__main__": 58 | main() 59 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | experiment/ 2 | data/ 3 | checkpoint*/ 4 | data 5 | checkpoint* 6 | wandb 7 | wandb/ 8 | results/ 9 | *csv 10 | 11 | 12 | # Byte-compiled / optimized / DLL files 13 | __pycache__/ 14 | *.py[cod] 15 | *$py.class 16 | 17 | # C extensions 18 | *.so 19 | 20 | # Distribution / packaging 21 | .Python 22 | build/ 23 | develop-eggs/ 24 | dist/ 25 | downloads/ 26 | eggs/ 27 | .eggs/ 28 | lib/ 29 | lib64/ 30 | parts/ 31 | sdist/ 32 | var/ 33 | wheels/ 34 | pip-wheel-metadata/ 35 | share/python-wheels/ 36 | *.egg-info/ 37 | .installed.cfg 38 | *.egg 39 | MANIFEST 40 | 41 | # PyInstaller 42 | # Usually these files are written by a python script from a template 43 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 44 | *.manifest 45 | *.spec 46 | 47 | # Installer logs 48 | pip-log.txt 49 | pip-delete-this-directory.txt 50 | 51 | # Unit test / coverage reports 52 | htmlcov/ 53 | .tox/ 54 | .nox/ 55 | .coverage 56 | .coverage.* 57 | .cache 58 | nosetests.xml 59 | coverage.xml 60 | *.cover 61 | *.py,cover 62 | .hypothesis/ 63 | .pytest_cache/ 64 | 65 | # Translations 66 | *.mo 67 | *.pot 68 | 69 | # Django stuff: 70 | *.log 71 | local_settings.py 72 | db.sqlite3 73 | db.sqlite3-journal 74 | 75 | # Flask stuff: 76 | instance/ 77 | .webassets-cache 78 | 79 | # Scrapy stuff: 80 | .scrapy 81 | 82 | # Sphinx documentation 83 | docs/_build/ 84 | 85 | # PyBuilder 86 | target/ 87 | 88 | # Jupyter Notebook 89 | .ipynb_checkpoints 90 | 91 | # IPython 92 | profile_default/ 93 | ipython_config.py 94 | 95 | # pyenv 96 | .python-version 97 | 98 | # pipenv 99 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 100 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 101 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 102 | # install all needed dependencies. 103 | #Pipfile.lock 104 | 105 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 106 | __pypackages__/ 107 | 108 | # Celery stuff 109 | celerybeat-schedule 110 | celerybeat.pid 111 | 112 | # SageMath parsed files 113 | *.sage.py 114 | 115 | # Environments 116 | .env 117 | .venv 118 | env/ 119 | venv/ 120 | ENV/ 121 | env.bak/ 122 | venv.bak/ 123 | 124 | # Spyder project settings 125 | .spyderproject 126 | .spyproject 127 | 128 | # Rope project settings 129 | .ropeproject 130 | 131 | # mkdocs documentation 132 | /site 133 | 134 | # mypy 135 | .mypy_cache/ 136 | .dmypy.json 137 | dmypy.json 138 | 139 | # Pyre type checker 140 | .pyre/ 141 | -------------------------------------------------------------------------------- /framework/__init__.py: -------------------------------------------------------------------------------- 1 | from .models import GCN, GAT, GIN, RGCN, RGAT, GCNDelete, GATDelete, GINDelete, RGCNDelete, RGATDelete 2 | from .trainer.base import Trainer, KGTrainer, NodeClassificationTrainer 3 | from .trainer.retrain import RetrainTrainer, KGRetrainTrainer 4 | from .trainer.gnndelete import GNNDeleteTrainer 5 | from .trainer.gnndelete_nodeemb import GNNDeleteNodeembTrainer, KGGNNDeleteNodeembTrainer 6 | from .trainer.gradient_ascent import GradientAscentTrainer, KGGradientAscentTrainer 7 | from .trainer.descent_to_delete import DtdTrainer 8 | from .trainer.approx_retrain import ApproxTrainer 9 | from .trainer.graph_eraser import GraphEraserTrainer 10 | from .trainer.graph_editor import GraphEditorTrainer 11 | from .trainer.member_infer import MIAttackTrainer, MIAttackTrainerNode 12 | 13 | 14 | trainer_mapping = { 15 | 'original': Trainer, 16 | 'original_node': NodeClassificationTrainer, 17 | 'retrain': RetrainTrainer, 18 | 'gnndelete': GNNDeleteTrainer, 19 | 'gradient_ascent': GradientAscentTrainer, 20 | 'descent_to_delete': DtdTrainer, 21 | 'approx_retrain': ApproxTrainer, 22 | 'gnndelete_mse': GNNDeleteTrainer, 23 | 'gnndelete_kld': GNNDeleteTrainer, 24 | 'gnndelete_nodeemb': GNNDeleteNodeembTrainer, 25 | 'gnndelete_cosine': GNNDeleteTrainer, 26 | 'graph_eraser': GraphEraserTrainer, 27 | 'graph_editor': GraphEditorTrainer, 28 | 'member_infer_all': MIAttackTrainer, 29 | 'member_infer_sub': MIAttackTrainer, 30 | 'member_infer_all_node': MIAttackTrainerNode, 31 | 'member_infer_sub_node': MIAttackTrainerNode, 32 | } 33 | 34 | kg_trainer_mapping = { 35 | 'original': KGTrainer, 36 | 'retrain': KGRetrainTrainer, 37 | 'gnndelete': KGGNNDeleteNodeembTrainer, 38 | 'gradient_ascent': KGGradientAscentTrainer, 39 | 'descent_to_delete': DtdTrainer, 40 | 'approx_retrain': ApproxTrainer, 41 | 'gnndelete_mse': GNNDeleteTrainer, 42 | 'gnndelete_kld': GNNDeleteTrainer, 43 | 'gnndelete_cosine': GNNDeleteTrainer, 44 | 'gnndelete_nodeemb': KGGNNDeleteNodeembTrainer, 45 | 'graph_eraser': GraphEraserTrainer, 46 | 'member_infer_all': MIAttackTrainer, 47 | 'member_infer_sub': MIAttackTrainer, 48 | } 49 | 50 | 51 | def get_model(args, mask_1hop=None, mask_2hop=None, num_nodes=None, num_edge_type=None): 52 | 53 | if 'gnndelete' in args.unlearning_model: 54 | model_mapping = {'gcn': GCNDelete, 'gat': GATDelete, 'gin': GINDelete, 'rgcn': RGCNDelete, 'rgat': RGATDelete} 55 | 56 | else: 57 | model_mapping = {'gcn': GCN, 'gat': GAT, 'gin': GIN, 'rgcn': RGCN, 'rgat': RGAT} 58 | 59 | return model_mapping[args.gnn](args, mask_1hop=mask_1hop, mask_2hop=mask_2hop, num_nodes=num_nodes, num_edge_type=num_edge_type) 60 | 61 | 62 | def get_trainer(args): 63 | if args.gnn in ['rgcn', 'rgat']: 64 | return kg_trainer_mapping[args.unlearning_model](args) 65 | 66 | else: 67 | return trainer_mapping[args.unlearning_model](args) 68 | -------------------------------------------------------------------------------- /train_gnn.py: -------------------------------------------------------------------------------- 1 | import os 2 | import wandb 3 | import pickle 4 | import torch 5 | from torch_geometric.seed import seed_everything 6 | from torch_geometric.utils import to_undirected, is_undirected 7 | from torch_geometric.datasets import RelLinkPredDataset, WordNet18 8 | from torch_geometric.seed import seed_everything 9 | 10 | from framework import get_model, get_trainer 11 | from framework.training_args import parse_args 12 | from framework.trainer.base import Trainer 13 | from framework.utils import negative_sampling_kg 14 | 15 | 16 | 17 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 18 | 19 | 20 | def main(): 21 | args = parse_args() 22 | args.unlearning_model = 'original' 23 | args.checkpoint_dir = os.path.join(args.checkpoint_dir, args.dataset, args.gnn, args.unlearning_model, str(args.random_seed)) 24 | os.makedirs(args.checkpoint_dir, exist_ok=True) 25 | seed_everything(args.random_seed) 26 | 27 | # Dataset 28 | with open(os.path.join(args.data_dir, args.dataset, f'd_{args.random_seed}.pkl'), 'rb') as f: 29 | dataset, data = pickle.load(f) 30 | print('Directed dataset:', dataset, data) 31 | if args.gnn not in ['rgcn', 'rgat']: 32 | args.in_dim = dataset.num_features 33 | 34 | wandb.init(config=args) 35 | 36 | # Use proper training data for original and Dr 37 | if args.gnn in ['rgcn', 'rgat']: 38 | if not hasattr(data, 'train_mask'): 39 | data.train_mask = torch.ones(data.edge_index.shape[1], dtype=torch.bool) 40 | 41 | # data.dtrain_mask = torch.ones(data.train_pos_edge_index.shape[1], dtype=torch.bool) 42 | # data.edge_index_mask = data.dtrain_mask.repeat(2) 43 | 44 | else: 45 | data.dtrain_mask = torch.ones(data.train_pos_edge_index.shape[1], dtype=torch.bool) 46 | 47 | # To undirected 48 | if args.gnn in ['rgcn', 'rgat']: 49 | r, c = data.train_pos_edge_index 50 | rev_edge_index = torch.stack([c, r], dim=0) 51 | rev_edge_type = data.train_edge_type + args.num_edge_type 52 | 53 | data.edge_index = torch.cat((data.train_pos_edge_index, rev_edge_index), dim=1) 54 | data.edge_type = torch.cat([data.train_edge_type, rev_edge_type], dim=0) 55 | # data.train_mask = data.train_mask.repeat(2) 56 | 57 | data.dr_mask = torch.ones(data.edge_index.shape[1], dtype=torch.bool) 58 | assert is_undirected(data.edge_index) 59 | 60 | else: 61 | train_pos_edge_index = to_undirected(data.train_pos_edge_index) 62 | data.train_pos_edge_index = train_pos_edge_index 63 | data.dtrain_mask = torch.ones(data.train_pos_edge_index.shape[1], dtype=torch.bool) 64 | assert is_undirected(data.train_pos_edge_index) 65 | 66 | 67 | print('Undirected dataset:', data) 68 | 69 | # Model 70 | model = get_model(args, num_nodes=data.num_nodes, num_edge_type=args.num_edge_type).to(device) 71 | wandb.watch(model, log_freq=100) 72 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)#, weight_decay=args.weight_decay) 73 | 74 | # Train 75 | trainer = get_trainer(args) 76 | trainer.train(model, data, optimizer, args) 77 | 78 | # Test 79 | trainer.test(model, data) 80 | trainer.save_log() 81 | 82 | 83 | if __name__ == "__main__": 84 | main() 85 | -------------------------------------------------------------------------------- /framework/data_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torch_geometric.data import Data, GraphSAINTRandomWalkSampler 4 | 5 | 6 | def load_dict(filename): 7 | '''Load entity and relation to id mapping''' 8 | 9 | mapping = {} 10 | with open(filename, 'r') as f: 11 | for l in f: 12 | l = l.strip().split('\t') 13 | mapping[l[0]] = l[1] 14 | 15 | return mapping 16 | 17 | def load_edges(filename): 18 | with open(filename, 'r') as f: 19 | r = f.readlines() 20 | r = [i.strip().split('\t') for i in r] 21 | 22 | return r 23 | 24 | def generate_true_dict(all_triples): 25 | heads = {(r, t) : [] for _, r, t in all_triples} 26 | tails = {(h, r) : [] for h, r, _ in all_triples} 27 | 28 | for h, r, t in all_triples: 29 | heads[r, t].append(h) 30 | tails[h, r].append(t) 31 | 32 | return heads, tails 33 | 34 | def get_loader(args, delete=[]): 35 | prefix = os.path.join('./data', args.dataset) 36 | 37 | # Edges 38 | train = load_edges(os.path.join(prefix, 'train.txt')) 39 | valid = load_edges(os.path.join(prefix, 'valid.txt')) 40 | test = load_edges(os.path.join(prefix, 'test.txt')) 41 | train = [(int(i[0]), int(i[1]), int(i[2])) for i in train] 42 | valid = [(int(i[0]), int(i[1]), int(i[2])) for i in valid] 43 | test = [(int(i[0]), int(i[1]), int(i[2])) for i in test] 44 | train_rev = [(int(i[2]), int(i[1]), int(i[0])) for i in train] 45 | valid_rev = [(int(i[2]), int(i[1]), int(i[0])) for i in valid] 46 | test_rev = [(int(i[2]), int(i[1]), int(i[0])) for i in test] 47 | train = train + train_rev 48 | valid = valid + valid_rev 49 | test = test + test_rev 50 | all_edge = train + valid + test 51 | 52 | true_triples = generate_true_dict(all_edge) 53 | 54 | edge = torch.tensor([(int(i[0]), int(i[2])) for i in all_edge], dtype=torch.long).t() 55 | edge_type = torch.tensor([int(i[1]) for i in all_edge], dtype=torch.long)#.view(-1, 1) 56 | 57 | # Masks 58 | train_size = len(train) 59 | valid_size = len(valid) 60 | test_size = len(test) 61 | total_size = train_size + valid_size + test_size 62 | 63 | train_mask = torch.zeros((total_size,)).bool() 64 | train_mask[:train_size] = True 65 | 66 | valid_mask = torch.zeros((total_size,)).bool() 67 | valid_mask[train_size:train_size + valid_size] = True 68 | 69 | test_mask = torch.zeros((total_size,)).bool() 70 | test_mask[-test_size:] = True 71 | 72 | # Graph size 73 | num_nodes = edge.flatten().unique().shape[0] 74 | num_edges = edge.shape[1] 75 | num_edge_type = edge_type.unique().shape[0] 76 | 77 | # Node feature 78 | x = torch.rand((num_nodes, args.in_dim)) 79 | 80 | # Delete edges 81 | if len(delete) > 0: 82 | delete_idx = torch.tensor(delete, dtype=torch.long) 83 | num_train_edges = train_size // 2 84 | train_mask[delete_idx] = False 85 | train_mask[delete_idx + num_train_edges] = False 86 | train_size -= 2 * len(delete) 87 | 88 | node_id = torch.arange(num_nodes) 89 | dataset = Data( 90 | edge_index=edge, edge_type=edge_type, x=x, node_id=node_id, 91 | train_mask=train_mask, valid_mask=valid_mask, test_mask=test_mask) 92 | 93 | dataloader = GraphSAINTRandomWalkSampler( 94 | dataset, batch_size=args.batch_size, walk_length=args.walk_length, num_steps=args.num_steps) 95 | 96 | print(f'Dataset: {args.dataset}, Num nodes: {num_nodes}, Num edges: {num_edges//2}, Num relation types: {num_edge_type}') 97 | print(f'Train edges: {train_size//2}, Valid edges: {valid_size//2}, Test edges: {test_size//2}') 98 | 99 | return dataloader, valid, test, true_triples, num_nodes, num_edges, num_edge_type 100 | -------------------------------------------------------------------------------- /framework/trainer/gradient_ascent_with_mp.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | from tqdm import tqdm, trange 4 | import torch 5 | import torch.nn.functional as F 6 | from torch_geometric.utils import negative_sampling 7 | 8 | from .base import Trainer 9 | from ..evaluation import * 10 | from ..utils import * 11 | 12 | 13 | class GradientAscentWithMessagePassingTrainer(Trainer): 14 | def __init__(self,): 15 | self.trainer_log = {'unlearning_model': 'gradient_ascent_with_mp', 'log': []} 16 | 17 | def freeze_unused_mask(self, model, edge_to_delete, subgraph, h): 18 | gradient_mask = torch.zeros_like(delete_model.operator) 19 | 20 | edges = subgraph[h] 21 | for s, t in edges: 22 | if s < t: 23 | gradient_mask[s, t] = 1 24 | gradient_mask = gradient_mask.to(device) 25 | model.operator.register_hook(lambda grad: grad.mul_(gradient_mask)) 26 | 27 | def train(self, model_retrain, model, data, optimizer, args): 28 | best_loss = 100000 29 | for epoch in trange(args.epochs, desc='Unlerning'): 30 | model.train() 31 | total_step = 0 32 | total_loss = 0 33 | 34 | ## Gradient Ascent 35 | neg_edge_index = negative_sampling( 36 | edge_index=data.train_pos_edge_index[:, data.ga_mask], 37 | num_nodes=data.num_nodes, 38 | num_neg_samples=data.ga_mask.sum()) 39 | 40 | # print('data train to unlearn', data.train_pos_edge_index[:, data.ga_mask]) 41 | z = model(data.x, data.train_pos_edge_index[:, data.ga_mask]) 42 | logits = model.decode(z, data.train_pos_edge_index[:, data.ga_mask]) 43 | label = torch.tensor([1], dtype=torch.float, device='cuda') 44 | loss_ga = -F.binary_cross_entropy_with_logits(logits, label) 45 | 46 | ## Message Passing 47 | neg_edge_index = negative_sampling( 48 | edge_index=data.train_pos_edge_index[:, data.mp_mask], 49 | num_nodes=data.num_nodes, 50 | num_neg_samples=data.mp_mask.sum()) 51 | 52 | z = model(data.x, data.train_pos_edge_index[:, data.mp_mask]) 53 | logits = model.decode(z, data.train_pos_edge_index[:, data.mp_mask]) 54 | label = self.get_link_labels(data.train_pos_edge_index[:, data.mp_mask], dtype=torch.float, device='cuda') 55 | loss_mp = F.binary_cross_entropy_with_logits(logits, label) 56 | 57 | loss = loss_ga + loss_mp 58 | loss.backward() 59 | # torch.nn.utils.clip_grad_norm_(model.parameters(), 1) 60 | optimizer.step() 61 | optimizer.zero_grad() 62 | 63 | total_step += 1 64 | total_loss += loss.item() 65 | 66 | msg = [ 67 | f'Epoch: {epoch:>4d}', 68 | f'train loss: {total_loss / total_step:.6f}' 69 | ] 70 | tqdm.write(' | '.join(msg)) 71 | 72 | valid_loss, auc, aup = self.eval(model, data, 'val') 73 | 74 | self.trainer_log['log'].append({ 75 | 'dt_loss': valid_loss, 76 | 'dt_auc': auc, 77 | 'dt_aup': aup 78 | }) 79 | 80 | # Eval unlearn 81 | loss, auc, aup = self.test(model, data) 82 | self.trainer_log['dt_loss'] = loss 83 | self.trainer_log['dt_auc'] = auc 84 | self.trainer_log['dt_aup'] = aup 85 | 86 | self.trainer_log['ve'] = verification_error(model, model_retrain).cpu().item() 87 | self.trainer_log['dr_kld'] = output_kldiv(model, model_retrain, data=data).cpu().item() 88 | 89 | embedding = get_node_embedding_data(model, data) 90 | logits = model.decode(embedding, data.train_pos_edge_index[:, data.dtrain_mask]).sigmoid().detach().cpu() 91 | self.trainer_log['df_score'] = logits[:1].cpu().item() 92 | 93 | 94 | # Save 95 | ckpt = { 96 | 'model_state': model.state_dict(), 97 | 'node_emb': z, 98 | 'optimizer_state': optimizer.state_dict(), 99 | } 100 | torch.save(ckpt, os.path.join(args.checkpoint_dir, 'model.pt')) 101 | print(self.trainer_log) 102 | with open(os.path.join(args.checkpoint_dir, 'trainer_log.json'), 'w') as f: 103 | json.dump(self.trainer_log, f) 104 | -------------------------------------------------------------------------------- /framework/trainer/descent_to_delete.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import wandb 4 | from tqdm import tqdm, trange 5 | import torch 6 | import torch.nn.functional as F 7 | from torch_geometric.utils import negative_sampling 8 | 9 | from .base import Trainer 10 | from ..evaluation import * 11 | from ..utils import * 12 | 13 | 14 | class DtdTrainer(Trainer): 15 | '''This code is adapte from https://github.com/ChrisWaites/descent-to-delete''' 16 | 17 | def compute_sigma(self, num_examples, iterations, lipshitz, smooth, strong, epsilon, delta): 18 | """Theorem 3.1 https://arxiv.org/pdf/2007.02923.pdf""" 19 | 20 | print('delta', delta) 21 | gamma = (smooth - strong) / (smooth + strong) 22 | numerator = 4 * np.sqrt(2) * lipshitz * np.power(gamma, iterations) 23 | denominator = (strong * num_examples * (1 - np.power(gamma, iterations))) * ((np.sqrt(np.log(1 / delta) + epsilon)) - np.sqrt(np.log(1 / delta))) 24 | # print('sigma', numerator, denominator, numerator / denominator) 25 | 26 | return numerator / denominator 27 | 28 | def publish(self, model, sigma): 29 | """Publishing function which adds Gaussian noise with scale sigma.""" 30 | 31 | with torch.no_grad(): 32 | for n, p in model.named_parameters(): 33 | p.copy_(p + torch.empty_like(p).normal_(0, sigma)) 34 | 35 | def train(self, model, data, optimizer, args, logits_ori=None, attack_model_all=None, attack_model_sub=None): 36 | start_time = time.time() 37 | best_valid_loss = 100000 38 | 39 | # MI Attack before unlearning 40 | if attack_model_all is not None: 41 | mi_logit_all_before, mi_sucrate_all_before = member_infer_attack(model, attack_model_all, data) 42 | self.trainer_log['mi_logit_all_before'] = mi_logit_all_before 43 | self.trainer_log['mi_sucrate_all_before'] = mi_sucrate_all_before 44 | if attack_model_sub is not None: 45 | mi_logit_sub_before, mi_sucrate_sub_before = member_infer_attack(model, attack_model_sub, data) 46 | self.trainer_log['mi_logit_sub_before'] = mi_logit_sub_before 47 | self.trainer_log['mi_sucrate_sub_before'] = mi_sucrate_sub_before 48 | 49 | for epoch in trange(args.epochs, desc='Unlerning'): 50 | model.train() 51 | 52 | # Positive and negative sample 53 | neg_edge_index = negative_sampling( 54 | edge_index=data.train_pos_edge_index[:, data.dr_mask], 55 | num_nodes=data.num_nodes, 56 | num_neg_samples=data.dr_mask.sum()) 57 | 58 | z = model(data.x, data.train_pos_edge_index[:, data.dr_mask]) 59 | logits = model.decode(z, data.train_pos_edge_index[:, data.dr_mask], neg_edge_index) 60 | label = get_link_labels(data.train_pos_edge_index[:, data.dr_mask], neg_edge_index) 61 | loss = F.binary_cross_entropy_with_logits(logits, label) 62 | 63 | loss.backward() 64 | torch.nn.utils.clip_grad_norm_(model.parameters(), 1) 65 | optimizer.step() 66 | optimizer.zero_grad() 67 | 68 | log = { 69 | 'Epoch': epoch, 70 | 'train_loss': loss.item(), 71 | } 72 | wandb.log(log) 73 | msg = [f'{i}: {j:>4d}' if isinstance(j, int) else f'{i}: {j:.4f}' for i, j in log.items()] 74 | tqdm.write(' | '.join(msg)) 75 | 76 | valid_loss, auc, aup, df_logt, logit_all_pair = self.eval(model, data, 'val') 77 | 78 | self.trainer_log['log'].append({ 79 | 'dt_loss': valid_loss, 80 | 'dt_auc': auc, 81 | 'dt_aup': aup 82 | }) 83 | 84 | train_size = data.dr_mask.sum().cpu().item() 85 | sigma = self.compute_sigma( 86 | train_size, 87 | args.epochs, 88 | 1 + args.weight_decay, 89 | 4 - args.weight_decay, 90 | args.weight_decay, 91 | 5, 92 | 1 / train_size / train_size) 93 | 94 | self.publish(model, sigma) 95 | 96 | self.trainer_log['sigma'] = sigma 97 | self.trainer_log['training_time'] = time.time() - start_time 98 | 99 | # Save 100 | ckpt = { 101 | 'model_state': {k: v.cpu() for k, v in model.state_dict().items()}, 102 | 'optimizer_state': optimizer.state_dict(), 103 | } 104 | torch.save(ckpt, os.path.join(args.checkpoint_dir, 'model_best.pt')) 105 | -------------------------------------------------------------------------------- /framework/models/graph_classification/gcn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from ogb.graphproppred.mol_encoder import AtomEncoder 5 | from torch_geometric.nn import GCNConv, MessagePassing, global_add_pool, global_mean_pool, global_max_pool, GlobalAttention, Set2Set 6 | from ogb.graphproppred.mol_encoder import AtomEncoder, BondEncoder 7 | from torch_geometric.utils import degree 8 | from torch_geometric.nn.inits import uniform 9 | from torch_scatter import scatter_mean 10 | 11 | 12 | ''' 13 | Source: OGB github 14 | https://github.com/snap-stanford/ogb/blob/master/examples/graphproppred/mol/main_pyg.py 15 | ''' 16 | ### GCN convolution along the graph structure 17 | class GCNConv(MessagePassing): 18 | def __init__(self, emb_dim): 19 | super().__init__(aggr='add') 20 | 21 | self.linear = nn.Linear(emb_dim, emb_dim) 22 | self.root_emb = nn.Embedding(1, emb_dim) 23 | self.bond_encoder = BondEncoder(emb_dim=emb_dim) 24 | 25 | def forward(self, x, edge_index, edge_attr): 26 | x = self.linear(x) 27 | edge_embedding = self.bond_encoder(edge_attr) 28 | 29 | row, col = edge_index 30 | 31 | #edge_weight = torch.ones((edge_index.size(1), ), device=edge_index.device) 32 | deg = degree(row, x.size(0), dtype=x.dtype) + 1 33 | deg_inv_sqrt = deg.pow(-0.5) 34 | deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0 35 | 36 | norm = deg_inv_sqrt[row] * deg_inv_sqrt[col] 37 | 38 | return self.propagate(edge_index, x=x, edge_attr=edge_embedding, norm=norm) + F.relu(x + self.root_emb.weight) * 1./deg.view(-1,1) 39 | 40 | def message(self, x_j, edge_attr, norm): 41 | return norm.view(-1, 1) * F.relu(x_j + edge_attr) 42 | 43 | def update(self, aggr_out): 44 | return aggr_out 45 | 46 | ### GNN to generate node embedding 47 | class GNN_node(nn.Module): 48 | def __init__(self, num_layer, emb_dim, drop_ratio=0.5, JK="last", residual=False): 49 | super().__init__() 50 | self.num_layer = num_layer 51 | self.drop_ratio = drop_ratio 52 | self.JK = JK 53 | ### add residual connection or not 54 | self.residual = residual 55 | 56 | self.atom_encoder = AtomEncoder(emb_dim) 57 | 58 | ###List of GNNs 59 | self.convs = nn.ModuleList() 60 | self.batch_norms = nn.ModuleList() 61 | 62 | for layer in range(num_layer): 63 | self.convs.append(GCNConv(emb_dim)) 64 | self.batch_norms.append(nn.BatchNorm1d(emb_dim)) 65 | 66 | def forward(self, batched_data): 67 | x, edge_index, edge_attr, batch = batched_data.x, batched_data.edge_index, batched_data.edge_attr, batched_data.batch 68 | 69 | ### computing input node embedding 70 | 71 | h_list = [self.atom_encoder(x)] 72 | for layer in range(self.num_layer): 73 | 74 | h = self.convs[layer](h_list[layer], edge_index, edge_attr) 75 | h = self.batch_norms[layer](h) 76 | 77 | if layer == self.num_layer - 1: 78 | #remove relu for the last layer 79 | h = F.dropout(h, self.drop_ratio, training=self.training) 80 | else: 81 | h = F.dropout(F.relu(h), self.drop_ratio, training=self.training) 82 | 83 | if self.residual: 84 | h += h_list[layer] 85 | 86 | h_list.append(h) 87 | 88 | ### Different implementations of Jk-concat 89 | if self.JK == "last": 90 | node_representation = h_list[-1] 91 | elif self.JK == "sum": 92 | node_representation = 0 93 | for layer in range(self.num_layer + 1): 94 | node_representation += h_list[layer] 95 | 96 | return node_representation 97 | 98 | class GNN(torch.nn.Module): 99 | def __init__(self, num_tasks, num_layer=2, emb_dim=300, virtual_node=True, residual=False, drop_ratio=0.5, JK="last", graph_pooling="mean"): 100 | super().__init__() 101 | 102 | self.num_layer = num_layer 103 | self.drop_ratio = drop_ratio 104 | self.JK = JK 105 | self.emb_dim = emb_dim 106 | self.num_tasks = num_tasks 107 | self.graph_pooling = graph_pooling 108 | 109 | ### GNN to generate node embeddings 110 | self.gnn_node = GNN_node(num_layer, emb_dim, JK=JK, drop_ratio=drop_ratio, residual=residual) 111 | 112 | ### Pooling function to generate whole-graph embeddings 113 | if self.graph_pooling == "sum": 114 | self.pool = global_add_pool 115 | elif self.graph_pooling == "mean": 116 | self.pool = global_mean_pool 117 | elif self.graph_pooling == "max": 118 | self.pool = global_max_pool 119 | elif self.graph_pooling == "attention": 120 | self.pool = GlobalAttention(gate_nn=nn.Sequential(Linear(emb_dim, 2*emb_dim), nn.BatchNorm1d(2*emb_dim), nn.ReLU(), nn.Linear(2*emb_dim, 1))) 121 | elif self.graph_pooling == "set2set": 122 | self.pool = Set2Set(emb_dim, processing_steps = 2) 123 | else: 124 | raise ValueError("Invalid graph pooling type.") 125 | 126 | if graph_pooling == "set2set": 127 | self.graph_pred_linear = nn.Linear(2*self.emb_dim, self.num_tasks) 128 | else: 129 | self.graph_pred_linear = nn.Linear(self.emb_dim, self.num_tasks) 130 | 131 | def forward(self, batched_data): 132 | h_node = self.gnn_node(batched_data) 133 | 134 | h_graph = self.pool(h_node, batched_data.batch) 135 | 136 | return self.graph_pred_linear(h_graph) 137 | -------------------------------------------------------------------------------- /framework/trainer/approx_retrain.py: -------------------------------------------------------------------------------- 1 | import os 2 | import wandb 3 | from tqdm import tqdm, trange 4 | import torch 5 | import torch.nn.functional as F 6 | from torch_geometric.utils import negative_sampling 7 | from torch.utils.data import DataLoader, TensorDataset 8 | 9 | from .base import Trainer 10 | from ..evaluation import * 11 | from ..utils import * 12 | 13 | 14 | DTYPE = np.float16 15 | 16 | class ApproxTrainer(Trainer): 17 | '''This code is adapted from https://github.com/zleizzo/datadeletion''' 18 | 19 | def gram_schmidt(self, X): 20 | """ 21 | Uses numpy's qr factorization method to perform Gram-Schmidt. 22 | Args: 23 | X: (k x d matrix) X[i] = i-th vector 24 | Returns: 25 | U: (k x d matrix) U[i] = i-th orthonormal vector 26 | C: (k x k matrix) Coefficient matrix, C[i] = coeffs for X[i], X = CU 27 | """ 28 | (k, d) = X.shape 29 | if k <= d: 30 | q, r = np.linalg.qr(np.transpose(X)) 31 | else: 32 | q, r = np.linalg.qr(np.transpose(X), mode='complete') 33 | U = np.transpose(q) 34 | C = np.transpose(r) 35 | return U, C 36 | 37 | def LKO_pred(self, X, Y, ind, H=None, reg=1e-4): 38 | """ 39 | Computes the LKO model's prediction values on the left-out points. 40 | Args: 41 | X: (n x d matrix) Covariate matrix 42 | Y: (n x 1 vector) Response vector 43 | ind: (k x 1 list) List of indices to be removed 44 | H: (n x n matrix, optional) Hat matrix X (X^T X)^{-1} X^T 45 | Returns: 46 | LKO: (k x 1 vector) Retrained model's predictions on X[i], i in ind 47 | """ 48 | n = len(Y) 49 | k = len(ind) 50 | d = len(X[0, :]) 51 | if H is None: 52 | H = np.matmul(X, np.linalg.solve(np.matmul(X.T, X) + reg * np.eye(d), X.T)) 53 | 54 | LOO = np.zeros(k) 55 | for i in range(k): 56 | idx = ind[i] 57 | # This is the LOO residual y_i - \hat{y}^{LOO}_i 58 | LOO[i] = (Y[idx] - np.matmul(H[idx, :], Y)) / (1 - H[idx, idx]) 59 | 60 | # S = I - T from the paper 61 | S = np.eye(k) 62 | for i in range(k): 63 | for j in range(k): 64 | if j != i: 65 | idx_i = ind[i] 66 | idx_j = ind[j] 67 | S[i, j] = -H[idx_i, idx_j] / (1 - H[idx_i, idx_i]) 68 | 69 | LKO = np.linalg.solve(S, LOO) 70 | 71 | return Y[ind] - LKO 72 | 73 | 74 | def lin_res(self, X, Y, theta, ind, H=None, reg=1e-4): 75 | """ 76 | Approximate retraining via the projective residual update. 77 | Args: 78 | X: (n x d matrix) Covariate matrix 79 | Y: (n x 1 vector) Response vector 80 | theta: (d x 1 vector) Current value of parameters to be updated 81 | ind: (k x 1 list) List of indices to be removed 82 | H: (n x n matrix, optional) Hat matrix X (X^T X)^{-1} X^T 83 | Returns: 84 | updated: (d x 1 vector) Updated parameters 85 | """ 86 | d = len(X[0]) 87 | k = len(ind) 88 | 89 | # Step 1: Compute LKO predictions 90 | LKO = self.LKO_pred(X, Y, ind, H, reg) 91 | 92 | # Step 2: Eigendecompose B 93 | # 2.I 94 | U, C = self.gram_schmidt(X[ind, :]) 95 | # 2.II 96 | Cmatrix = np.matmul(C.T, C) 97 | eigenval, a = np.linalg.eigh(Cmatrix) 98 | V = np.matmul(a.T, U) 99 | 100 | # Step 3: Perform the update 101 | # 3.I 102 | grad = np.zeros_like(theta) # 2D grad 103 | for i in range(k): 104 | grad += (X[ind[i], :] * theta - LKO[i]) * X[ind[i], :] 105 | # 3.II 106 | step = np.zeros_like(theta) # 2D grad 107 | for i in range(k): 108 | factor = 1 / eigenval[i] if eigenval[i] > 1e-10 else 0 109 | step += factor * V[i, :] * grad * V[i, :] 110 | # 3.III 111 | return step 112 | # update = theta - step 113 | # return update 114 | 115 | @torch.no_grad() 116 | def train(self, model, data, optimizer, args, logits_ori=None, attack_model=None): 117 | model.eval() 118 | best_loss = 100000 119 | 120 | neg_edge_index = negative_sampling( 121 | edge_index=data.train_pos_edge_index[:, data.dr_mask], 122 | num_nodes=data.num_nodes, 123 | num_neg_samples=data.dr_mask.sum()) 124 | 125 | z = model(data.x, data.train_pos_edge_index[:, data.dr_mask]) 126 | edge_index_all = torch.cat([data.train_pos_edge_index[:, data.dr_mask], neg_edge_index], dim=1) 127 | 128 | X = z[edge_index_all[0]] * z[edge_index_all[1]] 129 | Y = self.get_link_labels(data.train_pos_edge_index[:, data.dr_mask], neg_edge_index) 130 | X = X.cpu() 131 | Y = Y.cpu() 132 | 133 | # According to the code, theta should be of (d, d). So only update the weights of the last layer 134 | theta = model.conv2.lin.weight.cpu().numpy() 135 | ind = [int(i) for i in self.args.df_idx.split(',')] 136 | 137 | # Not enough RAM for solving matrix inverse. So break into multiple batches 138 | update = [] 139 | loader = DataLoader(TensorDataset(X, Y), batch_size=4096, num_workers=8) 140 | for x, y in tqdm(loader, desc='Unlearning'): 141 | 142 | x = x.numpy() 143 | y = y.numpy() 144 | 145 | update_step = self.lin_res(x, y, theta.T, ind) 146 | update.append(torch.tensor(update_step)) 147 | 148 | update = torch.stack(update).mean(0) 149 | model.conv2.lin.weight = torch.nn.Parameter(model.conv2.lin.weight - update.t().cuda()) 150 | 151 | print(f'Update model weights from {torch.norm(torch.tensor(theta))} to {torch.norm(model.conv2.lin.weight)}') 152 | 153 | valid_loss, auc, aup, df_logt, logit_all_pair = self.eval(model, data, 'val') 154 | 155 | self.trainer_log['log'].append({ 156 | 'dt_loss': valid_loss, 157 | 'dt_auc': auc, 158 | 'dt_aup': aup 159 | }) 160 | 161 | # Save 162 | ckpt = { 163 | 'model_state': {k: v.cpu() for k, v in model.state_dict().items()}, 164 | 'node_emb': None, 165 | 'optimizer_state': None, 166 | } 167 | torch.save(ckpt, os.path.join(args.checkpoint_dir, 'model_best.pt')) 168 | -------------------------------------------------------------------------------- /framework/models/graph_classification/gcn_delete.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from ogb.graphproppred.mol_encoder import AtomEncoder 5 | from torch_geometric.nn import GCNConv, MessagePassing, global_add_pool, global_mean_pool, global_max_pool, GlobalAttention, Set2Set 6 | from ogb.graphproppred.mol_encoder import AtomEncoder, BondEncoder 7 | from torch_geometric.utils import degree 8 | from torch_geometric.nn.inits import uniform 9 | from torch_scatter import scatter_mean 10 | 11 | from ..deletion import DeletionLayer 12 | 13 | 14 | def remove_edges(edge_index, edge_attr=None, ratio=0.025): 15 | row, col = edge_index 16 | mask = row < col 17 | row, col = row[mask], col[mask] 18 | 19 | if edge_attr is not None: 20 | edge_attr = edge_attr[mask] 21 | 22 | num_edges = len(row) 23 | num_remove = max(1, int(num_edges * ratio)) 24 | 25 | selected = torch.randperm(num_edges)[:num_edges - num_remove] 26 | 27 | row = row[selected] 28 | col = col[selected] 29 | edge_attr = edge_attr[selected] 30 | 31 | return torch.stack([row, col], dim=0), edge_attr 32 | 33 | ''' 34 | Source: OGB github 35 | https://github.com/snap-stanford/ogb/blob/master/examples/graphproppred/mol/main_pyg.py 36 | ''' 37 | ### GCN convolution along the graph structure 38 | class GCNConv(MessagePassing): 39 | def __init__(self, emb_dim): 40 | super().__init__(aggr='add') 41 | 42 | self.linear = nn.Linear(emb_dim, emb_dim) 43 | self.root_emb = nn.Embedding(1, emb_dim) 44 | self.bond_encoder = BondEncoder(emb_dim=emb_dim) 45 | 46 | def forward(self, x, edge_index, edge_attr): 47 | x = self.linear(x) 48 | edge_embedding = self.bond_encoder(edge_attr) 49 | 50 | row, col = edge_index 51 | 52 | #edge_weight = torch.ones((edge_index.size(1), ), device=edge_index.device) 53 | deg = degree(row, x.size(0), dtype=x.dtype) + 1 54 | deg_inv_sqrt = deg.pow(-0.5) 55 | deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0 56 | 57 | norm = deg_inv_sqrt[row] * deg_inv_sqrt[col] 58 | 59 | return self.propagate(edge_index, x=x, edge_attr=edge_embedding, norm=norm) + F.relu(x + self.root_emb.weight) * 1./deg.view(-1,1) 60 | 61 | def message(self, x_j, edge_attr, norm): 62 | return norm.view(-1, 1) * F.relu(x_j + edge_attr) 63 | 64 | def update(self, aggr_out): 65 | return aggr_out 66 | 67 | ### GNN to generate node embedding 68 | class GNN_node_delete(nn.Module): 69 | def __init__(self, num_layer, emb_dim, drop_ratio=0.5, JK="last", residual=False, mask_1hop=None, mask_2hop=None): 70 | super().__init__() 71 | self.num_layer = num_layer 72 | self.drop_ratio = drop_ratio 73 | self.JK = JK 74 | ### add residual connection or not 75 | self.residual = residual 76 | 77 | self.atom_encoder = AtomEncoder(emb_dim) 78 | 79 | ###List of GNNs 80 | self.deletes = nn.ModuleList([ 81 | DeletionLayer(emb_dim, None), 82 | DeletionLayer(emb_dim, None) 83 | ]) 84 | self.convs = nn.ModuleList() 85 | self.batch_norms = nn.ModuleList() 86 | 87 | for layer in range(num_layer): 88 | self.convs.append(GCNConv(emb_dim)) 89 | self.batch_norms.append(nn.BatchNorm1d(emb_dim)) 90 | 91 | def forward(self, batched_data): 92 | x, edge_index, edge_attr, batch = batched_data.x, batched_data.edge_index, batched_data.edge_attr, batched_data.batch 93 | edge_index, edge_attr = remove_edges(edge_index, edge_attr) 94 | 95 | ### computing input node embedding 96 | 97 | h_list = [self.atom_encoder(x)] 98 | for layer in range(self.num_layer): 99 | 100 | h = self.convs[layer](h_list[layer], edge_index, edge_attr) 101 | h = self.deletes[layer](h) 102 | h = self.batch_norms[layer](h) 103 | 104 | if layer == self.num_layer - 1: 105 | #remove relu for the last layer 106 | h = F.dropout(h, self.drop_ratio, training=self.training) 107 | else: 108 | h = F.dropout(F.relu(h), self.drop_ratio, training=self.training) 109 | 110 | if self.residual: 111 | h += h_list[layer] 112 | 113 | h_list.append(h) 114 | 115 | ### Different implementations of Jk-concat 116 | if self.JK == "last": 117 | node_representation = h_list[-1] 118 | elif self.JK == "sum": 119 | node_representation = 0 120 | for layer in range(self.num_layer + 1): 121 | node_representation += h_list[layer] 122 | 123 | return node_representation 124 | 125 | class GNN(torch.nn.Module): 126 | def __init__(self, num_tasks, num_layer=2, emb_dim=300, virtual_node=True, residual=False, drop_ratio=0.5, JK="last", graph_pooling="mean"): 127 | super().__init__() 128 | 129 | self.num_layer = num_layer 130 | self.drop_ratio = drop_ratio 131 | self.JK = JK 132 | self.emb_dim = emb_dim 133 | self.num_tasks = num_tasks 134 | self.graph_pooling = graph_pooling 135 | 136 | ### GNN to generate node embeddings 137 | self.gnn_node = GNN_node_delete(num_layer, emb_dim, JK=JK, drop_ratio=drop_ratio, residual=residual) 138 | 139 | ### Pooling function to generate whole-graph embeddings 140 | if self.graph_pooling == "sum": 141 | self.pool = global_add_pool 142 | elif self.graph_pooling == "mean": 143 | self.pool = global_mean_pool 144 | elif self.graph_pooling == "max": 145 | self.pool = global_max_pool 146 | elif self.graph_pooling == "attention": 147 | self.pool = GlobalAttention(gate_nn=nn.Sequential(Linear(emb_dim, 2*emb_dim), nn.BatchNorm1d(2*emb_dim), nn.ReLU(), nn.Linear(2*emb_dim, 1))) 148 | elif self.graph_pooling == "set2set": 149 | self.pool = Set2Set(emb_dim, processing_steps = 2) 150 | else: 151 | raise ValueError("Invalid graph pooling type.") 152 | 153 | if graph_pooling == "set2set": 154 | self.graph_pred_linear = nn.Linear(2*self.emb_dim, self.num_tasks) 155 | else: 156 | self.graph_pred_linear = nn.Linear(self.emb_dim, self.num_tasks) 157 | 158 | def forward(self, batched_data): 159 | h_node = self.gnn_node(batched_data) 160 | 161 | h_graph = self.pool(h_node, batched_data.batch) 162 | 163 | return self.graph_pred_linear(h_graph) 164 | -------------------------------------------------------------------------------- /framework/evaluation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from sklearn.metrics import roc_auc_score, average_precision_score 4 | from .utils import get_link_labels 5 | 6 | 7 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 8 | 9 | @torch.no_grad() 10 | def eval_lp(model, stage, data=None, loader=None): 11 | model.eval() 12 | 13 | # For full batch 14 | if data is not None: 15 | pos_edge_index = data[f'{stage}_pos_edge_index'] 16 | neg_edge_index = data[f'{stage}_neg_edge_index'] 17 | 18 | if hasattr(data, 'dtrain_mask') and data.dtrain_mask is not None: 19 | embedding = model(data.x.to(device), data.train_pos_edge_index[:, data.dtrain_mask].to(device)) 20 | else: 21 | embedding = model(data.x.to(device), data.train_pos_edge_index.to(device)) 22 | 23 | logits = model.decode(embedding, pos_edge_index, neg_edge_index).sigmoid() 24 | label = get_link_labels(pos_edge_index, neg_edge_index) 25 | 26 | # For mini batch 27 | if loader is not None: 28 | logits = [] 29 | label = [] 30 | for batch in loader: 31 | edge_index = batch.edge_index.to(device) 32 | 33 | if hasattr(batch, 'edge_type'): 34 | edge_type = batch.edge_type.to(device) 35 | 36 | embedding1 = model1(edge_index, edge_type) 37 | embedding2 = model2(edge_index, edge_type) 38 | 39 | s1 = model.decode(embedding1, edge_index, edge_type) 40 | s2 = model.decode(embedding2, edge_index, edge_type) 41 | 42 | else: 43 | embedding1 = model1(edge_index) 44 | embedding2 = model2(edge_index) 45 | 46 | s1 = model.decode(embedding1, edge_index) 47 | s2 = model.decode(embedding2, edge_index) 48 | 49 | embedding = model(data.train_pos_edge_index.to(device)) 50 | 51 | lg = model.decode(embedding, pos_edge_index, neg_edge_index).sigmoid() 52 | lb = get_link_labels(pos_edge_index, neg_edge_index) 53 | 54 | logits.append(lg) 55 | label.append(lb) 56 | 57 | loss = F.binary_cross_entropy_with_logits(logits, label) 58 | auc = roc_auc_score(label.cpu(), logits.cpu()) 59 | aup = average_precision_score(label.cpu(), logits.cpu()) 60 | 61 | return loss, auc, aup 62 | 63 | @torch.no_grad() 64 | def verification_error(model1, model2): 65 | '''L2 distance between aproximate model and re-trained model''' 66 | 67 | model1 = model1.to('cpu') 68 | model2 = model2.to('cpu') 69 | 70 | modules1 = {n: p for n, p in model1.named_parameters()} 71 | modules2 = {n: p for n, p in model2.named_parameters()} 72 | 73 | all_names = set(modules1.keys()) & set(modules2.keys()) 74 | 75 | print(all_names) 76 | 77 | diff = torch.tensor(0.0).float() 78 | for n in all_names: 79 | diff += torch.norm(modules1[n] - modules2[n]) 80 | 81 | return diff 82 | 83 | @torch.no_grad() 84 | def member_infer_attack(target_model, attack_model, data, logits=None): 85 | '''Membership inference attack''' 86 | 87 | edge = data.train_pos_edge_index[:, data.df_mask] 88 | z = target_model(data.x, data.train_pos_edge_index[:, data.dr_mask]) 89 | feature1 = target_model.decode(z, edge).sigmoid() 90 | feature0 = 1 - feature1 91 | feature = torch.stack([feature0, feature1], dim=1) 92 | # feature = torch.cat([z[edge[0]], z[edge][1]], dim=-1) 93 | logits = attack_model(feature) 94 | _, pred = torch.max(logits, 1) 95 | suc_rate = 1 - pred.float().mean() 96 | 97 | return torch.softmax(logits, dim=-1).squeeze().tolist(), suc_rate.cpu().item() 98 | 99 | @torch.no_grad() 100 | def member_infer_attack_node(target_model, attack_model, data, logits=None): 101 | '''Membership inference attack''' 102 | 103 | edge = data.train_pos_edge_index[:, data.df_mask] 104 | z = target_model(data.x, data.train_pos_edge_index[:, data.dr_mask]) 105 | feature = torch.cat([z[edge[0]], z[edge][1]], dim=-1) 106 | logits = attack_model(feature) 107 | _, pred = torch.max(logits, 1) 108 | suc_rate = 1 - pred.float().mean() 109 | 110 | return torch.softmax(logits, dim=-1).squeeze().tolist(), suc_rate.cpu().item() 111 | 112 | @torch.no_grad() 113 | def get_node_embedding_data(model, data): 114 | model.eval() 115 | 116 | if hasattr(data, 'dtrain_mask') and data.dtrain_mask is not None: 117 | node_embedding = model(data.x.to(device), data.train_pos_edge_index[:, data.dtrain_mask].to(device)) 118 | else: 119 | node_embedding = model(data.x.to(device), data.train_pos_edge_index.to(device)) 120 | 121 | return node_embedding 122 | 123 | @torch.no_grad() 124 | def output_kldiv(model1, model2, data=None, loader=None): 125 | '''KL-Divergence between output distribution of model and re-trained model''' 126 | 127 | model1.eval() 128 | model2.eval() 129 | 130 | # For full batch 131 | if data is not None: 132 | embedding1 = get_node_embedding_data(model1, data).to(device) 133 | embedding2 = get_node_embedding_data(model2, data).to(device) 134 | 135 | if data.edge_index is not None: 136 | edge_index = data.edge_index.to(device) 137 | if data.train_pos_edge_index is not None: 138 | edge_index = data.train_pos_edge_index.to(device) 139 | 140 | 141 | if hasattr(data, 'edge_type'): 142 | edge_type = data.edge_type.to(device) 143 | score1 = model1.decode(embedding1, edge_index, edge_type) 144 | score2 = model2.decode(embedding2, edge_index, edge_type) 145 | else: 146 | score1 = model1.decode(embedding1, edge_index) 147 | score2 = model2.decode(embedding2, edge_index) 148 | 149 | # For mini batch 150 | if loader is not None: 151 | score1 = [] 152 | score2 = [] 153 | for batch in loader: 154 | edge_index = batch.edge_index.to(device) 155 | 156 | if hasattr(batch, 'edge_type'): 157 | edge_type = batch.edge_type.to(device) 158 | 159 | embedding1 = model1(edge, edge_type) 160 | embedding2 = model2(edge, edge_type) 161 | 162 | s1 = model.decode(embedding1, edge, edge_type) 163 | s2 = model.decode(embedding2, edge, edge_type) 164 | 165 | else: 166 | embedding1 = model1(edge) 167 | embedding2 = model2(edge) 168 | 169 | s1 = model.decode(embedding1, edge) 170 | s2 = model.decode(embedding2, edge) 171 | 172 | score1.append(s1) 173 | score2.append(s2) 174 | 175 | score1 = torch.hstack(score1) 176 | score2 = torch.hstack(score2) 177 | 178 | kldiv = F.kl_div( 179 | F.log_softmax(score1, dim=-1), 180 | F.softmax(score2, dim=-1) 181 | ) 182 | 183 | return kldiv 184 | 185 | -------------------------------------------------------------------------------- /framework/training_args.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | 4 | num_edge_type_mapping = { 5 | 'FB15k-237': 237, 6 | 'WordNet18': 18, 7 | 'WordNet18RR': 11, 8 | 'ogbl-biokg': 51 9 | } 10 | 11 | def parse_args(): 12 | parser = argparse.ArgumentParser() 13 | 14 | # Model 15 | parser.add_argument('--unlearning_model', type=str, default='retrain', 16 | help='unlearning method') 17 | parser.add_argument('--gnn', type=str, default='gcn', 18 | help='GNN architecture') 19 | parser.add_argument('--in_dim', type=int, default=128, 20 | help='input dimension') 21 | parser.add_argument('--hidden_dim', type=int, default=128, 22 | help='hidden dimension') 23 | parser.add_argument('--out_dim', type=int, default=64, 24 | help='output dimension') 25 | 26 | # Data 27 | parser.add_argument('--data_dir', type=str, default='./data', 28 | help='data dir') 29 | parser.add_argument('--df', type=str, default='none', 30 | help='Df set to use') 31 | parser.add_argument('--df_idx', type=str, default='none', 32 | help='indices of data to be deleted') 33 | parser.add_argument('--df_size', type=float, default=0.5, 34 | help='Df size') 35 | parser.add_argument('--dataset', type=str, default='Cora', 36 | help='dataset') 37 | parser.add_argument('--random_seed', type=int, default=42, 38 | help='random seed') 39 | parser.add_argument('--batch_size', type=int, default=8192, 40 | help='batch size for GraphSAINTRandomWalk sampler') 41 | parser.add_argument('--walk_length', type=int, default=2, 42 | help='random walk length for GraphSAINTRandomWalk sampler') 43 | parser.add_argument('--num_steps', type=int, default=32, 44 | help='number of steps for GraphSAINTRandomWalk sampler') 45 | 46 | # Training 47 | parser.add_argument('--lr', type=float, default=1e-3, 48 | help='initial learning rate') 49 | parser.add_argument('--weight_decay', type=float, default=0.0005, 50 | help='weight decay') 51 | parser.add_argument('--optimizer', type=str, default='Adam', 52 | help='optimizer to use') 53 | parser.add_argument('--epochs', type=int, default=3000, 54 | help='number of epochs to train') 55 | parser.add_argument('--valid_freq', type=int, default=100, 56 | help='# of epochs to do validation') 57 | parser.add_argument('--checkpoint_dir', type=str, default='./checkpoint', 58 | help='checkpoint folder') 59 | parser.add_argument('--alpha', type=float, default=0.5, 60 | help='alpha in loss function') 61 | parser.add_argument('--neg_sample_random', type=str, default='non_connected', 62 | help='type of negative samples for randomness') 63 | parser.add_argument('--loss_fct', type=str, default='mse_mean', 64 | help='loss function. one of {mse, kld, cosine}') 65 | parser.add_argument('--loss_type', type=str, default='both_layerwise', 66 | help='type of loss. one of {both_all, both_layerwise, only2_layerwise, only2_all, only1}') 67 | 68 | # GraphEraser 69 | parser.add_argument('--num_clusters', type=int, default=10, 70 | help='top k for evaluation') 71 | parser.add_argument('--kmeans_max_iters', type=int, default=1, 72 | help='top k for evaluation') 73 | parser.add_argument('--shard_size_delta', type=float, default=0.005) 74 | parser.add_argument('--terminate_delta', type=int, default=0) 75 | 76 | # GraphEditor 77 | parser.add_argument('--eval_steps', type=int, default=1) 78 | parser.add_argument('--runs', type=int, default=1) 79 | 80 | parser.add_argument('--num_remove_links', type=int, default=11) 81 | parser.add_argument('--parallel_unlearning', type=int, default=4) 82 | 83 | parser.add_argument('--lam', type=float, default=0) 84 | parser.add_argument('--regen_feats', action='store_true') 85 | parser.add_argument('--regen_neighbors', action='store_true') 86 | parser.add_argument('--regen_links', action='store_true') 87 | parser.add_argument('--regen_subgraphs', action='store_true') 88 | parser.add_argument('--hop_neighbors', type=int, default=20) 89 | 90 | 91 | # Evaluation 92 | parser.add_argument('--topk', type=int, default=500, 93 | help='top k for evaluation') 94 | parser.add_argument('--eval_on_cpu', type=bool, default=False, 95 | help='whether to evaluate on CPU') 96 | 97 | # KG 98 | parser.add_argument('--num_edge_type', type=int, default=None, 99 | help='number of edges types') 100 | 101 | args = parser.parse_args() 102 | 103 | if 'ogbl' in args.dataset: 104 | args.eval_on_cpu = True 105 | 106 | # For KG 107 | if args.gnn in ['rgcn', 'rgat']: 108 | args.lr = 1e-3 109 | args.epochs = 3000 110 | args.valid_freq = 500 111 | args.batch_size //= 2 112 | args.num_edge_type = num_edge_type_mapping[args.dataset] 113 | args.eval_on_cpu = True 114 | # args.in_dim = 512 115 | # args.hidden_dim = 256 116 | # args.out_dim = 128 117 | 118 | if args.unlearning_model in ['original', 'retrain']: 119 | args.epochs = 2000 120 | args.valid_freq = 500 121 | 122 | # For large graphs 123 | if args.gnn not in ['rgcn', 'rgat'] and 'ogbl' in args.dataset: 124 | args.epochs = 600 125 | args.valid_freq = 200 126 | if args.gnn in ['rgcn', 'rgat'] and 'ogbl' in args.dataset: 127 | args.batch_size = 1024 128 | 129 | if 'gnndelete' in args.unlearning_model: 130 | if args.gnn not in ['rgcn', 'rgat'] and 'ogbl' in args.dataset: 131 | args.epochs = 600 132 | args.valid_freq = 100 133 | if args.gnn in ['rgcn', 'rgat']: 134 | if args.dataset == 'WordNet18': 135 | args.epochs = 50 136 | args.valid_freq = 2 137 | args.batch_size = 1024 138 | if args.dataset == 'ogbl-biokg': 139 | args.epochs = 50 140 | args.valid_freq = 10 141 | args.batch_size = 64 142 | 143 | elif args.unlearning_model == 'gradient_ascent': 144 | args.epochs = 10 145 | args.valid_freq = 1 146 | 147 | elif args.unlearning_model == 'descent_to_delete': 148 | args.epochs = 1 149 | 150 | elif args.unlearning_model == 'graph_editor': 151 | args.epochs = 400 152 | args.valid_freq = 200 153 | 154 | 155 | if args.dataset == 'ogbg-molhiv': 156 | args.epochs = 100 157 | args.valid_freq = 5 158 | 159 | return args 160 | -------------------------------------------------------------------------------- /framework/models/deletion.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.nn.init as init 5 | from . import GCN, GAT, GIN, RGCN, RGAT 6 | 7 | 8 | class DeletionLayer(nn.Module): 9 | def __init__(self, dim, mask): 10 | super().__init__() 11 | self.dim = dim 12 | self.mask = mask 13 | self.deletion_weight = nn.Parameter(torch.ones(dim, dim) / 1000) 14 | # self.deletion_weight = nn.Parameter(torch.eye(dim, dim)) 15 | # init.xavier_uniform_(self.deletion_weight) 16 | 17 | def forward(self, x, mask=None): 18 | '''Only apply deletion operator to the local nodes identified by mask''' 19 | 20 | if mask is None: 21 | mask = self.mask 22 | 23 | if mask is not None: 24 | new_rep = x.clone() 25 | new_rep[mask] = torch.matmul(new_rep[mask], self.deletion_weight) 26 | 27 | return new_rep 28 | 29 | return x 30 | 31 | class DeletionLayerKG(nn.Module): 32 | def __init__(self, dim, mask): 33 | super().__init__() 34 | self.dim = dim 35 | self.mask = mask 36 | self.deletion_weight = nn.Parameter(torch.ones(dim, dim) / 1000) 37 | 38 | def forward(self, x, mask=None): 39 | '''Only apply deletion operator to the local nodes identified by mask''' 40 | 41 | if mask is None: 42 | mask = self.mask 43 | 44 | if mask is not None: 45 | new_rep = x.clone() 46 | new_rep[mask] = torch.matmul(new_rep[mask], self.deletion_weight) 47 | 48 | return new_rep 49 | 50 | return x 51 | 52 | class GCNDelete(GCN): 53 | def __init__(self, args, mask_1hop=None, mask_2hop=None, **kwargs): 54 | super().__init__(args) 55 | self.deletion1 = DeletionLayer(args.hidden_dim, mask_1hop) 56 | self.deletion2 = DeletionLayer(args.out_dim, mask_2hop) 57 | 58 | self.conv1.requires_grad = False 59 | self.conv2.requires_grad = False 60 | 61 | def forward(self, x, edge_index, mask_1hop=None, mask_2hop=None, return_all_emb=False): 62 | # with torch.no_grad(): 63 | x1 = self.conv1(x, edge_index) 64 | 65 | x1 = self.deletion1(x1, mask_1hop) 66 | 67 | x = F.relu(x1) 68 | 69 | x2 = self.conv2(x, edge_index) 70 | x2 = self.deletion2(x2, mask_2hop) 71 | 72 | if return_all_emb: 73 | return x1, x2 74 | 75 | return x2 76 | 77 | def get_original_embeddings(self, x, edge_index, return_all_emb=False): 78 | return super().forward(x, edge_index, return_all_emb) 79 | 80 | class GATDelete(GAT): 81 | def __init__(self, args, mask_1hop=None, mask_2hop=None, **kwargs): 82 | super().__init__(args) 83 | self.deletion1 = DeletionLayer(args.hidden_dim, mask_1hop) 84 | self.deletion2 = DeletionLayer(args.out_dim, mask_2hop) 85 | 86 | self.conv1.requires_grad = False 87 | self.conv2.requires_grad = False 88 | 89 | def forward(self, x, edge_index, mask_1hop=None, mask_2hop=None, return_all_emb=False): 90 | with torch.no_grad(): 91 | x1 = self.conv1(x, edge_index) 92 | x1 = self.deletion1(x1, mask_1hop) 93 | 94 | x = F.relu(x1) 95 | 96 | x2 = self.conv2(x, edge_index) 97 | x2 = self.deletion2(x2, mask_2hop) 98 | 99 | if return_all_emb: 100 | return x1, x2 101 | 102 | return x2 103 | 104 | def get_original_embeddings(self, x, edge_index, return_all_emb=False): 105 | return super().forward(x, edge_index, return_all_emb) 106 | 107 | class GINDelete(GIN): 108 | def __init__(self, args, mask_1hop=None, mask_2hop=None, **kwargs): 109 | super().__init__(args) 110 | self.deletion1 = DeletionLayer(args.hidden_dim, mask_1hop) 111 | self.deletion2 = DeletionLayer(args.out_dim, mask_2hop) 112 | 113 | self.conv1.requires_grad = False 114 | self.conv2.requires_grad = False 115 | 116 | def forward(self, x, edge_index, mask_1hop=None, mask_2hop=None, return_all_emb=False): 117 | with torch.no_grad(): 118 | x1 = self.conv1(x, edge_index) 119 | 120 | x1 = self.deletion1(x1, mask_1hop) 121 | 122 | x = F.relu(x1) 123 | 124 | x2 = self.conv2(x, edge_index) 125 | x2 = self.deletion2(x2, mask_2hop) 126 | 127 | if return_all_emb: 128 | return x1, x2 129 | 130 | return x2 131 | 132 | def get_original_embeddings(self, x, edge_index, return_all_emb=False): 133 | return super().forward(x, edge_index, return_all_emb) 134 | 135 | class RGCNDelete(RGCN): 136 | def __init__(self, args, num_nodes, num_edge_type, mask_1hop=None, mask_2hop=None, **kwargs): 137 | super().__init__(args, num_nodes, num_edge_type) 138 | self.deletion1 = DeletionLayer(args.hidden_dim, mask_1hop) 139 | self.deletion2 = DeletionLayer(args.out_dim, mask_2hop) 140 | 141 | self.node_emb.requires_grad = False 142 | self.conv1.requires_grad = False 143 | self.conv2.requires_grad = False 144 | 145 | def forward(self, x, edge_index, edge_type, mask_1hop=None, mask_2hop=None, return_all_emb=False): 146 | with torch.no_grad(): 147 | x = self.node_emb(x) 148 | x1 = self.conv1(x, edge_index, edge_type) 149 | 150 | x1 = self.deletion1(x1, mask_1hop) 151 | 152 | x = F.relu(x1) 153 | 154 | x2 = self.conv2(x, edge_index, edge_type) 155 | x2 = self.deletion2(x2, mask_2hop) 156 | 157 | if return_all_emb: 158 | return x1, x2 159 | 160 | return x2 161 | 162 | def get_original_embeddings(self, x, edge_index, edge_type, return_all_emb=False): 163 | return super().forward(x, edge_index, edge_type, return_all_emb) 164 | 165 | class RGATDelete(RGAT): 166 | def __init__(self, args, num_nodes, num_edge_type, mask_1hop=None, mask_2hop=None, **kwargs): 167 | super().__init__(args, num_nodes, num_edge_type) 168 | self.deletion1 = DeletionLayer(args.hidden_dim, mask_1hop) 169 | self.deletion2 = DeletionLayer(args.out_dim, mask_2hop) 170 | 171 | self.node_emb.requires_grad = False 172 | self.conv1.requires_grad = False 173 | self.conv2.requires_grad = False 174 | 175 | def forward(self, x, edge_index, edge_type, mask_1hop=None, mask_2hop=None, return_all_emb=False): 176 | with torch.no_grad(): 177 | x = self.node_emb(x) 178 | x1 = self.conv1(x, edge_index, edge_type) 179 | 180 | x1 = self.deletion1(x1, mask_1hop) 181 | 182 | x = F.relu(x1) 183 | 184 | x2 = self.conv2(x, edge_index, edge_type) 185 | x2 = self.deletion2(x2, mask_2hop) 186 | 187 | if return_all_emb: 188 | return x1, x2 189 | 190 | return x2 191 | 192 | def get_original_embeddings(self, x, edge_index, edge_type, return_all_emb=False): 193 | return super().forward(x, edge_index, edge_type, return_all_emb) 194 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # GNNDelete: A General Unlearning Strategy for Graph Neural Networks 3 | 4 | #### Authors: 5 | - [Jiali Cheng]() (jiali.cheng.ccnr@gmail.com) 6 | - [George Dasoulas](https://gdasoulas.github.io/) (george.dasoulas1@gmail.com) 7 | - [Huan He](https://github.com/mims-harvard/Raindrop) (huan_he@hms.harvard.edu) 8 | - [Chirag Agarwal](https://chirag126.github.io/) (chiragagarwall12@gmail.com) 9 | - [Marinka Zitnik](https://zitniklab.hms.harvard.edu/) (marinka@hms.harvard.edu) 10 | 11 | #### [Project website](https://zitniklab.hms.harvard.edu/projects/GNNDelete/) 12 | 13 | #### GNNDelete Paper: [ICLR 2023](https://openreview.net/forum?id=X9yCkmT5Qrl), [Preprint]() 14 | 15 | 16 | ## Overview 17 | 18 | This repository contains the code to preprocess datasets, train GNN models, and perform data deletion on trained GNN models for manuscript *GNNDelete: A General Graph Unlearning Strategy*. We propose GNNDelete, a model-agnostic layer-wise operator that optimize both properties for unlearning tasks. It formalizes the required properties for graph unlearning in the form of Deleted Edge Consistency and Neighborhood Influence. GNNDelete updates latent representations to delete nodes and edges from the model while keeping the rest of the learned knowledge intact. 19 | 20 |

21 | 22 |

23 | 24 | 25 | ## Key idea of GNNDelete 26 | 27 | To unlearn information from a trained GNN, its influence on both GNN model weights as well as on representations of neighbors in the graph must be deleted from the model. However, existing methods using retraining and weight modification either degrade model weights shared across all nodes or are ineffective because of strong dependency of deleted edges on their local graph neighborhood. 28 | 29 | Our model formulates the unlearning problem as a representation learning task. It formalizes the required properties for graph unlearning in the form of Deleted Edge Consistency and Neighborhood Influence. GNNDelete updates latent representations to delete nodes and edges from the model while keeping the rest of the learned knowledge intact. 30 | 31 | **Overview of GNNDelete approach.** Our model extends the standard (Msg, Agg, Upd) GNN framework into (Msg, Agg, Upd, Del). Upon unlearning, GNNDelete inserts trainable deletion operators after the GNN layers. The **Del** operator updates the node representations of the affected nodes of deletion (based on the local enclosing subgraph of the deleted information). The updated representations are optimized to meet objectives of **Deleted Edge Consistency** and **Neighborhood Influence**. We only train the deletion operators, while freezing the rest of the GNN weights. 32 | 33 | 34 | ## Datasets 35 | 36 | We prepared seven commonly used datasets of different scales, including homogenous and heterogeneous graphs. 37 | 38 | Please run the following command to do train-test split and sample deleted information (Two strategies are described in the paper). 39 | ``` 40 | python prepare_dataset.py 41 | ``` 42 | 43 | The following table summarizes the statistics of all these seven datasets: 44 | | Dataset | # Nodes | # Edges | # Unique edge types | Max # deleted edges | 45 | |------------|:-------:|:---------:|:-------------------:|:-------------------:| 46 | | Cora | 19,793 | 126,842 | N/A | 6,342 | 47 | | PubMed | 19,717 | 88,648 | N/A | 4,432 | 48 | | DBLP | 17,716 | 105,734 | N/A | 5,286 | 49 | | CS | 18,333 | 163,788 | N/A | 8,189 | 50 | | OGB-Collab | 235,368 | 1,285,465 | N/A | 117,905 | 51 | | WordNet18 | 40,943 | 151,442 | 18 | 7,072 | 52 | | BioKG | 93,773 | 5,088,434 | 51 | 127,210 | 53 | 54 | 55 | ## Experimental setups 56 | 57 | We evaluated our model in three unlearning tasks and in comparison with six baselines. The baselines include three state-of-the-art models designed for graph unlearning (GraphEraser, GraphEditor, Certified Graph Unlearning) and three general unlearning method (retraining, gradient ascent, Descent-to-Delete). The three different unlearning tasks are: 58 | 59 | **Unlearning task 1: delete edges.** We delete a set of edges from a trained GNN model. 60 | 61 | **Unlearning task 2: delete nodes.** We delete a set of nodes from a trained GNN model. 62 | 63 | **Unlearning task 3: delete node features.** We delete the node features of a set of nodes from a trained GNN model. 64 | 65 | The deleted information can be sampled with two different strategies: 66 | 67 | **A simpler setting: Out setting** The deleted information is sampled from **outside** the enclosing subgraph of test set. 68 | 69 | **A harder setting: In setting** The deleted information is sampled from **within** the enclosing subgraph of test set. 70 | 71 | 72 | ## Requirements 73 | 74 | GNNDelete has been tested using Python >=3.6. 75 | 76 | Please install the required packages by running 77 | 78 | ``` 79 | pip install -r requirements.txt 80 | ``` 81 | 82 | 83 | ## Running the code 84 | 85 | **Train GNN** The first step is to train a GNN model, on either link prediction or node classification 86 | 87 | ```python train_gnn.py``` 88 | 89 | Or 90 | 91 | ```python train_node_classification.py``` 92 | 93 | **Train Membership Inference attacker (_Optional_)** We use the model in _Membership Inference Attack on Graph Neural Networks_ as our MI model. Please refer to the [official implementation](https://github.com/iyempissy/rebMIGraph). 94 | 95 | **Unlearn** Then we can delete information from the trained GNN model. Based on what you want to delete, run one of the three scrips 96 | 97 | To unlearn edges, please run 98 | ``` 99 | python delete_gnn.py 100 | ``` 101 | 102 | To unlearn nodes, please run 103 | ``` 104 | python delete_nodes.py 105 | ``` 106 | 107 | To unlearn node features, please run 108 | ``` 109 | python delete_node_feature.py 110 | ``` 111 | 112 | 113 | **Baselines** 114 | We compare GNNDelete to several baselines 115 | - Retraining from scratch, please run the above unlearning scripts with `--unlearning_method retrain` 116 | - Gradient ascent, please run the above unlearning scripts with `--unlearning_method gradient_ascent` 117 | - Descent-to-Delete, please run the above unlearning scripts with `--unlearning_method descent_to_delete` 118 | - GraphEditor, please run the above unlearning scripts with `--unlearning_method graph_editor` 119 | - GraphEraser, please refer to the [official implementation](https://github.com/MinChen00/Graph-Unlearning) 120 | - Certified Graph Unlearning, please refer to the [official implementation](https://github.com/thupchnsky/sgc_unlearn) 121 | 122 | 123 | ## Citation 124 | 125 | If you find *GNNDelete* useful for your research, please consider citing this paper: 126 | 127 | ``` 128 | @inproceedings{cheng2023gnndelete, 129 | title={{GNND}elete: A General Unlearning Strategy for Graph Neural Networks}, 130 | author={Jiali Cheng and George Dasoulas and Huan He and Chirag Agarwal and Marinka Zitnik}, 131 | booktitle={International Conference on Learning Representations}, 132 | year={2023}, 133 | url={https://openreview.net/forum?id=X9yCkmT5Qrl} 134 | } 135 | ``` 136 | 137 | 138 | ## Miscellaneous 139 | 140 | Please send any questions you might have about the code and/or the algorithm to . 141 | 142 | 143 | 144 | ## License 145 | 146 | GNNDelete codebase is released under the MIT license. 147 | -------------------------------------------------------------------------------- /framework/trainer/member_infer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import wandb 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | from tqdm import trange, tqdm 8 | from torch_geometric.utils import negative_sampling 9 | from sklearn.metrics import accuracy_score, roc_auc_score, average_precision_score, f1_score 10 | 11 | from .base import Trainer 12 | from ..evaluation import * 13 | from ..utils import * 14 | 15 | 16 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 17 | 18 | class MIAttackTrainer(Trainer): 19 | '''This code is adapted from https://github.com/iyempissy/rebMIGraph''' 20 | 21 | def __init__(self, args): 22 | self.args = args 23 | self.trainer_log = { 24 | 'unlearning_model': 'member_infer', 25 | 'dataset': args.dataset, 26 | 'seed': args.random_seed, 27 | 'shadow_log': [], 28 | 'attack_log': []} 29 | self.logit_all_pair = None 30 | 31 | with open(os.path.join(self.args.checkpoint_dir, 'training_args.json'), 'w') as f: 32 | json.dump(vars(args), f) 33 | 34 | def train_shadow(self, model, data, optimizer, args): 35 | best_valid_loss = 1000000 36 | 37 | all_neg = [] 38 | # Train shadow model using the test data 39 | for epoch in trange(args.epochs, desc='Train shadow model'): 40 | model.train() 41 | 42 | # Positive and negative sample 43 | neg_edge_index = negative_sampling( 44 | edge_index=data.test_pos_edge_index, 45 | num_nodes=data.num_nodes, 46 | num_neg_samples=data.test_pos_edge_index.shape[1]) 47 | 48 | z = model(data.x, data.test_pos_edge_index) 49 | logits = model.decode(z, data.test_pos_edge_index, neg_edge_index) 50 | label = get_link_labels(data.test_pos_edge_index, neg_edge_index) 51 | loss = F.binary_cross_entropy_with_logits(logits, label) 52 | 53 | loss.backward() 54 | optimizer.step() 55 | optimizer.zero_grad() 56 | 57 | all_neg.append(neg_edge_index.cpu()) 58 | 59 | if (epoch+1) % args.valid_freq == 0: 60 | valid_loss, auc, aup, df_logit, logit_all_pair = self.eval_shadow(model, data, 'val') 61 | 62 | log = { 63 | 'shadow_epoch': epoch, 64 | 'shadow_train_loss': loss.item(), 65 | 'shadow_valid_loss': valid_loss, 66 | 'shadow_valid_auc': auc, 67 | 'shadow_valid_aup': aup, 68 | 'shadow_df_logit': df_logit 69 | } 70 | wandb.log(log) 71 | self.trainer_log['shadow_log'].append(log) 72 | 73 | msg = [f'{i}: {j:>4d}' if isinstance(j, int) else f'{i}: {j:.4f}' for i, j in log.items()] 74 | tqdm.write(' | '.join(msg)) 75 | 76 | if valid_loss < best_valid_loss: 77 | best_valid_loss = valid_loss 78 | best_epoch = epoch 79 | 80 | self.trainer_log['shadow_best_epoch'] = best_epoch 81 | self.trainer_log['shadow_best_valid_loss'] = best_valid_loss 82 | 83 | print(f'Save best checkpoint at epoch {epoch:04d}. Valid loss = {valid_loss:.4f}') 84 | ckpt = { 85 | 'model_state': model.state_dict(), 86 | 'optimizer_state': optimizer.state_dict(), 87 | } 88 | torch.save(ckpt, os.path.join(args.checkpoint_dir, 'shadow_model_best.pt')) 89 | 90 | return torch.cat(all_neg, dim=-1) 91 | 92 | @torch.no_grad() 93 | def eval_shadow(self, model, data, stage='val'): 94 | model.eval() 95 | pos_edge_index = data[f'{stage}_pos_edge_index'] 96 | neg_edge_index = data[f'{stage}_neg_edge_index'] 97 | 98 | z = model(data.x, data.val_pos_edge_index) 99 | logits = model.decode(z, pos_edge_index, neg_edge_index).sigmoid() 100 | label = self.get_link_labels(pos_edge_index, neg_edge_index) 101 | 102 | loss = F.binary_cross_entropy_with_logits(logits, label).cpu().item() 103 | auc = roc_auc_score(label.cpu(), logits.cpu()) 104 | aup = average_precision_score(label.cpu(), logits.cpu()) 105 | df_logit = float('nan') 106 | 107 | logit_all_pair = (z @ z.t()).cpu() 108 | 109 | log = { 110 | f'{stage}_loss': loss, 111 | f'{stage}_auc': auc, 112 | f'{stage}_aup': aup, 113 | f'{stage}_df_logit': df_logit, 114 | } 115 | wandb.log(log) 116 | msg = [f'{i}: {j:.4f}' if isinstance(j, (np.floating, float)) else f'{i}: {j:>4d}' for i, j in log.items()] 117 | tqdm.write(' | '.join(msg)) 118 | 119 | return loss, auc, aup, df_logit, logit_all_pair 120 | 121 | def train_attack(self, model, train_loader, valid_loader, optimizer, args): 122 | loss_fct = nn.CrossEntropyLoss() 123 | best_auc = 0 124 | best_epoch = 0 125 | for epoch in trange(50, desc='Train attack model'): 126 | model.train() 127 | 128 | train_loss = 0 129 | for x, y in train_loader: 130 | logits = model(x.to(device)) 131 | loss = loss_fct(logits, y.to(device)) 132 | 133 | loss.backward() 134 | optimizer.step() 135 | optimizer.zero_grad() 136 | 137 | train_loss += loss.item() 138 | 139 | valid_loss, valid_acc, valid_auc, valid_f1 = self.eval_attack(model, valid_loader) 140 | 141 | log = { 142 | 'attack_train_loss': train_loss / len(train_loader), 143 | 'attack_valid_loss': valid_loss, 144 | 'attack_valid_acc': valid_acc, 145 | 'attack_valid_auc': valid_auc, 146 | 'attack_valid_f1': valid_f1} 147 | wandb.log(log) 148 | self.trainer_log['attack_log'].append(log) 149 | 150 | msg = [f'{i}: {j:>4d}' if isinstance(j, int) else f'{i}: {j:.4f}' for i, j in log.items()] 151 | tqdm.write(' | '.join(msg)) 152 | 153 | 154 | if valid_auc > best_auc: 155 | best_auc = valid_auc 156 | best_epoch = epoch 157 | self.trainer_log['attack_best_auc'] = valid_auc 158 | self.trainer_log['attack_best_epoch'] = epoch 159 | 160 | ckpt = { 161 | 'model_state': model.state_dict(), 162 | 'optimizer_state': optimizer.state_dict(), 163 | } 164 | torch.save(ckpt, os.path.join(args.checkpoint_dir, 'attack_model_best.pt')) 165 | 166 | @torch.no_grad() 167 | def eval_attack(self, model, eval_loader): 168 | loss_fct = nn.CrossEntropyLoss() 169 | pred = [] 170 | label = [] 171 | for x, y in eval_loader: 172 | logits = model(x.to(device)) 173 | loss = loss_fct(logits, y.to(device)) 174 | _, p = torch.max(logits, 1) 175 | 176 | pred.extend(p.cpu()) 177 | label.extend(y) 178 | 179 | pred = torch.stack(pred) 180 | label = torch.stack(label) 181 | 182 | return loss.item(), accuracy_score(label.numpy(), pred.numpy()), roc_auc_score(label.numpy(), pred.numpy()), f1_score(label.numpy(), pred.numpy(), average='macro') 183 | 184 | @torch.no_grad() 185 | def prepare_attack_training_data(self, model, data, all_neg=None): 186 | '''Prepare the training data of attack model (Present vs. Absent) 187 | Present edges (label = 1): training data of shadow model (Test pos and neg edges) 188 | Absent edges (label = 0): validation data of shadow model (Valid pos and neg edges) 189 | ''' 190 | 191 | z = model(data.x, data.test_pos_edge_index) 192 | 193 | # Sample same size of neg as pos 194 | sample_idx = torch.randperm(all_neg.shape[1])[:data.test_pos_edge_index.shape[1]] 195 | neg_subset = all_neg[:, sample_idx] 196 | 197 | present_edge_index = torch.cat([data.test_pos_edge_index, data.test_neg_edge_index], dim=-1) 198 | 199 | if 'sub' in self.args.unlearning_model: 200 | absent_edge_index = torch.cat([data.val_pos_edge_index, data.val_neg_edge_index], dim=-1) 201 | else: #if 'all' in self.args.unlearning_model: 202 | absent_edge_index = torch.cat([data.val_pos_edge_index, data.val_neg_edge_index, data.train_pos_edge_index, neg_subset.to(device)], dim=-1) 203 | 204 | edge_index = torch.cat([present_edge_index, absent_edge_index], dim=-1) 205 | 206 | feature = torch.cat([z[edge_index[0]], z[edge_index[1]]], dim=-1).cpu() 207 | label = get_link_labels(present_edge_index, absent_edge_index).long().cpu() 208 | 209 | return feature, label 210 | -------------------------------------------------------------------------------- /delete_gnn.py: -------------------------------------------------------------------------------- 1 | import os 2 | import copy 3 | import json 4 | import wandb 5 | import pickle 6 | import argparse 7 | import torch 8 | import torch.nn as nn 9 | from torch_geometric.utils import to_undirected, to_networkx, k_hop_subgraph, is_undirected 10 | from torch_geometric.data import Data 11 | from torch_geometric.loader import GraphSAINTRandomWalkSampler 12 | from torch_geometric.seed import seed_everything 13 | 14 | from framework import get_model, get_trainer 15 | from framework.models.gcn import GCN 16 | from framework.training_args import parse_args 17 | from framework.utils import * 18 | from train_mi import MLPAttacker 19 | 20 | 21 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 22 | 23 | def load_args(path): 24 | with open(path, 'r') as f: 25 | d = json.load(f) 26 | parser = argparse.ArgumentParser() 27 | for k, v in d.items(): 28 | parser.add_argument('--' + k, default=v) 29 | try: 30 | parser.add_argument('--df_size', default=0.5) 31 | except: 32 | pass 33 | args = parser.parse_args() 34 | 35 | for k, v in d.items(): 36 | setattr(args, k, v) 37 | 38 | return args 39 | 40 | @torch.no_grad() 41 | def get_node_embedding(model, data): 42 | model.eval() 43 | node_embedding = model(data.x.to(device), data.edge_index.to(device)) 44 | 45 | return node_embedding 46 | 47 | @torch.no_grad() 48 | def get_output(model, node_embedding, data): 49 | model.eval() 50 | node_embedding = node_embedding.to(device) 51 | edge = data.edge_index.to(device) 52 | output = model.decode(node_embedding, edge, edge_type) 53 | 54 | return output 55 | 56 | torch.autograd.set_detect_anomaly(True) 57 | def main(): 58 | args = parse_args() 59 | original_path = os.path.join(args.checkpoint_dir, args.dataset, args.gnn, 'original', str(args.random_seed)) 60 | attack_path_all = os.path.join(args.checkpoint_dir, args.dataset, 'member_infer_all', str(args.random_seed)) 61 | attack_path_sub = os.path.join(args.checkpoint_dir, args.dataset, 'member_infer_sub', str(args.random_seed)) 62 | seed_everything(args.random_seed) 63 | 64 | if 'gnndelete' in args.unlearning_model: 65 | args.checkpoint_dir = os.path.join( 66 | args.checkpoint_dir, args.dataset, args.gnn, args.unlearning_model, 67 | '-'.join([str(i) for i in [args.loss_fct, args.loss_type, args.alpha, args.neg_sample_random]]), 68 | '-'.join([str(i) for i in [args.df, args.df_size, args.random_seed]])) 69 | else: 70 | args.checkpoint_dir = os.path.join( 71 | args.checkpoint_dir, args.dataset, args.gnn, args.unlearning_model, 72 | '-'.join([str(i) for i in [args.df, args.df_size, args.random_seed]])) 73 | os.makedirs(args.checkpoint_dir, exist_ok=True) 74 | 75 | # Dataset 76 | with open(os.path.join(args.data_dir, args.dataset, f'd_{args.random_seed}.pkl'), 'rb') as f: 77 | dataset, data = pickle.load(f) 78 | print('Directed dataset:', dataset, data) 79 | if args.gnn not in ['rgcn', 'rgat']: 80 | args.in_dim = dataset.num_features 81 | 82 | print('Training args', args) 83 | wandb.init(config=args) 84 | 85 | # Df and Dr 86 | assert args.df != 'none' 87 | 88 | if args.df_size >= 100: # df_size is number of nodes/edges to be deleted 89 | df_size = int(args.df_size) 90 | else: # df_size is the ratio 91 | df_size = int(args.df_size / 100 * data.train_pos_edge_index.shape[1]) 92 | print(f'Original size: {data.train_pos_edge_index.shape[1]:,}') 93 | print(f'Df size: {df_size:,}') 94 | 95 | df_mask_all = torch.load(os.path.join(args.data_dir, args.dataset, f'df_{args.random_seed}.pt'))[args.df] 96 | df_nonzero = df_mask_all.nonzero().squeeze() 97 | 98 | idx = torch.randperm(df_nonzero.shape[0])[:df_size] 99 | df_global_idx = df_nonzero[idx] 100 | 101 | print('Deleting the following edges:', df_global_idx) 102 | 103 | # df_idx = [int(i) for i in args.df_idx.split(',')] 104 | # df_idx_global = df_mask.nonzero()[df_idx] 105 | 106 | dr_mask = torch.ones(data.train_pos_edge_index.shape[1], dtype=torch.bool) 107 | dr_mask[df_global_idx] = False 108 | 109 | df_mask = torch.zeros(data.train_pos_edge_index.shape[1], dtype=torch.bool) 110 | df_mask[df_global_idx] = True 111 | 112 | # For testing 113 | data.directed_df_edge_index = data.train_pos_edge_index[:, df_mask] 114 | if args.gnn in ['rgcn', 'rgat']: 115 | data.directed_df_edge_type = data.train_edge_type[df_mask] 116 | 117 | 118 | # data.dr_mask = dr_mask 119 | # data.df_mask = df_mask 120 | # data.edge_index = data.train_pos_edge_index[:, dr_mask] 121 | 122 | # assert df_mask.sum() == len(df_global_idx) 123 | # assert dr_mask.shape[0] - len(df_global_idx) == data.train_pos_edge_index[:, dr_mask].shape[1] 124 | # data.dtrain_mask = dr_mask 125 | 126 | 127 | # Edges in S_Df 128 | _, two_hop_edge, _, two_hop_mask = k_hop_subgraph( 129 | data.train_pos_edge_index[:, df_mask].flatten().unique(), 130 | 2, 131 | data.train_pos_edge_index, 132 | num_nodes=data.num_nodes) 133 | data.sdf_mask = two_hop_mask 134 | 135 | # Nodes in S_Df 136 | _, one_hop_edge, _, one_hop_mask = k_hop_subgraph( 137 | data.train_pos_edge_index[:, df_mask].flatten().unique(), 138 | 1, 139 | data.train_pos_edge_index, 140 | num_nodes=data.num_nodes) 141 | sdf_node_1hop = torch.zeros(data.num_nodes, dtype=torch.bool) 142 | sdf_node_2hop = torch.zeros(data.num_nodes, dtype=torch.bool) 143 | 144 | sdf_node_1hop[one_hop_edge.flatten().unique()] = True 145 | sdf_node_2hop[two_hop_edge.flatten().unique()] = True 146 | 147 | assert sdf_node_1hop.sum() == len(one_hop_edge.flatten().unique()) 148 | assert sdf_node_2hop.sum() == len(two_hop_edge.flatten().unique()) 149 | 150 | data.sdf_node_1hop_mask = sdf_node_1hop 151 | data.sdf_node_2hop_mask = sdf_node_2hop 152 | 153 | 154 | # To undirected for message passing 155 | # print(is_undir0.0175ected(data.train_pos_edge_index), data.train_pos_edge_index.shape, two_hop_mask.shape, df_mask.shape, two_hop_mask.shape) 156 | assert not is_undirected(data.train_pos_edge_index) 157 | 158 | if args.gnn in ['rgcn', 'rgat']: 159 | r, c = data.train_pos_edge_index 160 | rev_edge_index = torch.stack([c, r], dim=0) 161 | rev_edge_type = data.train_edge_type + args.num_edge_type 162 | 163 | data.edge_index = torch.cat((data.train_pos_edge_index, rev_edge_index), dim=1) 164 | data.edge_type = torch.cat([data.train_edge_type, rev_edge_type], dim=0) 165 | 166 | if hasattr(data, 'train_mask'): 167 | data.train_mask = data.train_mask.repeat(2).view(-1) 168 | 169 | two_hop_mask = two_hop_mask.repeat(2).view(-1) 170 | df_mask = df_mask.repeat(2).view(-1) 171 | dr_mask = dr_mask.repeat(2).view(-1) 172 | assert is_undirected(data.edge_index) 173 | 174 | else: 175 | train_pos_edge_index, [df_mask, two_hop_mask] = to_undirected(data.train_pos_edge_index, [df_mask.int(), two_hop_mask.int()]) 176 | two_hop_mask = two_hop_mask.bool() 177 | df_mask = df_mask.bool() 178 | dr_mask = ~df_mask 179 | 180 | data.train_pos_edge_index = train_pos_edge_index 181 | data.edge_index = train_pos_edge_index 182 | assert is_undirected(data.train_pos_edge_index) 183 | 184 | 185 | print('Undirected dataset:', data) 186 | 187 | data.sdf_mask = two_hop_mask 188 | data.df_mask = df_mask 189 | data.dr_mask = dr_mask 190 | # data.dtrain_mask = dr_mask 191 | # print(is_undirected(train_pos_edge_index), train_pos_edge_index.shape, two_hop_mask.shape, df_mask.shape, two_hop_mask.shape) 192 | # print(is_undirected(data.train_pos_edge_index), data.train_pos_edge_index.shape, data.df_mask.shape, ) 193 | # raise 194 | 195 | # Model 196 | model = get_model(args, sdf_node_1hop, sdf_node_2hop, num_nodes=data.num_nodes, num_edge_type=args.num_edge_type) 197 | 198 | if args.unlearning_model != 'retrain': # Start from trained GNN model 199 | if os.path.exists(os.path.join(original_path, 'pred_proba.pt')): 200 | logits_ori = torch.load(os.path.join(original_path, 'pred_proba.pt')) 201 | if logits_ori is not None: 202 | logits_ori = logits_ori.to(device) 203 | else: 204 | logits_ori = None 205 | 206 | model_ckpt = torch.load(os.path.join(original_path, 'model_best.pt'), map_location=device) 207 | model.load_state_dict(model_ckpt['model_state'], strict=False) 208 | 209 | else: # Initialize a new GNN model 210 | retrain = None 211 | logits_ori = None 212 | 213 | model = model.to(device) 214 | 215 | if 'gnndelete' in args.unlearning_model and 'nodeemb' in args.unlearning_model: 216 | parameters_to_optimize = [ 217 | {'params': [p for n, p in model.named_parameters() if 'del' in n], 'weight_decay': 0.0} 218 | ] 219 | print('parameters_to_optimize', [n for n, p in model.named_parameters() if 'del' in n]) 220 | 221 | if 'layerwise' in args.loss_type: 222 | optimizer1 = torch.optim.Adam(model.deletion1.parameters(), lr=args.lr) 223 | optimizer2 = torch.optim.Adam(model.deletion2.parameters(), lr=args.lr) 224 | optimizer = [optimizer1, optimizer2] 225 | else: 226 | optimizer = torch.optim.Adam(parameters_to_optimize, lr=args.lr) 227 | 228 | else: 229 | if 'gnndelete' in args.unlearning_model: 230 | parameters_to_optimize = [ 231 | {'params': [p for n, p in model.named_parameters() if 'del' in n], 'weight_decay': 0.0} 232 | ] 233 | print('parameters_to_optimize', [n for n, p in model.named_parameters() if 'del' in n]) 234 | 235 | else: 236 | parameters_to_optimize = [ 237 | {'params': [p for n, p in model.named_parameters()], 'weight_decay': 0.0} 238 | ] 239 | print('parameters_to_optimize', [n for n, p in model.named_parameters()]) 240 | 241 | optimizer = torch.optim.Adam(parameters_to_optimize, lr=args.lr)#, weight_decay=args.weight_decay) 242 | 243 | wandb.watch(model, log_freq=100) 244 | 245 | # MI attack model 246 | attack_model_all = None 247 | # attack_model_all = MLPAttacker(args) 248 | # attack_ckpt = torch.load(os.path.join(attack_path_all, 'attack_model_best.pt')) 249 | # attack_model_all.load_state_dict(attack_ckpt['model_state']) 250 | # attack_model_all = attack_model_all.to(device) 251 | 252 | attack_model_sub = None 253 | # attack_model_sub = MLPAttacker(args) 254 | # attack_ckpt = torch.load(os.path.join(attack_path_sub, 'attack_model_best.pt')) 255 | # attack_model_sub.load_state_dict(attack_ckpt['model_state']) 256 | # attack_model_sub = attack_model_sub.to(device) 257 | 258 | # Train 259 | trainer = get_trainer(args) 260 | trainer.train(model, data, optimizer, args, logits_ori, attack_model_all, attack_model_sub) 261 | 262 | # Test 263 | if args.unlearning_model != 'retrain': 264 | retrain_path = os.path.join( 265 | 'checkpoint', args.dataset, args.gnn, 'retrain', 266 | '-'.join([str(i) for i in [args.df, args.df_size, args.random_seed]]), 267 | 'model_best.pt') 268 | if os.path.exists(retrain_path): 269 | retrain_ckpt = torch.load(retrain_path, map_location=device) 270 | retrain_args = copy.deepcopy(args) 271 | retrain_args.unlearning_model = 'retrain' 272 | retrain = get_model(retrain_args, num_nodes=data.num_nodes, num_edge_type=args.num_edge_type) 273 | retrain.load_state_dict(retrain_ckpt['model_state']) 274 | retrain = retrain.to(device) 275 | retrain.eval() 276 | else: 277 | retrain = None 278 | else: 279 | retrain = None 280 | 281 | test_results = trainer.test(model, data, model_retrain=retrain, attack_model_all=attack_model_all, attack_model_sub=attack_model_sub) 282 | print(test_results[-1]) 283 | trainer.save_log() 284 | 285 | 286 | if __name__ == "__main__": 287 | main() 288 | -------------------------------------------------------------------------------- /delete_node.py: -------------------------------------------------------------------------------- 1 | import os 2 | import copy 3 | import json 4 | import wandb 5 | import pickle 6 | import argparse 7 | import torch 8 | import torch.nn as nn 9 | from torch_geometric.utils import to_undirected, to_networkx, k_hop_subgraph, is_undirected 10 | from torch_geometric.data import Data 11 | import torch_geometric.transforms as T 12 | from torch_geometric.datasets import CitationFull, Coauthor, Flickr, RelLinkPredDataset, WordNet18, WordNet18RR 13 | from torch_geometric.loader import GraphSAINTRandomWalkSampler 14 | from torch_geometric.seed import seed_everything 15 | 16 | from framework import get_model, get_trainer 17 | from framework.models.gcn import GCN 18 | from framework.models.deletion import GCNDelete 19 | from framework.training_args import parse_args 20 | from framework.utils import * 21 | from framework.trainer.gnndelete_nodeemb import GNNDeleteNodeClassificationTrainer 22 | from train_mi import MLPAttacker 23 | 24 | 25 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 26 | torch.autograd.set_detect_anomaly(True) 27 | 28 | def to_directed(edge_index): 29 | row, col = edge_index 30 | mask = row < col 31 | return torch.cat([row[mask], col[mask]], dim=0) 32 | 33 | def main(): 34 | args = parse_args() 35 | args.checkpoint_dir = 'checkpoint_node' 36 | args.dataset = 'DBLP' 37 | original_path = os.path.join(args.checkpoint_dir, args.dataset, args.gnn, 'original', str(args.random_seed)) 38 | attack_path_all = os.path.join(args.checkpoint_dir, args.dataset, 'member_infer_all', str(args.random_seed)) 39 | attack_path_sub = os.path.join(args.checkpoint_dir, args.dataset, 'member_infer_sub', str(args.random_seed)) 40 | seed_everything(args.random_seed) 41 | 42 | if 'gnndelete' in args.unlearning_model: 43 | args.checkpoint_dir = os.path.join( 44 | args.checkpoint_dir, args.dataset, args.gnn, f'{args.unlearning_model}-node_deletion', 45 | '-'.join([str(i) for i in [args.loss_fct, args.loss_type, args.alpha, args.neg_sample_random]]), 46 | '-'.join([str(i) for i in [args.df, args.df_size, args.random_seed]])) 47 | else: 48 | args.checkpoint_dir = os.path.join( 49 | args.checkpoint_dir, args.dataset, args.gnn, f'{args.unlearning_model}-node_deletion', 50 | '-'.join([str(i) for i in [args.df, args.df_size, args.random_seed]])) 51 | os.makedirs(args.checkpoint_dir, exist_ok=True) 52 | 53 | # Dataset 54 | dataset = CitationFull(os.path.join(args.data_dir, args.dataset), args.dataset, transform=T.NormalizeFeatures()) 55 | data = dataset[0] 56 | print('Original data', data) 57 | 58 | split = T.RandomNodeSplit() 59 | data = split(data) 60 | assert is_undirected(data.edge_index) 61 | 62 | print('Split data', data) 63 | args.in_dim = data.x.shape[1] 64 | args.out_dim = dataset.num_classes 65 | 66 | wandb.init(config=args) 67 | 68 | # Df and Dr 69 | if args.df_size >= 100: # df_size is number of nodes/edges to be deleted 70 | df_size = int(args.df_size) 71 | else: # df_size is the ratio 72 | df_size = int(args.df_size / 100 * data.train_pos_edge_index.shape[1]) 73 | print(f'Original size: {data.num_nodes:,}') 74 | print(f'Df size: {df_size:,}') 75 | 76 | # Delete nodes 77 | df_nodes = torch.randperm(data.num_nodes)[:df_size] 78 | global_node_mask = torch.ones(data.num_nodes, dtype=torch.bool) 79 | global_node_mask[df_nodes] = False 80 | 81 | dr_mask_node = global_node_mask 82 | df_mask_node = ~global_node_mask 83 | assert df_mask_node.sum() == df_size 84 | 85 | # Delete edges associated with deleted nodes from training set 86 | res = [torch.eq(data.edge_index, aelem).logical_or_(torch.eq(data.edge_index, aelem)) for aelem in df_nodes] 87 | df_mask_edge = torch.any(torch.stack(res, dim=0), dim = 0) 88 | df_mask_edge = df_mask_edge.sum(0).bool() 89 | dr_mask_edge = ~df_mask_edge 90 | 91 | df_edge = data.edge_index[:, df_mask_edge] 92 | data.directed_df_edge_index = to_directed(df_edge) 93 | # print(df_edge.shape, directed_df_edge_index.shape) 94 | # raise 95 | 96 | print('Deleting the following nodes:', df_nodes) 97 | 98 | # # Delete edges associated with deleted nodes from valid and test set 99 | # res = [torch.eq(data.val_pos_edge_index, aelem).logical_or_(torch.eq(data.val_pos_edge_index, aelem)) for aelem in df_nodes] 100 | # mask = torch.any(torch.stack(res, dim=0), dim = 0) 101 | # mask = mask.sum(0).bool() 102 | # mask = ~mask 103 | # data.val_pos_edge_index = data.val_pos_edge_index[:, mask] 104 | # data.val_neg_edge_index = data.val_neg_edge_index[:, :data.val_pos_edge_index.shape[1]] 105 | 106 | # res = [torch.eq(data.test_pos_edge_index, aelem).logical_or_(torch.eq(data.test_pos_edge_index, aelem)) for aelem in df_nodes] 107 | # mask = torch.any(torch.stack(res, dim=0), dim = 0) 108 | # mask = mask.sum(0).bool() 109 | # mask = ~mask 110 | # data.test_pos_edge_index = data.test_pos_edge_index[:, mask] 111 | # data.test_neg_edge_index = data.test_neg_edge_index[:, :data.test_pos_edge_index.shape[1]] 112 | 113 | 114 | # For testing 115 | # data.directed_df_edge_index = data.train_pos_edge_index[:, df_mask_edge] 116 | # if args.gnn in ['rgcn', 'rgat']: 117 | # data.directed_df_edge_type = data.train_edge_type[df_mask] 118 | 119 | # Edges in S_Df 120 | _, two_hop_edge, _, two_hop_mask = k_hop_subgraph( 121 | data.edge_index[:, df_mask_edge].flatten().unique(), 122 | 2, 123 | data.edge_index, 124 | num_nodes=data.num_nodes) 125 | 126 | # Nodes in S_Df 127 | _, one_hop_edge, _, one_hop_mask = k_hop_subgraph( 128 | data.edge_index[:, df_mask_edge].flatten().unique(), 129 | 1, 130 | data.edge_index, 131 | num_nodes=data.num_nodes) 132 | sdf_node_1hop = torch.zeros(data.num_nodes, dtype=torch.bool) 133 | sdf_node_2hop = torch.zeros(data.num_nodes, dtype=torch.bool) 134 | 135 | sdf_node_1hop[one_hop_edge.flatten().unique()] = True 136 | sdf_node_2hop[two_hop_edge.flatten().unique()] = True 137 | 138 | assert sdf_node_1hop.sum() == len(one_hop_edge.flatten().unique()) 139 | assert sdf_node_2hop.sum() == len(two_hop_edge.flatten().unique()) 140 | 141 | data.sdf_node_1hop_mask = sdf_node_1hop 142 | data.sdf_node_2hop_mask = sdf_node_2hop 143 | 144 | 145 | # To undirected for message passing 146 | # print(is_undir0.0175ected(data.train_pos_edge_index), data.train_pos_edge_index.shape, two_hop_mask.shape, df_mask.shape, two_hop_mask.shape) 147 | # assert not is_undirected(data.edge_index) 148 | print(is_undirected(data.edge_index)) 149 | 150 | if args.gnn in ['rgcn', 'rgat']: 151 | r, c = data.train_pos_edge_index 152 | rev_edge_index = torch.stack([c, r], dim=0) 153 | rev_edge_type = data.train_edge_type + args.num_edge_type 154 | 155 | data.edge_index = torch.cat((data.train_pos_edge_index, rev_edge_index), dim=1) 156 | data.edge_type = torch.cat([data.train_edge_type, rev_edge_type], dim=0) 157 | # data.train_mask = data.train_mask.repeat(2) 158 | 159 | two_hop_mask = two_hop_mask.repeat(2).view(-1) 160 | df_mask = df_mask.repeat(2).view(-1) 161 | dr_mask = dr_mask.repeat(2).view(-1) 162 | assert is_undirected(data.edge_index) 163 | 164 | else: 165 | # train_pos_edge_index, [df_mask, two_hop_mask] = to_undirected(data.train_pos_edge_index, [df_mask.int(), two_hop_mask.int()]) 166 | two_hop_mask = two_hop_mask.bool() 167 | df_mask_edge = df_mask_edge.bool() 168 | dr_mask_edge = ~df_mask_edge 169 | 170 | # data.train_pos_edge_index = train_pos_edge_index 171 | # assert is_undirected(data.train_pos_edge_index) 172 | 173 | 174 | print('Undirected dataset:', data) 175 | # print(is_undirected(train_pos_edge_index), train_pos_edge_index.shape, two_hop_mask.shape, df_mask.shape, two_hop_mask.shape) 176 | 177 | data.sdf_mask = two_hop_mask 178 | data.df_mask = df_mask_edge 179 | data.dr_mask = dr_mask_edge 180 | data.dtrain_mask = dr_mask_edge 181 | # print(is_undirected(data.train_pos_edge_index), data.train_pos_edge_index.shape, data.two_hop_mask.shape, data.df_mask.shape, data.two_hop_mask.shape) 182 | # raise 183 | 184 | # Model 185 | model = GCNDelete(args) 186 | # model = get_model(args, sdf_node_1hop, sdf_node_2hop, num_nodes=data.num_nodes, num_edge_type=args.num_edge_type) 187 | 188 | if args.unlearning_model != 'retrain': # Start from trained GNN model 189 | if os.path.exists(os.path.join(original_path, 'pred_proba.pt')): 190 | logits_ori = torch.load(os.path.join(original_path, 'pred_proba.pt')) 191 | if logits_ori is not None: 192 | logits_ori = logits_ori.to(device) 193 | else: 194 | logits_ori = None 195 | 196 | model_ckpt = torch.load(os.path.join(original_path, 'model_best.pt'), map_location=device) 197 | model.load_state_dict(model_ckpt['model_state'], strict=False) 198 | 199 | else: # Initialize a new GNN model 200 | retrain = None 201 | logits_ori = None 202 | 203 | model = model.to(device) 204 | 205 | if 'gnndelete' in args.unlearning_model and 'nodeemb' in args.unlearning_model: 206 | parameters_to_optimize = [ 207 | {'params': [p for n, p in model.named_parameters() if 'del' in n], 'weight_decay': 0.0} 208 | ] 209 | print('parameters_to_optimize', [n for n, p in model.named_parameters() if 'del' in n]) 210 | 211 | if 'layerwise' in args.loss_type: 212 | optimizer1 = torch.optim.Adam(model.deletion1.parameters(), lr=args.lr) 213 | optimizer2 = torch.optim.Adam(model.deletion2.parameters(), lr=args.lr) 214 | optimizer = [optimizer1, optimizer2] 215 | else: 216 | optimizer = torch.optim.Adam(parameters_to_optimize, lr=args.lr) 217 | 218 | else: 219 | if 'gnndelete' in args.unlearning_model: 220 | parameters_to_optimize = [ 221 | {'params': [p for n, p in model.named_parameters() if 'del' in n], 'weight_decay': 0.0} 222 | ] 223 | print('parameters_to_optimize', [n for n, p in model.named_parameters() if 'del' in n]) 224 | 225 | else: 226 | parameters_to_optimize = [ 227 | {'params': [p for n, p in model.named_parameters()], 'weight_decay': 0.0} 228 | ] 229 | print('parameters_to_optimize', [n for n, p in model.named_parameters()]) 230 | 231 | optimizer = torch.optim.Adam(parameters_to_optimize, lr=args.lr)#, weight_decay=args.weight_decay) 232 | 233 | wandb.watch(model, log_freq=100) 234 | 235 | # MI attack model 236 | attack_model_all = None 237 | # attack_model_all = MLPAttacker(args) 238 | # attack_ckpt = torch.load(os.path.join(attack_path_all, 'attack_model_best.pt')) 239 | # attack_model_all.load_state_dict(attack_ckpt['model_state']) 240 | # attack_model_all = attack_model_all.to(device) 241 | 242 | attack_model_sub = None 243 | # attack_model_sub = MLPAttacker(args) 244 | # attack_ckpt = torch.load(os.path.join(attack_path_sub, 'attack_model_best.pt')) 245 | # attack_model_sub.load_state_dict(attack_ckpt['model_state']) 246 | # attack_model_sub = attack_model_sub.to(device) 247 | 248 | # Train 249 | trainer = GNNDeleteNodeClassificationTrainer(args) 250 | trainer.train(model, data, optimizer, args, logits_ori, attack_model_all, attack_model_sub) 251 | 252 | # Test 253 | if args.unlearning_model != 'retrain': 254 | retrain_path = os.path.join( 255 | 'checkpoint', args.dataset, args.gnn, 'retrain', 256 | '-'.join([str(i) for i in [args.df, args.df_size, args.random_seed]])) 257 | retrain_ckpt = torch.load(os.path.join(retrain_path, 'model_best.pt'), map_location=device) 258 | retrain_args = copy.deepcopy(args) 259 | retrain_args.unlearning_model = 'retrain' 260 | retrain = get_model(retrain_args, num_nodes=data.num_nodes, num_edge_type=args.num_edge_type) 261 | retrain.load_state_dict(retrain_ckpt['model_state']) 262 | retrain = retrain.to(device) 263 | retrain.eval() 264 | 265 | else: 266 | retrain = None 267 | 268 | trainer.test(model, data, model_retrain=retrain, attack_model_all=attack_model_all, attack_model_sub=attack_model_sub) 269 | trainer.save_log() 270 | 271 | 272 | if __name__ == "__main__": 273 | main() 274 | -------------------------------------------------------------------------------- /delete_node_feature.py: -------------------------------------------------------------------------------- 1 | import os 2 | import copy 3 | import json 4 | import wandb 5 | import pickle 6 | import argparse 7 | import torch 8 | import torch.nn as nn 9 | from torch_geometric.utils import to_undirected, to_networkx, k_hop_subgraph, is_undirected 10 | from torch_geometric.data import Data 11 | import torch_geometric.transforms as T 12 | from torch_geometric.datasets import CitationFull, Coauthor, Flickr, RelLinkPredDataset, WordNet18, WordNet18RR 13 | from torch_geometric.loader import GraphSAINTRandomWalkSampler 14 | from torch_geometric.seed import seed_everything 15 | 16 | from framework import get_model, get_trainer 17 | from framework.models.gcn import GCN 18 | from framework.models.deletion import GCNDelete 19 | from framework.training_args import parse_args 20 | from framework.utils import * 21 | from framework.trainer.gnndelete_nodeemb import GNNDeleteNodeClassificationTrainer 22 | from train_mi import MLPAttacker 23 | 24 | 25 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 26 | torch.autograd.set_detect_anomaly(True) 27 | 28 | def to_directed(edge_index): 29 | row, col = edge_index 30 | mask = row < col 31 | return torch.cat([row[mask], col[mask]], dim=0) 32 | 33 | def main(): 34 | args = parse_args() 35 | args.checkpoint_dir = 'checkpoint_node_feature' 36 | args.dataset = 'DBLP' 37 | original_path = os.path.join(args.checkpoint_dir, args.dataset, args.gnn, 'original', str(args.random_seed)) 38 | attack_path_all = os.path.join(args.checkpoint_dir, args.dataset, 'member_infer_all', str(args.random_seed)) 39 | attack_path_sub = os.path.join(args.checkpoint_dir, args.dataset, 'member_infer_sub', str(args.random_seed)) 40 | seed_everything(args.random_seed) 41 | 42 | if 'gnndelete' in args.unlearning_model: 43 | args.checkpoint_dir = os.path.join( 44 | args.checkpoint_dir, args.dataset, args.gnn, f'{args.unlearning_model}-node_deletion', 45 | '-'.join([str(i) for i in [args.loss_fct, args.loss_type, args.alpha, args.neg_sample_random]]), 46 | '-'.join([str(i) for i in [args.df, args.df_size, args.random_seed]])) 47 | else: 48 | args.checkpoint_dir = os.path.join( 49 | args.checkpoint_dir, args.dataset, args.gnn, f'{args.unlearning_model}-node_deletion', 50 | '-'.join([str(i) for i in [args.df, args.df_size, args.random_seed]])) 51 | os.makedirs(args.checkpoint_dir, exist_ok=True) 52 | 53 | # Dataset 54 | dataset = CitationFull(os.path.join(args.data_dir, args.dataset), args.dataset, transform=T.NormalizeFeatures()) 55 | data = dataset[0] 56 | print('Original data', data) 57 | 58 | split = T.RandomNodeSplit() 59 | data = split(data) 60 | assert is_undirected(data.edge_index) 61 | 62 | print('Split data', data) 63 | args.in_dim = data.x.shape[1] 64 | args.out_dim = dataset.num_classes 65 | 66 | wandb.init(config=args) 67 | 68 | # Df and Dr 69 | if args.df_size >= 100: # df_size is number of nodes/edges to be deleted 70 | df_size = int(args.df_size) 71 | else: # df_size is the ratio 72 | df_size = int(args.df_size / 100 * data.train_pos_edge_index.shape[1]) 73 | print(f'Original size: {data.num_nodes:,}') 74 | print(f'Df size: {df_size:,}') 75 | 76 | # Delete node feature 77 | df_nodes = torch.randperm(data.num_nodes)[:df_size] 78 | global_node_mask = torch.ones(data.num_nodes, dtype=torch.bool) 79 | # global_node_mask[df_nodes] = False 80 | data.x[df_nodes] = 0 81 | assert data.x[df_nodes].sum() == 0 82 | 83 | dr_mask_node = torch.ones(data.num_nodes, dtype=torch.bool) 84 | df_mask_node = ~global_node_mask 85 | # assert df_mask_node.sum() == df_size 86 | 87 | # Delete edges associated with deleted nodes from training set 88 | res = [torch.eq(data.edge_index, aelem).logical_or_(torch.eq(data.edge_index, aelem)) for aelem in df_nodes] 89 | df_mask_edge = torch.any(torch.stack(res, dim=0), dim = 0) 90 | df_mask_edge = df_mask_edge.sum(0).bool() 91 | dr_mask_edge = ~df_mask_edge 92 | 93 | df_edge = data.edge_index[:, df_mask_edge] 94 | data.directed_df_edge_index = to_directed(df_edge) 95 | # print(df_edge.shape, directed_df_edge_index.shape) 96 | # raise 97 | 98 | print('Deleting the following nodes:', df_nodes) 99 | 100 | # # Delete edges associated with deleted nodes from valid and test set 101 | # res = [torch.eq(data.val_pos_edge_index, aelem).logical_or_(torch.eq(data.val_pos_edge_index, aelem)) for aelem in df_nodes] 102 | # mask = torch.any(torch.stack(res, dim=0), dim = 0) 103 | # mask = mask.sum(0).bool() 104 | # mask = ~mask 105 | # data.val_pos_edge_index = data.val_pos_edge_index[:, mask] 106 | # data.val_neg_edge_index = data.val_neg_edge_index[:, :data.val_pos_edge_index.shape[1]] 107 | 108 | # res = [torch.eq(data.test_pos_edge_index, aelem).logical_or_(torch.eq(data.test_pos_edge_index, aelem)) for aelem in df_nodes] 109 | # mask = torch.any(torch.stack(res, dim=0), dim = 0) 110 | # mask = mask.sum(0).bool() 111 | # mask = ~mask 112 | # data.test_pos_edge_index = data.test_pos_edge_index[:, mask] 113 | # data.test_neg_edge_index = data.test_neg_edge_index[:, :data.test_pos_edge_index.shape[1]] 114 | 115 | 116 | # For testing 117 | # data.directed_df_edge_index = data.train_pos_edge_index[:, df_mask_edge] 118 | # if args.gnn in ['rgcn', 'rgat']: 119 | # data.directed_df_edge_type = data.train_edge_type[df_mask] 120 | 121 | # Edges in S_Df 122 | _, two_hop_edge, _, two_hop_mask = k_hop_subgraph( 123 | data.edge_index[:, df_mask_edge].flatten().unique(), 124 | 2, 125 | data.edge_index, 126 | num_nodes=data.num_nodes) 127 | 128 | # Nodes in S_Df 129 | _, one_hop_edge, _, one_hop_mask = k_hop_subgraph( 130 | data.edge_index[:, df_mask_edge].flatten().unique(), 131 | 1, 132 | data.edge_index, 133 | num_nodes=data.num_nodes) 134 | sdf_node_1hop = torch.zeros(data.num_nodes, dtype=torch.bool) 135 | sdf_node_2hop = torch.zeros(data.num_nodes, dtype=torch.bool) 136 | 137 | sdf_node_1hop[one_hop_edge.flatten().unique()] = True 138 | sdf_node_2hop[two_hop_edge.flatten().unique()] = True 139 | 140 | assert sdf_node_1hop.sum() == len(one_hop_edge.flatten().unique()) 141 | assert sdf_node_2hop.sum() == len(two_hop_edge.flatten().unique()) 142 | 143 | data.sdf_node_1hop_mask = sdf_node_1hop 144 | data.sdf_node_2hop_mask = sdf_node_2hop 145 | 146 | 147 | # To undirected for message passing 148 | # print(is_undir0.0175ected(data.train_pos_edge_index), data.train_pos_edge_index.shape, two_hop_mask.shape, df_mask.shape, two_hop_mask.shape) 149 | # assert not is_undirected(data.edge_index) 150 | print(is_undirected(data.edge_index)) 151 | 152 | if args.gnn in ['rgcn', 'rgat']: 153 | r, c = data.train_pos_edge_index 154 | rev_edge_index = torch.stack([c, r], dim=0) 155 | rev_edge_type = data.train_edge_type + args.num_edge_type 156 | 157 | data.edge_index = torch.cat((data.train_pos_edge_index, rev_edge_index), dim=1) 158 | data.edge_type = torch.cat([data.train_edge_type, rev_edge_type], dim=0) 159 | # data.train_mask = data.train_mask.repeat(2) 160 | 161 | two_hop_mask = two_hop_mask.repeat(2).view(-1) 162 | df_mask = df_mask.repeat(2).view(-1) 163 | dr_mask = dr_mask.repeat(2).view(-1) 164 | assert is_undirected(data.edge_index) 165 | 166 | else: 167 | # train_pos_edge_index, [df_mask, two_hop_mask] = to_undirected(data.train_pos_edge_index, [df_mask.int(), two_hop_mask.int()]) 168 | two_hop_mask = two_hop_mask.bool() 169 | df_mask_edge = df_mask_edge.bool() 170 | dr_mask_edge = ~df_mask_edge 171 | 172 | # data.train_pos_edge_index = train_pos_edge_index 173 | # assert is_undirected(data.train_pos_edge_index) 174 | 175 | 176 | print('Undirected dataset:', data) 177 | # print(is_undirected(train_pos_edge_index), train_pos_edge_index.shape, two_hop_mask.shape, df_mask.shape, two_hop_mask.shape) 178 | 179 | data.sdf_mask = two_hop_mask 180 | data.df_mask = df_mask_edge 181 | data.dr_mask = dr_mask_edge 182 | data.dtrain_mask = dr_mask_edge 183 | # print(is_undirected(data.train_pos_edge_index), data.train_pos_edge_index.shape, data.two_hop_mask.shape, data.df_mask.shape, data.two_hop_mask.shape) 184 | # raise 185 | 186 | # Model 187 | model = GCNDelete(args) 188 | # model = get_model(args, sdf_node_1hop, sdf_node_2hop, num_nodes=data.num_nodes, num_edge_type=args.num_edge_type) 189 | 190 | if args.unlearning_model != 'retrain': # Start from trained GNN model 191 | if os.path.exists(os.path.join(original_path, 'pred_proba.pt')): 192 | logits_ori = torch.load(os.path.join(original_path, 'pred_proba.pt')) 193 | if logits_ori is not None: 194 | logits_ori = logits_ori.to(device) 195 | else: 196 | logits_ori = None 197 | 198 | model_ckpt = torch.load(os.path.join(original_path, 'model_best.pt'), map_location=device) 199 | model.load_state_dict(model_ckpt['model_state'], strict=False) 200 | 201 | else: # Initialize a new GNN model 202 | retrain = None 203 | logits_ori = None 204 | 205 | model = model.to(device) 206 | 207 | if 'gnndelete' in args.unlearning_model and 'nodeemb' in args.unlearning_model: 208 | parameters_to_optimize = [ 209 | {'params': [p for n, p in model.named_parameters() if 'del' in n], 'weight_decay': 0.0} 210 | ] 211 | print('parameters_to_optimize', [n for n, p in model.named_parameters() if 'del' in n]) 212 | 213 | if 'layerwise' in args.loss_type: 214 | optimizer1 = torch.optim.Adam(model.deletion1.parameters(), lr=args.lr) 215 | optimizer2 = torch.optim.Adam(model.deletion2.parameters(), lr=args.lr) 216 | optimizer = [optimizer1, optimizer2] 217 | else: 218 | optimizer = torch.optim.Adam(parameters_to_optimize, lr=args.lr) 219 | 220 | else: 221 | if 'gnndelete' in args.unlearning_model: 222 | parameters_to_optimize = [ 223 | {'params': [p for n, p in model.named_parameters() if 'del' in n], 'weight_decay': 0.0} 224 | ] 225 | print('parameters_to_optimize', [n for n, p in model.named_parameters() if 'del' in n]) 226 | 227 | else: 228 | parameters_to_optimize = [ 229 | {'params': [p for n, p in model.named_parameters()], 'weight_decay': 0.0} 230 | ] 231 | print('parameters_to_optimize', [n for n, p in model.named_parameters()]) 232 | 233 | optimizer = torch.optim.Adam(parameters_to_optimize, lr=args.lr)#, weight_decay=args.weight_decay) 234 | 235 | wandb.watch(model, log_freq=100) 236 | 237 | # MI attack model 238 | attack_model_all = None 239 | # attack_model_all = MLPAttacker(args) 240 | # attack_ckpt = torch.load(os.path.join(attack_path_all, 'attack_model_best.pt')) 241 | # attack_model_all.load_state_dict(attack_ckpt['model_state']) 242 | # attack_model_all = attack_model_all.to(device) 243 | 244 | attack_model_sub = None 245 | # attack_model_sub = MLPAttacker(args) 246 | # attack_ckpt = torch.load(os.path.join(attack_path_sub, 'attack_model_best.pt')) 247 | # attack_model_sub.load_state_dict(attack_ckpt['model_state']) 248 | # attack_model_sub = attack_model_sub.to(device) 249 | 250 | # Train 251 | trainer = GNNDeleteNodeClassificationTrainer(args) 252 | trainer.train(model, data, optimizer, args, logits_ori, attack_model_all, attack_model_sub) 253 | 254 | # Test 255 | if args.unlearning_model != 'retrain': 256 | retrain_path = os.path.join( 257 | 'checkpoint', args.dataset, args.gnn, 'retrain', 258 | '-'.join([str(i) for i in [args.df, args.df_size, args.random_seed]])) 259 | retrain_ckpt = torch.load(os.path.join(retrain_path, 'model_best.pt'), map_location=device) 260 | retrain_args = copy.deepcopy(args) 261 | retrain_args.unlearning_model = 'retrain' 262 | retrain = get_model(retrain_args, num_nodes=data.num_nodes, num_edge_type=args.num_edge_type) 263 | retrain.load_state_dict(retrain_ckpt['model_state']) 264 | retrain = retrain.to(device) 265 | retrain.eval() 266 | 267 | else: 268 | retrain = None 269 | 270 | trainer.test(model, data, model_retrain=retrain, attack_model_all=attack_model_all, attack_model_sub=attack_model_sub) 271 | trainer.save_log() 272 | 273 | 274 | if __name__ == "__main__": 275 | main() 276 | -------------------------------------------------------------------------------- /framework/trainer/gradient_ascent.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import wandb 4 | from tqdm import tqdm, trange 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from torch_geometric.utils import negative_sampling 9 | from torch_geometric.loader import GraphSAINTRandomWalkSampler 10 | 11 | from .base import Trainer, KGTrainer 12 | from ..evaluation import * 13 | from ..utils import * 14 | 15 | 16 | device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') 17 | 18 | def weight(model): 19 | t = 0 20 | for p in model.parameters(): 21 | t += torch.norm(p) 22 | 23 | return t 24 | 25 | class GradientAscentTrainer(Trainer): 26 | 27 | def train(self, model, data, optimizer, args, logits_ori=None, attack_model_all=None, attack_model_sub=None): 28 | if 'ogbl' in self.args.dataset: 29 | return self.train_minibatch(model, data, optimizer, args, logits_ori, attack_model_all, attack_model_sub) 30 | 31 | else: 32 | return self.train_fullbatch(model, data, optimizer, args, logits_ori, attack_model_all, attack_model_sub) 33 | 34 | def train_fullbatch(self, model, data, optimizer, args, logits_ori=None, attack_model_all=None, attack_model_sub=None): 35 | model = model.to('cuda') 36 | data = data.to('cuda') 37 | 38 | start_time = time.time() 39 | best_metric = 0 40 | 41 | # MI Attack before unlearning 42 | if attack_model_all is not None: 43 | mi_logit_all_before, mi_sucrate_all_before = member_infer_attack(model, attack_model_all, data) 44 | self.trainer_log['mi_logit_all_before'] = mi_logit_all_before 45 | self.trainer_log['mi_sucrate_all_before'] = mi_sucrate_all_before 46 | if attack_model_sub is not None: 47 | mi_logit_sub_before, mi_sucrate_sub_before = member_infer_attack(model, attack_model_sub, data) 48 | self.trainer_log['mi_logit_sub_before'] = mi_logit_sub_before 49 | self.trainer_log['mi_sucrate_sub_before'] = mi_sucrate_sub_before 50 | 51 | 52 | for epoch in trange(args.epochs, desc='Unlerning'): 53 | model.train() 54 | 55 | start_time = time.time() 56 | 57 | # Positive and negative sample 58 | neg_edge_index = negative_sampling( 59 | edge_index=data.train_pos_edge_index[:, data.df_mask], 60 | num_nodes=data.num_nodes, 61 | num_neg_samples=data.df_mask.sum()) 62 | 63 | z = model(data.x, data.train_pos_edge_index) 64 | logits = model.decode(z, data.train_pos_edge_index[:, data.df_mask]) 65 | label = torch.ones_like(logits, dtype=torch.float, device='cuda') 66 | loss = -F.binary_cross_entropy_with_logits(logits, label) 67 | 68 | # print('aaaaaaaaaaaaaa', data.df_mask.sum(), weight(model)) 69 | 70 | loss.backward() 71 | # torch.nn.utils.clip_grad_norm_(model.parameters(), 1) 72 | optimizer.step() 73 | optimizer.zero_grad() 74 | 75 | end_time = time.time() 76 | epoch_time = end_time - start_time 77 | 78 | step_log = { 79 | 'Epoch': epoch, 80 | 'train_loss': loss.item(), 81 | 'train_time': epoch_time 82 | } 83 | wandb.log(step_log) 84 | msg = [f'{i}: {j:>4d}' if isinstance(j, int) else f'{i}: {j:.4f}' for i, j in step_log.items()] 85 | tqdm.write(' | '.join(msg)) 86 | 87 | if (epoch + 1) % self.args.valid_freq == 0: 88 | valid_loss, dt_auc, dt_aup, df_auc, df_aup, df_logit, logit_all_pair, valid_log = self.eval(model, data, 'val') 89 | valid_log['epoch'] = epoch 90 | 91 | train_log = { 92 | 'epoch': epoch, 93 | 'train_loss': loss.item(), 94 | 'train_time': epoch_time, 95 | } 96 | 97 | for log in [train_log, valid_log]: 98 | wandb.log(log) 99 | msg = [f'{i}: {j:>4d}' if isinstance(j, int) else f'{i}: {j:.4f}' for i, j in log.items()] 100 | tqdm.write(' | '.join(msg)) 101 | self.trainer_log['log'].append(log) 102 | 103 | if dt_auc + df_auc > best_metric: 104 | best_metric = dt_auc + df_auc 105 | best_epoch = epoch 106 | 107 | print(f'Save best checkpoint at epoch {epoch:04d}. Valid loss = {valid_loss:.4f}') 108 | ckpt = { 109 | 'model_state': model.state_dict(), 110 | 'optimizer_state': optimizer.state_dict(), 111 | } 112 | torch.save(ckpt, os.path.join(args.checkpoint_dir, 'model_best.pt')) 113 | 114 | self.trainer_log['training_time'] = time.time() - start_time 115 | 116 | # Save 117 | ckpt = { 118 | 'model_state': {k: v.cpu() for k, v in model.state_dict().items()}, 119 | 'optimizer_state': optimizer.state_dict(), 120 | } 121 | torch.save(ckpt, os.path.join(args.checkpoint_dir, 'model_best.pt')) 122 | 123 | def train_minibatch(self, model, data, optimizer, args, logits_ori=None, attack_model_all=None, attack_model_sub=None): 124 | best_metric = 0 125 | 126 | # MI Attack before unlearning 127 | if attack_model_all is not None: 128 | mi_logit_all_before, mi_sucrate_all_before = member_infer_attack(model, attack_model_all, data) 129 | self.trainer_log['mi_logit_all_before'] = mi_logit_all_before 130 | self.trainer_log['mi_sucrate_all_before'] = mi_sucrate_all_before 131 | if attack_model_sub is not None: 132 | mi_logit_sub_before, mi_sucrate_sub_before = member_infer_attack(model, attack_model_sub, data) 133 | self.trainer_log['mi_logit_sub_before'] = mi_logit_sub_before 134 | self.trainer_log['mi_sucrate_sub_before'] = mi_sucrate_sub_before 135 | 136 | data.edge_index = data.train_pos_edge_index 137 | data.node_id = torch.arange(data.x.shape[0]) 138 | loader = GraphSAINTRandomWalkSampler( 139 | data, batch_size=args.batch_size, walk_length=2, num_steps=args.num_steps, 140 | ) 141 | for epoch in trange(args.epochs, desc='Unlerning'): 142 | model.train() 143 | 144 | epoch_loss = 0 145 | epoch_time = 0 146 | for step, batch in enumerate(tqdm(loader, leave=False)): 147 | start_time = time.time() 148 | batch = batch.to(device) 149 | 150 | z = model(batch.x, batch.edge_index[:, batch.dr_mask]) 151 | 152 | # Positive and negative sample 153 | neg_edge_index = negative_sampling( 154 | edge_index=batch.edge_index[:, batch.df_mask], 155 | num_nodes=z.size(0)) 156 | 157 | logits = model.decode(z, batch.edge_index[:, batch.df_mask]) 158 | label = torch.ones_like(logits, dtype=torch.float, device=device) 159 | loss = -F.binary_cross_entropy_with_logits(logits, label) 160 | 161 | loss.backward() 162 | # torch.nn.utils.clip_grad_norm_(model.parameters(), 1) 163 | optimizer.step() 164 | optimizer.zero_grad() 165 | 166 | end_time = time.time() 167 | epoch_loss += loss.item() 168 | epoch_time += end_time - start_time 169 | 170 | epoch_loss /= step 171 | epoch_time /= step 172 | 173 | if (epoch+1) % args.valid_freq == 0: 174 | valid_loss, dt_auc, dt_aup, df_auc, df_aup, df_logit, logit_all_pair, valid_log = self.eval(model, data, 'val') 175 | 176 | train_log = { 177 | 'epoch': epoch, 178 | 'train_loss': epoch_loss / step, 179 | 'train_time': epoch_time / step, 180 | } 181 | 182 | for log in [train_log, valid_log]: 183 | wandb.log(log) 184 | msg = [f'{i}: {j:>4d}' if isinstance(j, int) else f'{i}: {j:.4f}' for i, j in log.items()] 185 | tqdm.write(' | '.join(msg)) 186 | self.trainer_log['log'].append(log) 187 | 188 | if dt_auc + df_auc > best_metric: 189 | best_metric = dt_auc + df_auc 190 | best_epoch = epoch 191 | 192 | print(f'Save best checkpoint at epoch {epoch:04d}. Valid loss = {valid_loss:.4f}') 193 | ckpt = { 194 | 'model_state': model.state_dict(), 195 | 'optimizer_state': optimizer.state_dict(), 196 | } 197 | torch.save(ckpt, os.path.join(args.checkpoint_dir, 'model_best.pt')) 198 | torch.save(z, os.path.join(args.checkpoint_dir, 'node_embeddings.pt')) 199 | 200 | # Save 201 | ckpt = { 202 | 'model_state': {k: v.to('cpu') for k, v in model.state_dict().items()}, 203 | 'optimizer_state': optimizer.state_dict(), 204 | } 205 | torch.save(ckpt, os.path.join(args.checkpoint_dir, 'model_final.pt')) 206 | 207 | class KGGradientAscentTrainer(KGTrainer): 208 | def train(self, model, data, optimizer, args, logits_ori=None, attack_model_all=None, attack_model_sub=None): 209 | model = model.to(device) 210 | start_time = time.time() 211 | best_metric = 0 212 | 213 | # MI Attack before unlearning 214 | if attack_model_all is not None: 215 | mi_logit_all_before, mi_sucrate_all_before = member_infer_attack(model, attack_model_all, data) 216 | self.trainer_log['mi_logit_all_before'] = mi_logit_all_before 217 | self.trainer_log['mi_sucrate_all_before'] = mi_sucrate_all_before 218 | if attack_model_sub is not None: 219 | mi_logit_sub_before, mi_sucrate_sub_before = member_infer_attack(model, attack_model_sub, data) 220 | self.trainer_log['mi_logit_sub_before'] = mi_logit_sub_before 221 | self.trainer_log['mi_sucrate_sub_before'] = mi_sucrate_sub_before 222 | 223 | loader = GraphSAINTRandomWalkSampler( 224 | data, batch_size=128, walk_length=args.walk_length, num_steps=args.num_steps, 225 | ) 226 | for epoch in trange(args.epochs, desc='Epoch'): 227 | model.train() 228 | 229 | epoch_loss = 0 230 | for step, batch in enumerate(tqdm(loader, desc='Step', leave=False)): 231 | batch = batch.to(device) 232 | 233 | # Message passing 234 | edge_index = batch.edge_index[:, batch.dr_mask] 235 | edge_type = batch.edge_type[batch.dr_mask] 236 | z = model(batch.x, edge_index, edge_type) 237 | 238 | # Positive and negative sample 239 | decoding_edge_index = batch.edge_index[:, batch.df_mask] 240 | decoding_edge_type = batch.edge_type[batch.df_mask] 241 | decoding_mask = (decoding_edge_type < args.num_edge_type) # Only select directed edges for link prediction 242 | decoding_edge_index = decoding_edge_index[:, decoding_mask] 243 | decoding_edge_type = decoding_edge_type[decoding_mask] 244 | 245 | logits = model.decode(z, decoding_edge_index, decoding_edge_type) 246 | label = torch.ones_like(logits, dtype=torch.float, device=device) 247 | loss = -F.binary_cross_entropy_with_logits(logits, label) 248 | 249 | loss.backward() 250 | torch.nn.utils.clip_grad_norm_(model.parameters(), 1) 251 | optimizer.step() 252 | optimizer.zero_grad() 253 | 254 | log = { 255 | 'epoch': epoch, 256 | 'step': step, 257 | 'train_loss': loss.item(), 258 | } 259 | wandb.log(log) 260 | # msg = [f'{i}: {j:>4d}' if isinstance(j, int) else f'{i}: {j:.4f}' for i, j in log.items()] 261 | # tqdm.write(' | '.join(msg)) 262 | 263 | epoch_loss += loss.item() 264 | 265 | if (epoch + 1) % args.valid_freq == 0: 266 | valid_loss, dt_auc, dt_aup, df_auc, df_aup, df_logit, logit_all_pair, valid_log = self.eval(model, data, 'val') 267 | 268 | train_log = { 269 | 'epoch': epoch, 270 | 'train_loss': epoch_loss / step 271 | } 272 | 273 | for log in [train_log, valid_log]: 274 | wandb.log(log) 275 | msg = [f'{i}: {j:>4d}' if isinstance(j, int) else f'{i}: {j:.4f}' for i, j in log.items()] 276 | tqdm.write(' | '.join(msg)) 277 | 278 | self.trainer_log['log'].append(train_log) 279 | self.trainer_log['log'].append(valid_log) 280 | 281 | if dt_auc + df_auc > best_metric: 282 | best_metric = dt_auc + df_auc 283 | best_epoch = epoch 284 | 285 | print(f'Save best checkpoint at epoch {epoch:04d}. Valid loss = {valid_loss:.4f}') 286 | ckpt = { 287 | 'model_state': model.state_dict(), 288 | 'optimizer_state': optimizer.state_dict(), 289 | } 290 | torch.save(ckpt, os.path.join(args.checkpoint_dir, 'model_best.pt')) 291 | 292 | self.trainer_log['training_time'] = time.time() - start_time 293 | 294 | # Save models and node embeddings 295 | print('Saving final checkpoint') 296 | ckpt = { 297 | 'model_state': model.state_dict(), 298 | 'optimizer_state': optimizer.state_dict(), 299 | } 300 | torch.save(ckpt, os.path.join(args.checkpoint_dir, 'model_final.pt')) 301 | 302 | print(f'Training finished. Best checkpoint at epoch = {best_epoch:04d}, best valid loss = {best_metric:.4f}') 303 | 304 | self.trainer_log['best_epoch'] = best_epoch 305 | self.trainer_log['best_metric'] = best_metric 306 | self.trainer_log['training_time'] = np.mean([i['epoch_time'] for i in self.trainer_log['log'] if 'epoch_time' in i]) 307 | -------------------------------------------------------------------------------- /framework/trainer/gnndelete_embdis.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import wandb 4 | from tqdm import tqdm, trange 5 | import torch 6 | import torch.nn as nn 7 | from torch_geometric.utils import negative_sampling, k_hop_subgraph 8 | from torch_geometric.loader import GraphSAINTRandomWalkSampler 9 | 10 | from .base import Trainer 11 | from ..evaluation import * 12 | from ..utils import * 13 | 14 | 15 | def BoundedKLD(logits, truth): 16 | return 1 - torch.exp(-F.kl_div(F.log_softmax(logits, -1), truth.softmax(-1), None, None, 'batchmean')) 17 | 18 | class GNNDeleteEmbeddingDistanceTrainer(Trainer): 19 | 20 | def train(self, model, data, optimizer, args, logits_ori=None, attack_model_all=None, attack_model_sub=None): 21 | if 'ogbl' in self.args.dataset: 22 | return self.train_minibatch(model, data, optimizer, args, logits_ori, attack_model_all, attack_model_sub) 23 | 24 | else: 25 | return self.train_fullbatch(model, data, optimizer, args, logits_ori, attack_model_all, attack_model_sub) 26 | 27 | def train_fullbatch(self, model, data, optimizer, args, logits_ori=None, attack_model_all=None, attack_model_sub=None): 28 | model = model.to('cuda') 29 | data = data.to('cuda') 30 | 31 | best_metric = 0 32 | if 'kld' in args.unlearning_model: 33 | loss_fct = BoundedKLD 34 | else: 35 | loss_fct = nn.MSELoss() 36 | # neg_size = 10 37 | 38 | # MI Attack before unlearning 39 | if attack_model_all is not None: 40 | mi_logit_all_before, mi_sucrate_all_before = member_infer_attack(model, attack_model_all, data) 41 | self.trainer_log['mi_logit_all_before'] = mi_logit_all_before 42 | self.trainer_log['mi_sucrate_all_before'] = mi_sucrate_all_before 43 | if attack_model_sub is not None: 44 | mi_logit_sub_before, mi_sucrate_sub_before = member_infer_attack(model, attack_model_sub, data) 45 | self.trainer_log['mi_logit_sub_before'] = mi_logit_sub_before 46 | self.trainer_log['mi_sucrate_sub_before'] = mi_sucrate_sub_before 47 | 48 | # All node paris in S_Df without Df. For Local Causality 49 | ## S_Df all pair mask 50 | sdf_all_pair_mask = torch.zeros(data.num_nodes, data.num_nodes, dtype=torch.bool) 51 | idx = torch.combinations(torch.arange(data.num_nodes)[data.sdf_node_2hop_mask], with_replacement=True).t() 52 | sdf_all_pair_mask[idx[0], idx[1]] = True 53 | sdf_all_pair_mask[idx[1], idx[0]] = True 54 | 55 | # print(data.sdf_node_2hop_mask.sum()) 56 | # print(sdf_all_pair_mask.nonzero()) 57 | # print(data.train_pos_edge_index[:, data.df_mask][0], data.train_pos_edge_index[:, data.df_mask][1]) 58 | 59 | assert sdf_all_pair_mask.sum().cpu() == data.sdf_node_2hop_mask.sum().cpu() * data.sdf_node_2hop_mask.sum().cpu() 60 | 61 | ## Remove Df itself 62 | sdf_all_pair_mask[data.train_pos_edge_index[:, data.df_mask][0], data.train_pos_edge_index[:, data.df_mask][1]] = False 63 | sdf_all_pair_mask[data.train_pos_edge_index[:, data.df_mask][1], data.train_pos_edge_index[:, data.df_mask][0]] = False 64 | 65 | ## Lower triangular mask 66 | idx = torch.tril_indices(data.num_nodes, data.num_nodes, -1) 67 | lower_mask = torch.zeros(data.num_nodes, data.num_nodes, dtype=torch.bool) 68 | lower_mask[idx[0], idx[1]] = True 69 | 70 | ## The final mask is the intersection 71 | sdf_all_pair_without_df_mask = sdf_all_pair_mask & lower_mask 72 | 73 | # print('aaaaaaaaaaaa', data.sdf_node_2hop_mask.sum(), a, sdf_all_pair_mask.sum()) 74 | # print('aaaaaaaaaaaa', lower_mask.sum()) 75 | # print('aaaaaaaaaaaa', sdf_all_pair_without_df_mask.sum()) 76 | # print('aaaaaaaaaaaa', data.sdf_node_2hop_mask.sum()) 77 | # assert sdf_all_pair_without_df_mask.sum() == \ 78 | # data.sdf_node_2hop_mask.sum().cpu() * (data.sdf_node_2hop_mask.sum().cpu() - 1) // 2 - data.df_mask.sum().cpu() 79 | 80 | 81 | # Node representation for local causality 82 | with torch.no_grad(): 83 | z1_ori, z2_ori = model.get_original_embeddings(data.x, data.train_pos_edge_index[:, data.dtrain_mask], return_all_emb=True) 84 | 85 | total_time = 0 86 | for epoch in trange(args.epochs, desc='Unlerning'): 87 | model.train() 88 | start_time = time.time() 89 | 90 | z1, z2 = model(data.x, data.train_pos_edge_index[:, data.sdf_mask], return_all_emb=True) 91 | print('current deletion weight', model.deletion1.deletion_weight.sum(), model.deletion2.deletion_weight.sum()) 92 | 93 | # Effectiveness and Randomness 94 | neg_size = data.df_mask.sum() 95 | neg_edge_index = negative_sampling( 96 | edge_index=data.train_pos_edge_index, 97 | num_nodes=data.num_nodes, 98 | num_neg_samples=neg_size) 99 | 100 | df_logits = model.decode(z2, data.train_pos_edge_index[:, data.df_mask], neg_edge_index) 101 | loss_e = loss_fct(df_logits[:neg_size], df_logits[neg_size:]) 102 | # df_logits = model.decode( 103 | # z, 104 | # data.train_pos_edge_index[:, data.df_mask].repeat(1, neg_size), 105 | # neg_edge_index).sigmoid() 106 | 107 | # loss_e = loss_fct(df_logits[:neg_size], df_logits[neg_size:]) 108 | # print('df_logits', df_logits) 109 | # raise 110 | 111 | # Local causality 112 | if sdf_all_pair_without_df_mask.sum() != 0: 113 | loss_l = loss_fct(z1_ori[data.sdf_node_1hop_mask], z1[data.sdf_node_1hop_mask]) + \ 114 | loss_fct(z2_ori[data.sdf_node_2hop_mask], z2[data.sdf_node_2hop_mask]) 115 | print('local proba', loss_l.item()) 116 | 117 | else: 118 | loss_l = torch.tensor(0) 119 | print('local proba', 0) 120 | 121 | 122 | alpha = 0.5 123 | if 'ablation_random' in self.args.unlearning_model: 124 | loss_l = torch.tensor(0) 125 | loss = loss_e 126 | elif 'ablation_locality' in self.args.unlearning_model: 127 | loss_e = torch.tensor(0) 128 | loss = loss_l 129 | else: 130 | loss = alpha * loss_e + (1 - alpha) * loss_l 131 | 132 | loss.backward() 133 | # torch.nn.utils.clip_grad_norm_(model.parameters(), 1) 134 | optimizer.step() 135 | optimizer.zero_grad() 136 | 137 | end_time = time.time() 138 | 139 | log = { 140 | 'epoch': epoch, 141 | 'train_loss': loss.item(), 142 | 'train_loss_l': loss_l.item(), 143 | 'train_loss_e': loss_e.item(), 144 | 'train_time': end_time - start_time, 145 | } 146 | # wandb.log(log) 147 | msg = [f'{i}: {j:>4d}' if isinstance(j, int) else f'{i}: {j:.4f}' for i, j in log.items()] 148 | tqdm.write(' | '.join(msg)) 149 | 150 | if (epoch+1) % args.valid_freq == 0: 151 | valid_loss, dt_auc, dt_aup, df_auc, df_aup, df_logit, logit_all_pair, valid_log = self.eval(model, data, 'val') 152 | 153 | train_log = { 154 | 'epoch': epoch, 155 | 'train_loss': loss.item(), 156 | 'train_loss_l': loss_e.item(), 157 | 'train_loss_e': loss_l.item(), 158 | 'train_time': end_time - start_time, 159 | } 160 | 161 | for log in [train_log, valid_log]: 162 | wandb.log(log) 163 | msg = [f'{i}: {j:>4d}' if isinstance(j, int) else f'{i}: {j:.4f}' for i, j in log.items()] 164 | tqdm.write(' | '.join(msg)) 165 | self.trainer_log['log'].append(log) 166 | 167 | if dt_auc + df_auc > best_metric: 168 | best_metric = dt_auc + df_auc 169 | best_epoch = epoch 170 | 171 | print(f'Save best checkpoint at epoch {epoch:04d}. Valid loss = {valid_loss:.4f}') 172 | ckpt = { 173 | 'model_state': model.state_dict(), 174 | 'optimizer_state': optimizer.state_dict(), 175 | } 176 | torch.save(ckpt, os.path.join(args.checkpoint_dir, 'model_best.pt')) 177 | 178 | # Save 179 | ckpt = { 180 | 'model_state': {k: v.to('cpu') for k, v in model.state_dict().items()}, 181 | 'optimizer_state': optimizer.state_dict(), 182 | } 183 | torch.save(ckpt, os.path.join(args.checkpoint_dir, 'model_final.pt')) 184 | 185 | # Save 186 | ckpt = { 187 | 'model_state': {k: v.to('cpu') for k, v in model.state_dict().items()}, 188 | 'optimizer_state': optimizer.state_dict(), 189 | } 190 | torch.save(ckpt, os.path.join(args.checkpoint_dir, 'model_best.pt')) 191 | 192 | def train_minibatch(self, model, data, optimizer, args, logits_ori=None, attack_model_all=None, attack_model_sub=None): 193 | start_time = time.time() 194 | best_loss = 100000 195 | if 'kld' in args.unlearning_model: 196 | loss_fct = BoundedKLD 197 | else: 198 | loss_fct = nn.MSELoss() 199 | # neg_size = 10 200 | 201 | # MI Attack before unlearning 202 | if attack_model_all is not None: 203 | mi_logit_all_before, mi_sucrate_all_before = member_infer_attack(model, attack_model_all, data) 204 | self.trainer_log['mi_logit_all_before'] = mi_logit_all_before 205 | self.trainer_log['mi_sucrate_all_before'] = mi_sucrate_all_before 206 | if attack_model_sub is not None: 207 | mi_logit_sub_before, mi_sucrate_sub_before = member_infer_attack(model, attack_model_sub, data) 208 | self.trainer_log['mi_logit_sub_before'] = mi_logit_sub_before 209 | self.trainer_log['mi_sucrate_sub_before'] = mi_sucrate_sub_before 210 | 211 | z_ori = self.get_embedding(model, data, on_cpu=True) 212 | z_ori_two_hop = z_ori[data.sdf_node_2hop_mask] 213 | 214 | data.edge_index = data.train_pos_edge_index 215 | data.node_id = torch.arange(data.x.shape[0]) 216 | loader = GraphSAINTRandomWalkSampler( 217 | data, batch_size=args.batch_size, walk_length=2, num_steps=args.num_steps, 218 | ) 219 | for epoch in trange(args.epochs, desc='Unlerning'): 220 | model.train() 221 | 222 | print('current deletion weight', model.deletion1.deletion_weight.sum(), model.deletion2.deletion_weight.sum()) 223 | 224 | epoch_loss_e = 0 225 | epoch_loss_l = 0 226 | epoch_loss = 0 227 | for step, batch in enumerate(tqdm(loader, leave=False)): 228 | # print('data', batch) 229 | # print('two hop nodes', batch.sdf_node_2hop_mask.sum()) 230 | batch = batch.to('cuda') 231 | 232 | train_pos_edge_index = batch.edge_index 233 | z = model(batch.x, train_pos_edge_index[:, batch.sdf_mask], batch.sdf_node_1hop_mask, batch.sdf_node_2hop_mask) 234 | z_two_hop = z[batch.sdf_node_2hop_mask] 235 | 236 | # Effectiveness and Randomness 237 | neg_size = batch.df_mask.sum() 238 | neg_edge_index = negative_sampling( 239 | edge_index=train_pos_edge_index, 240 | num_nodes=z.size(0), 241 | num_neg_samples=neg_size) 242 | 243 | df_logits = model.decode(z, train_pos_edge_index[:, batch.df_mask], neg_edge_index) 244 | loss_e = loss_fct(df_logits[:neg_size], df_logits[neg_size:]) 245 | 246 | # Local causality 247 | mask = torch.zeros(data.x.shape[0], dtype=torch.bool) 248 | mask[batch.node_id[batch.sdf_node_2hop_mask]] = True 249 | z_ori_subset = z_ori[mask].to('cuda') 250 | 251 | # Only take the lower triangular part 252 | num_nodes = z_ori_subset.shape[0] 253 | idx = torch.tril_indices(num_nodes, num_nodes, -1) 254 | local_lower_mask = torch.zeros(num_nodes, num_nodes, dtype=torch.bool) 255 | local_lower_mask[idx[0], idx[1]] = True 256 | 257 | logits_ori = (z_ori_subset @ z_ori_subset.t())[local_lower_mask].sigmoid() 258 | logits = (z_two_hop @ z_two_hop.t())[local_lower_mask].sigmoid() 259 | 260 | loss_l = loss_fct(logits, logits_ori) 261 | 262 | 263 | alpha = 0.5 264 | if 'ablation_random' in self.args.unlearning_model: 265 | loss_l = torch.tensor(0) 266 | loss = loss_e 267 | elif 'ablation_locality' in self.args.unlearning_model: 268 | loss_e = torch.tensor(0) 269 | loss = loss_l 270 | else: 271 | loss = alpha * loss_e + (1 - alpha) * loss_l 272 | 273 | loss.backward() 274 | # torch.nn.utils.clip_grad_norm_(model.parameters(), 1) 275 | optimizer.step() 276 | optimizer.zero_grad() 277 | 278 | epoch_loss_e += loss_e.item() 279 | epoch_loss_l += loss_l.item() 280 | epoch_loss += loss.item() 281 | 282 | epoch_loss_e /= step 283 | epoch_loss_l /= step 284 | epoch_loss /= step 285 | 286 | 287 | if (epoch+1) % args.valid_freq == 0: 288 | valid_loss, auc, aup, df_logt, logit_all_pair = self.eval(model, data, 'val') 289 | 290 | log = { 291 | 'epoch': epoch, 292 | 'train_loss': epoch_loss, 293 | 'train_loss_e': epoch_loss_e, 294 | 'train_loss_l': epoch_loss_l, 295 | 'valid_dt_loss': valid_loss, 296 | 'valid_dt_auc': auc, 297 | 'valid_dt_aup': aup, 298 | } 299 | wandb.log(log) 300 | msg = [f'{i}: {j:>4d}' if isinstance(j, int) else f'{i}: {j:.4f}' for i, j in log.items()] 301 | tqdm.write(' | '.join(msg)) 302 | 303 | self.trainer_log['log'].append(log) 304 | 305 | self.trainer_log['training_time'] = time.time() - start_time 306 | 307 | # Save 308 | ckpt = { 309 | 'model_state': {k: v.to('cpu') for k, v in model.state_dict().items()}, 310 | 'optimizer_state': optimizer.state_dict(), 311 | } 312 | torch.save(ckpt, os.path.join(args.checkpoint_dir, 'model_best.pt')) 313 | -------------------------------------------------------------------------------- /framework/trainer/retrain.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import wandb 4 | from tqdm import tqdm, trange 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from torch_geometric.utils import negative_sampling 10 | from torch_geometric.loader import GraphSAINTRandomWalkSampler 11 | 12 | from .base import Trainer, KGTrainer 13 | from ..evaluation import * 14 | from ..utils import * 15 | 16 | 17 | device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') 18 | 19 | class RetrainTrainer(Trainer): 20 | 21 | def freeze_unused_mask(self, model, edge_to_delete, subgraph, h): 22 | gradient_mask = torch.zeros_like(delete_model.operator) 23 | 24 | edges = subgraph[h] 25 | for s, t in edges: 26 | if s < t: 27 | gradient_mask[s, t] = 1 28 | gradient_mask = gradient_mask.to(device) 29 | model.operator.register_hook(lambda grad: grad.mul_(gradient_mask)) 30 | 31 | def train(self, model, data, optimizer, args, logits_ori=None, attack_model_all=None, attack_model_sub=None): 32 | if 'ogbl' in self.args.dataset: 33 | return self.train_minibatch(model, data, optimizer, args, logits_ori, attack_model_all, attack_model_sub) 34 | 35 | else: 36 | return self.train_fullbatch(model, data, optimizer, args, logits_ori, attack_model_all, attack_model_sub) 37 | 38 | def train_fullbatch(self, model, data, optimizer, args, logits_ori=None, attack_model_all=None, attack_model_sub=None): 39 | model = model.to('cuda') 40 | data = data.to('cuda') 41 | 42 | best_metric = 0 43 | loss_fct = nn.MSELoss() 44 | 45 | # MI Attack before unlearning 46 | if attack_model_all is not None: 47 | mi_logit_all_before, mi_sucrate_all_before = member_infer_attack(model, attack_model_all, data) 48 | self.trainer_log['mi_logit_all_before'] = mi_logit_all_before 49 | self.trainer_log['mi_sucrate_all_before'] = mi_sucrate_all_before 50 | if attack_model_sub is not None: 51 | mi_logit_sub_before, mi_sucrate_sub_before = member_infer_attack(model, attack_model_sub, data) 52 | self.trainer_log['mi_logit_sub_before'] = mi_logit_sub_before 53 | self.trainer_log['mi_sucrate_sub_before'] = mi_sucrate_sub_before 54 | 55 | for epoch in trange(args.epochs, desc='Unlearning'): 56 | model.train() 57 | 58 | start_time = time.time() 59 | total_step = 0 60 | total_loss = 0 61 | 62 | neg_edge_index = negative_sampling( 63 | edge_index=data.train_pos_edge_index[:, data.dr_mask], 64 | num_nodes=data.num_nodes, 65 | num_neg_samples=data.dr_mask.sum()) 66 | 67 | z = model(data.x, data.train_pos_edge_index[:, data.dr_mask]) 68 | logits = model.decode(z, data.train_pos_edge_index[:, data.dr_mask], neg_edge_index) 69 | label = self.get_link_labels(data.train_pos_edge_index[:, data.dr_mask], neg_edge_index) 70 | loss = F.binary_cross_entropy_with_logits(logits, label) 71 | 72 | loss.backward() 73 | # torch.nn.utils.clip_grad_norm_(model.parameters(), 1) 74 | optimizer.step() 75 | optimizer.zero_grad() 76 | 77 | total_step += 1 78 | total_loss += loss.item() 79 | 80 | end_time = time.time() 81 | epoch_time = end_time - start_time 82 | 83 | step_log = { 84 | 'Epoch': epoch, 85 | 'train_loss': loss.item(), 86 | 'train_time': epoch_time 87 | } 88 | wandb.log(step_log) 89 | msg = [f'{i}: {j:>4d}' if isinstance(j, int) else f'{i}: {j:.4f}' for i, j in step_log.items()] 90 | tqdm.write(' | '.join(msg)) 91 | 92 | if (epoch + 1) % self.args.valid_freq == 0: 93 | valid_loss, dt_auc, dt_aup, df_auc, df_aup, df_logit, logit_all_pair, valid_log = self.eval(model, data, 'val') 94 | valid_log['epoch'] = epoch 95 | 96 | train_log = { 97 | 'epoch': epoch, 98 | 'train_loss': loss.item(), 99 | 'train_time': epoch_time, 100 | } 101 | 102 | for log in [train_log, valid_log]: 103 | wandb.log(log) 104 | msg = [f'{i}: {j:>4d}' if isinstance(j, int) else f'{i}: {j:.4f}' for i, j in log.items()] 105 | tqdm.write(' | '.join(msg)) 106 | self.trainer_log['log'].append(log) 107 | 108 | if dt_auc + df_auc > best_metric: 109 | best_metric = dt_auc + df_auc 110 | best_epoch = epoch 111 | 112 | print(f'Save best checkpoint at epoch {epoch:04d}. Valid loss = {valid_loss:.4f}') 113 | ckpt = { 114 | 'model_state': model.state_dict(), 115 | 'optimizer_state': optimizer.state_dict(), 116 | } 117 | torch.save(ckpt, os.path.join(args.checkpoint_dir, 'model_best.pt')) 118 | 119 | self.trainer_log['training_time'] = time.time() - start_time 120 | 121 | # Save 122 | ckpt = { 123 | 'model_state': model.state_dict(), 124 | 'optimizer_state': optimizer.state_dict(), 125 | } 126 | torch.save(ckpt, os.path.join(args.checkpoint_dir, 'model_final.pt')) 127 | 128 | print(f'Training finished. Best checkpoint at epoch = {best_epoch:04d}, best metric = {best_metric:.4f}') 129 | 130 | self.trainer_log['best_epoch'] = best_epoch 131 | self.trainer_log['best_metric'] = best_metric 132 | 133 | 134 | def train_minibatch(self, model, data, optimizer, args, logits_ori=None, attack_model_all=None, attack_model_sub=None): 135 | start_time = time.time() 136 | best_metric = 0 137 | 138 | # MI Attack before unlearning 139 | if attack_model_all is not None: 140 | mi_logit_all_before, mi_sucrate_all_before = member_infer_attack(model, attack_model_all, data) 141 | self.trainer_log['mi_logit_all_before'] = mi_logit_all_before 142 | self.trainer_log['mi_sucrate_all_before'] = mi_sucrate_all_before 143 | if attack_model_sub is not None: 144 | mi_logit_sub_before, mi_sucrate_sub_before = member_infer_attack(model, attack_model_sub, data) 145 | self.trainer_log['mi_logit_sub_before'] = mi_logit_sub_before 146 | self.trainer_log['mi_sucrate_sub_before'] = mi_sucrate_sub_before 147 | 148 | data.edge_index = data.train_pos_edge_index 149 | loader = GraphSAINTRandomWalkSampler( 150 | data, batch_size=args.batch_size, walk_length=2, num_steps=args.num_steps, 151 | ) 152 | for epoch in trange(args.epochs, desc='Epoch'): 153 | model.train() 154 | 155 | start_time = time.time() 156 | epoch_loss = 0 157 | for step, batch in enumerate(tqdm(loader, desc='Step', leave=False)): 158 | batch = batch.to(device) 159 | 160 | # Positive and negative sample 161 | train_pos_edge_index = batch.edge_index[:, batch.dr_mask] 162 | z = model(batch.x, train_pos_edge_index) 163 | 164 | neg_edge_index = negative_sampling( 165 | edge_index=train_pos_edge_index, 166 | num_nodes=z.size(0)) 167 | 168 | logits = model.decode(z, train_pos_edge_index, neg_edge_index) 169 | label = get_link_labels(train_pos_edge_index, neg_edge_index) 170 | loss = F.binary_cross_entropy_with_logits(logits, label) 171 | 172 | loss.backward() 173 | # torch.nn.utils.clip_grad_norm_(model.parameters(), 1) 174 | optimizer.step() 175 | optimizer.zero_grad() 176 | 177 | step_log = { 178 | 'epoch': epoch, 179 | 'step': step, 180 | 'train_loss': loss.item(), 181 | } 182 | wandb.log(step_log) 183 | msg = [f'{i}: {j:>4d}' if isinstance(j, int) else f'{i}: {j:.4f}' for i, j in step_log.items()] 184 | tqdm.write(' | '.join(msg)) 185 | 186 | epoch_loss += loss.item() 187 | 188 | end_time = time.time() 189 | epoch_time = end_time - start_time 190 | 191 | if (epoch+1) % args.valid_freq == 0: 192 | valid_loss, dt_auc, dt_aup, df_auc, df_aup, df_logit, logit_all_pair, valid_log = self.eval(model, data, 'val') 193 | valid_log['epoch'] = epoch 194 | 195 | train_log = { 196 | 'epoch': epoch, 197 | 'train_loss': loss.item(), 198 | 'train_time': epoch_time, 199 | } 200 | 201 | for log in [train_log, valid_log]: 202 | wandb.log(log) 203 | msg = [f'{i}: {j:>4d}' if isinstance(j, int) else f'{i}: {j:.4f}' for i, j in log.items()] 204 | tqdm.write(' | '.join(msg)) 205 | self.trainer_log['log'].append(log) 206 | 207 | if dt_auc + df_auc > best_metric: 208 | best_metric = dt_auc + df_auc 209 | best_epoch = epoch 210 | 211 | print(f'Save best checkpoint at epoch {epoch:04d}. Valid loss = {valid_loss:.4f}') 212 | ckpt = { 213 | 'model_state': model.state_dict(), 214 | 'optimizer_state': optimizer.state_dict(), 215 | } 216 | torch.save(ckpt, os.path.join(args.checkpoint_dir, 'model_best.pt')) 217 | torch.save(z, os.path.join(args.checkpoint_dir, 'node_embeddings.pt')) 218 | 219 | self.trainer_log['training_time'] = time.time() - start_time 220 | 221 | # Save models and node embeddings 222 | print('Saving final checkpoint') 223 | ckpt = { 224 | 'model_state': model.state_dict(), 225 | 'optimizer_state': optimizer.state_dict(), 226 | } 227 | torch.save(ckpt, os.path.join(args.checkpoint_dir, 'model_final.pt')) 228 | 229 | print(f'Training finished. Best checkpoint at epoch = {best_epoch:04d}, best metric = {best_metric:.4f}') 230 | 231 | self.trainer_log['best_epoch'] = best_epoch 232 | self.trainer_log['best_metric'] = best_metric 233 | 234 | 235 | class KGRetrainTrainer(KGTrainer): 236 | def train(self, model, data, optimizer, args, logits_ori=None, attack_model_all=None, attack_model_sub=None): 237 | model = model.to(device) 238 | start_time = time.time() 239 | best_metric = 0 240 | 241 | # MI Attack before unlearning 242 | if attack_model_all is not None: 243 | mi_logit_all_before, mi_sucrate_all_before = member_infer_attack(model, attack_model_all, data) 244 | self.trainer_log['mi_logit_all_before'] = mi_logit_all_before 245 | self.trainer_log['mi_sucrate_all_before'] = mi_sucrate_all_before 246 | if attack_model_sub is not None: 247 | mi_logit_sub_before, mi_sucrate_sub_before = member_infer_attack(model, attack_model_sub, data) 248 | self.trainer_log['mi_logit_sub_before'] = mi_logit_sub_before 249 | self.trainer_log['mi_sucrate_sub_before'] = mi_sucrate_sub_before 250 | 251 | loader = GraphSAINTRandomWalkSampler( 252 | data, batch_size=128, walk_length=2, num_steps=args.num_steps, 253 | ) 254 | for epoch in trange(args.epochs, desc='Epoch'): 255 | model.train() 256 | 257 | epoch_loss = 0 258 | for step, batch in enumerate(tqdm(loader, desc='Step', leave=False)): 259 | batch = batch.to(device) 260 | 261 | # Message passing 262 | edge_index = batch.edge_index[:, batch.dr_mask] 263 | edge_type = batch.edge_type[batch.dr_mask] 264 | z = model(batch.x, edge_index, edge_type) 265 | 266 | # Positive and negative sample 267 | decoding_mask = (edge_type < args.num_edge_type) # Only select directed edges for link prediction 268 | decoding_edge_index = edge_index[:, decoding_mask] 269 | decoding_edge_type = edge_type[decoding_mask] 270 | 271 | neg_edge_index = negative_sampling_kg( 272 | edge_index=decoding_edge_index, 273 | edge_type=decoding_edge_type) 274 | 275 | pos_logits = model.decode(z, decoding_edge_index, decoding_edge_type) 276 | neg_logits = model.decode(z, neg_edge_index, decoding_edge_type) 277 | logits = torch.cat([pos_logits, neg_logits], dim=-1) 278 | label = get_link_labels(decoding_edge_index, neg_edge_index) 279 | # reg_loss = z.pow(2).mean() + model.W.pow(2).mean() 280 | loss = F.binary_cross_entropy_with_logits(logits, label)# + 1e-2 * reg_loss 281 | 282 | loss.backward() 283 | torch.nn.utils.clip_grad_norm_(model.parameters(), 1) 284 | optimizer.step() 285 | optimizer.zero_grad() 286 | 287 | log = { 288 | 'epoch': epoch, 289 | 'step': step, 290 | 'train_loss': loss.item(), 291 | } 292 | wandb.log(log) 293 | # msg = [f'{i}: {j:>4d}' if isinstance(j, int) else f'{i}: {j:.4f}' for i, j in log.items()] 294 | # tqdm.write(' | '.join(msg)) 295 | 296 | epoch_loss += loss.item() 297 | 298 | if (epoch + 1) % args.valid_freq == 0: 299 | valid_loss, dt_auc, dt_aup, df_auc, df_aup, df_logit, logit_all_pair, valid_log = self.eval(model, data, 'val') 300 | 301 | train_log = { 302 | 'epoch': epoch, 303 | 'train_loss': epoch_loss / step 304 | } 305 | 306 | for log in [train_log, valid_log]: 307 | wandb.log(log) 308 | msg = [f'{i}: {j:>4d}' if isinstance(j, int) else f'{i}: {j:.4f}' for i, j in log.items()] 309 | tqdm.write(' | '.join(msg)) 310 | 311 | self.trainer_log['log'].append(train_log) 312 | self.trainer_log['log'].append(valid_log) 313 | 314 | if dt_auc + df_auc > best_metric: 315 | best_metric = dt_auc + df_auc 316 | best_epoch = epoch 317 | 318 | print(f'Save best checkpoint at epoch {epoch:04d}. Valid loss = {valid_loss:.4f}') 319 | ckpt = { 320 | 'model_state': model.state_dict(), 321 | 'optimizer_state': optimizer.state_dict(), 322 | } 323 | torch.save(ckpt, os.path.join(args.checkpoint_dir, 'model_best.pt')) 324 | 325 | self.trainer_log['training_time'] = time.time() - start_time 326 | 327 | # Save models and node embeddings 328 | print('Saving final checkpoint') 329 | ckpt = { 330 | 'model_state': model.state_dict(), 331 | 'optimizer_state': optimizer.state_dict(), 332 | } 333 | torch.save(ckpt, os.path.join(args.checkpoint_dir, 'model_final.pt')) 334 | 335 | print(f'Training finished. Best checkpoint at epoch = {best_epoch:04d}, best valid loss = {best_metric:.4f}') 336 | 337 | self.trainer_log['best_epoch'] = best_epoch 338 | self.trainer_log['best_metric'] = best_metric 339 | self.trainer_log['training_time'] = np.mean([i['epoch_time'] for i in self.trainer_log['log'] if 'epoch_time' in i]) 340 | -------------------------------------------------------------------------------- /framework/trainer/graph_eraser.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import copy 4 | import math 5 | from tqdm import tqdm, trange 6 | import numpy as np 7 | import torch 8 | import torch.nn.functional as F 9 | from torch_geometric.utils import negative_sampling, subgraph 10 | 11 | from .base import Trainer 12 | from ..evaluation import * 13 | from ..utils import * 14 | 15 | 16 | class ConstrainedKmeans: 17 | '''This code is from https://github.com/MinChen00/Graph-Unlearning''' 18 | 19 | def __init__(self, args, data_feat, num_clusters, node_threshold, terminate_delta, max_iteration=20): 20 | self.args = args 21 | self.data_feat = data_feat 22 | self.num_clusters = num_clusters 23 | self.node_threshold = node_threshold 24 | self.terminate_delta = terminate_delta 25 | self.max_iteration = max_iteration 26 | 27 | def initialization(self): 28 | centroids = np.random.choice(np.arange(self.data_feat.shape[0]), self.num_clusters, replace=False) 29 | self.centroid = {} 30 | for i in range(self.num_clusters): 31 | self.centroid[i] = self.data_feat[centroids[i]] 32 | 33 | def clustering(self): 34 | centroid = copy.deepcopy(self.centroid) 35 | km_delta = [] 36 | 37 | # pbar = tqdm(total=self.max_iteration) 38 | # pbar.set_description('Clustering') 39 | 40 | for i in trange(self.max_iteration, desc='Graph partition'): 41 | # self.logger.info('iteration %s' % (i,)) 42 | 43 | self._node_reassignment() 44 | self._centroid_updating() 45 | 46 | # record the average change of centroids, if the change is smaller than a very small value, then terminate 47 | delta = self._centroid_delta(centroid, self.centroid) 48 | km_delta.append(delta) 49 | centroid = copy.deepcopy(self.centroid) 50 | 51 | if delta <= self.terminate_delta: 52 | break 53 | print("delta: %s" % delta) 54 | # pbar.close() 55 | return self.clusters, km_delta 56 | 57 | def _node_reassignment(self): 58 | self.clusters = {} 59 | for i in range(self.num_clusters): 60 | self.clusters[i] = np.zeros(0, dtype=np.uint64) 61 | 62 | distance = np.zeros([self.num_clusters, self.data_feat.shape[0]]) 63 | 64 | for i in range(self.num_clusters): 65 | distance[i] = np.sum(np.power((self.data_feat - self.centroid[i]), 2), axis=1) 66 | 67 | sort_indices = np.unravel_index(np.argsort(distance, axis=None), distance.shape) 68 | clusters = sort_indices[0] 69 | users = sort_indices[1] 70 | selected_nodes = np.zeros(0, dtype=np.int64) 71 | counter = 0 72 | 73 | while len(selected_nodes) < self.data_feat.shape[0]: 74 | cluster = int(clusters[counter]) 75 | user = users[counter] 76 | if self.clusters[cluster].size < self.node_threshold: 77 | self.clusters[cluster] = np.append(self.clusters[cluster], np.array(int(user))) 78 | selected_nodes = np.append(selected_nodes, np.array(int(user))) 79 | 80 | # delete all the following pairs for the selected user 81 | user_indices = np.where(users == user)[0] 82 | a = np.arange(users.size) 83 | b = user_indices[user_indices > counter] 84 | remain_indices = a[np.where(np.logical_not(np.isin(a, b)))[0]] 85 | clusters = clusters[remain_indices] 86 | users = users[remain_indices] 87 | 88 | counter += 1 89 | 90 | def _centroid_updating(self): 91 | for i in range(self.num_clusters): 92 | self.centroid[i] = np.mean(self.data_feat[self.clusters[i].astype(int)], axis=0) 93 | 94 | def _centroid_delta(self, centroid_pre, centroid_cur): 95 | delta = 0.0 96 | for i in range(len(centroid_cur)): 97 | delta += np.sum(np.abs(centroid_cur[i] - centroid_pre[i])) 98 | 99 | return delta 100 | 101 | def generate_shard_data(self, data): 102 | shard_data = {} 103 | for shard in trange(self.args['num_shards'], desc='Generate shard data'): 104 | train_shard_indices = list(self.community_to_node[shard]) 105 | shard_indices = np.union1d(train_shard_indices, self.test_indices) 106 | 107 | x = data.x[shard_indices] 108 | y = data.y[shard_indices] 109 | edge_index = utils.filter_edge_index_1(data, shard_indices) 110 | 111 | data = Data(x=x, edge_index=torch.from_numpy(edge_index), y=y) 112 | data.train_mask = torch.from_numpy(np.isin(shard_indices, train_shard_indices)) 113 | data.test_mask = torch.from_numpy(np.isin(shard_indices, self.test_indices)) 114 | 115 | shard_data[shard] = data 116 | 117 | self.data_store.save_shard_data(self.shard_data) 118 | 119 | class OptimalAggregator: 120 | def __init__(self, run, target_model, data, args): 121 | self.args = args 122 | 123 | self.run = run 124 | self.target_model = target_model 125 | self.data = data 126 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 127 | 128 | self.num_shards = args.num_clusters 129 | 130 | def generate_train_data(self): 131 | data_store = DataStore(self.args) 132 | train_indices, _ = data_store.load_train_test_split() 133 | 134 | # sample a set of nodes from train_indices 135 | if self.args["num_opt_samples"] == 1000: 136 | train_indices = np.random.choice(train_indices, size=1000, replace=False) 137 | elif self.args["num_opt_samples"] == 10000: 138 | train_indices = np.random.choice(train_indices, size=int(train_indices.shape[0] * 0.1), replace=False) 139 | elif self.args["num_opt_samples"] == 1: 140 | train_indices = np.random.choice(train_indices, size=int(train_indices.shape[0]), replace=False) 141 | 142 | train_indices = np.sort(train_indices) 143 | self.logger.info("Using %s samples for optimization" % (int(train_indices.shape[0]))) 144 | 145 | x = self.data.x[train_indices] 146 | y = self.data.y[train_indices] 147 | edge_index = utils.filter_edge_index(self.data.edge_index, train_indices) 148 | 149 | train_data = Data(x=x, edge_index=torch.from_numpy(edge_index), y=y) 150 | train_data.train_mask = torch.zeros(train_indices.shape[0], dtype=torch.bool) 151 | train_data.test_mask = torch.ones(train_indices.shape[0], dtype=torch.bool) 152 | self.true_labels = y 153 | 154 | self.posteriors = {} 155 | for shard in range(self.num_shards): 156 | self.target_model.data = train_data 157 | data_store.load_target_model(self.run, self.target_model, shard) 158 | self.posteriors[shard] = self.target_model.posterior().to(self.device) 159 | 160 | def optimization(self): 161 | weight_para = nn.Parameter(torch.full((self.num_shards,), fill_value=1.0 / self.num_shards), requires_grad=True) 162 | optimizer = optim.Adam([weight_para], lr=self.args['opt_lr']) 163 | scheduler = MultiStepLR(optimizer, milestones=[500, 1000], gamma=self.args['opt_lr']) 164 | 165 | train_dset = OptDataset(self.posteriors, self.true_labels) 166 | train_loader = DataLoader(train_dset, batch_size=32, shuffle=True, num_workers=0) 167 | 168 | min_loss = 1000.0 169 | for epoch in range(self.args.epochs): 170 | loss_all = 0.0 171 | 172 | for posteriors, labels in train_loader: 173 | labels = labels.to(self.device) 174 | 175 | optimizer.zero_grad() 176 | loss = self._loss_fn(posteriors, labels, weight_para) 177 | loss.backward() 178 | loss_all += loss 179 | 180 | optimizer.step() 181 | with torch.no_grad(): 182 | weight_para[:] = torch.clamp(weight_para, min=0.0) 183 | 184 | scheduler.step() 185 | 186 | if loss_all < min_loss: 187 | ret_weight_para = copy.deepcopy(weight_para) 188 | min_loss = loss_all 189 | 190 | self.logger.info('epoch: %s, loss: %s' % (epoch, loss_all)) 191 | 192 | return ret_weight_para / torch.sum(ret_weight_para) 193 | 194 | def _loss_fn(self, posteriors, labels, weight_para): 195 | aggregate_posteriors = torch.zeros_like(posteriors[0]) 196 | for shard in range(self.num_shards): 197 | aggregate_posteriors += weight_para[shard] * posteriors[shard] 198 | 199 | aggregate_posteriors = F.softmax(aggregate_posteriors, dim=1) 200 | loss_1 = F.cross_entropy(aggregate_posteriors, labels) 201 | loss_2 = torch.sqrt(torch.sum(weight_para ** 2)) 202 | 203 | return loss_1 + loss_2 204 | 205 | class Aggregator: 206 | def __init__(self, run, target_model, data, shard_data, args): 207 | self.args = args 208 | 209 | self.run = run 210 | self.target_model = target_model 211 | self.data = data 212 | self.shard_data = shard_data 213 | 214 | self.num_shards = args.num_clusters 215 | 216 | def generate_posterior(self, suffix=""): 217 | self.true_label = self.shard_data[0].y[self.shard_data[0]['test_mask']].detach().cpu().numpy() 218 | self.posteriors = {} 219 | 220 | for shard in range(self.args.num_clusters): 221 | self.target_model.data = self.shard_data[shard] 222 | self.data_store.load_target_model(self.run, self.target_model, shard, suffix) 223 | self.posteriors[shard] = self.target_model.posterior() 224 | 225 | def _optimal_aggregator(self): 226 | optimal = OptimalAggregator(self.run, self.target_model, self.data, self.args) 227 | optimal.generate_train_data() 228 | weight_para = optimal.optimization() 229 | self.data_store.save_optimal_weight(weight_para, run=self.run) 230 | 231 | posterior = self.posteriors[0] * weight_para[0] 232 | for shard in range(1, self.num_shards): 233 | posterior += self.posteriors[shard] * weight_para[shard] 234 | 235 | return f1_score(self.true_label, posterior.argmax(axis=1).cpu().numpy(), average="micro") 236 | 237 | class GraphEraserTrainer(Trainer): 238 | 239 | def train(self, model, data, optimizer, args, logits_ori=None, attack_model_all=None, attack_model_sub=None): 240 | 241 | with torch.no_grad(): 242 | z = model(data.x, data.train_pos_edge_index[:, data.dr_mask]) 243 | 244 | # Retrain the model 245 | for c in model.children(): 246 | print('before', torch.norm(c.lin.weight), torch.norm(c.bias)) 247 | for c in model.children(): 248 | c.reset_parameters() 249 | for c in model.children(): 250 | print('after', torch.norm(c.lin.weight), torch.norm(c.bias)) 251 | model = model.cpu() 252 | 253 | num_nodes = data.num_nodes 254 | node_threshold = math.ceil( 255 | num_nodes / args.num_clusters + args.shard_size_delta * (num_nodes - num_nodes / args.num_clusters)) 256 | print(f'Number of nodes: {num_nodes}. Shard threshold: {node_threshold}') 257 | 258 | cluster = ConstrainedKmeans( 259 | args, 260 | z.cpu().numpy(), 261 | args.num_clusters, 262 | node_threshold, 263 | args.terminate_delta, 264 | args.kmeans_max_iters) 265 | cluster.initialization() 266 | 267 | community, km_deltas = cluster.clustering() 268 | # with open(os.path.join(args.checkpoint_dir, 'kmeans_delta.pkl'), 'wb') as f: 269 | # pickle.dump(km_deltas, f) 270 | 271 | community_to_node = {} 272 | for i in range(args.num_clusters): 273 | community_to_node[i] = np.array(community[i].astype(int)) 274 | 275 | models = {} 276 | test_result = [] 277 | for shard_id in trange(args.num_clusters, desc='Sharded retraining'): 278 | model_shard_id = copy.deepcopy(model).to('cuda') 279 | optimizer = torch.optim.Adam(model_shard_id.parameters(), lr=args.lr) 280 | 281 | subset_train, _ = subgraph( 282 | torch.tensor(community[shard_id], dtype=torch.long, device=device), 283 | data.train_pos_edge_index, 284 | num_nodes=data.num_nodes) 285 | 286 | self.train_model(model_shard_id, data, subset_train, optimizer, args, shard_id) 287 | 288 | with torch.no_grad(): 289 | z = model_shard_id(data.x, subset_train) 290 | logits = model_shard_id.decode(data.test_pos_edge_index, data.test_neg_edge_index) 291 | 292 | weight_para = nn.Parameter(torch.full((self.num_shards,), fill_value=1.0 / self.num_shards), requires_grad=True) 293 | optimizer = optim.Adam([weight_para], lr=self.args.lr) 294 | 295 | 296 | aggregator.generate_posterior() 297 | self.aggregate_f1_score = aggregator.aggregate() 298 | aggregate_time = time.time() - start_time 299 | self.logger.info("Partition cost %s seconds." % aggregate_time) 300 | 301 | self.logger.info("Final Test F1: %s" % (self.aggregate_f1_score,)) 302 | 303 | 304 | 305 | def train_model(self, model, data, subset_train, optimizer, args, shard_id): 306 | 307 | best_loss = 100000 308 | for epoch in range(args.epochs): 309 | model.train() 310 | 311 | neg_edge_index = negative_sampling( 312 | edge_index=subset_train, 313 | num_nodes=data.num_nodes, 314 | num_neg_samples=subset_train.shape[1]) 315 | 316 | z = model(data.x, subset_train) 317 | logits = model.decode(z, subset_train, neg_edge_index) 318 | label = self.get_link_labels(subset_train, neg_edge_index) 319 | loss = F.binary_cross_entropy_with_logits(logits, label) 320 | 321 | loss.backward() 322 | # torch.nn.utils.clip_grad_norm_(model.parameters(), 1) 323 | optimizer.step() 324 | optimizer.zero_grad() 325 | 326 | valid_loss, auc, aup, _, _, = self.eval_model(model, data, subset_train, 'val') 327 | log = { 328 | 'train_loss': loss.item(), 329 | 'valid_loss': valid_loss, 330 | 'valid_auc': auc, 331 | 'valid_aup': aup, 332 | } 333 | msg = [f'{i}: {j:>4d}' if isinstance(j, int) else f'{i}: {j:.4f}' for i, j in log.items()] 334 | tqdm.write(' | '.join(msg)) 335 | self.trainer_log[f'shard_{shard_id}'] = log 336 | 337 | torch.save(model.state_dict(), os.path.join(args.checkpoint_dir, f'model_{shard_id}.pt')) 338 | 339 | @torch.no_grad() 340 | def eval_model(self, model, data, subset_train, stage='val', pred_all=False): 341 | model.eval() 342 | pos_edge_index = data[f'{stage}_pos_edge_index'] 343 | neg_edge_index = data[f'{stage}_neg_edge_index'] 344 | 345 | z = model(data.x, subset_train) 346 | logits = model.decode(z, pos_edge_index, neg_edge_index).sigmoid() 347 | label = self.get_link_labels(pos_edge_index, neg_edge_index) 348 | 349 | loss = F.binary_cross_entropy_with_logits(logits, label).cpu().item() 350 | auc = roc_auc_score(label.cpu(), logits.cpu()) 351 | aup = average_precision_score(label.cpu(), logits.cpu()) 352 | 353 | if self.args.unlearning_model in ['original', 'retrain']: 354 | df_logit = float('nan') 355 | else: 356 | # df_logit = float('nan') 357 | df_logit = model.decode(z, subset_train).sigmoid().detach().cpu().item() 358 | 359 | if pred_all: 360 | logit_all_pair = (z @ z.t()).cpu() 361 | else: 362 | logit_all_pair = None 363 | 364 | log = { 365 | f'{stage}_loss': loss, 366 | f'{stage}_auc': auc, 367 | f'{stage}_aup': aup, 368 | f'{stage}_df_logit': df_logit, 369 | } 370 | wandb.log(log) 371 | msg = [f'{i}: {j:.4f}' if isinstance(j, (np.floating, float)) else f'{i}: {j:>4d}' for i, j in log.items()] 372 | tqdm.write(' | '.join(msg)) 373 | 374 | return loss, auc, aup, df_logit, logit_all_pair 375 | -------------------------------------------------------------------------------- /framework/models/rgat.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from sklearn.metrics import roc_auc_score, average_precision_score 6 | 7 | from typing import Optional 8 | 9 | import torch 10 | import torch.nn.functional as F 11 | from torch import Tensor 12 | from torch.nn import Parameter, ReLU 13 | from torch_scatter import scatter_add 14 | from torch_sparse import SparseTensor 15 | 16 | from torch_geometric.nn.conv import MessagePassing 17 | from torch_geometric.nn.dense.linear import Linear 18 | from torch_geometric.nn.inits import glorot, ones, zeros 19 | from torch_geometric.typing import Adj, OptTensor, Size 20 | from torch_geometric.utils import softmax 21 | 22 | 23 | # Source: torch_geometric 24 | class RGATConv(MessagePassing): 25 | _alpha: OptTensor 26 | 27 | def __init__( 28 | self, 29 | in_channels: int, 30 | out_channels: int, 31 | num_relations: int, 32 | num_bases: Optional[int] = None, 33 | num_blocks: Optional[int] = None, 34 | mod: Optional[str] = None, 35 | attention_mechanism: str = "across-relation", 36 | attention_mode: str = "additive-self-attention", 37 | heads: int = 1, 38 | dim: int = 1, 39 | concat: bool = True, 40 | negative_slope: float = 0.2, 41 | dropout: float = 0.0, 42 | edge_dim: Optional[int] = None, 43 | bias: bool = True, 44 | **kwargs, 45 | ): 46 | kwargs.setdefault('aggr', 'add') 47 | super().__init__(node_dim=0, **kwargs) 48 | 49 | self.heads = heads 50 | self.negative_slope = negative_slope 51 | self.dropout = dropout 52 | self.mod = mod 53 | self.activation = ReLU() 54 | self.concat = concat 55 | self.attention_mode = attention_mode 56 | self.attention_mechanism = attention_mechanism 57 | self.dim = dim 58 | self.edge_dim = edge_dim 59 | 60 | self.in_channels = in_channels 61 | self.out_channels = out_channels 62 | self.num_relations = num_relations 63 | self.num_bases = num_bases 64 | self.num_blocks = num_blocks 65 | 66 | mod_types = ['additive', 'scaled', 'f-additive', 'f-scaled'] 67 | 68 | if (self.attention_mechanism != "within-relation" 69 | and self.attention_mechanism != "across-relation"): 70 | raise ValueError('attention mechanism must either be ' 71 | '"within-relation" or "across-relation"') 72 | 73 | if (self.attention_mode != "additive-self-attention" 74 | and self.attention_mode != "multiplicative-self-attention"): 75 | raise ValueError('attention mode must either be ' 76 | '"additive-self-attention" or ' 77 | '"multiplicative-self-attention"') 78 | 79 | if self.attention_mode == "additive-self-attention" and self.dim > 1: 80 | raise ValueError('"additive-self-attention" mode cannot be ' 81 | 'applied when value of d is greater than 1. ' 82 | 'Use "multiplicative-self-attention" instead.') 83 | 84 | if self.dropout > 0.0 and self.mod in mod_types: 85 | raise ValueError('mod must be None with dropout value greater ' 86 | 'than 0 in order to sample attention ' 87 | 'coefficients stochastically') 88 | 89 | if num_bases is not None and num_blocks is not None: 90 | raise ValueError('Can not apply both basis-decomposition and ' 91 | 'block-diagonal-decomposition at the same time.') 92 | 93 | # The learnable parameters to compute both attention logits and 94 | # attention coefficients: 95 | self.q = Parameter( 96 | torch.Tensor(self.heads * self.out_channels, 97 | self.heads * self.dim)) 98 | self.k = Parameter( 99 | torch.Tensor(self.heads * self.out_channels, 100 | self.heads * self.dim)) 101 | 102 | if bias and concat: 103 | self.bias = Parameter( 104 | torch.Tensor(self.heads * self.dim * self.out_channels)) 105 | elif bias and not concat: 106 | self.bias = Parameter(torch.Tensor(self.dim * self.out_channels)) 107 | else: 108 | self.register_parameter('bias', None) 109 | 110 | if edge_dim is not None: 111 | self.lin_edge = Linear(self.edge_dim, 112 | self.heads * self.out_channels, bias=False, 113 | weight_initializer='glorot') 114 | self.e = Parameter( 115 | torch.Tensor(self.heads * self.out_channels, 116 | self.heads * self.dim)) 117 | else: 118 | self.lin_edge = None 119 | self.register_parameter('e', None) 120 | 121 | if num_bases is not None: 122 | self.att = Parameter( 123 | torch.Tensor(self.num_relations, self.num_bases)) 124 | self.basis = Parameter( 125 | torch.Tensor(self.num_bases, self.in_channels, 126 | self.heads * self.out_channels)) 127 | elif num_blocks is not None: 128 | assert ( 129 | self.in_channels % self.num_blocks == 0 130 | and (self.heads * self.out_channels) % self.num_blocks == 0), ( 131 | "both 'in_channels' and 'heads * out_channels' must be " 132 | "multiple of 'num_blocks' used") 133 | self.weight = Parameter( 134 | torch.Tensor(self.num_relations, self.num_blocks, 135 | self.in_channels // self.num_blocks, 136 | (self.heads * self.out_channels) // 137 | self.num_blocks)) 138 | else: 139 | self.weight = Parameter( 140 | torch.Tensor(self.num_relations, self.in_channels, 141 | self.heads * self.out_channels)) 142 | 143 | self.w = Parameter(torch.ones(self.out_channels)) 144 | self.l1 = Parameter(torch.Tensor(1, self.out_channels)) 145 | self.b1 = Parameter(torch.Tensor(1, self.out_channels)) 146 | self.l2 = Parameter(torch.Tensor(self.out_channels, self.out_channels)) 147 | self.b2 = Parameter(torch.Tensor(1, self.out_channels)) 148 | 149 | self._alpha = None 150 | 151 | self.reset_parameters() 152 | 153 | def reset_parameters(self): 154 | if self.num_bases is not None: 155 | glorot(self.basis) 156 | glorot(self.att) 157 | else: 158 | glorot(self.weight) 159 | glorot(self.q) 160 | glorot(self.k) 161 | zeros(self.bias) 162 | ones(self.l1) 163 | zeros(self.b1) 164 | torch.full(self.l2.size(), 1 / self.out_channels) 165 | zeros(self.b2) 166 | if self.lin_edge is not None: 167 | glorot(self.lin_edge) 168 | glorot(self.e) 169 | 170 | def forward(self, x: Tensor, edge_index: Adj, edge_type: OptTensor = None, 171 | edge_attr: OptTensor = None, size: Size = None, 172 | return_attention_weights=None): 173 | # propagate_type: (x: Tensor, edge_type: OptTensor, edge_attr: OptTensor) # noqa 174 | out = self.propagate(edge_index=edge_index, edge_type=edge_type, x=x, 175 | size=size, edge_attr=edge_attr) 176 | 177 | alpha = self._alpha 178 | assert alpha is not None 179 | self._alpha = None 180 | 181 | if isinstance(return_attention_weights, bool): 182 | if isinstance(edge_index, Tensor): 183 | return out, (edge_index, alpha) 184 | elif isinstance(edge_index, SparseTensor): 185 | return out, edge_index.set_value(alpha, layout='coo') 186 | else: 187 | return out 188 | 189 | 190 | def message(self, x_i: Tensor, x_j: Tensor, edge_type: Tensor, 191 | edge_attr: OptTensor, index: Tensor, ptr: OptTensor, 192 | size_i: Optional[int]) -> Tensor: 193 | 194 | if self.num_bases is not None: # Basis-decomposition ================= 195 | w = torch.matmul(self.att, self.basis.view(self.num_bases, -1)) 196 | w = w.view(self.num_relations, self.in_channels, 197 | self.heads * self.out_channels) 198 | if self.num_blocks is not None: # Block-diagonal-decomposition ======= 199 | if (x_i.dtype == torch.long and x_j.dtype == torch.long 200 | and self.num_blocks is not None): 201 | raise ValueError('Block-diagonal decomposition not supported ' 202 | 'for non-continuous input features.') 203 | w = self.weight 204 | x_i = x_i.view(-1, 1, w.size(1), w.size(2)) 205 | x_j = x_j.view(-1, 1, w.size(1), w.size(2)) 206 | w = torch.index_select(w, 0, edge_type) 207 | outi = torch.einsum('abcd,acde->ace', x_i, w) 208 | outi = outi.contiguous().view(-1, self.heads * self.out_channels) 209 | outj = torch.einsum('abcd,acde->ace', x_j, w) 210 | outj = outj.contiguous().view(-1, self.heads * self.out_channels) 211 | else: # No regularization/Basis-decomposition ======================== 212 | if self.num_bases is None: 213 | w = self.weight 214 | w = torch.index_select(w, 0, edge_type) 215 | outi = torch.bmm(x_i.unsqueeze(1), w).squeeze(-2) 216 | outj = torch.bmm(x_j.unsqueeze(1), w).squeeze(-2) 217 | 218 | qi = torch.matmul(outi, self.q) 219 | kj = torch.matmul(outj, self.k) 220 | 221 | alpha_edge, alpha = 0, torch.tensor([0]) 222 | if edge_attr is not None: 223 | if edge_attr.dim() == 1: 224 | edge_attr = edge_attr.view(-1, 1) 225 | assert self.lin_edge is not None, ( 226 | "Please set 'edge_dim = edge_attr.size(-1)' while calling the " 227 | "RGATConv layer") 228 | edge_attributes = self.lin_edge(edge_attr).view( 229 | -1, self.heads * self.out_channels) 230 | if edge_attributes.size(0) != edge_attr.size(0): 231 | edge_attributes = torch.index_select(edge_attributes, 0, 232 | edge_type) 233 | alpha_edge = torch.matmul(edge_attributes, self.e) 234 | 235 | if self.attention_mode == "additive-self-attention": 236 | if edge_attr is not None: 237 | alpha = torch.add(qi, kj) + alpha_edge 238 | else: 239 | alpha = torch.add(qi, kj) 240 | alpha = F.leaky_relu(alpha, self.negative_slope) 241 | elif self.attention_mode == "multiplicative-self-attention": 242 | if edge_attr is not None: 243 | alpha = (qi * kj) * alpha_edge 244 | else: 245 | alpha = qi * kj 246 | 247 | if self.attention_mechanism == "within-relation": 248 | across_out = torch.zeros_like(alpha) 249 | for r in range(self.num_relations): 250 | mask = edge_type == r 251 | across_out[mask] = softmax(alpha[mask], index[mask]) 252 | alpha = across_out 253 | elif self.attention_mechanism == "across-relation": 254 | alpha = softmax(alpha, index, ptr, size_i) 255 | 256 | self._alpha = alpha 257 | 258 | if self.mod == "additive": 259 | if self.attention_mode == "additive-self-attention": 260 | ones = torch.ones_like(alpha) 261 | h = (outj.view(-1, self.heads, self.out_channels) * 262 | ones.view(-1, self.heads, 1)) 263 | h = torch.mul(self.w, h) 264 | 265 | return (outj.view(-1, self.heads, self.out_channels) * 266 | alpha.view(-1, self.heads, 1) + h) 267 | elif self.attention_mode == "multiplicative-self-attention": 268 | ones = torch.ones_like(alpha) 269 | h = (outj.view(-1, self.heads, 1, self.out_channels) * 270 | ones.view(-1, self.heads, self.dim, 1)) 271 | h = torch.mul(self.w, h) 272 | 273 | return (outj.view(-1, self.heads, 1, self.out_channels) * 274 | alpha.view(-1, self.heads, self.dim, 1) + h) 275 | 276 | elif self.mod == "scaled": 277 | if self.attention_mode == "additive-self-attention": 278 | ones = alpha.new_ones(index.size()) 279 | degree = scatter_add(ones, index, 280 | dim_size=size_i)[index].unsqueeze(-1) 281 | degree = torch.matmul(degree, self.l1) + self.b1 282 | degree = self.activation(degree) 283 | degree = torch.matmul(degree, self.l2) + self.b2 284 | 285 | return torch.mul( 286 | outj.view(-1, self.heads, self.out_channels) * 287 | alpha.view(-1, self.heads, 1), 288 | degree.view(-1, 1, self.out_channels)) 289 | elif self.attention_mode == "multiplicative-self-attention": 290 | ones = alpha.new_ones(index.size()) 291 | degree = scatter_add(ones, index, 292 | dim_size=size_i)[index].unsqueeze(-1) 293 | degree = torch.matmul(degree, self.l1) + self.b1 294 | degree = self.activation(degree) 295 | degree = torch.matmul(degree, self.l2) + self.b2 296 | 297 | return torch.mul( 298 | outj.view(-1, self.heads, 1, self.out_channels) * 299 | alpha.view(-1, self.heads, self.dim, 1), 300 | degree.view(-1, 1, 1, self.out_channels)) 301 | 302 | elif self.mod == "f-additive": 303 | alpha = torch.where(alpha > 0, alpha + 1, alpha) 304 | 305 | elif self.mod == "f-scaled": 306 | ones = alpha.new_ones(index.size()) 307 | degree = scatter_add(ones, index, 308 | dim_size=size_i)[index].unsqueeze(-1) 309 | alpha = alpha * degree 310 | 311 | elif self.training and self.dropout > 0: 312 | alpha = F.dropout(alpha, p=self.dropout, training=True) 313 | 314 | else: 315 | alpha = alpha # original 316 | 317 | if self.attention_mode == "additive-self-attention": 318 | return alpha.view(-1, self.heads, 1) * outj.view( 319 | -1, self.heads, self.out_channels) 320 | else: 321 | return (alpha.view(-1, self.heads, self.dim, 1) * 322 | outj.view(-1, self.heads, 1, self.out_channels)) 323 | 324 | def update(self, aggr_out: Tensor) -> Tensor: 325 | if self.attention_mode == "additive-self-attention": 326 | if self.concat is True: 327 | aggr_out = aggr_out.view(-1, self.heads * self.out_channels) 328 | else: 329 | aggr_out = aggr_out.mean(dim=1) 330 | 331 | if self.bias is not None: 332 | aggr_out = aggr_out + self.bias 333 | 334 | return aggr_out 335 | else: 336 | if self.concat is True: 337 | aggr_out = aggr_out.view( 338 | -1, self.heads * self.dim * self.out_channels) 339 | else: 340 | aggr_out = aggr_out.mean(dim=1) 341 | aggr_out = aggr_out.view(-1, self.dim * self.out_channels) 342 | 343 | if self.bias is not None: 344 | aggr_out = aggr_out + self.bias 345 | 346 | return aggr_out 347 | 348 | def __repr__(self) -> str: 349 | return '{}({}, {}, heads={})'.format(self.__class__.__name__, 350 | self.in_channels, 351 | self.out_channels, self.heads) 352 | 353 | class RGAT(nn.Module): 354 | def __init__(self, args, num_nodes, num_edge_type, **kwargs): 355 | super().__init__() 356 | self.args = args 357 | self.num_edge_type = num_edge_type 358 | 359 | # Encoder: RGAT 360 | self.node_emb = nn.Embedding(num_nodes, args.in_dim) 361 | if num_edge_type > 20: 362 | self.conv1 = RGATConv(args.in_dim, args.hidden_dim, num_edge_type * 2, num_blocks=4) 363 | self.conv2 = RGATConv(args.hidden_dim, args.out_dim, num_edge_type * 2, num_blocks=4) 364 | else: 365 | self.conv1 = RGATConv(args.in_dim, args.hidden_dim, num_edge_type * 2) 366 | self.conv2 = RGATConv(args.hidden_dim, args.out_dim, num_edge_type * 2) 367 | self.relu = nn.ReLU() 368 | 369 | # Decoder: DistMult 370 | self.W = nn.Parameter(torch.Tensor(num_edge_type, args.out_dim)) 371 | nn.init.xavier_uniform_(self.W, gain=nn.init.calculate_gain('relu')) 372 | 373 | def forward(self, x, edge, edge_type, return_all_emb=False): 374 | x = self.node_emb(x) 375 | x1 = self.conv1(x, edge, edge_type) 376 | x = self.relu(x1) 377 | x2 = self.conv2(x, edge, edge_type) 378 | 379 | if return_all_emb: 380 | return x1, x2 381 | 382 | return x2 383 | 384 | def decode(self, z, edge_index, edge_type): 385 | h = z[edge_index[0]] 386 | t = z[edge_index[1]] 387 | r = self.W[edge_type] 388 | 389 | logits = torch.sum(h * r * t, dim=1) 390 | 391 | return logits 392 | -------------------------------------------------------------------------------- /prepare_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import pickle 4 | import torch 5 | import pandas as pd 6 | import networkx as nx 7 | from tqdm import tqdm 8 | from torch_geometric.seed import seed_everything 9 | import torch_geometric.transforms as T 10 | from torch_geometric.data import Data 11 | from torch_geometric.datasets import CitationFull, Coauthor, Flickr, RelLinkPredDataset, WordNet18, WordNet18RR 12 | from torch_geometric.utils import train_test_split_edges, k_hop_subgraph, negative_sampling, to_undirected, is_undirected, to_networkx 13 | from ogb.linkproppred import PygLinkPropPredDataset 14 | from framework.utils import * 15 | 16 | 17 | data_dir = './data' 18 | df_size = [i / 100 for i in range(10)] + [i / 10 for i in range(10)] + [i for i in range(10)] # Df_size in percentage 19 | seeds = [42, 21, 13, 87, 100] 20 | graph_datasets = ['Cora', 'PubMed', 'DBLP', 'CS', 'ogbl-citation2', 'ogbl-collab'][4:] 21 | kg_datasets = ['FB15k-237', 'WordNet18', 'WordNet18RR', 'ogbl-biokg'][-1:] 22 | os.makedirs(data_dir, exist_ok=True) 23 | 24 | 25 | num_edge_type_mapping = { 26 | 'FB15k-237': 237, 27 | 'WordNet18': 18, 28 | 'WordNet18RR': 11 29 | } 30 | 31 | def train_test_split_edges_no_neg_adj_mask(data, val_ratio: float = 0.05, test_ratio: float = 0.1, two_hop_degree=None, kg=False): 32 | '''Avoid adding neg_adj_mask''' 33 | 34 | num_nodes = data.num_nodes 35 | row, col = data.edge_index 36 | edge_attr = data.edge_attr 37 | if kg: 38 | edge_type = data.edge_type 39 | data.edge_index = data.edge_attr = data.edge_weight = data.edge_year = data.edge_type = None 40 | 41 | if not kg: 42 | # Return upper triangular portion. 43 | mask = row < col 44 | row, col = row[mask], col[mask] 45 | 46 | if edge_attr is not None: 47 | edge_attr = edge_attr[mask] 48 | 49 | n_v = int(math.floor(val_ratio * row.size(0))) 50 | n_t = int(math.floor(test_ratio * row.size(0))) 51 | 52 | if two_hop_degree is not None: # Use low degree edges for test sets 53 | low_degree_mask = two_hop_degree < 50 54 | 55 | low = low_degree_mask.nonzero().squeeze() 56 | high = (~low_degree_mask).nonzero().squeeze() 57 | 58 | low = low[torch.randperm(low.size(0))] 59 | high = high[torch.randperm(high.size(0))] 60 | 61 | perm = torch.cat([low, high]) 62 | 63 | else: 64 | perm = torch.randperm(row.size(0)) 65 | 66 | row = row[perm] 67 | col = col[perm] 68 | 69 | # Train 70 | r, c = row[n_v + n_t:], col[n_v + n_t:] 71 | 72 | if kg: 73 | 74 | # data.edge_index and data.edge_type has reverse edges and edge types for message passing 75 | pos_edge_index = torch.stack([r, c], dim=0) 76 | # rev_pos_edge_index = torch.stack([r, c], dim=0) 77 | train_edge_type = edge_type[n_v + n_t:] 78 | # train_rev_edge_type = edge_type[n_v + n_t:] + edge_type.unique().shape[0] 79 | 80 | # data.edge_index = torch.cat((torch.stack([r, c], dim=0), torch.stack([r, c], dim=0)), dim=1) 81 | # data.edge_type = torch.cat([train_edge_type, train_rev_edge_type], dim=0) 82 | 83 | data.edge_index = pos_edge_index 84 | data.edge_type = train_edge_type 85 | 86 | # data.train_pos_edge_index and data.train_edge_type only has one direction edges and edge types for decoding 87 | data.train_pos_edge_index = torch.stack([r, c], dim=0) 88 | data.train_edge_type = train_edge_type 89 | 90 | else: 91 | data.train_pos_edge_index = torch.stack([r, c], dim=0) 92 | if edge_attr is not None: 93 | # out = to_undirected(data.train_pos_edge_index, edge_attr[n_v + n_t:]) 94 | data.train_pos_edge_index, data.train_pos_edge_attr = out 95 | else: 96 | data.train_pos_edge_index = data.train_pos_edge_index 97 | # data.train_pos_edge_index = to_undirected(data.train_pos_edge_index) 98 | 99 | assert not is_undirected(data.train_pos_edge_index) 100 | 101 | 102 | # Test 103 | r, c = row[:n_t], col[:n_t] 104 | data.test_pos_edge_index = torch.stack([r, c], dim=0) 105 | 106 | if kg: 107 | data.test_edge_type = edge_type[:n_t] 108 | neg_edge_index = negative_sampling_kg( 109 | edge_index=data.test_pos_edge_index, 110 | edge_type=data.test_edge_type) 111 | else: 112 | neg_edge_index = negative_sampling( 113 | edge_index=data.test_pos_edge_index, 114 | num_nodes=data.num_nodes, 115 | num_neg_samples=data.test_pos_edge_index.shape[1]) 116 | 117 | data.test_neg_edge_index = neg_edge_index 118 | 119 | # Valid 120 | r, c = row[n_t:n_t+n_v], col[n_t:n_t+n_v] 121 | data.val_pos_edge_index = torch.stack([r, c], dim=0) 122 | 123 | if kg: 124 | data.val_edge_type = edge_type[n_t:n_t+n_v] 125 | neg_edge_index = negative_sampling_kg( 126 | edge_index=data.val_pos_edge_index, 127 | edge_type=data.val_edge_type) 128 | else: 129 | neg_edge_index = negative_sampling( 130 | edge_index=data.val_pos_edge_index, 131 | num_nodes=data.num_nodes, 132 | num_neg_samples=data.val_pos_edge_index.shape[1]) 133 | 134 | data.val_neg_edge_index = neg_edge_index 135 | 136 | return data 137 | 138 | def process_graph(): 139 | for d in graph_datasets: 140 | 141 | if d in ['Cora', 'PubMed', 'DBLP']: 142 | dataset = CitationFull(os.path.join(data_dir, d), d, transform=T.NormalizeFeatures()) 143 | elif d in ['CS', 'Physics']: 144 | dataset = Coauthor(os.path.join(data_dir, d), d, transform=T.NormalizeFeatures()) 145 | elif d in ['Flickr']: 146 | dataset = Flickr(os.path.join(data_dir, d), transform=T.NormalizeFeatures()) 147 | elif 'ogbl' in d: 148 | dataset = PygLinkPropPredDataset(root=os.path.join(data_dir, d), name=d) 149 | else: 150 | raise NotImplementedError 151 | 152 | print('Processing:', d) 153 | print(dataset) 154 | data = dataset[0] 155 | data.train_mask = data.val_mask = data.test_mask = None 156 | graph = to_networkx(data) 157 | 158 | # Get two hop degree for all nodes 159 | node_to_neighbors = {} 160 | for n in tqdm(graph.nodes(), desc='Two hop neighbors'): 161 | neighbor_1 = set(graph.neighbors(n)) 162 | neighbor_2 = sum([list(graph.neighbors(i)) for i in neighbor_1], []) 163 | neighbor_2 = set(neighbor_2) 164 | neighbor = neighbor_1 | neighbor_2 165 | 166 | node_to_neighbors[n] = neighbor 167 | 168 | two_hop_degree = [] 169 | row, col = data.edge_index 170 | mask = row < col 171 | row, col = row[mask], col[mask] 172 | for r, c in tqdm(zip(row, col), total=len(row)): 173 | neighbor_row = node_to_neighbors[r.item()] 174 | neighbor_col = node_to_neighbors[c.item()] 175 | neighbor = neighbor_row | neighbor_col 176 | 177 | num = len(neighbor) 178 | 179 | two_hop_degree.append(num) 180 | 181 | two_hop_degree = torch.tensor(two_hop_degree) 182 | 183 | for s in seeds: 184 | seed_everything(s) 185 | 186 | # D 187 | data = dataset[0] 188 | if 'ogbl' in d: 189 | data = train_test_split_edges_no_neg_adj_mask(data, test_ratio=0.05, two_hop_degree=two_hop_degree) 190 | else: 191 | data = train_test_split_edges_no_neg_adj_mask(data, test_ratio=0.05) 192 | print(s, data) 193 | 194 | with open(os.path.join(data_dir, d, f'd_{s}.pkl'), 'wb') as f: 195 | pickle.dump((dataset, data), f) 196 | 197 | # Two ways to sample Df from the training set 198 | ## 1. Df is within 2 hop local enclosing subgraph of Dtest 199 | ## 2. Df is outside of 2 hop local enclosing subgraph of Dtest 200 | 201 | # All the candidate edges (train edges) 202 | # graph = to_networkx(Data(edge_index=data.train_pos_edge_index, x=data.x)) 203 | 204 | # Get the 2 hop local enclosing subgraph for all test edges 205 | _, local_edges, _, mask = k_hop_subgraph( 206 | data.test_pos_edge_index.flatten().unique(), 207 | 2, 208 | data.train_pos_edge_index, 209 | num_nodes=dataset[0].num_nodes) 210 | distant_edges = data.train_pos_edge_index[:, ~mask] 211 | print('Number of edges. Local: ', local_edges.shape[1], 'Distant:', distant_edges.shape[1]) 212 | 213 | in_mask = mask 214 | out_mask = ~mask 215 | 216 | # df_in_mask = torch.zeros_like(mask) 217 | # df_out_mask = torch.zeros_like(mask) 218 | 219 | # df_in_all_idx = in_mask.nonzero().squeeze() 220 | # df_out_all_idx = out_mask.nonzero().squeeze() 221 | # df_in_selected_idx = df_in_all_idx[torch.randperm(df_in_all_idx.shape[0])[:df_size]] 222 | # df_out_selected_idx = df_out_all_idx[torch.randperm(df_out_all_idx.shape[0])[:df_size]] 223 | 224 | # df_in_mask[df_in_selected_idx] = True 225 | # df_out_mask[df_out_selected_idx] = True 226 | 227 | # assert (in_mask & out_mask).sum() == 0 228 | # assert (df_in_mask & df_out_mask).sum() == 0 229 | 230 | 231 | # local_edges = set() 232 | # for i in range(data.test_pos_edge_index.shape[1]): 233 | # edge = data.test_pos_edge_index[:, i].tolist() 234 | # subgraph = get_enclosing_subgraph(graph, edge) 235 | # local_edges = local_edges | set(subgraph[2]) 236 | 237 | # distant_edges = graph.edges() - local_edges 238 | 239 | # print('aaaaaaa', len(local_edges), len(distant_edges)) 240 | # local_edges = torch.tensor(sorted(list([i for i in local_edges if i[0] < i[1]]))) 241 | # distant_edges = torch.tensor(sorted(list([i for i in distant_edges if i[0] < i[1]]))) 242 | 243 | 244 | # df_in = torch.randperm(local_edges.shape[1])[:df_size] 245 | # df_out = torch.randperm(distant_edges.shape[1])[:df_size] 246 | 247 | # df_in = local_edges[:, df_in] 248 | # df_out = distant_edges[:, df_out] 249 | 250 | # df_in_mask = torch.zeros(data.train_pos_edge_index.shape[1], dtype=torch.bool) 251 | # df_out_mask = torch.zeros(data.train_pos_edge_index.shape[1], dtype=torch.bool) 252 | 253 | # for row in df_in: 254 | # i = (data.train_pos_edge_index.T == row).all(axis=1).nonzero() 255 | # df_in_mask[i] = True 256 | 257 | # for row in df_out: 258 | # i = (data.train_pos_edge_index.T == row).all(axis=1).nonzero() 259 | # df_out_mask[i] = True 260 | 261 | torch.save( 262 | {'out': out_mask, 'in': in_mask}, 263 | os.path.join(data_dir, d, f'df_{s}.pt') 264 | ) 265 | 266 | def process_kg(): 267 | for d in kg_datasets: 268 | 269 | # Create the dataset to calculate node degrees 270 | if d in ['FB15k-237']: 271 | dataset = RelLinkPredDataset(os.path.join(data_dir, d), d, transform=T.NormalizeFeatures()) 272 | data = dataset[0] 273 | data.x = torch.arange(data.num_nodes) 274 | edge_index = torch.cat([data.train_edge_index, data.valid_edge_index, data.test_edge_index], dim=1) 275 | edge_type = torch.cat([data.train_edge_type, data.valid_edge_type, data.test_edge_type]) 276 | data = Data(edge_index=edge_index, edge_type=edge_type) 277 | 278 | elif d in ['WordNet18RR']: 279 | dataset = WordNet18RR(os.path.join(data_dir, d), transform=T.NormalizeFeatures()) 280 | data = dataset[0] 281 | data.x = torch.arange(data.num_nodes) 282 | data.train_mask = data.val_mask = data.test_mask = None 283 | 284 | elif d in ['WordNet18']: 285 | dataset = WordNet18(os.path.join(data_dir, d), transform=T.NormalizeFeatures()) 286 | data = dataset[0] 287 | data.x = torch.arange(data.num_nodes) 288 | 289 | # Use original split 290 | data.train_pos_edge_index = data.edge_index[:, data.train_mask] 291 | data.train_edge_type = data.edge_type[data.train_mask] 292 | 293 | data.val_pos_edge_index = data.edge_index[:, data.val_mask] 294 | data.val_edge_type = data.edge_type[data.val_mask] 295 | data.val_neg_edge_index = negative_sampling_kg(data.val_pos_edge_index, data.val_edge_type) 296 | 297 | data.test_pos_edge_index = data.edge_index[:, data.test_mask] 298 | data.test_edge_type = data.edge_type[data.test_mask] 299 | data.test_neg_edge_index = negative_sampling_kg(data.test_pos_edge_index, data.test_edge_type) 300 | 301 | elif 'ogbl' in d: 302 | dataset = PygLinkPropPredDataset(root=os.path.join(data_dir, d), name=d) 303 | 304 | split_edge = dataset.get_edge_split() 305 | train_edge, valid_edge, test_edge = split_edge["train"], split_edge["valid"], split_edge["test"] 306 | entity_dict = dict() 307 | cur_idx = 0 308 | for key in dataset[0]['num_nodes_dict']: 309 | entity_dict[key] = (cur_idx, cur_idx + dataset[0]['num_nodes_dict'][key]) 310 | cur_idx += dataset[0]['num_nodes_dict'][key] 311 | nentity = sum(dataset[0]['num_nodes_dict'].values()) 312 | 313 | valid_head_neg = valid_edge.pop('head_neg') 314 | valid_tail_neg = valid_edge.pop('tail_neg') 315 | test_head_neg = test_edge.pop('head_neg') 316 | test_tail_neg = test_edge.pop('tail_neg') 317 | 318 | train = pd.DataFrame(train_edge) 319 | valid = pd.DataFrame(valid_edge) 320 | test = pd.DataFrame(test_edge) 321 | 322 | # Convert to global index 323 | train['head'] = [idx + entity_dict[tp][0] for idx, tp in zip(train['head'], train['head_type'])] 324 | train['tail'] = [idx + entity_dict[tp][0] for idx, tp in zip(train['tail'], train['tail_type'])] 325 | 326 | valid['head'] = [idx + entity_dict[tp][0] for idx, tp in zip(valid['head'], valid['head_type'])] 327 | valid['tail'] = [idx + entity_dict[tp][0] for idx, tp in zip(valid['tail'], valid['tail_type'])] 328 | 329 | test['head'] = [idx + entity_dict[tp][0] for idx, tp in zip(test['head'], test['head_type'])] 330 | test['tail'] = [idx + entity_dict[tp][0] for idx, tp in zip(test['tail'], test['tail_type'])] 331 | 332 | valid_pos_edge_index = torch.tensor([valid['head'], valid['tail']]) 333 | valid_edge_type = torch.tensor(valid.relation) 334 | valid_neg_edge_index = torch.stack([valid_pos_edge_index[0], valid_tail_neg[:, 0]]) 335 | 336 | test_pos_edge_index = torch.tensor([test['head'], test['tail']]) 337 | test_edge_type = torch.tensor(test.relation) 338 | test_neg_edge_index = torch.stack([test_pos_edge_index[0], test_tail_neg[:, 0]]) 339 | 340 | train_directed = train[train.head_type != train.tail_type] 341 | train_undirected = train[train.head_type == train.tail_type] 342 | train_undirected_uni = train_undirected[train_undirected['head'] < train_undirected['tail']] 343 | train_uni = pd.concat([train_directed, train_undirected_uni], ignore_index=True) 344 | 345 | train_pos_edge_index = torch.tensor([train_uni['head'], train_uni['tail']]) 346 | train_edge_type = torch.tensor(train_uni.relation) 347 | 348 | r, c = train_pos_edge_index 349 | rev_edge_index = torch.stack([c, r]) 350 | rev_edge_type = train_edge_type + 51 351 | 352 | edge_index = torch.cat([train_pos_edge_index, rev_edge_index], dim=1) 353 | edge_type = torch.cat([train_edge_type, rev_edge_type], dim=0) 354 | 355 | data = Data( 356 | x=torch.arange(nentity), edge_index=edge_index, edge_type=edge_type, 357 | train_pos_edge_index=train_pos_edge_index, train_edge_type=train_edge_type, 358 | val_pos_edge_index=valid_pos_edge_index, val_edge_type=valid_edge_type, val_neg_edge_index=valid_neg_edge_index, 359 | test_pos_edge_index=test_pos_edge_index, test_edge_type=test_edge_type, test_neg_edge_index=test_neg_edge_index) 360 | 361 | else: 362 | raise NotImplementedError 363 | 364 | print('Processing:', d) 365 | print(dataset) 366 | 367 | for s in seeds: 368 | seed_everything(s) 369 | 370 | # D 371 | # data = train_test_split_edges_no_neg_adj_mask(data, test_ratio=0.05, two_hop_degree=two_hop_degree, kg=True) 372 | print(s, data) 373 | 374 | with open(os.path.join(data_dir, d, f'd_{s}.pkl'), 'wb') as f: 375 | pickle.dump((dataset, data), f) 376 | 377 | # Two ways to sample Df from the training set 378 | ## 1. Df is within 2 hop local enclosing subgraph of Dtest 379 | ## 2. Df is outside of 2 hop local enclosing subgraph of Dtest 380 | 381 | # All the candidate edges (train edges) 382 | # graph = to_networkx(Data(edge_index=data.train_pos_edge_index, x=data.x)) 383 | 384 | # Get the 2 hop local enclosing subgraph for all test edges 385 | _, local_edges, _, mask = k_hop_subgraph( 386 | data.test_pos_edge_index.flatten().unique(), 387 | 2, 388 | data.train_pos_edge_index, 389 | num_nodes=dataset[0].num_nodes) 390 | distant_edges = data.train_pos_edge_index[:, ~mask] 391 | print('Number of edges. Local: ', local_edges.shape[1], 'Distant:', distant_edges.shape[1]) 392 | 393 | in_mask = mask 394 | out_mask = ~mask 395 | 396 | torch.save( 397 | {'out': out_mask, 'in': in_mask}, 398 | os.path.join(data_dir, d, f'df_{s}.pt') 399 | ) 400 | 401 | 402 | def main(): 403 | process_graph() 404 | # process_kg() 405 | 406 | if __name__ == "__main__": 407 | main() 408 | --------------------------------------------------------------------------------