├── images
└── fig1.png
├── framework
├── models
│ ├── __init__.py
│ ├── gat.py
│ ├── gcn.py
│ ├── gin.py
│ ├── rgcn.py
│ ├── graph_classification
│ │ ├── gcn.py
│ │ └── gcn_delete.py
│ ├── deletion.py
│ └── rgat.py
├── utils.py
├── __init__.py
├── data_loader.py
├── trainer
│ ├── gradient_ascent_with_mp.py
│ ├── descent_to_delete.py
│ ├── approx_retrain.py
│ ├── member_infer.py
│ ├── gradient_ascent.py
│ ├── gnndelete_embdis.py
│ ├── retrain.py
│ └── graph_eraser.py
├── evaluation.py
└── training_args.py
├── run_original.sh
├── run_delete.sh
├── LICENSE
├── graph_stat.py
├── train_node.py
├── .gitignore
├── train_gnn.py
├── README.md
├── delete_gnn.py
├── delete_node.py
├── delete_node_feature.py
└── prepare_dataset.py
/images/fig1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mims-harvard/GNNDelete/HEAD/images/fig1.png
--------------------------------------------------------------------------------
/framework/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .gcn import GCN
2 | from .gat import GAT
3 | from .gin import GIN
4 | from .rgcn import RGCN
5 | from .rgat import RGAT
6 | from .deletion import GCNDelete, GATDelete, GINDelete, RGCNDelete, RGATDelete
--------------------------------------------------------------------------------
/run_original.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | source activate pyg
4 |
5 | DATA=$1
6 | MODEL=$2
7 | SEED=$3
8 | UN=original
9 |
10 | export WANDB_MODE=offline
11 | export WANDB_PROJECT=zitniklab-gnn-unlearning
12 |
13 | export WANDB_NAME="$UN"_"$DATA"_"$MODEL"_"$SEED"
14 | export WANDB_RUN_NAME="$UN"_"$DATA"_"$MODEL"_"$SEED"
15 | export WANDB_RUN_ID="$UN"_"$DATA"_"$MODEL"_"$SEED"
16 |
17 | python train_gnn.py --lr 1e-3 \
18 | --epochs 1500 \
19 | --dataset "$DATA" \
20 | --random_seed "$SEED" \
21 | --unlearning_model "$UN" \
22 | --gnn "$MODEL"
--------------------------------------------------------------------------------
/run_delete.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | source activate pyg
4 |
5 | DATA=$1
6 | MODEL=$2
7 | UN=$3
8 | DF=$4
9 | DF_SIZE=$5
10 | SEED=$6
11 |
12 | export WANDB_MODE=offline
13 | export WANDB_PROJECT=zitniklab-gnn-unlearning
14 |
15 |
16 | export WANDB_NAME="$UN"_"$DATA"_"$MODEL"_"$DF"_"$DF_SIZE"_"$SEED"
17 | export WANDB_RUN_NAME="$UN"_"$DATA"_"$MODEL"_"$DF"_"$DF_SIZE"_"$SEED"
18 | export WANDB_RUN_ID="$UN"_"$DATA"_"$MODEL"_"$DF"_"$DF_SIZE"_"$SEED"
19 |
20 | python delete_gnn.py --lr 1e-3 \
21 | --epochs 1500 \
22 | --dataset "$DATA" \
23 | --random_seed "$SEED" \
24 | --unlearning_model "$UN" \
25 | --gnn "$MODEL" \
26 | --df "$DF" \
27 | --df_size "$DF_SIZE"
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Machine Learning for Medicine and Science @ Harvard
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/framework/models/gat.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torch_geometric.nn import GATConv
5 |
6 |
7 | class GAT(nn.Module):
8 | def __init__(self, args, **kwargs):
9 | super().__init__()
10 |
11 | self.conv1 = GATConv(args.in_dim, args.hidden_dim)
12 | self.conv2 = GATConv(args.hidden_dim, args.out_dim)
13 | # self.dropout = nn.Dropout(args.dropout)
14 |
15 | def forward(self, x, edge_index, return_all_emb=False):
16 | x1 = self.conv1(x, edge_index)
17 | x = F.relu(x1)
18 | # x = self.dropout(x)
19 | x2 = self.conv2(x, edge_index)
20 |
21 | if return_all_emb:
22 | return x1, x2
23 |
24 | return x2
25 |
26 | def decode(self, z, pos_edge_index, neg_edge_index=None):
27 | if neg_edge_index is not None:
28 | edge_index = torch.cat([pos_edge_index, neg_edge_index], dim=-1)
29 | logits = (z[edge_index[0]] * z[edge_index[1]]).sum(dim=-1)
30 |
31 | else:
32 | edge_index = pos_edge_index
33 | logits = (z[edge_index[0]] * z[edge_index[1]]).sum(dim=-1)
34 |
35 | return logits
36 |
--------------------------------------------------------------------------------
/framework/models/gcn.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torch_geometric.nn import GCNConv
5 |
6 |
7 | class GCN(nn.Module):
8 | def __init__(self, args, **kwargs):
9 | super().__init__()
10 |
11 | self.conv1 = GCNConv(args.in_dim, args.hidden_dim)
12 | self.conv2 = GCNConv(args.hidden_dim, args.out_dim)
13 | # self.dropout = nn.Dropout(args.dropout)
14 |
15 | def forward(self, x, edge_index, return_all_emb=False):
16 | x1 = self.conv1(x, edge_index)
17 | x = F.relu(x1)
18 | # x = self.dropout(x)
19 | x2 = self.conv2(x, edge_index)
20 |
21 | if return_all_emb:
22 | return x1, x2
23 |
24 | return x2
25 |
26 | def decode(self, z, pos_edge_index, neg_edge_index=None):
27 | if neg_edge_index is not None:
28 | edge_index = torch.cat([pos_edge_index, neg_edge_index], dim=-1)
29 | logits = (z[edge_index[0]] * z[edge_index[1]]).sum(dim=-1)
30 |
31 | else:
32 | edge_index = pos_edge_index
33 | logits = (z[edge_index[0]] * z[edge_index[1]]).sum(dim=-1)
34 |
35 | return logits
36 |
--------------------------------------------------------------------------------
/graph_stat.py:
--------------------------------------------------------------------------------
1 | import os
2 | from torch_geometric.data import Data
3 | import torch_geometric.transforms as T
4 | from torch_geometric.datasets import CitationFull, Coauthor, Flickr, RelLinkPredDataset, WordNet18RR
5 | from ogb.linkproppred import PygLinkPropPredDataset
6 |
7 |
8 | data_dir = './data'
9 | datasets = ['Cora', 'PubMed', 'DBLP', 'CS', 'Physics', 'ogbl-citation2', 'ogbl-collab', 'FB15k-237', 'WordNet18RR', 'ogbl-biokg', 'ogbl-wikikg2'][-2:]
10 |
11 | def get_stat(d):
12 | if d in ['Cora', 'PubMed', 'DBLP']:
13 | dataset = CitationFull(os.path.join(data_dir, d), d, transform=T.NormalizeFeatures())
14 | if d in ['CS', 'Physics']:
15 | dataset = Coauthor(os.path.join(data_dir, d), d, transform=T.NormalizeFeatures())
16 | if d in ['Flickr']:
17 | dataset = Flickr(os.path.join(data_dir, d), transform=T.NormalizeFeatures())
18 | if 'ogbl' in d:
19 | dataset = PygLinkPropPredDataset(root=os.path.join(data_dir, d), name=d)
20 |
21 | data = dataset[0]
22 | print(d)
23 | print('Number of nodes:', data.num_nodes)
24 | print('Number of edges:', data.num_edges)
25 | print('Number of max deleted edges:', int(0.05 * data.num_edges))
26 | if hasattr(data, 'edge_type'):
27 | print('Number of nodes:', data.edge_type.unique().shape)
28 |
29 | def main():
30 | for d in datasets:
31 | get_stat(d)
32 |
33 | if __name__ == "__main__":
34 | main()
35 |
--------------------------------------------------------------------------------
/framework/models/gin.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torch_geometric.nn import GINConv
5 |
6 |
7 | class GIN(nn.Module):
8 | def __init__(self, args, **kwargs):
9 | super().__init__()
10 |
11 | self.conv1 = GINConv(nn.Linear(args.in_dim, args.hidden_dim))
12 | self.conv2= GINConv(nn.Linear(args.hidden_dim, args.out_dim))
13 | # self.transition = nn.Sequential(
14 | # nn.ReLU(),
15 | # # nn.Dropout(p=args.dropout)
16 | # )
17 | # self.mlp1 = nn.Sequential(
18 | # nn.Linear(args.in_dim, args.hidden_dim),
19 | # nn.ReLU(),
20 | # )
21 | # self.mlp2 = nn.Sequential(
22 | # nn.Linear(args.hidden_dim, args.out_dim),
23 | # nn.ReLU(),
24 | # )
25 |
26 | def forward(self, x, edge_index, return_all_emb=False):
27 | x1 = self.conv1(x, edge_index)
28 | x = F.relu(x1)
29 | x2 = self.conv2(x, edge_index)
30 |
31 | if return_all_emb:
32 | return x1, x2
33 |
34 | return x2
35 |
36 | def decode(self, z, pos_edge_index, neg_edge_index=None):
37 | if neg_edge_index is not None:
38 | edge_index = torch.cat([pos_edge_index, neg_edge_index], dim=-1)
39 | logits = (z[edge_index[0]] * z[edge_index[1]]).sum(dim=-1)
40 |
41 | else:
42 | edge_index = pos_edge_index
43 | logits = (z[edge_index[0]] * z[edge_index[1]]).sum(dim=-1)
44 |
45 | return logits
46 |
--------------------------------------------------------------------------------
/framework/models/rgcn.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from torch_geometric.nn import RGCNConv, FastRGCNConv
6 | from sklearn.metrics import roc_auc_score, average_precision_score
7 |
8 |
9 | class RGCN(nn.Module):
10 | def __init__(self, args, num_nodes, num_edge_type, **kwargs):
11 | super().__init__()
12 | self.args = args
13 | self.num_edge_type = num_edge_type
14 |
15 | # Encoder: RGCN
16 | self.node_emb = nn.Embedding(num_nodes, args.in_dim)
17 | if num_edge_type > 20:
18 | self.conv1 = RGCNConv(args.in_dim, args.hidden_dim, num_edge_type * 2, num_blocks=4)
19 | self.conv2 = RGCNConv(args.hidden_dim, args.out_dim, num_edge_type * 2, num_blocks=4)
20 | else:
21 | self.conv1 = RGCNConv(args.in_dim, args.hidden_dim, num_edge_type * 2)
22 | self.conv2 = RGCNConv(args.hidden_dim, args.out_dim, num_edge_type * 2)
23 | self.relu = nn.ReLU()
24 |
25 | # Decoder: DistMult
26 | self.W = nn.Parameter(torch.Tensor(num_edge_type, args.out_dim))
27 | nn.init.xavier_uniform_(self.W, gain=nn.init.calculate_gain('relu'))
28 |
29 | def forward(self, x, edge, edge_type, return_all_emb=False):
30 | x = self.node_emb(x)
31 | x1 = self.conv1(x, edge, edge_type)
32 | x = self.relu(x1)
33 | x2 = self.conv2(x, edge, edge_type)
34 |
35 | if return_all_emb:
36 | return x1, x2
37 |
38 | return x2
39 |
40 | def decode(self, z, edge_index, edge_type):
41 | h = z[edge_index[0]]
42 | t = z[edge_index[1]]
43 | r = self.W[edge_type]
44 |
45 | logits = torch.sum(h * r * t, dim=1)
46 |
47 | return logits
48 |
49 | class RGCNDelete(RGCN):
50 | def __init__(self):
51 | pass
52 |
--------------------------------------------------------------------------------
/framework/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import networkx as nx
4 |
5 |
6 | def get_node_edge(graph):
7 | degree_sorted_ascend = sorted(graph.degree, key=lambda x: x[1])
8 |
9 | return degree_sorted_ascend[-1][0]
10 |
11 | def h_hop_neighbor(G, node, h):
12 | path_lengths = nx.single_source_dijkstra_path_length(G, node)
13 | return [node for node, length in path_lengths.items() if length == h]
14 |
15 | def get_enclosing_subgraph(graph, edge_to_delete):
16 | subgraph = {0: [edge_to_delete]}
17 | s, t = edge_to_delete
18 |
19 | neighbor_s = []
20 | neighbor_t = []
21 | for h in range(1, 2+1):
22 | neighbor_s += h_hop_neighbor(graph, s, h)
23 | neighbor_t += h_hop_neighbor(graph, t, h)
24 |
25 | nodes = neighbor_s + neighbor_t + [s, t]
26 |
27 | subgraph[h] = list(graph.subgraph(nodes).edges())
28 |
29 | return subgraph
30 |
31 | @torch.no_grad()
32 | def get_link_labels(pos_edge_index, neg_edge_index):
33 | E = pos_edge_index.size(1) + neg_edge_index.size(1)
34 | link_labels = torch.zeros(E, dtype=torch.float, device=pos_edge_index.device)
35 | link_labels[:pos_edge_index.size(1)] = 1.
36 | return link_labels
37 |
38 | @torch.no_grad()
39 | def get_link_labels_kg(pos_edge_index, neg_edge_index):
40 | E = pos_edge_index.size(1) + neg_edge_index.size(1)
41 | link_labels = torch.zeros(E, dtype=torch.float, device=pos_edge_index.device)
42 | link_labels[:pos_edge_index.size(1)] = 1.
43 |
44 | return link_labels
45 |
46 | @torch.no_grad()
47 | def negative_sampling_kg(edge_index, edge_type):
48 | '''Generate negative samples but keep the node type the same'''
49 |
50 | edge_index_copy = edge_index.clone()
51 | for et in edge_type.unique():
52 | mask = (edge_type == et)
53 | old_source = edge_index_copy[0, mask]
54 | new_index = torch.randperm(old_source.shape[0])
55 | new_source = old_source[new_index]
56 | edge_index_copy[0, mask] = new_source
57 |
58 | return edge_index_copy
59 |
--------------------------------------------------------------------------------
/train_node.py:
--------------------------------------------------------------------------------
1 | import os
2 | import wandb
3 | import pickle
4 | import torch
5 | from torch_geometric.seed import seed_everything
6 | from torch_geometric.utils import to_undirected, is_undirected
7 | import torch_geometric.transforms as T
8 | from torch_geometric.datasets import CitationFull, Coauthor, Flickr, RelLinkPredDataset, WordNet18, WordNet18RR
9 | from torch_geometric.seed import seed_everything
10 |
11 | from framework import get_model, get_trainer
12 | from framework.training_args import parse_args
13 | from framework.trainer.base import NodeClassificationTrainer
14 | from framework.utils import negative_sampling_kg
15 |
16 |
17 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
18 |
19 | def main():
20 | args = parse_args()
21 | args.checkpoint_dir = 'checkpoint_node'
22 | args.dataset = 'DBLP'
23 | args.unlearning_model = 'original'
24 | args.checkpoint_dir = os.path.join(args.checkpoint_dir, args.dataset, args.gnn, args.unlearning_model, str(args.random_seed))
25 | os.makedirs(args.checkpoint_dir, exist_ok=True)
26 | seed_everything(args.random_seed)
27 |
28 | # Dataset
29 | dataset = CitationFull(os.path.join(args.data_dir, args.dataset), args.dataset, transform=T.NormalizeFeatures())
30 | data = dataset[0]
31 | print('Original data', data)
32 |
33 | split = T.RandomNodeSplit()
34 | data = split(data)
35 | assert is_undirected(data.edge_index)
36 |
37 | print('Split data', data)
38 | args.in_dim = data.x.shape[1]
39 | args.out_dim = dataset.num_classes
40 |
41 | wandb.init(config=args)
42 |
43 | # Model
44 | model = get_model(args, num_nodes=data.num_nodes, num_edge_type=args.num_edge_type).to(device)
45 | wandb.watch(model, log_freq=100)
46 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)#, weight_decay=args.weight_decay)
47 |
48 | # Train
49 | trainer = NodeClassificationTrainer(args)
50 | trainer.train(model, data, optimizer, args)
51 |
52 | # Test
53 | trainer.test(model, data)
54 | trainer.save_log()
55 |
56 |
57 | if __name__ == "__main__":
58 | main()
59 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | experiment/
2 | data/
3 | checkpoint*/
4 | data
5 | checkpoint*
6 | wandb
7 | wandb/
8 | results/
9 | *csv
10 |
11 |
12 | # Byte-compiled / optimized / DLL files
13 | __pycache__/
14 | *.py[cod]
15 | *$py.class
16 |
17 | # C extensions
18 | *.so
19 |
20 | # Distribution / packaging
21 | .Python
22 | build/
23 | develop-eggs/
24 | dist/
25 | downloads/
26 | eggs/
27 | .eggs/
28 | lib/
29 | lib64/
30 | parts/
31 | sdist/
32 | var/
33 | wheels/
34 | pip-wheel-metadata/
35 | share/python-wheels/
36 | *.egg-info/
37 | .installed.cfg
38 | *.egg
39 | MANIFEST
40 |
41 | # PyInstaller
42 | # Usually these files are written by a python script from a template
43 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
44 | *.manifest
45 | *.spec
46 |
47 | # Installer logs
48 | pip-log.txt
49 | pip-delete-this-directory.txt
50 |
51 | # Unit test / coverage reports
52 | htmlcov/
53 | .tox/
54 | .nox/
55 | .coverage
56 | .coverage.*
57 | .cache
58 | nosetests.xml
59 | coverage.xml
60 | *.cover
61 | *.py,cover
62 | .hypothesis/
63 | .pytest_cache/
64 |
65 | # Translations
66 | *.mo
67 | *.pot
68 |
69 | # Django stuff:
70 | *.log
71 | local_settings.py
72 | db.sqlite3
73 | db.sqlite3-journal
74 |
75 | # Flask stuff:
76 | instance/
77 | .webassets-cache
78 |
79 | # Scrapy stuff:
80 | .scrapy
81 |
82 | # Sphinx documentation
83 | docs/_build/
84 |
85 | # PyBuilder
86 | target/
87 |
88 | # Jupyter Notebook
89 | .ipynb_checkpoints
90 |
91 | # IPython
92 | profile_default/
93 | ipython_config.py
94 |
95 | # pyenv
96 | .python-version
97 |
98 | # pipenv
99 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
100 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
101 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
102 | # install all needed dependencies.
103 | #Pipfile.lock
104 |
105 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
106 | __pypackages__/
107 |
108 | # Celery stuff
109 | celerybeat-schedule
110 | celerybeat.pid
111 |
112 | # SageMath parsed files
113 | *.sage.py
114 |
115 | # Environments
116 | .env
117 | .venv
118 | env/
119 | venv/
120 | ENV/
121 | env.bak/
122 | venv.bak/
123 |
124 | # Spyder project settings
125 | .spyderproject
126 | .spyproject
127 |
128 | # Rope project settings
129 | .ropeproject
130 |
131 | # mkdocs documentation
132 | /site
133 |
134 | # mypy
135 | .mypy_cache/
136 | .dmypy.json
137 | dmypy.json
138 |
139 | # Pyre type checker
140 | .pyre/
141 |
--------------------------------------------------------------------------------
/framework/__init__.py:
--------------------------------------------------------------------------------
1 | from .models import GCN, GAT, GIN, RGCN, RGAT, GCNDelete, GATDelete, GINDelete, RGCNDelete, RGATDelete
2 | from .trainer.base import Trainer, KGTrainer, NodeClassificationTrainer
3 | from .trainer.retrain import RetrainTrainer, KGRetrainTrainer
4 | from .trainer.gnndelete import GNNDeleteTrainer
5 | from .trainer.gnndelete_nodeemb import GNNDeleteNodeembTrainer, KGGNNDeleteNodeembTrainer
6 | from .trainer.gradient_ascent import GradientAscentTrainer, KGGradientAscentTrainer
7 | from .trainer.descent_to_delete import DtdTrainer
8 | from .trainer.approx_retrain import ApproxTrainer
9 | from .trainer.graph_eraser import GraphEraserTrainer
10 | from .trainer.graph_editor import GraphEditorTrainer
11 | from .trainer.member_infer import MIAttackTrainer, MIAttackTrainerNode
12 |
13 |
14 | trainer_mapping = {
15 | 'original': Trainer,
16 | 'original_node': NodeClassificationTrainer,
17 | 'retrain': RetrainTrainer,
18 | 'gnndelete': GNNDeleteTrainer,
19 | 'gradient_ascent': GradientAscentTrainer,
20 | 'descent_to_delete': DtdTrainer,
21 | 'approx_retrain': ApproxTrainer,
22 | 'gnndelete_mse': GNNDeleteTrainer,
23 | 'gnndelete_kld': GNNDeleteTrainer,
24 | 'gnndelete_nodeemb': GNNDeleteNodeembTrainer,
25 | 'gnndelete_cosine': GNNDeleteTrainer,
26 | 'graph_eraser': GraphEraserTrainer,
27 | 'graph_editor': GraphEditorTrainer,
28 | 'member_infer_all': MIAttackTrainer,
29 | 'member_infer_sub': MIAttackTrainer,
30 | 'member_infer_all_node': MIAttackTrainerNode,
31 | 'member_infer_sub_node': MIAttackTrainerNode,
32 | }
33 |
34 | kg_trainer_mapping = {
35 | 'original': KGTrainer,
36 | 'retrain': KGRetrainTrainer,
37 | 'gnndelete': KGGNNDeleteNodeembTrainer,
38 | 'gradient_ascent': KGGradientAscentTrainer,
39 | 'descent_to_delete': DtdTrainer,
40 | 'approx_retrain': ApproxTrainer,
41 | 'gnndelete_mse': GNNDeleteTrainer,
42 | 'gnndelete_kld': GNNDeleteTrainer,
43 | 'gnndelete_cosine': GNNDeleteTrainer,
44 | 'gnndelete_nodeemb': KGGNNDeleteNodeembTrainer,
45 | 'graph_eraser': GraphEraserTrainer,
46 | 'member_infer_all': MIAttackTrainer,
47 | 'member_infer_sub': MIAttackTrainer,
48 | }
49 |
50 |
51 | def get_model(args, mask_1hop=None, mask_2hop=None, num_nodes=None, num_edge_type=None):
52 |
53 | if 'gnndelete' in args.unlearning_model:
54 | model_mapping = {'gcn': GCNDelete, 'gat': GATDelete, 'gin': GINDelete, 'rgcn': RGCNDelete, 'rgat': RGATDelete}
55 |
56 | else:
57 | model_mapping = {'gcn': GCN, 'gat': GAT, 'gin': GIN, 'rgcn': RGCN, 'rgat': RGAT}
58 |
59 | return model_mapping[args.gnn](args, mask_1hop=mask_1hop, mask_2hop=mask_2hop, num_nodes=num_nodes, num_edge_type=num_edge_type)
60 |
61 |
62 | def get_trainer(args):
63 | if args.gnn in ['rgcn', 'rgat']:
64 | return kg_trainer_mapping[args.unlearning_model](args)
65 |
66 | else:
67 | return trainer_mapping[args.unlearning_model](args)
68 |
--------------------------------------------------------------------------------
/train_gnn.py:
--------------------------------------------------------------------------------
1 | import os
2 | import wandb
3 | import pickle
4 | import torch
5 | from torch_geometric.seed import seed_everything
6 | from torch_geometric.utils import to_undirected, is_undirected
7 | from torch_geometric.datasets import RelLinkPredDataset, WordNet18
8 | from torch_geometric.seed import seed_everything
9 |
10 | from framework import get_model, get_trainer
11 | from framework.training_args import parse_args
12 | from framework.trainer.base import Trainer
13 | from framework.utils import negative_sampling_kg
14 |
15 |
16 |
17 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
18 |
19 |
20 | def main():
21 | args = parse_args()
22 | args.unlearning_model = 'original'
23 | args.checkpoint_dir = os.path.join(args.checkpoint_dir, args.dataset, args.gnn, args.unlearning_model, str(args.random_seed))
24 | os.makedirs(args.checkpoint_dir, exist_ok=True)
25 | seed_everything(args.random_seed)
26 |
27 | # Dataset
28 | with open(os.path.join(args.data_dir, args.dataset, f'd_{args.random_seed}.pkl'), 'rb') as f:
29 | dataset, data = pickle.load(f)
30 | print('Directed dataset:', dataset, data)
31 | if args.gnn not in ['rgcn', 'rgat']:
32 | args.in_dim = dataset.num_features
33 |
34 | wandb.init(config=args)
35 |
36 | # Use proper training data for original and Dr
37 | if args.gnn in ['rgcn', 'rgat']:
38 | if not hasattr(data, 'train_mask'):
39 | data.train_mask = torch.ones(data.edge_index.shape[1], dtype=torch.bool)
40 |
41 | # data.dtrain_mask = torch.ones(data.train_pos_edge_index.shape[1], dtype=torch.bool)
42 | # data.edge_index_mask = data.dtrain_mask.repeat(2)
43 |
44 | else:
45 | data.dtrain_mask = torch.ones(data.train_pos_edge_index.shape[1], dtype=torch.bool)
46 |
47 | # To undirected
48 | if args.gnn in ['rgcn', 'rgat']:
49 | r, c = data.train_pos_edge_index
50 | rev_edge_index = torch.stack([c, r], dim=0)
51 | rev_edge_type = data.train_edge_type + args.num_edge_type
52 |
53 | data.edge_index = torch.cat((data.train_pos_edge_index, rev_edge_index), dim=1)
54 | data.edge_type = torch.cat([data.train_edge_type, rev_edge_type], dim=0)
55 | # data.train_mask = data.train_mask.repeat(2)
56 |
57 | data.dr_mask = torch.ones(data.edge_index.shape[1], dtype=torch.bool)
58 | assert is_undirected(data.edge_index)
59 |
60 | else:
61 | train_pos_edge_index = to_undirected(data.train_pos_edge_index)
62 | data.train_pos_edge_index = train_pos_edge_index
63 | data.dtrain_mask = torch.ones(data.train_pos_edge_index.shape[1], dtype=torch.bool)
64 | assert is_undirected(data.train_pos_edge_index)
65 |
66 |
67 | print('Undirected dataset:', data)
68 |
69 | # Model
70 | model = get_model(args, num_nodes=data.num_nodes, num_edge_type=args.num_edge_type).to(device)
71 | wandb.watch(model, log_freq=100)
72 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)#, weight_decay=args.weight_decay)
73 |
74 | # Train
75 | trainer = get_trainer(args)
76 | trainer.train(model, data, optimizer, args)
77 |
78 | # Test
79 | trainer.test(model, data)
80 | trainer.save_log()
81 |
82 |
83 | if __name__ == "__main__":
84 | main()
85 |
--------------------------------------------------------------------------------
/framework/data_loader.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | from torch_geometric.data import Data, GraphSAINTRandomWalkSampler
4 |
5 |
6 | def load_dict(filename):
7 | '''Load entity and relation to id mapping'''
8 |
9 | mapping = {}
10 | with open(filename, 'r') as f:
11 | for l in f:
12 | l = l.strip().split('\t')
13 | mapping[l[0]] = l[1]
14 |
15 | return mapping
16 |
17 | def load_edges(filename):
18 | with open(filename, 'r') as f:
19 | r = f.readlines()
20 | r = [i.strip().split('\t') for i in r]
21 |
22 | return r
23 |
24 | def generate_true_dict(all_triples):
25 | heads = {(r, t) : [] for _, r, t in all_triples}
26 | tails = {(h, r) : [] for h, r, _ in all_triples}
27 |
28 | for h, r, t in all_triples:
29 | heads[r, t].append(h)
30 | tails[h, r].append(t)
31 |
32 | return heads, tails
33 |
34 | def get_loader(args, delete=[]):
35 | prefix = os.path.join('./data', args.dataset)
36 |
37 | # Edges
38 | train = load_edges(os.path.join(prefix, 'train.txt'))
39 | valid = load_edges(os.path.join(prefix, 'valid.txt'))
40 | test = load_edges(os.path.join(prefix, 'test.txt'))
41 | train = [(int(i[0]), int(i[1]), int(i[2])) for i in train]
42 | valid = [(int(i[0]), int(i[1]), int(i[2])) for i in valid]
43 | test = [(int(i[0]), int(i[1]), int(i[2])) for i in test]
44 | train_rev = [(int(i[2]), int(i[1]), int(i[0])) for i in train]
45 | valid_rev = [(int(i[2]), int(i[1]), int(i[0])) for i in valid]
46 | test_rev = [(int(i[2]), int(i[1]), int(i[0])) for i in test]
47 | train = train + train_rev
48 | valid = valid + valid_rev
49 | test = test + test_rev
50 | all_edge = train + valid + test
51 |
52 | true_triples = generate_true_dict(all_edge)
53 |
54 | edge = torch.tensor([(int(i[0]), int(i[2])) for i in all_edge], dtype=torch.long).t()
55 | edge_type = torch.tensor([int(i[1]) for i in all_edge], dtype=torch.long)#.view(-1, 1)
56 |
57 | # Masks
58 | train_size = len(train)
59 | valid_size = len(valid)
60 | test_size = len(test)
61 | total_size = train_size + valid_size + test_size
62 |
63 | train_mask = torch.zeros((total_size,)).bool()
64 | train_mask[:train_size] = True
65 |
66 | valid_mask = torch.zeros((total_size,)).bool()
67 | valid_mask[train_size:train_size + valid_size] = True
68 |
69 | test_mask = torch.zeros((total_size,)).bool()
70 | test_mask[-test_size:] = True
71 |
72 | # Graph size
73 | num_nodes = edge.flatten().unique().shape[0]
74 | num_edges = edge.shape[1]
75 | num_edge_type = edge_type.unique().shape[0]
76 |
77 | # Node feature
78 | x = torch.rand((num_nodes, args.in_dim))
79 |
80 | # Delete edges
81 | if len(delete) > 0:
82 | delete_idx = torch.tensor(delete, dtype=torch.long)
83 | num_train_edges = train_size // 2
84 | train_mask[delete_idx] = False
85 | train_mask[delete_idx + num_train_edges] = False
86 | train_size -= 2 * len(delete)
87 |
88 | node_id = torch.arange(num_nodes)
89 | dataset = Data(
90 | edge_index=edge, edge_type=edge_type, x=x, node_id=node_id,
91 | train_mask=train_mask, valid_mask=valid_mask, test_mask=test_mask)
92 |
93 | dataloader = GraphSAINTRandomWalkSampler(
94 | dataset, batch_size=args.batch_size, walk_length=args.walk_length, num_steps=args.num_steps)
95 |
96 | print(f'Dataset: {args.dataset}, Num nodes: {num_nodes}, Num edges: {num_edges//2}, Num relation types: {num_edge_type}')
97 | print(f'Train edges: {train_size//2}, Valid edges: {valid_size//2}, Test edges: {test_size//2}')
98 |
99 | return dataloader, valid, test, true_triples, num_nodes, num_edges, num_edge_type
100 |
--------------------------------------------------------------------------------
/framework/trainer/gradient_ascent_with_mp.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | from tqdm import tqdm, trange
4 | import torch
5 | import torch.nn.functional as F
6 | from torch_geometric.utils import negative_sampling
7 |
8 | from .base import Trainer
9 | from ..evaluation import *
10 | from ..utils import *
11 |
12 |
13 | class GradientAscentWithMessagePassingTrainer(Trainer):
14 | def __init__(self,):
15 | self.trainer_log = {'unlearning_model': 'gradient_ascent_with_mp', 'log': []}
16 |
17 | def freeze_unused_mask(self, model, edge_to_delete, subgraph, h):
18 | gradient_mask = torch.zeros_like(delete_model.operator)
19 |
20 | edges = subgraph[h]
21 | for s, t in edges:
22 | if s < t:
23 | gradient_mask[s, t] = 1
24 | gradient_mask = gradient_mask.to(device)
25 | model.operator.register_hook(lambda grad: grad.mul_(gradient_mask))
26 |
27 | def train(self, model_retrain, model, data, optimizer, args):
28 | best_loss = 100000
29 | for epoch in trange(args.epochs, desc='Unlerning'):
30 | model.train()
31 | total_step = 0
32 | total_loss = 0
33 |
34 | ## Gradient Ascent
35 | neg_edge_index = negative_sampling(
36 | edge_index=data.train_pos_edge_index[:, data.ga_mask],
37 | num_nodes=data.num_nodes,
38 | num_neg_samples=data.ga_mask.sum())
39 |
40 | # print('data train to unlearn', data.train_pos_edge_index[:, data.ga_mask])
41 | z = model(data.x, data.train_pos_edge_index[:, data.ga_mask])
42 | logits = model.decode(z, data.train_pos_edge_index[:, data.ga_mask])
43 | label = torch.tensor([1], dtype=torch.float, device='cuda')
44 | loss_ga = -F.binary_cross_entropy_with_logits(logits, label)
45 |
46 | ## Message Passing
47 | neg_edge_index = negative_sampling(
48 | edge_index=data.train_pos_edge_index[:, data.mp_mask],
49 | num_nodes=data.num_nodes,
50 | num_neg_samples=data.mp_mask.sum())
51 |
52 | z = model(data.x, data.train_pos_edge_index[:, data.mp_mask])
53 | logits = model.decode(z, data.train_pos_edge_index[:, data.mp_mask])
54 | label = self.get_link_labels(data.train_pos_edge_index[:, data.mp_mask], dtype=torch.float, device='cuda')
55 | loss_mp = F.binary_cross_entropy_with_logits(logits, label)
56 |
57 | loss = loss_ga + loss_mp
58 | loss.backward()
59 | # torch.nn.utils.clip_grad_norm_(model.parameters(), 1)
60 | optimizer.step()
61 | optimizer.zero_grad()
62 |
63 | total_step += 1
64 | total_loss += loss.item()
65 |
66 | msg = [
67 | f'Epoch: {epoch:>4d}',
68 | f'train loss: {total_loss / total_step:.6f}'
69 | ]
70 | tqdm.write(' | '.join(msg))
71 |
72 | valid_loss, auc, aup = self.eval(model, data, 'val')
73 |
74 | self.trainer_log['log'].append({
75 | 'dt_loss': valid_loss,
76 | 'dt_auc': auc,
77 | 'dt_aup': aup
78 | })
79 |
80 | # Eval unlearn
81 | loss, auc, aup = self.test(model, data)
82 | self.trainer_log['dt_loss'] = loss
83 | self.trainer_log['dt_auc'] = auc
84 | self.trainer_log['dt_aup'] = aup
85 |
86 | self.trainer_log['ve'] = verification_error(model, model_retrain).cpu().item()
87 | self.trainer_log['dr_kld'] = output_kldiv(model, model_retrain, data=data).cpu().item()
88 |
89 | embedding = get_node_embedding_data(model, data)
90 | logits = model.decode(embedding, data.train_pos_edge_index[:, data.dtrain_mask]).sigmoid().detach().cpu()
91 | self.trainer_log['df_score'] = logits[:1].cpu().item()
92 |
93 |
94 | # Save
95 | ckpt = {
96 | 'model_state': model.state_dict(),
97 | 'node_emb': z,
98 | 'optimizer_state': optimizer.state_dict(),
99 | }
100 | torch.save(ckpt, os.path.join(args.checkpoint_dir, 'model.pt'))
101 | print(self.trainer_log)
102 | with open(os.path.join(args.checkpoint_dir, 'trainer_log.json'), 'w') as f:
103 | json.dump(self.trainer_log, f)
104 |
--------------------------------------------------------------------------------
/framework/trainer/descent_to_delete.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | import wandb
4 | from tqdm import tqdm, trange
5 | import torch
6 | import torch.nn.functional as F
7 | from torch_geometric.utils import negative_sampling
8 |
9 | from .base import Trainer
10 | from ..evaluation import *
11 | from ..utils import *
12 |
13 |
14 | class DtdTrainer(Trainer):
15 | '''This code is adapte from https://github.com/ChrisWaites/descent-to-delete'''
16 |
17 | def compute_sigma(self, num_examples, iterations, lipshitz, smooth, strong, epsilon, delta):
18 | """Theorem 3.1 https://arxiv.org/pdf/2007.02923.pdf"""
19 |
20 | print('delta', delta)
21 | gamma = (smooth - strong) / (smooth + strong)
22 | numerator = 4 * np.sqrt(2) * lipshitz * np.power(gamma, iterations)
23 | denominator = (strong * num_examples * (1 - np.power(gamma, iterations))) * ((np.sqrt(np.log(1 / delta) + epsilon)) - np.sqrt(np.log(1 / delta)))
24 | # print('sigma', numerator, denominator, numerator / denominator)
25 |
26 | return numerator / denominator
27 |
28 | def publish(self, model, sigma):
29 | """Publishing function which adds Gaussian noise with scale sigma."""
30 |
31 | with torch.no_grad():
32 | for n, p in model.named_parameters():
33 | p.copy_(p + torch.empty_like(p).normal_(0, sigma))
34 |
35 | def train(self, model, data, optimizer, args, logits_ori=None, attack_model_all=None, attack_model_sub=None):
36 | start_time = time.time()
37 | best_valid_loss = 100000
38 |
39 | # MI Attack before unlearning
40 | if attack_model_all is not None:
41 | mi_logit_all_before, mi_sucrate_all_before = member_infer_attack(model, attack_model_all, data)
42 | self.trainer_log['mi_logit_all_before'] = mi_logit_all_before
43 | self.trainer_log['mi_sucrate_all_before'] = mi_sucrate_all_before
44 | if attack_model_sub is not None:
45 | mi_logit_sub_before, mi_sucrate_sub_before = member_infer_attack(model, attack_model_sub, data)
46 | self.trainer_log['mi_logit_sub_before'] = mi_logit_sub_before
47 | self.trainer_log['mi_sucrate_sub_before'] = mi_sucrate_sub_before
48 |
49 | for epoch in trange(args.epochs, desc='Unlerning'):
50 | model.train()
51 |
52 | # Positive and negative sample
53 | neg_edge_index = negative_sampling(
54 | edge_index=data.train_pos_edge_index[:, data.dr_mask],
55 | num_nodes=data.num_nodes,
56 | num_neg_samples=data.dr_mask.sum())
57 |
58 | z = model(data.x, data.train_pos_edge_index[:, data.dr_mask])
59 | logits = model.decode(z, data.train_pos_edge_index[:, data.dr_mask], neg_edge_index)
60 | label = get_link_labels(data.train_pos_edge_index[:, data.dr_mask], neg_edge_index)
61 | loss = F.binary_cross_entropy_with_logits(logits, label)
62 |
63 | loss.backward()
64 | torch.nn.utils.clip_grad_norm_(model.parameters(), 1)
65 | optimizer.step()
66 | optimizer.zero_grad()
67 |
68 | log = {
69 | 'Epoch': epoch,
70 | 'train_loss': loss.item(),
71 | }
72 | wandb.log(log)
73 | msg = [f'{i}: {j:>4d}' if isinstance(j, int) else f'{i}: {j:.4f}' for i, j in log.items()]
74 | tqdm.write(' | '.join(msg))
75 |
76 | valid_loss, auc, aup, df_logt, logit_all_pair = self.eval(model, data, 'val')
77 |
78 | self.trainer_log['log'].append({
79 | 'dt_loss': valid_loss,
80 | 'dt_auc': auc,
81 | 'dt_aup': aup
82 | })
83 |
84 | train_size = data.dr_mask.sum().cpu().item()
85 | sigma = self.compute_sigma(
86 | train_size,
87 | args.epochs,
88 | 1 + args.weight_decay,
89 | 4 - args.weight_decay,
90 | args.weight_decay,
91 | 5,
92 | 1 / train_size / train_size)
93 |
94 | self.publish(model, sigma)
95 |
96 | self.trainer_log['sigma'] = sigma
97 | self.trainer_log['training_time'] = time.time() - start_time
98 |
99 | # Save
100 | ckpt = {
101 | 'model_state': {k: v.cpu() for k, v in model.state_dict().items()},
102 | 'optimizer_state': optimizer.state_dict(),
103 | }
104 | torch.save(ckpt, os.path.join(args.checkpoint_dir, 'model_best.pt'))
105 |
--------------------------------------------------------------------------------
/framework/models/graph_classification/gcn.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from ogb.graphproppred.mol_encoder import AtomEncoder
5 | from torch_geometric.nn import GCNConv, MessagePassing, global_add_pool, global_mean_pool, global_max_pool, GlobalAttention, Set2Set
6 | from ogb.graphproppred.mol_encoder import AtomEncoder, BondEncoder
7 | from torch_geometric.utils import degree
8 | from torch_geometric.nn.inits import uniform
9 | from torch_scatter import scatter_mean
10 |
11 |
12 | '''
13 | Source: OGB github
14 | https://github.com/snap-stanford/ogb/blob/master/examples/graphproppred/mol/main_pyg.py
15 | '''
16 | ### GCN convolution along the graph structure
17 | class GCNConv(MessagePassing):
18 | def __init__(self, emb_dim):
19 | super().__init__(aggr='add')
20 |
21 | self.linear = nn.Linear(emb_dim, emb_dim)
22 | self.root_emb = nn.Embedding(1, emb_dim)
23 | self.bond_encoder = BondEncoder(emb_dim=emb_dim)
24 |
25 | def forward(self, x, edge_index, edge_attr):
26 | x = self.linear(x)
27 | edge_embedding = self.bond_encoder(edge_attr)
28 |
29 | row, col = edge_index
30 |
31 | #edge_weight = torch.ones((edge_index.size(1), ), device=edge_index.device)
32 | deg = degree(row, x.size(0), dtype=x.dtype) + 1
33 | deg_inv_sqrt = deg.pow(-0.5)
34 | deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
35 |
36 | norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]
37 |
38 | return self.propagate(edge_index, x=x, edge_attr=edge_embedding, norm=norm) + F.relu(x + self.root_emb.weight) * 1./deg.view(-1,1)
39 |
40 | def message(self, x_j, edge_attr, norm):
41 | return norm.view(-1, 1) * F.relu(x_j + edge_attr)
42 |
43 | def update(self, aggr_out):
44 | return aggr_out
45 |
46 | ### GNN to generate node embedding
47 | class GNN_node(nn.Module):
48 | def __init__(self, num_layer, emb_dim, drop_ratio=0.5, JK="last", residual=False):
49 | super().__init__()
50 | self.num_layer = num_layer
51 | self.drop_ratio = drop_ratio
52 | self.JK = JK
53 | ### add residual connection or not
54 | self.residual = residual
55 |
56 | self.atom_encoder = AtomEncoder(emb_dim)
57 |
58 | ###List of GNNs
59 | self.convs = nn.ModuleList()
60 | self.batch_norms = nn.ModuleList()
61 |
62 | for layer in range(num_layer):
63 | self.convs.append(GCNConv(emb_dim))
64 | self.batch_norms.append(nn.BatchNorm1d(emb_dim))
65 |
66 | def forward(self, batched_data):
67 | x, edge_index, edge_attr, batch = batched_data.x, batched_data.edge_index, batched_data.edge_attr, batched_data.batch
68 |
69 | ### computing input node embedding
70 |
71 | h_list = [self.atom_encoder(x)]
72 | for layer in range(self.num_layer):
73 |
74 | h = self.convs[layer](h_list[layer], edge_index, edge_attr)
75 | h = self.batch_norms[layer](h)
76 |
77 | if layer == self.num_layer - 1:
78 | #remove relu for the last layer
79 | h = F.dropout(h, self.drop_ratio, training=self.training)
80 | else:
81 | h = F.dropout(F.relu(h), self.drop_ratio, training=self.training)
82 |
83 | if self.residual:
84 | h += h_list[layer]
85 |
86 | h_list.append(h)
87 |
88 | ### Different implementations of Jk-concat
89 | if self.JK == "last":
90 | node_representation = h_list[-1]
91 | elif self.JK == "sum":
92 | node_representation = 0
93 | for layer in range(self.num_layer + 1):
94 | node_representation += h_list[layer]
95 |
96 | return node_representation
97 |
98 | class GNN(torch.nn.Module):
99 | def __init__(self, num_tasks, num_layer=2, emb_dim=300, virtual_node=True, residual=False, drop_ratio=0.5, JK="last", graph_pooling="mean"):
100 | super().__init__()
101 |
102 | self.num_layer = num_layer
103 | self.drop_ratio = drop_ratio
104 | self.JK = JK
105 | self.emb_dim = emb_dim
106 | self.num_tasks = num_tasks
107 | self.graph_pooling = graph_pooling
108 |
109 | ### GNN to generate node embeddings
110 | self.gnn_node = GNN_node(num_layer, emb_dim, JK=JK, drop_ratio=drop_ratio, residual=residual)
111 |
112 | ### Pooling function to generate whole-graph embeddings
113 | if self.graph_pooling == "sum":
114 | self.pool = global_add_pool
115 | elif self.graph_pooling == "mean":
116 | self.pool = global_mean_pool
117 | elif self.graph_pooling == "max":
118 | self.pool = global_max_pool
119 | elif self.graph_pooling == "attention":
120 | self.pool = GlobalAttention(gate_nn=nn.Sequential(Linear(emb_dim, 2*emb_dim), nn.BatchNorm1d(2*emb_dim), nn.ReLU(), nn.Linear(2*emb_dim, 1)))
121 | elif self.graph_pooling == "set2set":
122 | self.pool = Set2Set(emb_dim, processing_steps = 2)
123 | else:
124 | raise ValueError("Invalid graph pooling type.")
125 |
126 | if graph_pooling == "set2set":
127 | self.graph_pred_linear = nn.Linear(2*self.emb_dim, self.num_tasks)
128 | else:
129 | self.graph_pred_linear = nn.Linear(self.emb_dim, self.num_tasks)
130 |
131 | def forward(self, batched_data):
132 | h_node = self.gnn_node(batched_data)
133 |
134 | h_graph = self.pool(h_node, batched_data.batch)
135 |
136 | return self.graph_pred_linear(h_graph)
137 |
--------------------------------------------------------------------------------
/framework/trainer/approx_retrain.py:
--------------------------------------------------------------------------------
1 | import os
2 | import wandb
3 | from tqdm import tqdm, trange
4 | import torch
5 | import torch.nn.functional as F
6 | from torch_geometric.utils import negative_sampling
7 | from torch.utils.data import DataLoader, TensorDataset
8 |
9 | from .base import Trainer
10 | from ..evaluation import *
11 | from ..utils import *
12 |
13 |
14 | DTYPE = np.float16
15 |
16 | class ApproxTrainer(Trainer):
17 | '''This code is adapted from https://github.com/zleizzo/datadeletion'''
18 |
19 | def gram_schmidt(self, X):
20 | """
21 | Uses numpy's qr factorization method to perform Gram-Schmidt.
22 | Args:
23 | X: (k x d matrix) X[i] = i-th vector
24 | Returns:
25 | U: (k x d matrix) U[i] = i-th orthonormal vector
26 | C: (k x k matrix) Coefficient matrix, C[i] = coeffs for X[i], X = CU
27 | """
28 | (k, d) = X.shape
29 | if k <= d:
30 | q, r = np.linalg.qr(np.transpose(X))
31 | else:
32 | q, r = np.linalg.qr(np.transpose(X), mode='complete')
33 | U = np.transpose(q)
34 | C = np.transpose(r)
35 | return U, C
36 |
37 | def LKO_pred(self, X, Y, ind, H=None, reg=1e-4):
38 | """
39 | Computes the LKO model's prediction values on the left-out points.
40 | Args:
41 | X: (n x d matrix) Covariate matrix
42 | Y: (n x 1 vector) Response vector
43 | ind: (k x 1 list) List of indices to be removed
44 | H: (n x n matrix, optional) Hat matrix X (X^T X)^{-1} X^T
45 | Returns:
46 | LKO: (k x 1 vector) Retrained model's predictions on X[i], i in ind
47 | """
48 | n = len(Y)
49 | k = len(ind)
50 | d = len(X[0, :])
51 | if H is None:
52 | H = np.matmul(X, np.linalg.solve(np.matmul(X.T, X) + reg * np.eye(d), X.T))
53 |
54 | LOO = np.zeros(k)
55 | for i in range(k):
56 | idx = ind[i]
57 | # This is the LOO residual y_i - \hat{y}^{LOO}_i
58 | LOO[i] = (Y[idx] - np.matmul(H[idx, :], Y)) / (1 - H[idx, idx])
59 |
60 | # S = I - T from the paper
61 | S = np.eye(k)
62 | for i in range(k):
63 | for j in range(k):
64 | if j != i:
65 | idx_i = ind[i]
66 | idx_j = ind[j]
67 | S[i, j] = -H[idx_i, idx_j] / (1 - H[idx_i, idx_i])
68 |
69 | LKO = np.linalg.solve(S, LOO)
70 |
71 | return Y[ind] - LKO
72 |
73 |
74 | def lin_res(self, X, Y, theta, ind, H=None, reg=1e-4):
75 | """
76 | Approximate retraining via the projective residual update.
77 | Args:
78 | X: (n x d matrix) Covariate matrix
79 | Y: (n x 1 vector) Response vector
80 | theta: (d x 1 vector) Current value of parameters to be updated
81 | ind: (k x 1 list) List of indices to be removed
82 | H: (n x n matrix, optional) Hat matrix X (X^T X)^{-1} X^T
83 | Returns:
84 | updated: (d x 1 vector) Updated parameters
85 | """
86 | d = len(X[0])
87 | k = len(ind)
88 |
89 | # Step 1: Compute LKO predictions
90 | LKO = self.LKO_pred(X, Y, ind, H, reg)
91 |
92 | # Step 2: Eigendecompose B
93 | # 2.I
94 | U, C = self.gram_schmidt(X[ind, :])
95 | # 2.II
96 | Cmatrix = np.matmul(C.T, C)
97 | eigenval, a = np.linalg.eigh(Cmatrix)
98 | V = np.matmul(a.T, U)
99 |
100 | # Step 3: Perform the update
101 | # 3.I
102 | grad = np.zeros_like(theta) # 2D grad
103 | for i in range(k):
104 | grad += (X[ind[i], :] * theta - LKO[i]) * X[ind[i], :]
105 | # 3.II
106 | step = np.zeros_like(theta) # 2D grad
107 | for i in range(k):
108 | factor = 1 / eigenval[i] if eigenval[i] > 1e-10 else 0
109 | step += factor * V[i, :] * grad * V[i, :]
110 | # 3.III
111 | return step
112 | # update = theta - step
113 | # return update
114 |
115 | @torch.no_grad()
116 | def train(self, model, data, optimizer, args, logits_ori=None, attack_model=None):
117 | model.eval()
118 | best_loss = 100000
119 |
120 | neg_edge_index = negative_sampling(
121 | edge_index=data.train_pos_edge_index[:, data.dr_mask],
122 | num_nodes=data.num_nodes,
123 | num_neg_samples=data.dr_mask.sum())
124 |
125 | z = model(data.x, data.train_pos_edge_index[:, data.dr_mask])
126 | edge_index_all = torch.cat([data.train_pos_edge_index[:, data.dr_mask], neg_edge_index], dim=1)
127 |
128 | X = z[edge_index_all[0]] * z[edge_index_all[1]]
129 | Y = self.get_link_labels(data.train_pos_edge_index[:, data.dr_mask], neg_edge_index)
130 | X = X.cpu()
131 | Y = Y.cpu()
132 |
133 | # According to the code, theta should be of (d, d). So only update the weights of the last layer
134 | theta = model.conv2.lin.weight.cpu().numpy()
135 | ind = [int(i) for i in self.args.df_idx.split(',')]
136 |
137 | # Not enough RAM for solving matrix inverse. So break into multiple batches
138 | update = []
139 | loader = DataLoader(TensorDataset(X, Y), batch_size=4096, num_workers=8)
140 | for x, y in tqdm(loader, desc='Unlearning'):
141 |
142 | x = x.numpy()
143 | y = y.numpy()
144 |
145 | update_step = self.lin_res(x, y, theta.T, ind)
146 | update.append(torch.tensor(update_step))
147 |
148 | update = torch.stack(update).mean(0)
149 | model.conv2.lin.weight = torch.nn.Parameter(model.conv2.lin.weight - update.t().cuda())
150 |
151 | print(f'Update model weights from {torch.norm(torch.tensor(theta))} to {torch.norm(model.conv2.lin.weight)}')
152 |
153 | valid_loss, auc, aup, df_logt, logit_all_pair = self.eval(model, data, 'val')
154 |
155 | self.trainer_log['log'].append({
156 | 'dt_loss': valid_loss,
157 | 'dt_auc': auc,
158 | 'dt_aup': aup
159 | })
160 |
161 | # Save
162 | ckpt = {
163 | 'model_state': {k: v.cpu() for k, v in model.state_dict().items()},
164 | 'node_emb': None,
165 | 'optimizer_state': None,
166 | }
167 | torch.save(ckpt, os.path.join(args.checkpoint_dir, 'model_best.pt'))
168 |
--------------------------------------------------------------------------------
/framework/models/graph_classification/gcn_delete.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from ogb.graphproppred.mol_encoder import AtomEncoder
5 | from torch_geometric.nn import GCNConv, MessagePassing, global_add_pool, global_mean_pool, global_max_pool, GlobalAttention, Set2Set
6 | from ogb.graphproppred.mol_encoder import AtomEncoder, BondEncoder
7 | from torch_geometric.utils import degree
8 | from torch_geometric.nn.inits import uniform
9 | from torch_scatter import scatter_mean
10 |
11 | from ..deletion import DeletionLayer
12 |
13 |
14 | def remove_edges(edge_index, edge_attr=None, ratio=0.025):
15 | row, col = edge_index
16 | mask = row < col
17 | row, col = row[mask], col[mask]
18 |
19 | if edge_attr is not None:
20 | edge_attr = edge_attr[mask]
21 |
22 | num_edges = len(row)
23 | num_remove = max(1, int(num_edges * ratio))
24 |
25 | selected = torch.randperm(num_edges)[:num_edges - num_remove]
26 |
27 | row = row[selected]
28 | col = col[selected]
29 | edge_attr = edge_attr[selected]
30 |
31 | return torch.stack([row, col], dim=0), edge_attr
32 |
33 | '''
34 | Source: OGB github
35 | https://github.com/snap-stanford/ogb/blob/master/examples/graphproppred/mol/main_pyg.py
36 | '''
37 | ### GCN convolution along the graph structure
38 | class GCNConv(MessagePassing):
39 | def __init__(self, emb_dim):
40 | super().__init__(aggr='add')
41 |
42 | self.linear = nn.Linear(emb_dim, emb_dim)
43 | self.root_emb = nn.Embedding(1, emb_dim)
44 | self.bond_encoder = BondEncoder(emb_dim=emb_dim)
45 |
46 | def forward(self, x, edge_index, edge_attr):
47 | x = self.linear(x)
48 | edge_embedding = self.bond_encoder(edge_attr)
49 |
50 | row, col = edge_index
51 |
52 | #edge_weight = torch.ones((edge_index.size(1), ), device=edge_index.device)
53 | deg = degree(row, x.size(0), dtype=x.dtype) + 1
54 | deg_inv_sqrt = deg.pow(-0.5)
55 | deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
56 |
57 | norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]
58 |
59 | return self.propagate(edge_index, x=x, edge_attr=edge_embedding, norm=norm) + F.relu(x + self.root_emb.weight) * 1./deg.view(-1,1)
60 |
61 | def message(self, x_j, edge_attr, norm):
62 | return norm.view(-1, 1) * F.relu(x_j + edge_attr)
63 |
64 | def update(self, aggr_out):
65 | return aggr_out
66 |
67 | ### GNN to generate node embedding
68 | class GNN_node_delete(nn.Module):
69 | def __init__(self, num_layer, emb_dim, drop_ratio=0.5, JK="last", residual=False, mask_1hop=None, mask_2hop=None):
70 | super().__init__()
71 | self.num_layer = num_layer
72 | self.drop_ratio = drop_ratio
73 | self.JK = JK
74 | ### add residual connection or not
75 | self.residual = residual
76 |
77 | self.atom_encoder = AtomEncoder(emb_dim)
78 |
79 | ###List of GNNs
80 | self.deletes = nn.ModuleList([
81 | DeletionLayer(emb_dim, None),
82 | DeletionLayer(emb_dim, None)
83 | ])
84 | self.convs = nn.ModuleList()
85 | self.batch_norms = nn.ModuleList()
86 |
87 | for layer in range(num_layer):
88 | self.convs.append(GCNConv(emb_dim))
89 | self.batch_norms.append(nn.BatchNorm1d(emb_dim))
90 |
91 | def forward(self, batched_data):
92 | x, edge_index, edge_attr, batch = batched_data.x, batched_data.edge_index, batched_data.edge_attr, batched_data.batch
93 | edge_index, edge_attr = remove_edges(edge_index, edge_attr)
94 |
95 | ### computing input node embedding
96 |
97 | h_list = [self.atom_encoder(x)]
98 | for layer in range(self.num_layer):
99 |
100 | h = self.convs[layer](h_list[layer], edge_index, edge_attr)
101 | h = self.deletes[layer](h)
102 | h = self.batch_norms[layer](h)
103 |
104 | if layer == self.num_layer - 1:
105 | #remove relu for the last layer
106 | h = F.dropout(h, self.drop_ratio, training=self.training)
107 | else:
108 | h = F.dropout(F.relu(h), self.drop_ratio, training=self.training)
109 |
110 | if self.residual:
111 | h += h_list[layer]
112 |
113 | h_list.append(h)
114 |
115 | ### Different implementations of Jk-concat
116 | if self.JK == "last":
117 | node_representation = h_list[-1]
118 | elif self.JK == "sum":
119 | node_representation = 0
120 | for layer in range(self.num_layer + 1):
121 | node_representation += h_list[layer]
122 |
123 | return node_representation
124 |
125 | class GNN(torch.nn.Module):
126 | def __init__(self, num_tasks, num_layer=2, emb_dim=300, virtual_node=True, residual=False, drop_ratio=0.5, JK="last", graph_pooling="mean"):
127 | super().__init__()
128 |
129 | self.num_layer = num_layer
130 | self.drop_ratio = drop_ratio
131 | self.JK = JK
132 | self.emb_dim = emb_dim
133 | self.num_tasks = num_tasks
134 | self.graph_pooling = graph_pooling
135 |
136 | ### GNN to generate node embeddings
137 | self.gnn_node = GNN_node_delete(num_layer, emb_dim, JK=JK, drop_ratio=drop_ratio, residual=residual)
138 |
139 | ### Pooling function to generate whole-graph embeddings
140 | if self.graph_pooling == "sum":
141 | self.pool = global_add_pool
142 | elif self.graph_pooling == "mean":
143 | self.pool = global_mean_pool
144 | elif self.graph_pooling == "max":
145 | self.pool = global_max_pool
146 | elif self.graph_pooling == "attention":
147 | self.pool = GlobalAttention(gate_nn=nn.Sequential(Linear(emb_dim, 2*emb_dim), nn.BatchNorm1d(2*emb_dim), nn.ReLU(), nn.Linear(2*emb_dim, 1)))
148 | elif self.graph_pooling == "set2set":
149 | self.pool = Set2Set(emb_dim, processing_steps = 2)
150 | else:
151 | raise ValueError("Invalid graph pooling type.")
152 |
153 | if graph_pooling == "set2set":
154 | self.graph_pred_linear = nn.Linear(2*self.emb_dim, self.num_tasks)
155 | else:
156 | self.graph_pred_linear = nn.Linear(self.emb_dim, self.num_tasks)
157 |
158 | def forward(self, batched_data):
159 | h_node = self.gnn_node(batched_data)
160 |
161 | h_graph = self.pool(h_node, batched_data.batch)
162 |
163 | return self.graph_pred_linear(h_graph)
164 |
--------------------------------------------------------------------------------
/framework/evaluation.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from sklearn.metrics import roc_auc_score, average_precision_score
4 | from .utils import get_link_labels
5 |
6 |
7 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
8 |
9 | @torch.no_grad()
10 | def eval_lp(model, stage, data=None, loader=None):
11 | model.eval()
12 |
13 | # For full batch
14 | if data is not None:
15 | pos_edge_index = data[f'{stage}_pos_edge_index']
16 | neg_edge_index = data[f'{stage}_neg_edge_index']
17 |
18 | if hasattr(data, 'dtrain_mask') and data.dtrain_mask is not None:
19 | embedding = model(data.x.to(device), data.train_pos_edge_index[:, data.dtrain_mask].to(device))
20 | else:
21 | embedding = model(data.x.to(device), data.train_pos_edge_index.to(device))
22 |
23 | logits = model.decode(embedding, pos_edge_index, neg_edge_index).sigmoid()
24 | label = get_link_labels(pos_edge_index, neg_edge_index)
25 |
26 | # For mini batch
27 | if loader is not None:
28 | logits = []
29 | label = []
30 | for batch in loader:
31 | edge_index = batch.edge_index.to(device)
32 |
33 | if hasattr(batch, 'edge_type'):
34 | edge_type = batch.edge_type.to(device)
35 |
36 | embedding1 = model1(edge_index, edge_type)
37 | embedding2 = model2(edge_index, edge_type)
38 |
39 | s1 = model.decode(embedding1, edge_index, edge_type)
40 | s2 = model.decode(embedding2, edge_index, edge_type)
41 |
42 | else:
43 | embedding1 = model1(edge_index)
44 | embedding2 = model2(edge_index)
45 |
46 | s1 = model.decode(embedding1, edge_index)
47 | s2 = model.decode(embedding2, edge_index)
48 |
49 | embedding = model(data.train_pos_edge_index.to(device))
50 |
51 | lg = model.decode(embedding, pos_edge_index, neg_edge_index).sigmoid()
52 | lb = get_link_labels(pos_edge_index, neg_edge_index)
53 |
54 | logits.append(lg)
55 | label.append(lb)
56 |
57 | loss = F.binary_cross_entropy_with_logits(logits, label)
58 | auc = roc_auc_score(label.cpu(), logits.cpu())
59 | aup = average_precision_score(label.cpu(), logits.cpu())
60 |
61 | return loss, auc, aup
62 |
63 | @torch.no_grad()
64 | def verification_error(model1, model2):
65 | '''L2 distance between aproximate model and re-trained model'''
66 |
67 | model1 = model1.to('cpu')
68 | model2 = model2.to('cpu')
69 |
70 | modules1 = {n: p for n, p in model1.named_parameters()}
71 | modules2 = {n: p for n, p in model2.named_parameters()}
72 |
73 | all_names = set(modules1.keys()) & set(modules2.keys())
74 |
75 | print(all_names)
76 |
77 | diff = torch.tensor(0.0).float()
78 | for n in all_names:
79 | diff += torch.norm(modules1[n] - modules2[n])
80 |
81 | return diff
82 |
83 | @torch.no_grad()
84 | def member_infer_attack(target_model, attack_model, data, logits=None):
85 | '''Membership inference attack'''
86 |
87 | edge = data.train_pos_edge_index[:, data.df_mask]
88 | z = target_model(data.x, data.train_pos_edge_index[:, data.dr_mask])
89 | feature1 = target_model.decode(z, edge).sigmoid()
90 | feature0 = 1 - feature1
91 | feature = torch.stack([feature0, feature1], dim=1)
92 | # feature = torch.cat([z[edge[0]], z[edge][1]], dim=-1)
93 | logits = attack_model(feature)
94 | _, pred = torch.max(logits, 1)
95 | suc_rate = 1 - pred.float().mean()
96 |
97 | return torch.softmax(logits, dim=-1).squeeze().tolist(), suc_rate.cpu().item()
98 |
99 | @torch.no_grad()
100 | def member_infer_attack_node(target_model, attack_model, data, logits=None):
101 | '''Membership inference attack'''
102 |
103 | edge = data.train_pos_edge_index[:, data.df_mask]
104 | z = target_model(data.x, data.train_pos_edge_index[:, data.dr_mask])
105 | feature = torch.cat([z[edge[0]], z[edge][1]], dim=-1)
106 | logits = attack_model(feature)
107 | _, pred = torch.max(logits, 1)
108 | suc_rate = 1 - pred.float().mean()
109 |
110 | return torch.softmax(logits, dim=-1).squeeze().tolist(), suc_rate.cpu().item()
111 |
112 | @torch.no_grad()
113 | def get_node_embedding_data(model, data):
114 | model.eval()
115 |
116 | if hasattr(data, 'dtrain_mask') and data.dtrain_mask is not None:
117 | node_embedding = model(data.x.to(device), data.train_pos_edge_index[:, data.dtrain_mask].to(device))
118 | else:
119 | node_embedding = model(data.x.to(device), data.train_pos_edge_index.to(device))
120 |
121 | return node_embedding
122 |
123 | @torch.no_grad()
124 | def output_kldiv(model1, model2, data=None, loader=None):
125 | '''KL-Divergence between output distribution of model and re-trained model'''
126 |
127 | model1.eval()
128 | model2.eval()
129 |
130 | # For full batch
131 | if data is not None:
132 | embedding1 = get_node_embedding_data(model1, data).to(device)
133 | embedding2 = get_node_embedding_data(model2, data).to(device)
134 |
135 | if data.edge_index is not None:
136 | edge_index = data.edge_index.to(device)
137 | if data.train_pos_edge_index is not None:
138 | edge_index = data.train_pos_edge_index.to(device)
139 |
140 |
141 | if hasattr(data, 'edge_type'):
142 | edge_type = data.edge_type.to(device)
143 | score1 = model1.decode(embedding1, edge_index, edge_type)
144 | score2 = model2.decode(embedding2, edge_index, edge_type)
145 | else:
146 | score1 = model1.decode(embedding1, edge_index)
147 | score2 = model2.decode(embedding2, edge_index)
148 |
149 | # For mini batch
150 | if loader is not None:
151 | score1 = []
152 | score2 = []
153 | for batch in loader:
154 | edge_index = batch.edge_index.to(device)
155 |
156 | if hasattr(batch, 'edge_type'):
157 | edge_type = batch.edge_type.to(device)
158 |
159 | embedding1 = model1(edge, edge_type)
160 | embedding2 = model2(edge, edge_type)
161 |
162 | s1 = model.decode(embedding1, edge, edge_type)
163 | s2 = model.decode(embedding2, edge, edge_type)
164 |
165 | else:
166 | embedding1 = model1(edge)
167 | embedding2 = model2(edge)
168 |
169 | s1 = model.decode(embedding1, edge)
170 | s2 = model.decode(embedding2, edge)
171 |
172 | score1.append(s1)
173 | score2.append(s2)
174 |
175 | score1 = torch.hstack(score1)
176 | score2 = torch.hstack(score2)
177 |
178 | kldiv = F.kl_div(
179 | F.log_softmax(score1, dim=-1),
180 | F.softmax(score2, dim=-1)
181 | )
182 |
183 | return kldiv
184 |
185 |
--------------------------------------------------------------------------------
/framework/training_args.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 |
4 | num_edge_type_mapping = {
5 | 'FB15k-237': 237,
6 | 'WordNet18': 18,
7 | 'WordNet18RR': 11,
8 | 'ogbl-biokg': 51
9 | }
10 |
11 | def parse_args():
12 | parser = argparse.ArgumentParser()
13 |
14 | # Model
15 | parser.add_argument('--unlearning_model', type=str, default='retrain',
16 | help='unlearning method')
17 | parser.add_argument('--gnn', type=str, default='gcn',
18 | help='GNN architecture')
19 | parser.add_argument('--in_dim', type=int, default=128,
20 | help='input dimension')
21 | parser.add_argument('--hidden_dim', type=int, default=128,
22 | help='hidden dimension')
23 | parser.add_argument('--out_dim', type=int, default=64,
24 | help='output dimension')
25 |
26 | # Data
27 | parser.add_argument('--data_dir', type=str, default='./data',
28 | help='data dir')
29 | parser.add_argument('--df', type=str, default='none',
30 | help='Df set to use')
31 | parser.add_argument('--df_idx', type=str, default='none',
32 | help='indices of data to be deleted')
33 | parser.add_argument('--df_size', type=float, default=0.5,
34 | help='Df size')
35 | parser.add_argument('--dataset', type=str, default='Cora',
36 | help='dataset')
37 | parser.add_argument('--random_seed', type=int, default=42,
38 | help='random seed')
39 | parser.add_argument('--batch_size', type=int, default=8192,
40 | help='batch size for GraphSAINTRandomWalk sampler')
41 | parser.add_argument('--walk_length', type=int, default=2,
42 | help='random walk length for GraphSAINTRandomWalk sampler')
43 | parser.add_argument('--num_steps', type=int, default=32,
44 | help='number of steps for GraphSAINTRandomWalk sampler')
45 |
46 | # Training
47 | parser.add_argument('--lr', type=float, default=1e-3,
48 | help='initial learning rate')
49 | parser.add_argument('--weight_decay', type=float, default=0.0005,
50 | help='weight decay')
51 | parser.add_argument('--optimizer', type=str, default='Adam',
52 | help='optimizer to use')
53 | parser.add_argument('--epochs', type=int, default=3000,
54 | help='number of epochs to train')
55 | parser.add_argument('--valid_freq', type=int, default=100,
56 | help='# of epochs to do validation')
57 | parser.add_argument('--checkpoint_dir', type=str, default='./checkpoint',
58 | help='checkpoint folder')
59 | parser.add_argument('--alpha', type=float, default=0.5,
60 | help='alpha in loss function')
61 | parser.add_argument('--neg_sample_random', type=str, default='non_connected',
62 | help='type of negative samples for randomness')
63 | parser.add_argument('--loss_fct', type=str, default='mse_mean',
64 | help='loss function. one of {mse, kld, cosine}')
65 | parser.add_argument('--loss_type', type=str, default='both_layerwise',
66 | help='type of loss. one of {both_all, both_layerwise, only2_layerwise, only2_all, only1}')
67 |
68 | # GraphEraser
69 | parser.add_argument('--num_clusters', type=int, default=10,
70 | help='top k for evaluation')
71 | parser.add_argument('--kmeans_max_iters', type=int, default=1,
72 | help='top k for evaluation')
73 | parser.add_argument('--shard_size_delta', type=float, default=0.005)
74 | parser.add_argument('--terminate_delta', type=int, default=0)
75 |
76 | # GraphEditor
77 | parser.add_argument('--eval_steps', type=int, default=1)
78 | parser.add_argument('--runs', type=int, default=1)
79 |
80 | parser.add_argument('--num_remove_links', type=int, default=11)
81 | parser.add_argument('--parallel_unlearning', type=int, default=4)
82 |
83 | parser.add_argument('--lam', type=float, default=0)
84 | parser.add_argument('--regen_feats', action='store_true')
85 | parser.add_argument('--regen_neighbors', action='store_true')
86 | parser.add_argument('--regen_links', action='store_true')
87 | parser.add_argument('--regen_subgraphs', action='store_true')
88 | parser.add_argument('--hop_neighbors', type=int, default=20)
89 |
90 |
91 | # Evaluation
92 | parser.add_argument('--topk', type=int, default=500,
93 | help='top k for evaluation')
94 | parser.add_argument('--eval_on_cpu', type=bool, default=False,
95 | help='whether to evaluate on CPU')
96 |
97 | # KG
98 | parser.add_argument('--num_edge_type', type=int, default=None,
99 | help='number of edges types')
100 |
101 | args = parser.parse_args()
102 |
103 | if 'ogbl' in args.dataset:
104 | args.eval_on_cpu = True
105 |
106 | # For KG
107 | if args.gnn in ['rgcn', 'rgat']:
108 | args.lr = 1e-3
109 | args.epochs = 3000
110 | args.valid_freq = 500
111 | args.batch_size //= 2
112 | args.num_edge_type = num_edge_type_mapping[args.dataset]
113 | args.eval_on_cpu = True
114 | # args.in_dim = 512
115 | # args.hidden_dim = 256
116 | # args.out_dim = 128
117 |
118 | if args.unlearning_model in ['original', 'retrain']:
119 | args.epochs = 2000
120 | args.valid_freq = 500
121 |
122 | # For large graphs
123 | if args.gnn not in ['rgcn', 'rgat'] and 'ogbl' in args.dataset:
124 | args.epochs = 600
125 | args.valid_freq = 200
126 | if args.gnn in ['rgcn', 'rgat'] and 'ogbl' in args.dataset:
127 | args.batch_size = 1024
128 |
129 | if 'gnndelete' in args.unlearning_model:
130 | if args.gnn not in ['rgcn', 'rgat'] and 'ogbl' in args.dataset:
131 | args.epochs = 600
132 | args.valid_freq = 100
133 | if args.gnn in ['rgcn', 'rgat']:
134 | if args.dataset == 'WordNet18':
135 | args.epochs = 50
136 | args.valid_freq = 2
137 | args.batch_size = 1024
138 | if args.dataset == 'ogbl-biokg':
139 | args.epochs = 50
140 | args.valid_freq = 10
141 | args.batch_size = 64
142 |
143 | elif args.unlearning_model == 'gradient_ascent':
144 | args.epochs = 10
145 | args.valid_freq = 1
146 |
147 | elif args.unlearning_model == 'descent_to_delete':
148 | args.epochs = 1
149 |
150 | elif args.unlearning_model == 'graph_editor':
151 | args.epochs = 400
152 | args.valid_freq = 200
153 |
154 |
155 | if args.dataset == 'ogbg-molhiv':
156 | args.epochs = 100
157 | args.valid_freq = 5
158 |
159 | return args
160 |
--------------------------------------------------------------------------------
/framework/models/deletion.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import torch.nn.init as init
5 | from . import GCN, GAT, GIN, RGCN, RGAT
6 |
7 |
8 | class DeletionLayer(nn.Module):
9 | def __init__(self, dim, mask):
10 | super().__init__()
11 | self.dim = dim
12 | self.mask = mask
13 | self.deletion_weight = nn.Parameter(torch.ones(dim, dim) / 1000)
14 | # self.deletion_weight = nn.Parameter(torch.eye(dim, dim))
15 | # init.xavier_uniform_(self.deletion_weight)
16 |
17 | def forward(self, x, mask=None):
18 | '''Only apply deletion operator to the local nodes identified by mask'''
19 |
20 | if mask is None:
21 | mask = self.mask
22 |
23 | if mask is not None:
24 | new_rep = x.clone()
25 | new_rep[mask] = torch.matmul(new_rep[mask], self.deletion_weight)
26 |
27 | return new_rep
28 |
29 | return x
30 |
31 | class DeletionLayerKG(nn.Module):
32 | def __init__(self, dim, mask):
33 | super().__init__()
34 | self.dim = dim
35 | self.mask = mask
36 | self.deletion_weight = nn.Parameter(torch.ones(dim, dim) / 1000)
37 |
38 | def forward(self, x, mask=None):
39 | '''Only apply deletion operator to the local nodes identified by mask'''
40 |
41 | if mask is None:
42 | mask = self.mask
43 |
44 | if mask is not None:
45 | new_rep = x.clone()
46 | new_rep[mask] = torch.matmul(new_rep[mask], self.deletion_weight)
47 |
48 | return new_rep
49 |
50 | return x
51 |
52 | class GCNDelete(GCN):
53 | def __init__(self, args, mask_1hop=None, mask_2hop=None, **kwargs):
54 | super().__init__(args)
55 | self.deletion1 = DeletionLayer(args.hidden_dim, mask_1hop)
56 | self.deletion2 = DeletionLayer(args.out_dim, mask_2hop)
57 |
58 | self.conv1.requires_grad = False
59 | self.conv2.requires_grad = False
60 |
61 | def forward(self, x, edge_index, mask_1hop=None, mask_2hop=None, return_all_emb=False):
62 | # with torch.no_grad():
63 | x1 = self.conv1(x, edge_index)
64 |
65 | x1 = self.deletion1(x1, mask_1hop)
66 |
67 | x = F.relu(x1)
68 |
69 | x2 = self.conv2(x, edge_index)
70 | x2 = self.deletion2(x2, mask_2hop)
71 |
72 | if return_all_emb:
73 | return x1, x2
74 |
75 | return x2
76 |
77 | def get_original_embeddings(self, x, edge_index, return_all_emb=False):
78 | return super().forward(x, edge_index, return_all_emb)
79 |
80 | class GATDelete(GAT):
81 | def __init__(self, args, mask_1hop=None, mask_2hop=None, **kwargs):
82 | super().__init__(args)
83 | self.deletion1 = DeletionLayer(args.hidden_dim, mask_1hop)
84 | self.deletion2 = DeletionLayer(args.out_dim, mask_2hop)
85 |
86 | self.conv1.requires_grad = False
87 | self.conv2.requires_grad = False
88 |
89 | def forward(self, x, edge_index, mask_1hop=None, mask_2hop=None, return_all_emb=False):
90 | with torch.no_grad():
91 | x1 = self.conv1(x, edge_index)
92 | x1 = self.deletion1(x1, mask_1hop)
93 |
94 | x = F.relu(x1)
95 |
96 | x2 = self.conv2(x, edge_index)
97 | x2 = self.deletion2(x2, mask_2hop)
98 |
99 | if return_all_emb:
100 | return x1, x2
101 |
102 | return x2
103 |
104 | def get_original_embeddings(self, x, edge_index, return_all_emb=False):
105 | return super().forward(x, edge_index, return_all_emb)
106 |
107 | class GINDelete(GIN):
108 | def __init__(self, args, mask_1hop=None, mask_2hop=None, **kwargs):
109 | super().__init__(args)
110 | self.deletion1 = DeletionLayer(args.hidden_dim, mask_1hop)
111 | self.deletion2 = DeletionLayer(args.out_dim, mask_2hop)
112 |
113 | self.conv1.requires_grad = False
114 | self.conv2.requires_grad = False
115 |
116 | def forward(self, x, edge_index, mask_1hop=None, mask_2hop=None, return_all_emb=False):
117 | with torch.no_grad():
118 | x1 = self.conv1(x, edge_index)
119 |
120 | x1 = self.deletion1(x1, mask_1hop)
121 |
122 | x = F.relu(x1)
123 |
124 | x2 = self.conv2(x, edge_index)
125 | x2 = self.deletion2(x2, mask_2hop)
126 |
127 | if return_all_emb:
128 | return x1, x2
129 |
130 | return x2
131 |
132 | def get_original_embeddings(self, x, edge_index, return_all_emb=False):
133 | return super().forward(x, edge_index, return_all_emb)
134 |
135 | class RGCNDelete(RGCN):
136 | def __init__(self, args, num_nodes, num_edge_type, mask_1hop=None, mask_2hop=None, **kwargs):
137 | super().__init__(args, num_nodes, num_edge_type)
138 | self.deletion1 = DeletionLayer(args.hidden_dim, mask_1hop)
139 | self.deletion2 = DeletionLayer(args.out_dim, mask_2hop)
140 |
141 | self.node_emb.requires_grad = False
142 | self.conv1.requires_grad = False
143 | self.conv2.requires_grad = False
144 |
145 | def forward(self, x, edge_index, edge_type, mask_1hop=None, mask_2hop=None, return_all_emb=False):
146 | with torch.no_grad():
147 | x = self.node_emb(x)
148 | x1 = self.conv1(x, edge_index, edge_type)
149 |
150 | x1 = self.deletion1(x1, mask_1hop)
151 |
152 | x = F.relu(x1)
153 |
154 | x2 = self.conv2(x, edge_index, edge_type)
155 | x2 = self.deletion2(x2, mask_2hop)
156 |
157 | if return_all_emb:
158 | return x1, x2
159 |
160 | return x2
161 |
162 | def get_original_embeddings(self, x, edge_index, edge_type, return_all_emb=False):
163 | return super().forward(x, edge_index, edge_type, return_all_emb)
164 |
165 | class RGATDelete(RGAT):
166 | def __init__(self, args, num_nodes, num_edge_type, mask_1hop=None, mask_2hop=None, **kwargs):
167 | super().__init__(args, num_nodes, num_edge_type)
168 | self.deletion1 = DeletionLayer(args.hidden_dim, mask_1hop)
169 | self.deletion2 = DeletionLayer(args.out_dim, mask_2hop)
170 |
171 | self.node_emb.requires_grad = False
172 | self.conv1.requires_grad = False
173 | self.conv2.requires_grad = False
174 |
175 | def forward(self, x, edge_index, edge_type, mask_1hop=None, mask_2hop=None, return_all_emb=False):
176 | with torch.no_grad():
177 | x = self.node_emb(x)
178 | x1 = self.conv1(x, edge_index, edge_type)
179 |
180 | x1 = self.deletion1(x1, mask_1hop)
181 |
182 | x = F.relu(x1)
183 |
184 | x2 = self.conv2(x, edge_index, edge_type)
185 | x2 = self.deletion2(x2, mask_2hop)
186 |
187 | if return_all_emb:
188 | return x1, x2
189 |
190 | return x2
191 |
192 | def get_original_embeddings(self, x, edge_index, edge_type, return_all_emb=False):
193 | return super().forward(x, edge_index, edge_type, return_all_emb)
194 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 | # GNNDelete: A General Unlearning Strategy for Graph Neural Networks
3 |
4 | #### Authors:
5 | - [Jiali Cheng]() (jiali.cheng.ccnr@gmail.com)
6 | - [George Dasoulas](https://gdasoulas.github.io/) (george.dasoulas1@gmail.com)
7 | - [Huan He](https://github.com/mims-harvard/Raindrop) (huan_he@hms.harvard.edu)
8 | - [Chirag Agarwal](https://chirag126.github.io/) (chiragagarwall12@gmail.com)
9 | - [Marinka Zitnik](https://zitniklab.hms.harvard.edu/) (marinka@hms.harvard.edu)
10 |
11 | #### [Project website](https://zitniklab.hms.harvard.edu/projects/GNNDelete/)
12 |
13 | #### GNNDelete Paper: [ICLR 2023](https://openreview.net/forum?id=X9yCkmT5Qrl), [Preprint]()
14 |
15 |
16 | ## Overview
17 |
18 | This repository contains the code to preprocess datasets, train GNN models, and perform data deletion on trained GNN models for manuscript *GNNDelete: A General Graph Unlearning Strategy*. We propose GNNDelete, a model-agnostic layer-wise operator that optimize both properties for unlearning tasks. It formalizes the required properties for graph unlearning in the form of Deleted Edge Consistency and Neighborhood Influence. GNNDelete updates latent representations to delete nodes and edges from the model while keeping the rest of the learned knowledge intact.
19 |
20 |
21 |
22 |
23 |
24 |
25 | ## Key idea of GNNDelete
26 |
27 | To unlearn information from a trained GNN, its influence on both GNN model weights as well as on representations of neighbors in the graph must be deleted from the model. However, existing methods using retraining and weight modification either degrade model weights shared across all nodes or are ineffective because of strong dependency of deleted edges on their local graph neighborhood.
28 |
29 | Our model formulates the unlearning problem as a representation learning task. It formalizes the required properties for graph unlearning in the form of Deleted Edge Consistency and Neighborhood Influence. GNNDelete updates latent representations to delete nodes and edges from the model while keeping the rest of the learned knowledge intact.
30 |
31 | **Overview of GNNDelete approach.** Our model extends the standard (Msg, Agg, Upd) GNN framework into (Msg, Agg, Upd, Del). Upon unlearning, GNNDelete inserts trainable deletion operators after the GNN layers. The **Del** operator updates the node representations of the affected nodes of deletion (based on the local enclosing subgraph of the deleted information). The updated representations are optimized to meet objectives of **Deleted Edge Consistency** and **Neighborhood Influence**. We only train the deletion operators, while freezing the rest of the GNN weights.
32 |
33 |
34 | ## Datasets
35 |
36 | We prepared seven commonly used datasets of different scales, including homogenous and heterogeneous graphs.
37 |
38 | Please run the following command to do train-test split and sample deleted information (Two strategies are described in the paper).
39 | ```
40 | python prepare_dataset.py
41 | ```
42 |
43 | The following table summarizes the statistics of all these seven datasets:
44 | | Dataset | # Nodes | # Edges | # Unique edge types | Max # deleted edges |
45 | |------------|:-------:|:---------:|:-------------------:|:-------------------:|
46 | | Cora | 19,793 | 126,842 | N/A | 6,342 |
47 | | PubMed | 19,717 | 88,648 | N/A | 4,432 |
48 | | DBLP | 17,716 | 105,734 | N/A | 5,286 |
49 | | CS | 18,333 | 163,788 | N/A | 8,189 |
50 | | OGB-Collab | 235,368 | 1,285,465 | N/A | 117,905 |
51 | | WordNet18 | 40,943 | 151,442 | 18 | 7,072 |
52 | | BioKG | 93,773 | 5,088,434 | 51 | 127,210 |
53 |
54 |
55 | ## Experimental setups
56 |
57 | We evaluated our model in three unlearning tasks and in comparison with six baselines. The baselines include three state-of-the-art models designed for graph unlearning (GraphEraser, GraphEditor, Certified Graph Unlearning) and three general unlearning method (retraining, gradient ascent, Descent-to-Delete). The three different unlearning tasks are:
58 |
59 | **Unlearning task 1: delete edges.** We delete a set of edges from a trained GNN model.
60 |
61 | **Unlearning task 2: delete nodes.** We delete a set of nodes from a trained GNN model.
62 |
63 | **Unlearning task 3: delete node features.** We delete the node features of a set of nodes from a trained GNN model.
64 |
65 | The deleted information can be sampled with two different strategies:
66 |
67 | **A simpler setting: Out setting** The deleted information is sampled from **outside** the enclosing subgraph of test set.
68 |
69 | **A harder setting: In setting** The deleted information is sampled from **within** the enclosing subgraph of test set.
70 |
71 |
72 | ## Requirements
73 |
74 | GNNDelete has been tested using Python >=3.6.
75 |
76 | Please install the required packages by running
77 |
78 | ```
79 | pip install -r requirements.txt
80 | ```
81 |
82 |
83 | ## Running the code
84 |
85 | **Train GNN** The first step is to train a GNN model, on either link prediction or node classification
86 |
87 | ```python train_gnn.py```
88 |
89 | Or
90 |
91 | ```python train_node_classification.py```
92 |
93 | **Train Membership Inference attacker (_Optional_)** We use the model in _Membership Inference Attack on Graph Neural Networks_ as our MI model. Please refer to the [official implementation](https://github.com/iyempissy/rebMIGraph).
94 |
95 | **Unlearn** Then we can delete information from the trained GNN model. Based on what you want to delete, run one of the three scrips
96 |
97 | To unlearn edges, please run
98 | ```
99 | python delete_gnn.py
100 | ```
101 |
102 | To unlearn nodes, please run
103 | ```
104 | python delete_nodes.py
105 | ```
106 |
107 | To unlearn node features, please run
108 | ```
109 | python delete_node_feature.py
110 | ```
111 |
112 |
113 | **Baselines**
114 | We compare GNNDelete to several baselines
115 | - Retraining from scratch, please run the above unlearning scripts with `--unlearning_method retrain`
116 | - Gradient ascent, please run the above unlearning scripts with `--unlearning_method gradient_ascent`
117 | - Descent-to-Delete, please run the above unlearning scripts with `--unlearning_method descent_to_delete`
118 | - GraphEditor, please run the above unlearning scripts with `--unlearning_method graph_editor`
119 | - GraphEraser, please refer to the [official implementation](https://github.com/MinChen00/Graph-Unlearning)
120 | - Certified Graph Unlearning, please refer to the [official implementation](https://github.com/thupchnsky/sgc_unlearn)
121 |
122 |
123 | ## Citation
124 |
125 | If you find *GNNDelete* useful for your research, please consider citing this paper:
126 |
127 | ```
128 | @inproceedings{cheng2023gnndelete,
129 | title={{GNND}elete: A General Unlearning Strategy for Graph Neural Networks},
130 | author={Jiali Cheng and George Dasoulas and Huan He and Chirag Agarwal and Marinka Zitnik},
131 | booktitle={International Conference on Learning Representations},
132 | year={2023},
133 | url={https://openreview.net/forum?id=X9yCkmT5Qrl}
134 | }
135 | ```
136 |
137 |
138 | ## Miscellaneous
139 |
140 | Please send any questions you might have about the code and/or the algorithm to .
141 |
142 |
143 |
144 | ## License
145 |
146 | GNNDelete codebase is released under the MIT license.
147 |
--------------------------------------------------------------------------------
/framework/trainer/member_infer.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import wandb
4 | import numpy as np
5 | import torch
6 | import torch.nn as nn
7 | from tqdm import trange, tqdm
8 | from torch_geometric.utils import negative_sampling
9 | from sklearn.metrics import accuracy_score, roc_auc_score, average_precision_score, f1_score
10 |
11 | from .base import Trainer
12 | from ..evaluation import *
13 | from ..utils import *
14 |
15 |
16 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
17 |
18 | class MIAttackTrainer(Trainer):
19 | '''This code is adapted from https://github.com/iyempissy/rebMIGraph'''
20 |
21 | def __init__(self, args):
22 | self.args = args
23 | self.trainer_log = {
24 | 'unlearning_model': 'member_infer',
25 | 'dataset': args.dataset,
26 | 'seed': args.random_seed,
27 | 'shadow_log': [],
28 | 'attack_log': []}
29 | self.logit_all_pair = None
30 |
31 | with open(os.path.join(self.args.checkpoint_dir, 'training_args.json'), 'w') as f:
32 | json.dump(vars(args), f)
33 |
34 | def train_shadow(self, model, data, optimizer, args):
35 | best_valid_loss = 1000000
36 |
37 | all_neg = []
38 | # Train shadow model using the test data
39 | for epoch in trange(args.epochs, desc='Train shadow model'):
40 | model.train()
41 |
42 | # Positive and negative sample
43 | neg_edge_index = negative_sampling(
44 | edge_index=data.test_pos_edge_index,
45 | num_nodes=data.num_nodes,
46 | num_neg_samples=data.test_pos_edge_index.shape[1])
47 |
48 | z = model(data.x, data.test_pos_edge_index)
49 | logits = model.decode(z, data.test_pos_edge_index, neg_edge_index)
50 | label = get_link_labels(data.test_pos_edge_index, neg_edge_index)
51 | loss = F.binary_cross_entropy_with_logits(logits, label)
52 |
53 | loss.backward()
54 | optimizer.step()
55 | optimizer.zero_grad()
56 |
57 | all_neg.append(neg_edge_index.cpu())
58 |
59 | if (epoch+1) % args.valid_freq == 0:
60 | valid_loss, auc, aup, df_logit, logit_all_pair = self.eval_shadow(model, data, 'val')
61 |
62 | log = {
63 | 'shadow_epoch': epoch,
64 | 'shadow_train_loss': loss.item(),
65 | 'shadow_valid_loss': valid_loss,
66 | 'shadow_valid_auc': auc,
67 | 'shadow_valid_aup': aup,
68 | 'shadow_df_logit': df_logit
69 | }
70 | wandb.log(log)
71 | self.trainer_log['shadow_log'].append(log)
72 |
73 | msg = [f'{i}: {j:>4d}' if isinstance(j, int) else f'{i}: {j:.4f}' for i, j in log.items()]
74 | tqdm.write(' | '.join(msg))
75 |
76 | if valid_loss < best_valid_loss:
77 | best_valid_loss = valid_loss
78 | best_epoch = epoch
79 |
80 | self.trainer_log['shadow_best_epoch'] = best_epoch
81 | self.trainer_log['shadow_best_valid_loss'] = best_valid_loss
82 |
83 | print(f'Save best checkpoint at epoch {epoch:04d}. Valid loss = {valid_loss:.4f}')
84 | ckpt = {
85 | 'model_state': model.state_dict(),
86 | 'optimizer_state': optimizer.state_dict(),
87 | }
88 | torch.save(ckpt, os.path.join(args.checkpoint_dir, 'shadow_model_best.pt'))
89 |
90 | return torch.cat(all_neg, dim=-1)
91 |
92 | @torch.no_grad()
93 | def eval_shadow(self, model, data, stage='val'):
94 | model.eval()
95 | pos_edge_index = data[f'{stage}_pos_edge_index']
96 | neg_edge_index = data[f'{stage}_neg_edge_index']
97 |
98 | z = model(data.x, data.val_pos_edge_index)
99 | logits = model.decode(z, pos_edge_index, neg_edge_index).sigmoid()
100 | label = self.get_link_labels(pos_edge_index, neg_edge_index)
101 |
102 | loss = F.binary_cross_entropy_with_logits(logits, label).cpu().item()
103 | auc = roc_auc_score(label.cpu(), logits.cpu())
104 | aup = average_precision_score(label.cpu(), logits.cpu())
105 | df_logit = float('nan')
106 |
107 | logit_all_pair = (z @ z.t()).cpu()
108 |
109 | log = {
110 | f'{stage}_loss': loss,
111 | f'{stage}_auc': auc,
112 | f'{stage}_aup': aup,
113 | f'{stage}_df_logit': df_logit,
114 | }
115 | wandb.log(log)
116 | msg = [f'{i}: {j:.4f}' if isinstance(j, (np.floating, float)) else f'{i}: {j:>4d}' for i, j in log.items()]
117 | tqdm.write(' | '.join(msg))
118 |
119 | return loss, auc, aup, df_logit, logit_all_pair
120 |
121 | def train_attack(self, model, train_loader, valid_loader, optimizer, args):
122 | loss_fct = nn.CrossEntropyLoss()
123 | best_auc = 0
124 | best_epoch = 0
125 | for epoch in trange(50, desc='Train attack model'):
126 | model.train()
127 |
128 | train_loss = 0
129 | for x, y in train_loader:
130 | logits = model(x.to(device))
131 | loss = loss_fct(logits, y.to(device))
132 |
133 | loss.backward()
134 | optimizer.step()
135 | optimizer.zero_grad()
136 |
137 | train_loss += loss.item()
138 |
139 | valid_loss, valid_acc, valid_auc, valid_f1 = self.eval_attack(model, valid_loader)
140 |
141 | log = {
142 | 'attack_train_loss': train_loss / len(train_loader),
143 | 'attack_valid_loss': valid_loss,
144 | 'attack_valid_acc': valid_acc,
145 | 'attack_valid_auc': valid_auc,
146 | 'attack_valid_f1': valid_f1}
147 | wandb.log(log)
148 | self.trainer_log['attack_log'].append(log)
149 |
150 | msg = [f'{i}: {j:>4d}' if isinstance(j, int) else f'{i}: {j:.4f}' for i, j in log.items()]
151 | tqdm.write(' | '.join(msg))
152 |
153 |
154 | if valid_auc > best_auc:
155 | best_auc = valid_auc
156 | best_epoch = epoch
157 | self.trainer_log['attack_best_auc'] = valid_auc
158 | self.trainer_log['attack_best_epoch'] = epoch
159 |
160 | ckpt = {
161 | 'model_state': model.state_dict(),
162 | 'optimizer_state': optimizer.state_dict(),
163 | }
164 | torch.save(ckpt, os.path.join(args.checkpoint_dir, 'attack_model_best.pt'))
165 |
166 | @torch.no_grad()
167 | def eval_attack(self, model, eval_loader):
168 | loss_fct = nn.CrossEntropyLoss()
169 | pred = []
170 | label = []
171 | for x, y in eval_loader:
172 | logits = model(x.to(device))
173 | loss = loss_fct(logits, y.to(device))
174 | _, p = torch.max(logits, 1)
175 |
176 | pred.extend(p.cpu())
177 | label.extend(y)
178 |
179 | pred = torch.stack(pred)
180 | label = torch.stack(label)
181 |
182 | return loss.item(), accuracy_score(label.numpy(), pred.numpy()), roc_auc_score(label.numpy(), pred.numpy()), f1_score(label.numpy(), pred.numpy(), average='macro')
183 |
184 | @torch.no_grad()
185 | def prepare_attack_training_data(self, model, data, all_neg=None):
186 | '''Prepare the training data of attack model (Present vs. Absent)
187 | Present edges (label = 1): training data of shadow model (Test pos and neg edges)
188 | Absent edges (label = 0): validation data of shadow model (Valid pos and neg edges)
189 | '''
190 |
191 | z = model(data.x, data.test_pos_edge_index)
192 |
193 | # Sample same size of neg as pos
194 | sample_idx = torch.randperm(all_neg.shape[1])[:data.test_pos_edge_index.shape[1]]
195 | neg_subset = all_neg[:, sample_idx]
196 |
197 | present_edge_index = torch.cat([data.test_pos_edge_index, data.test_neg_edge_index], dim=-1)
198 |
199 | if 'sub' in self.args.unlearning_model:
200 | absent_edge_index = torch.cat([data.val_pos_edge_index, data.val_neg_edge_index], dim=-1)
201 | else: #if 'all' in self.args.unlearning_model:
202 | absent_edge_index = torch.cat([data.val_pos_edge_index, data.val_neg_edge_index, data.train_pos_edge_index, neg_subset.to(device)], dim=-1)
203 |
204 | edge_index = torch.cat([present_edge_index, absent_edge_index], dim=-1)
205 |
206 | feature = torch.cat([z[edge_index[0]], z[edge_index[1]]], dim=-1).cpu()
207 | label = get_link_labels(present_edge_index, absent_edge_index).long().cpu()
208 |
209 | return feature, label
210 |
--------------------------------------------------------------------------------
/delete_gnn.py:
--------------------------------------------------------------------------------
1 | import os
2 | import copy
3 | import json
4 | import wandb
5 | import pickle
6 | import argparse
7 | import torch
8 | import torch.nn as nn
9 | from torch_geometric.utils import to_undirected, to_networkx, k_hop_subgraph, is_undirected
10 | from torch_geometric.data import Data
11 | from torch_geometric.loader import GraphSAINTRandomWalkSampler
12 | from torch_geometric.seed import seed_everything
13 |
14 | from framework import get_model, get_trainer
15 | from framework.models.gcn import GCN
16 | from framework.training_args import parse_args
17 | from framework.utils import *
18 | from train_mi import MLPAttacker
19 |
20 |
21 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
22 |
23 | def load_args(path):
24 | with open(path, 'r') as f:
25 | d = json.load(f)
26 | parser = argparse.ArgumentParser()
27 | for k, v in d.items():
28 | parser.add_argument('--' + k, default=v)
29 | try:
30 | parser.add_argument('--df_size', default=0.5)
31 | except:
32 | pass
33 | args = parser.parse_args()
34 |
35 | for k, v in d.items():
36 | setattr(args, k, v)
37 |
38 | return args
39 |
40 | @torch.no_grad()
41 | def get_node_embedding(model, data):
42 | model.eval()
43 | node_embedding = model(data.x.to(device), data.edge_index.to(device))
44 |
45 | return node_embedding
46 |
47 | @torch.no_grad()
48 | def get_output(model, node_embedding, data):
49 | model.eval()
50 | node_embedding = node_embedding.to(device)
51 | edge = data.edge_index.to(device)
52 | output = model.decode(node_embedding, edge, edge_type)
53 |
54 | return output
55 |
56 | torch.autograd.set_detect_anomaly(True)
57 | def main():
58 | args = parse_args()
59 | original_path = os.path.join(args.checkpoint_dir, args.dataset, args.gnn, 'original', str(args.random_seed))
60 | attack_path_all = os.path.join(args.checkpoint_dir, args.dataset, 'member_infer_all', str(args.random_seed))
61 | attack_path_sub = os.path.join(args.checkpoint_dir, args.dataset, 'member_infer_sub', str(args.random_seed))
62 | seed_everything(args.random_seed)
63 |
64 | if 'gnndelete' in args.unlearning_model:
65 | args.checkpoint_dir = os.path.join(
66 | args.checkpoint_dir, args.dataset, args.gnn, args.unlearning_model,
67 | '-'.join([str(i) for i in [args.loss_fct, args.loss_type, args.alpha, args.neg_sample_random]]),
68 | '-'.join([str(i) for i in [args.df, args.df_size, args.random_seed]]))
69 | else:
70 | args.checkpoint_dir = os.path.join(
71 | args.checkpoint_dir, args.dataset, args.gnn, args.unlearning_model,
72 | '-'.join([str(i) for i in [args.df, args.df_size, args.random_seed]]))
73 | os.makedirs(args.checkpoint_dir, exist_ok=True)
74 |
75 | # Dataset
76 | with open(os.path.join(args.data_dir, args.dataset, f'd_{args.random_seed}.pkl'), 'rb') as f:
77 | dataset, data = pickle.load(f)
78 | print('Directed dataset:', dataset, data)
79 | if args.gnn not in ['rgcn', 'rgat']:
80 | args.in_dim = dataset.num_features
81 |
82 | print('Training args', args)
83 | wandb.init(config=args)
84 |
85 | # Df and Dr
86 | assert args.df != 'none'
87 |
88 | if args.df_size >= 100: # df_size is number of nodes/edges to be deleted
89 | df_size = int(args.df_size)
90 | else: # df_size is the ratio
91 | df_size = int(args.df_size / 100 * data.train_pos_edge_index.shape[1])
92 | print(f'Original size: {data.train_pos_edge_index.shape[1]:,}')
93 | print(f'Df size: {df_size:,}')
94 |
95 | df_mask_all = torch.load(os.path.join(args.data_dir, args.dataset, f'df_{args.random_seed}.pt'))[args.df]
96 | df_nonzero = df_mask_all.nonzero().squeeze()
97 |
98 | idx = torch.randperm(df_nonzero.shape[0])[:df_size]
99 | df_global_idx = df_nonzero[idx]
100 |
101 | print('Deleting the following edges:', df_global_idx)
102 |
103 | # df_idx = [int(i) for i in args.df_idx.split(',')]
104 | # df_idx_global = df_mask.nonzero()[df_idx]
105 |
106 | dr_mask = torch.ones(data.train_pos_edge_index.shape[1], dtype=torch.bool)
107 | dr_mask[df_global_idx] = False
108 |
109 | df_mask = torch.zeros(data.train_pos_edge_index.shape[1], dtype=torch.bool)
110 | df_mask[df_global_idx] = True
111 |
112 | # For testing
113 | data.directed_df_edge_index = data.train_pos_edge_index[:, df_mask]
114 | if args.gnn in ['rgcn', 'rgat']:
115 | data.directed_df_edge_type = data.train_edge_type[df_mask]
116 |
117 |
118 | # data.dr_mask = dr_mask
119 | # data.df_mask = df_mask
120 | # data.edge_index = data.train_pos_edge_index[:, dr_mask]
121 |
122 | # assert df_mask.sum() == len(df_global_idx)
123 | # assert dr_mask.shape[0] - len(df_global_idx) == data.train_pos_edge_index[:, dr_mask].shape[1]
124 | # data.dtrain_mask = dr_mask
125 |
126 |
127 | # Edges in S_Df
128 | _, two_hop_edge, _, two_hop_mask = k_hop_subgraph(
129 | data.train_pos_edge_index[:, df_mask].flatten().unique(),
130 | 2,
131 | data.train_pos_edge_index,
132 | num_nodes=data.num_nodes)
133 | data.sdf_mask = two_hop_mask
134 |
135 | # Nodes in S_Df
136 | _, one_hop_edge, _, one_hop_mask = k_hop_subgraph(
137 | data.train_pos_edge_index[:, df_mask].flatten().unique(),
138 | 1,
139 | data.train_pos_edge_index,
140 | num_nodes=data.num_nodes)
141 | sdf_node_1hop = torch.zeros(data.num_nodes, dtype=torch.bool)
142 | sdf_node_2hop = torch.zeros(data.num_nodes, dtype=torch.bool)
143 |
144 | sdf_node_1hop[one_hop_edge.flatten().unique()] = True
145 | sdf_node_2hop[two_hop_edge.flatten().unique()] = True
146 |
147 | assert sdf_node_1hop.sum() == len(one_hop_edge.flatten().unique())
148 | assert sdf_node_2hop.sum() == len(two_hop_edge.flatten().unique())
149 |
150 | data.sdf_node_1hop_mask = sdf_node_1hop
151 | data.sdf_node_2hop_mask = sdf_node_2hop
152 |
153 |
154 | # To undirected for message passing
155 | # print(is_undir0.0175ected(data.train_pos_edge_index), data.train_pos_edge_index.shape, two_hop_mask.shape, df_mask.shape, two_hop_mask.shape)
156 | assert not is_undirected(data.train_pos_edge_index)
157 |
158 | if args.gnn in ['rgcn', 'rgat']:
159 | r, c = data.train_pos_edge_index
160 | rev_edge_index = torch.stack([c, r], dim=0)
161 | rev_edge_type = data.train_edge_type + args.num_edge_type
162 |
163 | data.edge_index = torch.cat((data.train_pos_edge_index, rev_edge_index), dim=1)
164 | data.edge_type = torch.cat([data.train_edge_type, rev_edge_type], dim=0)
165 |
166 | if hasattr(data, 'train_mask'):
167 | data.train_mask = data.train_mask.repeat(2).view(-1)
168 |
169 | two_hop_mask = two_hop_mask.repeat(2).view(-1)
170 | df_mask = df_mask.repeat(2).view(-1)
171 | dr_mask = dr_mask.repeat(2).view(-1)
172 | assert is_undirected(data.edge_index)
173 |
174 | else:
175 | train_pos_edge_index, [df_mask, two_hop_mask] = to_undirected(data.train_pos_edge_index, [df_mask.int(), two_hop_mask.int()])
176 | two_hop_mask = two_hop_mask.bool()
177 | df_mask = df_mask.bool()
178 | dr_mask = ~df_mask
179 |
180 | data.train_pos_edge_index = train_pos_edge_index
181 | data.edge_index = train_pos_edge_index
182 | assert is_undirected(data.train_pos_edge_index)
183 |
184 |
185 | print('Undirected dataset:', data)
186 |
187 | data.sdf_mask = two_hop_mask
188 | data.df_mask = df_mask
189 | data.dr_mask = dr_mask
190 | # data.dtrain_mask = dr_mask
191 | # print(is_undirected(train_pos_edge_index), train_pos_edge_index.shape, two_hop_mask.shape, df_mask.shape, two_hop_mask.shape)
192 | # print(is_undirected(data.train_pos_edge_index), data.train_pos_edge_index.shape, data.df_mask.shape, )
193 | # raise
194 |
195 | # Model
196 | model = get_model(args, sdf_node_1hop, sdf_node_2hop, num_nodes=data.num_nodes, num_edge_type=args.num_edge_type)
197 |
198 | if args.unlearning_model != 'retrain': # Start from trained GNN model
199 | if os.path.exists(os.path.join(original_path, 'pred_proba.pt')):
200 | logits_ori = torch.load(os.path.join(original_path, 'pred_proba.pt'))
201 | if logits_ori is not None:
202 | logits_ori = logits_ori.to(device)
203 | else:
204 | logits_ori = None
205 |
206 | model_ckpt = torch.load(os.path.join(original_path, 'model_best.pt'), map_location=device)
207 | model.load_state_dict(model_ckpt['model_state'], strict=False)
208 |
209 | else: # Initialize a new GNN model
210 | retrain = None
211 | logits_ori = None
212 |
213 | model = model.to(device)
214 |
215 | if 'gnndelete' in args.unlearning_model and 'nodeemb' in args.unlearning_model:
216 | parameters_to_optimize = [
217 | {'params': [p for n, p in model.named_parameters() if 'del' in n], 'weight_decay': 0.0}
218 | ]
219 | print('parameters_to_optimize', [n for n, p in model.named_parameters() if 'del' in n])
220 |
221 | if 'layerwise' in args.loss_type:
222 | optimizer1 = torch.optim.Adam(model.deletion1.parameters(), lr=args.lr)
223 | optimizer2 = torch.optim.Adam(model.deletion2.parameters(), lr=args.lr)
224 | optimizer = [optimizer1, optimizer2]
225 | else:
226 | optimizer = torch.optim.Adam(parameters_to_optimize, lr=args.lr)
227 |
228 | else:
229 | if 'gnndelete' in args.unlearning_model:
230 | parameters_to_optimize = [
231 | {'params': [p for n, p in model.named_parameters() if 'del' in n], 'weight_decay': 0.0}
232 | ]
233 | print('parameters_to_optimize', [n for n, p in model.named_parameters() if 'del' in n])
234 |
235 | else:
236 | parameters_to_optimize = [
237 | {'params': [p for n, p in model.named_parameters()], 'weight_decay': 0.0}
238 | ]
239 | print('parameters_to_optimize', [n for n, p in model.named_parameters()])
240 |
241 | optimizer = torch.optim.Adam(parameters_to_optimize, lr=args.lr)#, weight_decay=args.weight_decay)
242 |
243 | wandb.watch(model, log_freq=100)
244 |
245 | # MI attack model
246 | attack_model_all = None
247 | # attack_model_all = MLPAttacker(args)
248 | # attack_ckpt = torch.load(os.path.join(attack_path_all, 'attack_model_best.pt'))
249 | # attack_model_all.load_state_dict(attack_ckpt['model_state'])
250 | # attack_model_all = attack_model_all.to(device)
251 |
252 | attack_model_sub = None
253 | # attack_model_sub = MLPAttacker(args)
254 | # attack_ckpt = torch.load(os.path.join(attack_path_sub, 'attack_model_best.pt'))
255 | # attack_model_sub.load_state_dict(attack_ckpt['model_state'])
256 | # attack_model_sub = attack_model_sub.to(device)
257 |
258 | # Train
259 | trainer = get_trainer(args)
260 | trainer.train(model, data, optimizer, args, logits_ori, attack_model_all, attack_model_sub)
261 |
262 | # Test
263 | if args.unlearning_model != 'retrain':
264 | retrain_path = os.path.join(
265 | 'checkpoint', args.dataset, args.gnn, 'retrain',
266 | '-'.join([str(i) for i in [args.df, args.df_size, args.random_seed]]),
267 | 'model_best.pt')
268 | if os.path.exists(retrain_path):
269 | retrain_ckpt = torch.load(retrain_path, map_location=device)
270 | retrain_args = copy.deepcopy(args)
271 | retrain_args.unlearning_model = 'retrain'
272 | retrain = get_model(retrain_args, num_nodes=data.num_nodes, num_edge_type=args.num_edge_type)
273 | retrain.load_state_dict(retrain_ckpt['model_state'])
274 | retrain = retrain.to(device)
275 | retrain.eval()
276 | else:
277 | retrain = None
278 | else:
279 | retrain = None
280 |
281 | test_results = trainer.test(model, data, model_retrain=retrain, attack_model_all=attack_model_all, attack_model_sub=attack_model_sub)
282 | print(test_results[-1])
283 | trainer.save_log()
284 |
285 |
286 | if __name__ == "__main__":
287 | main()
288 |
--------------------------------------------------------------------------------
/delete_node.py:
--------------------------------------------------------------------------------
1 | import os
2 | import copy
3 | import json
4 | import wandb
5 | import pickle
6 | import argparse
7 | import torch
8 | import torch.nn as nn
9 | from torch_geometric.utils import to_undirected, to_networkx, k_hop_subgraph, is_undirected
10 | from torch_geometric.data import Data
11 | import torch_geometric.transforms as T
12 | from torch_geometric.datasets import CitationFull, Coauthor, Flickr, RelLinkPredDataset, WordNet18, WordNet18RR
13 | from torch_geometric.loader import GraphSAINTRandomWalkSampler
14 | from torch_geometric.seed import seed_everything
15 |
16 | from framework import get_model, get_trainer
17 | from framework.models.gcn import GCN
18 | from framework.models.deletion import GCNDelete
19 | from framework.training_args import parse_args
20 | from framework.utils import *
21 | from framework.trainer.gnndelete_nodeemb import GNNDeleteNodeClassificationTrainer
22 | from train_mi import MLPAttacker
23 |
24 |
25 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
26 | torch.autograd.set_detect_anomaly(True)
27 |
28 | def to_directed(edge_index):
29 | row, col = edge_index
30 | mask = row < col
31 | return torch.cat([row[mask], col[mask]], dim=0)
32 |
33 | def main():
34 | args = parse_args()
35 | args.checkpoint_dir = 'checkpoint_node'
36 | args.dataset = 'DBLP'
37 | original_path = os.path.join(args.checkpoint_dir, args.dataset, args.gnn, 'original', str(args.random_seed))
38 | attack_path_all = os.path.join(args.checkpoint_dir, args.dataset, 'member_infer_all', str(args.random_seed))
39 | attack_path_sub = os.path.join(args.checkpoint_dir, args.dataset, 'member_infer_sub', str(args.random_seed))
40 | seed_everything(args.random_seed)
41 |
42 | if 'gnndelete' in args.unlearning_model:
43 | args.checkpoint_dir = os.path.join(
44 | args.checkpoint_dir, args.dataset, args.gnn, f'{args.unlearning_model}-node_deletion',
45 | '-'.join([str(i) for i in [args.loss_fct, args.loss_type, args.alpha, args.neg_sample_random]]),
46 | '-'.join([str(i) for i in [args.df, args.df_size, args.random_seed]]))
47 | else:
48 | args.checkpoint_dir = os.path.join(
49 | args.checkpoint_dir, args.dataset, args.gnn, f'{args.unlearning_model}-node_deletion',
50 | '-'.join([str(i) for i in [args.df, args.df_size, args.random_seed]]))
51 | os.makedirs(args.checkpoint_dir, exist_ok=True)
52 |
53 | # Dataset
54 | dataset = CitationFull(os.path.join(args.data_dir, args.dataset), args.dataset, transform=T.NormalizeFeatures())
55 | data = dataset[0]
56 | print('Original data', data)
57 |
58 | split = T.RandomNodeSplit()
59 | data = split(data)
60 | assert is_undirected(data.edge_index)
61 |
62 | print('Split data', data)
63 | args.in_dim = data.x.shape[1]
64 | args.out_dim = dataset.num_classes
65 |
66 | wandb.init(config=args)
67 |
68 | # Df and Dr
69 | if args.df_size >= 100: # df_size is number of nodes/edges to be deleted
70 | df_size = int(args.df_size)
71 | else: # df_size is the ratio
72 | df_size = int(args.df_size / 100 * data.train_pos_edge_index.shape[1])
73 | print(f'Original size: {data.num_nodes:,}')
74 | print(f'Df size: {df_size:,}')
75 |
76 | # Delete nodes
77 | df_nodes = torch.randperm(data.num_nodes)[:df_size]
78 | global_node_mask = torch.ones(data.num_nodes, dtype=torch.bool)
79 | global_node_mask[df_nodes] = False
80 |
81 | dr_mask_node = global_node_mask
82 | df_mask_node = ~global_node_mask
83 | assert df_mask_node.sum() == df_size
84 |
85 | # Delete edges associated with deleted nodes from training set
86 | res = [torch.eq(data.edge_index, aelem).logical_or_(torch.eq(data.edge_index, aelem)) for aelem in df_nodes]
87 | df_mask_edge = torch.any(torch.stack(res, dim=0), dim = 0)
88 | df_mask_edge = df_mask_edge.sum(0).bool()
89 | dr_mask_edge = ~df_mask_edge
90 |
91 | df_edge = data.edge_index[:, df_mask_edge]
92 | data.directed_df_edge_index = to_directed(df_edge)
93 | # print(df_edge.shape, directed_df_edge_index.shape)
94 | # raise
95 |
96 | print('Deleting the following nodes:', df_nodes)
97 |
98 | # # Delete edges associated with deleted nodes from valid and test set
99 | # res = [torch.eq(data.val_pos_edge_index, aelem).logical_or_(torch.eq(data.val_pos_edge_index, aelem)) for aelem in df_nodes]
100 | # mask = torch.any(torch.stack(res, dim=0), dim = 0)
101 | # mask = mask.sum(0).bool()
102 | # mask = ~mask
103 | # data.val_pos_edge_index = data.val_pos_edge_index[:, mask]
104 | # data.val_neg_edge_index = data.val_neg_edge_index[:, :data.val_pos_edge_index.shape[1]]
105 |
106 | # res = [torch.eq(data.test_pos_edge_index, aelem).logical_or_(torch.eq(data.test_pos_edge_index, aelem)) for aelem in df_nodes]
107 | # mask = torch.any(torch.stack(res, dim=0), dim = 0)
108 | # mask = mask.sum(0).bool()
109 | # mask = ~mask
110 | # data.test_pos_edge_index = data.test_pos_edge_index[:, mask]
111 | # data.test_neg_edge_index = data.test_neg_edge_index[:, :data.test_pos_edge_index.shape[1]]
112 |
113 |
114 | # For testing
115 | # data.directed_df_edge_index = data.train_pos_edge_index[:, df_mask_edge]
116 | # if args.gnn in ['rgcn', 'rgat']:
117 | # data.directed_df_edge_type = data.train_edge_type[df_mask]
118 |
119 | # Edges in S_Df
120 | _, two_hop_edge, _, two_hop_mask = k_hop_subgraph(
121 | data.edge_index[:, df_mask_edge].flatten().unique(),
122 | 2,
123 | data.edge_index,
124 | num_nodes=data.num_nodes)
125 |
126 | # Nodes in S_Df
127 | _, one_hop_edge, _, one_hop_mask = k_hop_subgraph(
128 | data.edge_index[:, df_mask_edge].flatten().unique(),
129 | 1,
130 | data.edge_index,
131 | num_nodes=data.num_nodes)
132 | sdf_node_1hop = torch.zeros(data.num_nodes, dtype=torch.bool)
133 | sdf_node_2hop = torch.zeros(data.num_nodes, dtype=torch.bool)
134 |
135 | sdf_node_1hop[one_hop_edge.flatten().unique()] = True
136 | sdf_node_2hop[two_hop_edge.flatten().unique()] = True
137 |
138 | assert sdf_node_1hop.sum() == len(one_hop_edge.flatten().unique())
139 | assert sdf_node_2hop.sum() == len(two_hop_edge.flatten().unique())
140 |
141 | data.sdf_node_1hop_mask = sdf_node_1hop
142 | data.sdf_node_2hop_mask = sdf_node_2hop
143 |
144 |
145 | # To undirected for message passing
146 | # print(is_undir0.0175ected(data.train_pos_edge_index), data.train_pos_edge_index.shape, two_hop_mask.shape, df_mask.shape, two_hop_mask.shape)
147 | # assert not is_undirected(data.edge_index)
148 | print(is_undirected(data.edge_index))
149 |
150 | if args.gnn in ['rgcn', 'rgat']:
151 | r, c = data.train_pos_edge_index
152 | rev_edge_index = torch.stack([c, r], dim=0)
153 | rev_edge_type = data.train_edge_type + args.num_edge_type
154 |
155 | data.edge_index = torch.cat((data.train_pos_edge_index, rev_edge_index), dim=1)
156 | data.edge_type = torch.cat([data.train_edge_type, rev_edge_type], dim=0)
157 | # data.train_mask = data.train_mask.repeat(2)
158 |
159 | two_hop_mask = two_hop_mask.repeat(2).view(-1)
160 | df_mask = df_mask.repeat(2).view(-1)
161 | dr_mask = dr_mask.repeat(2).view(-1)
162 | assert is_undirected(data.edge_index)
163 |
164 | else:
165 | # train_pos_edge_index, [df_mask, two_hop_mask] = to_undirected(data.train_pos_edge_index, [df_mask.int(), two_hop_mask.int()])
166 | two_hop_mask = two_hop_mask.bool()
167 | df_mask_edge = df_mask_edge.bool()
168 | dr_mask_edge = ~df_mask_edge
169 |
170 | # data.train_pos_edge_index = train_pos_edge_index
171 | # assert is_undirected(data.train_pos_edge_index)
172 |
173 |
174 | print('Undirected dataset:', data)
175 | # print(is_undirected(train_pos_edge_index), train_pos_edge_index.shape, two_hop_mask.shape, df_mask.shape, two_hop_mask.shape)
176 |
177 | data.sdf_mask = two_hop_mask
178 | data.df_mask = df_mask_edge
179 | data.dr_mask = dr_mask_edge
180 | data.dtrain_mask = dr_mask_edge
181 | # print(is_undirected(data.train_pos_edge_index), data.train_pos_edge_index.shape, data.two_hop_mask.shape, data.df_mask.shape, data.two_hop_mask.shape)
182 | # raise
183 |
184 | # Model
185 | model = GCNDelete(args)
186 | # model = get_model(args, sdf_node_1hop, sdf_node_2hop, num_nodes=data.num_nodes, num_edge_type=args.num_edge_type)
187 |
188 | if args.unlearning_model != 'retrain': # Start from trained GNN model
189 | if os.path.exists(os.path.join(original_path, 'pred_proba.pt')):
190 | logits_ori = torch.load(os.path.join(original_path, 'pred_proba.pt'))
191 | if logits_ori is not None:
192 | logits_ori = logits_ori.to(device)
193 | else:
194 | logits_ori = None
195 |
196 | model_ckpt = torch.load(os.path.join(original_path, 'model_best.pt'), map_location=device)
197 | model.load_state_dict(model_ckpt['model_state'], strict=False)
198 |
199 | else: # Initialize a new GNN model
200 | retrain = None
201 | logits_ori = None
202 |
203 | model = model.to(device)
204 |
205 | if 'gnndelete' in args.unlearning_model and 'nodeemb' in args.unlearning_model:
206 | parameters_to_optimize = [
207 | {'params': [p for n, p in model.named_parameters() if 'del' in n], 'weight_decay': 0.0}
208 | ]
209 | print('parameters_to_optimize', [n for n, p in model.named_parameters() if 'del' in n])
210 |
211 | if 'layerwise' in args.loss_type:
212 | optimizer1 = torch.optim.Adam(model.deletion1.parameters(), lr=args.lr)
213 | optimizer2 = torch.optim.Adam(model.deletion2.parameters(), lr=args.lr)
214 | optimizer = [optimizer1, optimizer2]
215 | else:
216 | optimizer = torch.optim.Adam(parameters_to_optimize, lr=args.lr)
217 |
218 | else:
219 | if 'gnndelete' in args.unlearning_model:
220 | parameters_to_optimize = [
221 | {'params': [p for n, p in model.named_parameters() if 'del' in n], 'weight_decay': 0.0}
222 | ]
223 | print('parameters_to_optimize', [n for n, p in model.named_parameters() if 'del' in n])
224 |
225 | else:
226 | parameters_to_optimize = [
227 | {'params': [p for n, p in model.named_parameters()], 'weight_decay': 0.0}
228 | ]
229 | print('parameters_to_optimize', [n for n, p in model.named_parameters()])
230 |
231 | optimizer = torch.optim.Adam(parameters_to_optimize, lr=args.lr)#, weight_decay=args.weight_decay)
232 |
233 | wandb.watch(model, log_freq=100)
234 |
235 | # MI attack model
236 | attack_model_all = None
237 | # attack_model_all = MLPAttacker(args)
238 | # attack_ckpt = torch.load(os.path.join(attack_path_all, 'attack_model_best.pt'))
239 | # attack_model_all.load_state_dict(attack_ckpt['model_state'])
240 | # attack_model_all = attack_model_all.to(device)
241 |
242 | attack_model_sub = None
243 | # attack_model_sub = MLPAttacker(args)
244 | # attack_ckpt = torch.load(os.path.join(attack_path_sub, 'attack_model_best.pt'))
245 | # attack_model_sub.load_state_dict(attack_ckpt['model_state'])
246 | # attack_model_sub = attack_model_sub.to(device)
247 |
248 | # Train
249 | trainer = GNNDeleteNodeClassificationTrainer(args)
250 | trainer.train(model, data, optimizer, args, logits_ori, attack_model_all, attack_model_sub)
251 |
252 | # Test
253 | if args.unlearning_model != 'retrain':
254 | retrain_path = os.path.join(
255 | 'checkpoint', args.dataset, args.gnn, 'retrain',
256 | '-'.join([str(i) for i in [args.df, args.df_size, args.random_seed]]))
257 | retrain_ckpt = torch.load(os.path.join(retrain_path, 'model_best.pt'), map_location=device)
258 | retrain_args = copy.deepcopy(args)
259 | retrain_args.unlearning_model = 'retrain'
260 | retrain = get_model(retrain_args, num_nodes=data.num_nodes, num_edge_type=args.num_edge_type)
261 | retrain.load_state_dict(retrain_ckpt['model_state'])
262 | retrain = retrain.to(device)
263 | retrain.eval()
264 |
265 | else:
266 | retrain = None
267 |
268 | trainer.test(model, data, model_retrain=retrain, attack_model_all=attack_model_all, attack_model_sub=attack_model_sub)
269 | trainer.save_log()
270 |
271 |
272 | if __name__ == "__main__":
273 | main()
274 |
--------------------------------------------------------------------------------
/delete_node_feature.py:
--------------------------------------------------------------------------------
1 | import os
2 | import copy
3 | import json
4 | import wandb
5 | import pickle
6 | import argparse
7 | import torch
8 | import torch.nn as nn
9 | from torch_geometric.utils import to_undirected, to_networkx, k_hop_subgraph, is_undirected
10 | from torch_geometric.data import Data
11 | import torch_geometric.transforms as T
12 | from torch_geometric.datasets import CitationFull, Coauthor, Flickr, RelLinkPredDataset, WordNet18, WordNet18RR
13 | from torch_geometric.loader import GraphSAINTRandomWalkSampler
14 | from torch_geometric.seed import seed_everything
15 |
16 | from framework import get_model, get_trainer
17 | from framework.models.gcn import GCN
18 | from framework.models.deletion import GCNDelete
19 | from framework.training_args import parse_args
20 | from framework.utils import *
21 | from framework.trainer.gnndelete_nodeemb import GNNDeleteNodeClassificationTrainer
22 | from train_mi import MLPAttacker
23 |
24 |
25 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
26 | torch.autograd.set_detect_anomaly(True)
27 |
28 | def to_directed(edge_index):
29 | row, col = edge_index
30 | mask = row < col
31 | return torch.cat([row[mask], col[mask]], dim=0)
32 |
33 | def main():
34 | args = parse_args()
35 | args.checkpoint_dir = 'checkpoint_node_feature'
36 | args.dataset = 'DBLP'
37 | original_path = os.path.join(args.checkpoint_dir, args.dataset, args.gnn, 'original', str(args.random_seed))
38 | attack_path_all = os.path.join(args.checkpoint_dir, args.dataset, 'member_infer_all', str(args.random_seed))
39 | attack_path_sub = os.path.join(args.checkpoint_dir, args.dataset, 'member_infer_sub', str(args.random_seed))
40 | seed_everything(args.random_seed)
41 |
42 | if 'gnndelete' in args.unlearning_model:
43 | args.checkpoint_dir = os.path.join(
44 | args.checkpoint_dir, args.dataset, args.gnn, f'{args.unlearning_model}-node_deletion',
45 | '-'.join([str(i) for i in [args.loss_fct, args.loss_type, args.alpha, args.neg_sample_random]]),
46 | '-'.join([str(i) for i in [args.df, args.df_size, args.random_seed]]))
47 | else:
48 | args.checkpoint_dir = os.path.join(
49 | args.checkpoint_dir, args.dataset, args.gnn, f'{args.unlearning_model}-node_deletion',
50 | '-'.join([str(i) for i in [args.df, args.df_size, args.random_seed]]))
51 | os.makedirs(args.checkpoint_dir, exist_ok=True)
52 |
53 | # Dataset
54 | dataset = CitationFull(os.path.join(args.data_dir, args.dataset), args.dataset, transform=T.NormalizeFeatures())
55 | data = dataset[0]
56 | print('Original data', data)
57 |
58 | split = T.RandomNodeSplit()
59 | data = split(data)
60 | assert is_undirected(data.edge_index)
61 |
62 | print('Split data', data)
63 | args.in_dim = data.x.shape[1]
64 | args.out_dim = dataset.num_classes
65 |
66 | wandb.init(config=args)
67 |
68 | # Df and Dr
69 | if args.df_size >= 100: # df_size is number of nodes/edges to be deleted
70 | df_size = int(args.df_size)
71 | else: # df_size is the ratio
72 | df_size = int(args.df_size / 100 * data.train_pos_edge_index.shape[1])
73 | print(f'Original size: {data.num_nodes:,}')
74 | print(f'Df size: {df_size:,}')
75 |
76 | # Delete node feature
77 | df_nodes = torch.randperm(data.num_nodes)[:df_size]
78 | global_node_mask = torch.ones(data.num_nodes, dtype=torch.bool)
79 | # global_node_mask[df_nodes] = False
80 | data.x[df_nodes] = 0
81 | assert data.x[df_nodes].sum() == 0
82 |
83 | dr_mask_node = torch.ones(data.num_nodes, dtype=torch.bool)
84 | df_mask_node = ~global_node_mask
85 | # assert df_mask_node.sum() == df_size
86 |
87 | # Delete edges associated with deleted nodes from training set
88 | res = [torch.eq(data.edge_index, aelem).logical_or_(torch.eq(data.edge_index, aelem)) for aelem in df_nodes]
89 | df_mask_edge = torch.any(torch.stack(res, dim=0), dim = 0)
90 | df_mask_edge = df_mask_edge.sum(0).bool()
91 | dr_mask_edge = ~df_mask_edge
92 |
93 | df_edge = data.edge_index[:, df_mask_edge]
94 | data.directed_df_edge_index = to_directed(df_edge)
95 | # print(df_edge.shape, directed_df_edge_index.shape)
96 | # raise
97 |
98 | print('Deleting the following nodes:', df_nodes)
99 |
100 | # # Delete edges associated with deleted nodes from valid and test set
101 | # res = [torch.eq(data.val_pos_edge_index, aelem).logical_or_(torch.eq(data.val_pos_edge_index, aelem)) for aelem in df_nodes]
102 | # mask = torch.any(torch.stack(res, dim=0), dim = 0)
103 | # mask = mask.sum(0).bool()
104 | # mask = ~mask
105 | # data.val_pos_edge_index = data.val_pos_edge_index[:, mask]
106 | # data.val_neg_edge_index = data.val_neg_edge_index[:, :data.val_pos_edge_index.shape[1]]
107 |
108 | # res = [torch.eq(data.test_pos_edge_index, aelem).logical_or_(torch.eq(data.test_pos_edge_index, aelem)) for aelem in df_nodes]
109 | # mask = torch.any(torch.stack(res, dim=0), dim = 0)
110 | # mask = mask.sum(0).bool()
111 | # mask = ~mask
112 | # data.test_pos_edge_index = data.test_pos_edge_index[:, mask]
113 | # data.test_neg_edge_index = data.test_neg_edge_index[:, :data.test_pos_edge_index.shape[1]]
114 |
115 |
116 | # For testing
117 | # data.directed_df_edge_index = data.train_pos_edge_index[:, df_mask_edge]
118 | # if args.gnn in ['rgcn', 'rgat']:
119 | # data.directed_df_edge_type = data.train_edge_type[df_mask]
120 |
121 | # Edges in S_Df
122 | _, two_hop_edge, _, two_hop_mask = k_hop_subgraph(
123 | data.edge_index[:, df_mask_edge].flatten().unique(),
124 | 2,
125 | data.edge_index,
126 | num_nodes=data.num_nodes)
127 |
128 | # Nodes in S_Df
129 | _, one_hop_edge, _, one_hop_mask = k_hop_subgraph(
130 | data.edge_index[:, df_mask_edge].flatten().unique(),
131 | 1,
132 | data.edge_index,
133 | num_nodes=data.num_nodes)
134 | sdf_node_1hop = torch.zeros(data.num_nodes, dtype=torch.bool)
135 | sdf_node_2hop = torch.zeros(data.num_nodes, dtype=torch.bool)
136 |
137 | sdf_node_1hop[one_hop_edge.flatten().unique()] = True
138 | sdf_node_2hop[two_hop_edge.flatten().unique()] = True
139 |
140 | assert sdf_node_1hop.sum() == len(one_hop_edge.flatten().unique())
141 | assert sdf_node_2hop.sum() == len(two_hop_edge.flatten().unique())
142 |
143 | data.sdf_node_1hop_mask = sdf_node_1hop
144 | data.sdf_node_2hop_mask = sdf_node_2hop
145 |
146 |
147 | # To undirected for message passing
148 | # print(is_undir0.0175ected(data.train_pos_edge_index), data.train_pos_edge_index.shape, two_hop_mask.shape, df_mask.shape, two_hop_mask.shape)
149 | # assert not is_undirected(data.edge_index)
150 | print(is_undirected(data.edge_index))
151 |
152 | if args.gnn in ['rgcn', 'rgat']:
153 | r, c = data.train_pos_edge_index
154 | rev_edge_index = torch.stack([c, r], dim=0)
155 | rev_edge_type = data.train_edge_type + args.num_edge_type
156 |
157 | data.edge_index = torch.cat((data.train_pos_edge_index, rev_edge_index), dim=1)
158 | data.edge_type = torch.cat([data.train_edge_type, rev_edge_type], dim=0)
159 | # data.train_mask = data.train_mask.repeat(2)
160 |
161 | two_hop_mask = two_hop_mask.repeat(2).view(-1)
162 | df_mask = df_mask.repeat(2).view(-1)
163 | dr_mask = dr_mask.repeat(2).view(-1)
164 | assert is_undirected(data.edge_index)
165 |
166 | else:
167 | # train_pos_edge_index, [df_mask, two_hop_mask] = to_undirected(data.train_pos_edge_index, [df_mask.int(), two_hop_mask.int()])
168 | two_hop_mask = two_hop_mask.bool()
169 | df_mask_edge = df_mask_edge.bool()
170 | dr_mask_edge = ~df_mask_edge
171 |
172 | # data.train_pos_edge_index = train_pos_edge_index
173 | # assert is_undirected(data.train_pos_edge_index)
174 |
175 |
176 | print('Undirected dataset:', data)
177 | # print(is_undirected(train_pos_edge_index), train_pos_edge_index.shape, two_hop_mask.shape, df_mask.shape, two_hop_mask.shape)
178 |
179 | data.sdf_mask = two_hop_mask
180 | data.df_mask = df_mask_edge
181 | data.dr_mask = dr_mask_edge
182 | data.dtrain_mask = dr_mask_edge
183 | # print(is_undirected(data.train_pos_edge_index), data.train_pos_edge_index.shape, data.two_hop_mask.shape, data.df_mask.shape, data.two_hop_mask.shape)
184 | # raise
185 |
186 | # Model
187 | model = GCNDelete(args)
188 | # model = get_model(args, sdf_node_1hop, sdf_node_2hop, num_nodes=data.num_nodes, num_edge_type=args.num_edge_type)
189 |
190 | if args.unlearning_model != 'retrain': # Start from trained GNN model
191 | if os.path.exists(os.path.join(original_path, 'pred_proba.pt')):
192 | logits_ori = torch.load(os.path.join(original_path, 'pred_proba.pt'))
193 | if logits_ori is not None:
194 | logits_ori = logits_ori.to(device)
195 | else:
196 | logits_ori = None
197 |
198 | model_ckpt = torch.load(os.path.join(original_path, 'model_best.pt'), map_location=device)
199 | model.load_state_dict(model_ckpt['model_state'], strict=False)
200 |
201 | else: # Initialize a new GNN model
202 | retrain = None
203 | logits_ori = None
204 |
205 | model = model.to(device)
206 |
207 | if 'gnndelete' in args.unlearning_model and 'nodeemb' in args.unlearning_model:
208 | parameters_to_optimize = [
209 | {'params': [p for n, p in model.named_parameters() if 'del' in n], 'weight_decay': 0.0}
210 | ]
211 | print('parameters_to_optimize', [n for n, p in model.named_parameters() if 'del' in n])
212 |
213 | if 'layerwise' in args.loss_type:
214 | optimizer1 = torch.optim.Adam(model.deletion1.parameters(), lr=args.lr)
215 | optimizer2 = torch.optim.Adam(model.deletion2.parameters(), lr=args.lr)
216 | optimizer = [optimizer1, optimizer2]
217 | else:
218 | optimizer = torch.optim.Adam(parameters_to_optimize, lr=args.lr)
219 |
220 | else:
221 | if 'gnndelete' in args.unlearning_model:
222 | parameters_to_optimize = [
223 | {'params': [p for n, p in model.named_parameters() if 'del' in n], 'weight_decay': 0.0}
224 | ]
225 | print('parameters_to_optimize', [n for n, p in model.named_parameters() if 'del' in n])
226 |
227 | else:
228 | parameters_to_optimize = [
229 | {'params': [p for n, p in model.named_parameters()], 'weight_decay': 0.0}
230 | ]
231 | print('parameters_to_optimize', [n for n, p in model.named_parameters()])
232 |
233 | optimizer = torch.optim.Adam(parameters_to_optimize, lr=args.lr)#, weight_decay=args.weight_decay)
234 |
235 | wandb.watch(model, log_freq=100)
236 |
237 | # MI attack model
238 | attack_model_all = None
239 | # attack_model_all = MLPAttacker(args)
240 | # attack_ckpt = torch.load(os.path.join(attack_path_all, 'attack_model_best.pt'))
241 | # attack_model_all.load_state_dict(attack_ckpt['model_state'])
242 | # attack_model_all = attack_model_all.to(device)
243 |
244 | attack_model_sub = None
245 | # attack_model_sub = MLPAttacker(args)
246 | # attack_ckpt = torch.load(os.path.join(attack_path_sub, 'attack_model_best.pt'))
247 | # attack_model_sub.load_state_dict(attack_ckpt['model_state'])
248 | # attack_model_sub = attack_model_sub.to(device)
249 |
250 | # Train
251 | trainer = GNNDeleteNodeClassificationTrainer(args)
252 | trainer.train(model, data, optimizer, args, logits_ori, attack_model_all, attack_model_sub)
253 |
254 | # Test
255 | if args.unlearning_model != 'retrain':
256 | retrain_path = os.path.join(
257 | 'checkpoint', args.dataset, args.gnn, 'retrain',
258 | '-'.join([str(i) for i in [args.df, args.df_size, args.random_seed]]))
259 | retrain_ckpt = torch.load(os.path.join(retrain_path, 'model_best.pt'), map_location=device)
260 | retrain_args = copy.deepcopy(args)
261 | retrain_args.unlearning_model = 'retrain'
262 | retrain = get_model(retrain_args, num_nodes=data.num_nodes, num_edge_type=args.num_edge_type)
263 | retrain.load_state_dict(retrain_ckpt['model_state'])
264 | retrain = retrain.to(device)
265 | retrain.eval()
266 |
267 | else:
268 | retrain = None
269 |
270 | trainer.test(model, data, model_retrain=retrain, attack_model_all=attack_model_all, attack_model_sub=attack_model_sub)
271 | trainer.save_log()
272 |
273 |
274 | if __name__ == "__main__":
275 | main()
276 |
--------------------------------------------------------------------------------
/framework/trainer/gradient_ascent.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | import wandb
4 | from tqdm import tqdm, trange
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 | from torch_geometric.utils import negative_sampling
9 | from torch_geometric.loader import GraphSAINTRandomWalkSampler
10 |
11 | from .base import Trainer, KGTrainer
12 | from ..evaluation import *
13 | from ..utils import *
14 |
15 |
16 | device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
17 |
18 | def weight(model):
19 | t = 0
20 | for p in model.parameters():
21 | t += torch.norm(p)
22 |
23 | return t
24 |
25 | class GradientAscentTrainer(Trainer):
26 |
27 | def train(self, model, data, optimizer, args, logits_ori=None, attack_model_all=None, attack_model_sub=None):
28 | if 'ogbl' in self.args.dataset:
29 | return self.train_minibatch(model, data, optimizer, args, logits_ori, attack_model_all, attack_model_sub)
30 |
31 | else:
32 | return self.train_fullbatch(model, data, optimizer, args, logits_ori, attack_model_all, attack_model_sub)
33 |
34 | def train_fullbatch(self, model, data, optimizer, args, logits_ori=None, attack_model_all=None, attack_model_sub=None):
35 | model = model.to('cuda')
36 | data = data.to('cuda')
37 |
38 | start_time = time.time()
39 | best_metric = 0
40 |
41 | # MI Attack before unlearning
42 | if attack_model_all is not None:
43 | mi_logit_all_before, mi_sucrate_all_before = member_infer_attack(model, attack_model_all, data)
44 | self.trainer_log['mi_logit_all_before'] = mi_logit_all_before
45 | self.trainer_log['mi_sucrate_all_before'] = mi_sucrate_all_before
46 | if attack_model_sub is not None:
47 | mi_logit_sub_before, mi_sucrate_sub_before = member_infer_attack(model, attack_model_sub, data)
48 | self.trainer_log['mi_logit_sub_before'] = mi_logit_sub_before
49 | self.trainer_log['mi_sucrate_sub_before'] = mi_sucrate_sub_before
50 |
51 |
52 | for epoch in trange(args.epochs, desc='Unlerning'):
53 | model.train()
54 |
55 | start_time = time.time()
56 |
57 | # Positive and negative sample
58 | neg_edge_index = negative_sampling(
59 | edge_index=data.train_pos_edge_index[:, data.df_mask],
60 | num_nodes=data.num_nodes,
61 | num_neg_samples=data.df_mask.sum())
62 |
63 | z = model(data.x, data.train_pos_edge_index)
64 | logits = model.decode(z, data.train_pos_edge_index[:, data.df_mask])
65 | label = torch.ones_like(logits, dtype=torch.float, device='cuda')
66 | loss = -F.binary_cross_entropy_with_logits(logits, label)
67 |
68 | # print('aaaaaaaaaaaaaa', data.df_mask.sum(), weight(model))
69 |
70 | loss.backward()
71 | # torch.nn.utils.clip_grad_norm_(model.parameters(), 1)
72 | optimizer.step()
73 | optimizer.zero_grad()
74 |
75 | end_time = time.time()
76 | epoch_time = end_time - start_time
77 |
78 | step_log = {
79 | 'Epoch': epoch,
80 | 'train_loss': loss.item(),
81 | 'train_time': epoch_time
82 | }
83 | wandb.log(step_log)
84 | msg = [f'{i}: {j:>4d}' if isinstance(j, int) else f'{i}: {j:.4f}' for i, j in step_log.items()]
85 | tqdm.write(' | '.join(msg))
86 |
87 | if (epoch + 1) % self.args.valid_freq == 0:
88 | valid_loss, dt_auc, dt_aup, df_auc, df_aup, df_logit, logit_all_pair, valid_log = self.eval(model, data, 'val')
89 | valid_log['epoch'] = epoch
90 |
91 | train_log = {
92 | 'epoch': epoch,
93 | 'train_loss': loss.item(),
94 | 'train_time': epoch_time,
95 | }
96 |
97 | for log in [train_log, valid_log]:
98 | wandb.log(log)
99 | msg = [f'{i}: {j:>4d}' if isinstance(j, int) else f'{i}: {j:.4f}' for i, j in log.items()]
100 | tqdm.write(' | '.join(msg))
101 | self.trainer_log['log'].append(log)
102 |
103 | if dt_auc + df_auc > best_metric:
104 | best_metric = dt_auc + df_auc
105 | best_epoch = epoch
106 |
107 | print(f'Save best checkpoint at epoch {epoch:04d}. Valid loss = {valid_loss:.4f}')
108 | ckpt = {
109 | 'model_state': model.state_dict(),
110 | 'optimizer_state': optimizer.state_dict(),
111 | }
112 | torch.save(ckpt, os.path.join(args.checkpoint_dir, 'model_best.pt'))
113 |
114 | self.trainer_log['training_time'] = time.time() - start_time
115 |
116 | # Save
117 | ckpt = {
118 | 'model_state': {k: v.cpu() for k, v in model.state_dict().items()},
119 | 'optimizer_state': optimizer.state_dict(),
120 | }
121 | torch.save(ckpt, os.path.join(args.checkpoint_dir, 'model_best.pt'))
122 |
123 | def train_minibatch(self, model, data, optimizer, args, logits_ori=None, attack_model_all=None, attack_model_sub=None):
124 | best_metric = 0
125 |
126 | # MI Attack before unlearning
127 | if attack_model_all is not None:
128 | mi_logit_all_before, mi_sucrate_all_before = member_infer_attack(model, attack_model_all, data)
129 | self.trainer_log['mi_logit_all_before'] = mi_logit_all_before
130 | self.trainer_log['mi_sucrate_all_before'] = mi_sucrate_all_before
131 | if attack_model_sub is not None:
132 | mi_logit_sub_before, mi_sucrate_sub_before = member_infer_attack(model, attack_model_sub, data)
133 | self.trainer_log['mi_logit_sub_before'] = mi_logit_sub_before
134 | self.trainer_log['mi_sucrate_sub_before'] = mi_sucrate_sub_before
135 |
136 | data.edge_index = data.train_pos_edge_index
137 | data.node_id = torch.arange(data.x.shape[0])
138 | loader = GraphSAINTRandomWalkSampler(
139 | data, batch_size=args.batch_size, walk_length=2, num_steps=args.num_steps,
140 | )
141 | for epoch in trange(args.epochs, desc='Unlerning'):
142 | model.train()
143 |
144 | epoch_loss = 0
145 | epoch_time = 0
146 | for step, batch in enumerate(tqdm(loader, leave=False)):
147 | start_time = time.time()
148 | batch = batch.to(device)
149 |
150 | z = model(batch.x, batch.edge_index[:, batch.dr_mask])
151 |
152 | # Positive and negative sample
153 | neg_edge_index = negative_sampling(
154 | edge_index=batch.edge_index[:, batch.df_mask],
155 | num_nodes=z.size(0))
156 |
157 | logits = model.decode(z, batch.edge_index[:, batch.df_mask])
158 | label = torch.ones_like(logits, dtype=torch.float, device=device)
159 | loss = -F.binary_cross_entropy_with_logits(logits, label)
160 |
161 | loss.backward()
162 | # torch.nn.utils.clip_grad_norm_(model.parameters(), 1)
163 | optimizer.step()
164 | optimizer.zero_grad()
165 |
166 | end_time = time.time()
167 | epoch_loss += loss.item()
168 | epoch_time += end_time - start_time
169 |
170 | epoch_loss /= step
171 | epoch_time /= step
172 |
173 | if (epoch+1) % args.valid_freq == 0:
174 | valid_loss, dt_auc, dt_aup, df_auc, df_aup, df_logit, logit_all_pair, valid_log = self.eval(model, data, 'val')
175 |
176 | train_log = {
177 | 'epoch': epoch,
178 | 'train_loss': epoch_loss / step,
179 | 'train_time': epoch_time / step,
180 | }
181 |
182 | for log in [train_log, valid_log]:
183 | wandb.log(log)
184 | msg = [f'{i}: {j:>4d}' if isinstance(j, int) else f'{i}: {j:.4f}' for i, j in log.items()]
185 | tqdm.write(' | '.join(msg))
186 | self.trainer_log['log'].append(log)
187 |
188 | if dt_auc + df_auc > best_metric:
189 | best_metric = dt_auc + df_auc
190 | best_epoch = epoch
191 |
192 | print(f'Save best checkpoint at epoch {epoch:04d}. Valid loss = {valid_loss:.4f}')
193 | ckpt = {
194 | 'model_state': model.state_dict(),
195 | 'optimizer_state': optimizer.state_dict(),
196 | }
197 | torch.save(ckpt, os.path.join(args.checkpoint_dir, 'model_best.pt'))
198 | torch.save(z, os.path.join(args.checkpoint_dir, 'node_embeddings.pt'))
199 |
200 | # Save
201 | ckpt = {
202 | 'model_state': {k: v.to('cpu') for k, v in model.state_dict().items()},
203 | 'optimizer_state': optimizer.state_dict(),
204 | }
205 | torch.save(ckpt, os.path.join(args.checkpoint_dir, 'model_final.pt'))
206 |
207 | class KGGradientAscentTrainer(KGTrainer):
208 | def train(self, model, data, optimizer, args, logits_ori=None, attack_model_all=None, attack_model_sub=None):
209 | model = model.to(device)
210 | start_time = time.time()
211 | best_metric = 0
212 |
213 | # MI Attack before unlearning
214 | if attack_model_all is not None:
215 | mi_logit_all_before, mi_sucrate_all_before = member_infer_attack(model, attack_model_all, data)
216 | self.trainer_log['mi_logit_all_before'] = mi_logit_all_before
217 | self.trainer_log['mi_sucrate_all_before'] = mi_sucrate_all_before
218 | if attack_model_sub is not None:
219 | mi_logit_sub_before, mi_sucrate_sub_before = member_infer_attack(model, attack_model_sub, data)
220 | self.trainer_log['mi_logit_sub_before'] = mi_logit_sub_before
221 | self.trainer_log['mi_sucrate_sub_before'] = mi_sucrate_sub_before
222 |
223 | loader = GraphSAINTRandomWalkSampler(
224 | data, batch_size=128, walk_length=args.walk_length, num_steps=args.num_steps,
225 | )
226 | for epoch in trange(args.epochs, desc='Epoch'):
227 | model.train()
228 |
229 | epoch_loss = 0
230 | for step, batch in enumerate(tqdm(loader, desc='Step', leave=False)):
231 | batch = batch.to(device)
232 |
233 | # Message passing
234 | edge_index = batch.edge_index[:, batch.dr_mask]
235 | edge_type = batch.edge_type[batch.dr_mask]
236 | z = model(batch.x, edge_index, edge_type)
237 |
238 | # Positive and negative sample
239 | decoding_edge_index = batch.edge_index[:, batch.df_mask]
240 | decoding_edge_type = batch.edge_type[batch.df_mask]
241 | decoding_mask = (decoding_edge_type < args.num_edge_type) # Only select directed edges for link prediction
242 | decoding_edge_index = decoding_edge_index[:, decoding_mask]
243 | decoding_edge_type = decoding_edge_type[decoding_mask]
244 |
245 | logits = model.decode(z, decoding_edge_index, decoding_edge_type)
246 | label = torch.ones_like(logits, dtype=torch.float, device=device)
247 | loss = -F.binary_cross_entropy_with_logits(logits, label)
248 |
249 | loss.backward()
250 | torch.nn.utils.clip_grad_norm_(model.parameters(), 1)
251 | optimizer.step()
252 | optimizer.zero_grad()
253 |
254 | log = {
255 | 'epoch': epoch,
256 | 'step': step,
257 | 'train_loss': loss.item(),
258 | }
259 | wandb.log(log)
260 | # msg = [f'{i}: {j:>4d}' if isinstance(j, int) else f'{i}: {j:.4f}' for i, j in log.items()]
261 | # tqdm.write(' | '.join(msg))
262 |
263 | epoch_loss += loss.item()
264 |
265 | if (epoch + 1) % args.valid_freq == 0:
266 | valid_loss, dt_auc, dt_aup, df_auc, df_aup, df_logit, logit_all_pair, valid_log = self.eval(model, data, 'val')
267 |
268 | train_log = {
269 | 'epoch': epoch,
270 | 'train_loss': epoch_loss / step
271 | }
272 |
273 | for log in [train_log, valid_log]:
274 | wandb.log(log)
275 | msg = [f'{i}: {j:>4d}' if isinstance(j, int) else f'{i}: {j:.4f}' for i, j in log.items()]
276 | tqdm.write(' | '.join(msg))
277 |
278 | self.trainer_log['log'].append(train_log)
279 | self.trainer_log['log'].append(valid_log)
280 |
281 | if dt_auc + df_auc > best_metric:
282 | best_metric = dt_auc + df_auc
283 | best_epoch = epoch
284 |
285 | print(f'Save best checkpoint at epoch {epoch:04d}. Valid loss = {valid_loss:.4f}')
286 | ckpt = {
287 | 'model_state': model.state_dict(),
288 | 'optimizer_state': optimizer.state_dict(),
289 | }
290 | torch.save(ckpt, os.path.join(args.checkpoint_dir, 'model_best.pt'))
291 |
292 | self.trainer_log['training_time'] = time.time() - start_time
293 |
294 | # Save models and node embeddings
295 | print('Saving final checkpoint')
296 | ckpt = {
297 | 'model_state': model.state_dict(),
298 | 'optimizer_state': optimizer.state_dict(),
299 | }
300 | torch.save(ckpt, os.path.join(args.checkpoint_dir, 'model_final.pt'))
301 |
302 | print(f'Training finished. Best checkpoint at epoch = {best_epoch:04d}, best valid loss = {best_metric:.4f}')
303 |
304 | self.trainer_log['best_epoch'] = best_epoch
305 | self.trainer_log['best_metric'] = best_metric
306 | self.trainer_log['training_time'] = np.mean([i['epoch_time'] for i in self.trainer_log['log'] if 'epoch_time' in i])
307 |
--------------------------------------------------------------------------------
/framework/trainer/gnndelete_embdis.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | import wandb
4 | from tqdm import tqdm, trange
5 | import torch
6 | import torch.nn as nn
7 | from torch_geometric.utils import negative_sampling, k_hop_subgraph
8 | from torch_geometric.loader import GraphSAINTRandomWalkSampler
9 |
10 | from .base import Trainer
11 | from ..evaluation import *
12 | from ..utils import *
13 |
14 |
15 | def BoundedKLD(logits, truth):
16 | return 1 - torch.exp(-F.kl_div(F.log_softmax(logits, -1), truth.softmax(-1), None, None, 'batchmean'))
17 |
18 | class GNNDeleteEmbeddingDistanceTrainer(Trainer):
19 |
20 | def train(self, model, data, optimizer, args, logits_ori=None, attack_model_all=None, attack_model_sub=None):
21 | if 'ogbl' in self.args.dataset:
22 | return self.train_minibatch(model, data, optimizer, args, logits_ori, attack_model_all, attack_model_sub)
23 |
24 | else:
25 | return self.train_fullbatch(model, data, optimizer, args, logits_ori, attack_model_all, attack_model_sub)
26 |
27 | def train_fullbatch(self, model, data, optimizer, args, logits_ori=None, attack_model_all=None, attack_model_sub=None):
28 | model = model.to('cuda')
29 | data = data.to('cuda')
30 |
31 | best_metric = 0
32 | if 'kld' in args.unlearning_model:
33 | loss_fct = BoundedKLD
34 | else:
35 | loss_fct = nn.MSELoss()
36 | # neg_size = 10
37 |
38 | # MI Attack before unlearning
39 | if attack_model_all is not None:
40 | mi_logit_all_before, mi_sucrate_all_before = member_infer_attack(model, attack_model_all, data)
41 | self.trainer_log['mi_logit_all_before'] = mi_logit_all_before
42 | self.trainer_log['mi_sucrate_all_before'] = mi_sucrate_all_before
43 | if attack_model_sub is not None:
44 | mi_logit_sub_before, mi_sucrate_sub_before = member_infer_attack(model, attack_model_sub, data)
45 | self.trainer_log['mi_logit_sub_before'] = mi_logit_sub_before
46 | self.trainer_log['mi_sucrate_sub_before'] = mi_sucrate_sub_before
47 |
48 | # All node paris in S_Df without Df. For Local Causality
49 | ## S_Df all pair mask
50 | sdf_all_pair_mask = torch.zeros(data.num_nodes, data.num_nodes, dtype=torch.bool)
51 | idx = torch.combinations(torch.arange(data.num_nodes)[data.sdf_node_2hop_mask], with_replacement=True).t()
52 | sdf_all_pair_mask[idx[0], idx[1]] = True
53 | sdf_all_pair_mask[idx[1], idx[0]] = True
54 |
55 | # print(data.sdf_node_2hop_mask.sum())
56 | # print(sdf_all_pair_mask.nonzero())
57 | # print(data.train_pos_edge_index[:, data.df_mask][0], data.train_pos_edge_index[:, data.df_mask][1])
58 |
59 | assert sdf_all_pair_mask.sum().cpu() == data.sdf_node_2hop_mask.sum().cpu() * data.sdf_node_2hop_mask.sum().cpu()
60 |
61 | ## Remove Df itself
62 | sdf_all_pair_mask[data.train_pos_edge_index[:, data.df_mask][0], data.train_pos_edge_index[:, data.df_mask][1]] = False
63 | sdf_all_pair_mask[data.train_pos_edge_index[:, data.df_mask][1], data.train_pos_edge_index[:, data.df_mask][0]] = False
64 |
65 | ## Lower triangular mask
66 | idx = torch.tril_indices(data.num_nodes, data.num_nodes, -1)
67 | lower_mask = torch.zeros(data.num_nodes, data.num_nodes, dtype=torch.bool)
68 | lower_mask[idx[0], idx[1]] = True
69 |
70 | ## The final mask is the intersection
71 | sdf_all_pair_without_df_mask = sdf_all_pair_mask & lower_mask
72 |
73 | # print('aaaaaaaaaaaa', data.sdf_node_2hop_mask.sum(), a, sdf_all_pair_mask.sum())
74 | # print('aaaaaaaaaaaa', lower_mask.sum())
75 | # print('aaaaaaaaaaaa', sdf_all_pair_without_df_mask.sum())
76 | # print('aaaaaaaaaaaa', data.sdf_node_2hop_mask.sum())
77 | # assert sdf_all_pair_without_df_mask.sum() == \
78 | # data.sdf_node_2hop_mask.sum().cpu() * (data.sdf_node_2hop_mask.sum().cpu() - 1) // 2 - data.df_mask.sum().cpu()
79 |
80 |
81 | # Node representation for local causality
82 | with torch.no_grad():
83 | z1_ori, z2_ori = model.get_original_embeddings(data.x, data.train_pos_edge_index[:, data.dtrain_mask], return_all_emb=True)
84 |
85 | total_time = 0
86 | for epoch in trange(args.epochs, desc='Unlerning'):
87 | model.train()
88 | start_time = time.time()
89 |
90 | z1, z2 = model(data.x, data.train_pos_edge_index[:, data.sdf_mask], return_all_emb=True)
91 | print('current deletion weight', model.deletion1.deletion_weight.sum(), model.deletion2.deletion_weight.sum())
92 |
93 | # Effectiveness and Randomness
94 | neg_size = data.df_mask.sum()
95 | neg_edge_index = negative_sampling(
96 | edge_index=data.train_pos_edge_index,
97 | num_nodes=data.num_nodes,
98 | num_neg_samples=neg_size)
99 |
100 | df_logits = model.decode(z2, data.train_pos_edge_index[:, data.df_mask], neg_edge_index)
101 | loss_e = loss_fct(df_logits[:neg_size], df_logits[neg_size:])
102 | # df_logits = model.decode(
103 | # z,
104 | # data.train_pos_edge_index[:, data.df_mask].repeat(1, neg_size),
105 | # neg_edge_index).sigmoid()
106 |
107 | # loss_e = loss_fct(df_logits[:neg_size], df_logits[neg_size:])
108 | # print('df_logits', df_logits)
109 | # raise
110 |
111 | # Local causality
112 | if sdf_all_pair_without_df_mask.sum() != 0:
113 | loss_l = loss_fct(z1_ori[data.sdf_node_1hop_mask], z1[data.sdf_node_1hop_mask]) + \
114 | loss_fct(z2_ori[data.sdf_node_2hop_mask], z2[data.sdf_node_2hop_mask])
115 | print('local proba', loss_l.item())
116 |
117 | else:
118 | loss_l = torch.tensor(0)
119 | print('local proba', 0)
120 |
121 |
122 | alpha = 0.5
123 | if 'ablation_random' in self.args.unlearning_model:
124 | loss_l = torch.tensor(0)
125 | loss = loss_e
126 | elif 'ablation_locality' in self.args.unlearning_model:
127 | loss_e = torch.tensor(0)
128 | loss = loss_l
129 | else:
130 | loss = alpha * loss_e + (1 - alpha) * loss_l
131 |
132 | loss.backward()
133 | # torch.nn.utils.clip_grad_norm_(model.parameters(), 1)
134 | optimizer.step()
135 | optimizer.zero_grad()
136 |
137 | end_time = time.time()
138 |
139 | log = {
140 | 'epoch': epoch,
141 | 'train_loss': loss.item(),
142 | 'train_loss_l': loss_l.item(),
143 | 'train_loss_e': loss_e.item(),
144 | 'train_time': end_time - start_time,
145 | }
146 | # wandb.log(log)
147 | msg = [f'{i}: {j:>4d}' if isinstance(j, int) else f'{i}: {j:.4f}' for i, j in log.items()]
148 | tqdm.write(' | '.join(msg))
149 |
150 | if (epoch+1) % args.valid_freq == 0:
151 | valid_loss, dt_auc, dt_aup, df_auc, df_aup, df_logit, logit_all_pair, valid_log = self.eval(model, data, 'val')
152 |
153 | train_log = {
154 | 'epoch': epoch,
155 | 'train_loss': loss.item(),
156 | 'train_loss_l': loss_e.item(),
157 | 'train_loss_e': loss_l.item(),
158 | 'train_time': end_time - start_time,
159 | }
160 |
161 | for log in [train_log, valid_log]:
162 | wandb.log(log)
163 | msg = [f'{i}: {j:>4d}' if isinstance(j, int) else f'{i}: {j:.4f}' for i, j in log.items()]
164 | tqdm.write(' | '.join(msg))
165 | self.trainer_log['log'].append(log)
166 |
167 | if dt_auc + df_auc > best_metric:
168 | best_metric = dt_auc + df_auc
169 | best_epoch = epoch
170 |
171 | print(f'Save best checkpoint at epoch {epoch:04d}. Valid loss = {valid_loss:.4f}')
172 | ckpt = {
173 | 'model_state': model.state_dict(),
174 | 'optimizer_state': optimizer.state_dict(),
175 | }
176 | torch.save(ckpt, os.path.join(args.checkpoint_dir, 'model_best.pt'))
177 |
178 | # Save
179 | ckpt = {
180 | 'model_state': {k: v.to('cpu') for k, v in model.state_dict().items()},
181 | 'optimizer_state': optimizer.state_dict(),
182 | }
183 | torch.save(ckpt, os.path.join(args.checkpoint_dir, 'model_final.pt'))
184 |
185 | # Save
186 | ckpt = {
187 | 'model_state': {k: v.to('cpu') for k, v in model.state_dict().items()},
188 | 'optimizer_state': optimizer.state_dict(),
189 | }
190 | torch.save(ckpt, os.path.join(args.checkpoint_dir, 'model_best.pt'))
191 |
192 | def train_minibatch(self, model, data, optimizer, args, logits_ori=None, attack_model_all=None, attack_model_sub=None):
193 | start_time = time.time()
194 | best_loss = 100000
195 | if 'kld' in args.unlearning_model:
196 | loss_fct = BoundedKLD
197 | else:
198 | loss_fct = nn.MSELoss()
199 | # neg_size = 10
200 |
201 | # MI Attack before unlearning
202 | if attack_model_all is not None:
203 | mi_logit_all_before, mi_sucrate_all_before = member_infer_attack(model, attack_model_all, data)
204 | self.trainer_log['mi_logit_all_before'] = mi_logit_all_before
205 | self.trainer_log['mi_sucrate_all_before'] = mi_sucrate_all_before
206 | if attack_model_sub is not None:
207 | mi_logit_sub_before, mi_sucrate_sub_before = member_infer_attack(model, attack_model_sub, data)
208 | self.trainer_log['mi_logit_sub_before'] = mi_logit_sub_before
209 | self.trainer_log['mi_sucrate_sub_before'] = mi_sucrate_sub_before
210 |
211 | z_ori = self.get_embedding(model, data, on_cpu=True)
212 | z_ori_two_hop = z_ori[data.sdf_node_2hop_mask]
213 |
214 | data.edge_index = data.train_pos_edge_index
215 | data.node_id = torch.arange(data.x.shape[0])
216 | loader = GraphSAINTRandomWalkSampler(
217 | data, batch_size=args.batch_size, walk_length=2, num_steps=args.num_steps,
218 | )
219 | for epoch in trange(args.epochs, desc='Unlerning'):
220 | model.train()
221 |
222 | print('current deletion weight', model.deletion1.deletion_weight.sum(), model.deletion2.deletion_weight.sum())
223 |
224 | epoch_loss_e = 0
225 | epoch_loss_l = 0
226 | epoch_loss = 0
227 | for step, batch in enumerate(tqdm(loader, leave=False)):
228 | # print('data', batch)
229 | # print('two hop nodes', batch.sdf_node_2hop_mask.sum())
230 | batch = batch.to('cuda')
231 |
232 | train_pos_edge_index = batch.edge_index
233 | z = model(batch.x, train_pos_edge_index[:, batch.sdf_mask], batch.sdf_node_1hop_mask, batch.sdf_node_2hop_mask)
234 | z_two_hop = z[batch.sdf_node_2hop_mask]
235 |
236 | # Effectiveness and Randomness
237 | neg_size = batch.df_mask.sum()
238 | neg_edge_index = negative_sampling(
239 | edge_index=train_pos_edge_index,
240 | num_nodes=z.size(0),
241 | num_neg_samples=neg_size)
242 |
243 | df_logits = model.decode(z, train_pos_edge_index[:, batch.df_mask], neg_edge_index)
244 | loss_e = loss_fct(df_logits[:neg_size], df_logits[neg_size:])
245 |
246 | # Local causality
247 | mask = torch.zeros(data.x.shape[0], dtype=torch.bool)
248 | mask[batch.node_id[batch.sdf_node_2hop_mask]] = True
249 | z_ori_subset = z_ori[mask].to('cuda')
250 |
251 | # Only take the lower triangular part
252 | num_nodes = z_ori_subset.shape[0]
253 | idx = torch.tril_indices(num_nodes, num_nodes, -1)
254 | local_lower_mask = torch.zeros(num_nodes, num_nodes, dtype=torch.bool)
255 | local_lower_mask[idx[0], idx[1]] = True
256 |
257 | logits_ori = (z_ori_subset @ z_ori_subset.t())[local_lower_mask].sigmoid()
258 | logits = (z_two_hop @ z_two_hop.t())[local_lower_mask].sigmoid()
259 |
260 | loss_l = loss_fct(logits, logits_ori)
261 |
262 |
263 | alpha = 0.5
264 | if 'ablation_random' in self.args.unlearning_model:
265 | loss_l = torch.tensor(0)
266 | loss = loss_e
267 | elif 'ablation_locality' in self.args.unlearning_model:
268 | loss_e = torch.tensor(0)
269 | loss = loss_l
270 | else:
271 | loss = alpha * loss_e + (1 - alpha) * loss_l
272 |
273 | loss.backward()
274 | # torch.nn.utils.clip_grad_norm_(model.parameters(), 1)
275 | optimizer.step()
276 | optimizer.zero_grad()
277 |
278 | epoch_loss_e += loss_e.item()
279 | epoch_loss_l += loss_l.item()
280 | epoch_loss += loss.item()
281 |
282 | epoch_loss_e /= step
283 | epoch_loss_l /= step
284 | epoch_loss /= step
285 |
286 |
287 | if (epoch+1) % args.valid_freq == 0:
288 | valid_loss, auc, aup, df_logt, logit_all_pair = self.eval(model, data, 'val')
289 |
290 | log = {
291 | 'epoch': epoch,
292 | 'train_loss': epoch_loss,
293 | 'train_loss_e': epoch_loss_e,
294 | 'train_loss_l': epoch_loss_l,
295 | 'valid_dt_loss': valid_loss,
296 | 'valid_dt_auc': auc,
297 | 'valid_dt_aup': aup,
298 | }
299 | wandb.log(log)
300 | msg = [f'{i}: {j:>4d}' if isinstance(j, int) else f'{i}: {j:.4f}' for i, j in log.items()]
301 | tqdm.write(' | '.join(msg))
302 |
303 | self.trainer_log['log'].append(log)
304 |
305 | self.trainer_log['training_time'] = time.time() - start_time
306 |
307 | # Save
308 | ckpt = {
309 | 'model_state': {k: v.to('cpu') for k, v in model.state_dict().items()},
310 | 'optimizer_state': optimizer.state_dict(),
311 | }
312 | torch.save(ckpt, os.path.join(args.checkpoint_dir, 'model_best.pt'))
313 |
--------------------------------------------------------------------------------
/framework/trainer/retrain.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | import wandb
4 | from tqdm import tqdm, trange
5 | import numpy as np
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 | from torch_geometric.utils import negative_sampling
10 | from torch_geometric.loader import GraphSAINTRandomWalkSampler
11 |
12 | from .base import Trainer, KGTrainer
13 | from ..evaluation import *
14 | from ..utils import *
15 |
16 |
17 | device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
18 |
19 | class RetrainTrainer(Trainer):
20 |
21 | def freeze_unused_mask(self, model, edge_to_delete, subgraph, h):
22 | gradient_mask = torch.zeros_like(delete_model.operator)
23 |
24 | edges = subgraph[h]
25 | for s, t in edges:
26 | if s < t:
27 | gradient_mask[s, t] = 1
28 | gradient_mask = gradient_mask.to(device)
29 | model.operator.register_hook(lambda grad: grad.mul_(gradient_mask))
30 |
31 | def train(self, model, data, optimizer, args, logits_ori=None, attack_model_all=None, attack_model_sub=None):
32 | if 'ogbl' in self.args.dataset:
33 | return self.train_minibatch(model, data, optimizer, args, logits_ori, attack_model_all, attack_model_sub)
34 |
35 | else:
36 | return self.train_fullbatch(model, data, optimizer, args, logits_ori, attack_model_all, attack_model_sub)
37 |
38 | def train_fullbatch(self, model, data, optimizer, args, logits_ori=None, attack_model_all=None, attack_model_sub=None):
39 | model = model.to('cuda')
40 | data = data.to('cuda')
41 |
42 | best_metric = 0
43 | loss_fct = nn.MSELoss()
44 |
45 | # MI Attack before unlearning
46 | if attack_model_all is not None:
47 | mi_logit_all_before, mi_sucrate_all_before = member_infer_attack(model, attack_model_all, data)
48 | self.trainer_log['mi_logit_all_before'] = mi_logit_all_before
49 | self.trainer_log['mi_sucrate_all_before'] = mi_sucrate_all_before
50 | if attack_model_sub is not None:
51 | mi_logit_sub_before, mi_sucrate_sub_before = member_infer_attack(model, attack_model_sub, data)
52 | self.trainer_log['mi_logit_sub_before'] = mi_logit_sub_before
53 | self.trainer_log['mi_sucrate_sub_before'] = mi_sucrate_sub_before
54 |
55 | for epoch in trange(args.epochs, desc='Unlearning'):
56 | model.train()
57 |
58 | start_time = time.time()
59 | total_step = 0
60 | total_loss = 0
61 |
62 | neg_edge_index = negative_sampling(
63 | edge_index=data.train_pos_edge_index[:, data.dr_mask],
64 | num_nodes=data.num_nodes,
65 | num_neg_samples=data.dr_mask.sum())
66 |
67 | z = model(data.x, data.train_pos_edge_index[:, data.dr_mask])
68 | logits = model.decode(z, data.train_pos_edge_index[:, data.dr_mask], neg_edge_index)
69 | label = self.get_link_labels(data.train_pos_edge_index[:, data.dr_mask], neg_edge_index)
70 | loss = F.binary_cross_entropy_with_logits(logits, label)
71 |
72 | loss.backward()
73 | # torch.nn.utils.clip_grad_norm_(model.parameters(), 1)
74 | optimizer.step()
75 | optimizer.zero_grad()
76 |
77 | total_step += 1
78 | total_loss += loss.item()
79 |
80 | end_time = time.time()
81 | epoch_time = end_time - start_time
82 |
83 | step_log = {
84 | 'Epoch': epoch,
85 | 'train_loss': loss.item(),
86 | 'train_time': epoch_time
87 | }
88 | wandb.log(step_log)
89 | msg = [f'{i}: {j:>4d}' if isinstance(j, int) else f'{i}: {j:.4f}' for i, j in step_log.items()]
90 | tqdm.write(' | '.join(msg))
91 |
92 | if (epoch + 1) % self.args.valid_freq == 0:
93 | valid_loss, dt_auc, dt_aup, df_auc, df_aup, df_logit, logit_all_pair, valid_log = self.eval(model, data, 'val')
94 | valid_log['epoch'] = epoch
95 |
96 | train_log = {
97 | 'epoch': epoch,
98 | 'train_loss': loss.item(),
99 | 'train_time': epoch_time,
100 | }
101 |
102 | for log in [train_log, valid_log]:
103 | wandb.log(log)
104 | msg = [f'{i}: {j:>4d}' if isinstance(j, int) else f'{i}: {j:.4f}' for i, j in log.items()]
105 | tqdm.write(' | '.join(msg))
106 | self.trainer_log['log'].append(log)
107 |
108 | if dt_auc + df_auc > best_metric:
109 | best_metric = dt_auc + df_auc
110 | best_epoch = epoch
111 |
112 | print(f'Save best checkpoint at epoch {epoch:04d}. Valid loss = {valid_loss:.4f}')
113 | ckpt = {
114 | 'model_state': model.state_dict(),
115 | 'optimizer_state': optimizer.state_dict(),
116 | }
117 | torch.save(ckpt, os.path.join(args.checkpoint_dir, 'model_best.pt'))
118 |
119 | self.trainer_log['training_time'] = time.time() - start_time
120 |
121 | # Save
122 | ckpt = {
123 | 'model_state': model.state_dict(),
124 | 'optimizer_state': optimizer.state_dict(),
125 | }
126 | torch.save(ckpt, os.path.join(args.checkpoint_dir, 'model_final.pt'))
127 |
128 | print(f'Training finished. Best checkpoint at epoch = {best_epoch:04d}, best metric = {best_metric:.4f}')
129 |
130 | self.trainer_log['best_epoch'] = best_epoch
131 | self.trainer_log['best_metric'] = best_metric
132 |
133 |
134 | def train_minibatch(self, model, data, optimizer, args, logits_ori=None, attack_model_all=None, attack_model_sub=None):
135 | start_time = time.time()
136 | best_metric = 0
137 |
138 | # MI Attack before unlearning
139 | if attack_model_all is not None:
140 | mi_logit_all_before, mi_sucrate_all_before = member_infer_attack(model, attack_model_all, data)
141 | self.trainer_log['mi_logit_all_before'] = mi_logit_all_before
142 | self.trainer_log['mi_sucrate_all_before'] = mi_sucrate_all_before
143 | if attack_model_sub is not None:
144 | mi_logit_sub_before, mi_sucrate_sub_before = member_infer_attack(model, attack_model_sub, data)
145 | self.trainer_log['mi_logit_sub_before'] = mi_logit_sub_before
146 | self.trainer_log['mi_sucrate_sub_before'] = mi_sucrate_sub_before
147 |
148 | data.edge_index = data.train_pos_edge_index
149 | loader = GraphSAINTRandomWalkSampler(
150 | data, batch_size=args.batch_size, walk_length=2, num_steps=args.num_steps,
151 | )
152 | for epoch in trange(args.epochs, desc='Epoch'):
153 | model.train()
154 |
155 | start_time = time.time()
156 | epoch_loss = 0
157 | for step, batch in enumerate(tqdm(loader, desc='Step', leave=False)):
158 | batch = batch.to(device)
159 |
160 | # Positive and negative sample
161 | train_pos_edge_index = batch.edge_index[:, batch.dr_mask]
162 | z = model(batch.x, train_pos_edge_index)
163 |
164 | neg_edge_index = negative_sampling(
165 | edge_index=train_pos_edge_index,
166 | num_nodes=z.size(0))
167 |
168 | logits = model.decode(z, train_pos_edge_index, neg_edge_index)
169 | label = get_link_labels(train_pos_edge_index, neg_edge_index)
170 | loss = F.binary_cross_entropy_with_logits(logits, label)
171 |
172 | loss.backward()
173 | # torch.nn.utils.clip_grad_norm_(model.parameters(), 1)
174 | optimizer.step()
175 | optimizer.zero_grad()
176 |
177 | step_log = {
178 | 'epoch': epoch,
179 | 'step': step,
180 | 'train_loss': loss.item(),
181 | }
182 | wandb.log(step_log)
183 | msg = [f'{i}: {j:>4d}' if isinstance(j, int) else f'{i}: {j:.4f}' for i, j in step_log.items()]
184 | tqdm.write(' | '.join(msg))
185 |
186 | epoch_loss += loss.item()
187 |
188 | end_time = time.time()
189 | epoch_time = end_time - start_time
190 |
191 | if (epoch+1) % args.valid_freq == 0:
192 | valid_loss, dt_auc, dt_aup, df_auc, df_aup, df_logit, logit_all_pair, valid_log = self.eval(model, data, 'val')
193 | valid_log['epoch'] = epoch
194 |
195 | train_log = {
196 | 'epoch': epoch,
197 | 'train_loss': loss.item(),
198 | 'train_time': epoch_time,
199 | }
200 |
201 | for log in [train_log, valid_log]:
202 | wandb.log(log)
203 | msg = [f'{i}: {j:>4d}' if isinstance(j, int) else f'{i}: {j:.4f}' for i, j in log.items()]
204 | tqdm.write(' | '.join(msg))
205 | self.trainer_log['log'].append(log)
206 |
207 | if dt_auc + df_auc > best_metric:
208 | best_metric = dt_auc + df_auc
209 | best_epoch = epoch
210 |
211 | print(f'Save best checkpoint at epoch {epoch:04d}. Valid loss = {valid_loss:.4f}')
212 | ckpt = {
213 | 'model_state': model.state_dict(),
214 | 'optimizer_state': optimizer.state_dict(),
215 | }
216 | torch.save(ckpt, os.path.join(args.checkpoint_dir, 'model_best.pt'))
217 | torch.save(z, os.path.join(args.checkpoint_dir, 'node_embeddings.pt'))
218 |
219 | self.trainer_log['training_time'] = time.time() - start_time
220 |
221 | # Save models and node embeddings
222 | print('Saving final checkpoint')
223 | ckpt = {
224 | 'model_state': model.state_dict(),
225 | 'optimizer_state': optimizer.state_dict(),
226 | }
227 | torch.save(ckpt, os.path.join(args.checkpoint_dir, 'model_final.pt'))
228 |
229 | print(f'Training finished. Best checkpoint at epoch = {best_epoch:04d}, best metric = {best_metric:.4f}')
230 |
231 | self.trainer_log['best_epoch'] = best_epoch
232 | self.trainer_log['best_metric'] = best_metric
233 |
234 |
235 | class KGRetrainTrainer(KGTrainer):
236 | def train(self, model, data, optimizer, args, logits_ori=None, attack_model_all=None, attack_model_sub=None):
237 | model = model.to(device)
238 | start_time = time.time()
239 | best_metric = 0
240 |
241 | # MI Attack before unlearning
242 | if attack_model_all is not None:
243 | mi_logit_all_before, mi_sucrate_all_before = member_infer_attack(model, attack_model_all, data)
244 | self.trainer_log['mi_logit_all_before'] = mi_logit_all_before
245 | self.trainer_log['mi_sucrate_all_before'] = mi_sucrate_all_before
246 | if attack_model_sub is not None:
247 | mi_logit_sub_before, mi_sucrate_sub_before = member_infer_attack(model, attack_model_sub, data)
248 | self.trainer_log['mi_logit_sub_before'] = mi_logit_sub_before
249 | self.trainer_log['mi_sucrate_sub_before'] = mi_sucrate_sub_before
250 |
251 | loader = GraphSAINTRandomWalkSampler(
252 | data, batch_size=128, walk_length=2, num_steps=args.num_steps,
253 | )
254 | for epoch in trange(args.epochs, desc='Epoch'):
255 | model.train()
256 |
257 | epoch_loss = 0
258 | for step, batch in enumerate(tqdm(loader, desc='Step', leave=False)):
259 | batch = batch.to(device)
260 |
261 | # Message passing
262 | edge_index = batch.edge_index[:, batch.dr_mask]
263 | edge_type = batch.edge_type[batch.dr_mask]
264 | z = model(batch.x, edge_index, edge_type)
265 |
266 | # Positive and negative sample
267 | decoding_mask = (edge_type < args.num_edge_type) # Only select directed edges for link prediction
268 | decoding_edge_index = edge_index[:, decoding_mask]
269 | decoding_edge_type = edge_type[decoding_mask]
270 |
271 | neg_edge_index = negative_sampling_kg(
272 | edge_index=decoding_edge_index,
273 | edge_type=decoding_edge_type)
274 |
275 | pos_logits = model.decode(z, decoding_edge_index, decoding_edge_type)
276 | neg_logits = model.decode(z, neg_edge_index, decoding_edge_type)
277 | logits = torch.cat([pos_logits, neg_logits], dim=-1)
278 | label = get_link_labels(decoding_edge_index, neg_edge_index)
279 | # reg_loss = z.pow(2).mean() + model.W.pow(2).mean()
280 | loss = F.binary_cross_entropy_with_logits(logits, label)# + 1e-2 * reg_loss
281 |
282 | loss.backward()
283 | torch.nn.utils.clip_grad_norm_(model.parameters(), 1)
284 | optimizer.step()
285 | optimizer.zero_grad()
286 |
287 | log = {
288 | 'epoch': epoch,
289 | 'step': step,
290 | 'train_loss': loss.item(),
291 | }
292 | wandb.log(log)
293 | # msg = [f'{i}: {j:>4d}' if isinstance(j, int) else f'{i}: {j:.4f}' for i, j in log.items()]
294 | # tqdm.write(' | '.join(msg))
295 |
296 | epoch_loss += loss.item()
297 |
298 | if (epoch + 1) % args.valid_freq == 0:
299 | valid_loss, dt_auc, dt_aup, df_auc, df_aup, df_logit, logit_all_pair, valid_log = self.eval(model, data, 'val')
300 |
301 | train_log = {
302 | 'epoch': epoch,
303 | 'train_loss': epoch_loss / step
304 | }
305 |
306 | for log in [train_log, valid_log]:
307 | wandb.log(log)
308 | msg = [f'{i}: {j:>4d}' if isinstance(j, int) else f'{i}: {j:.4f}' for i, j in log.items()]
309 | tqdm.write(' | '.join(msg))
310 |
311 | self.trainer_log['log'].append(train_log)
312 | self.trainer_log['log'].append(valid_log)
313 |
314 | if dt_auc + df_auc > best_metric:
315 | best_metric = dt_auc + df_auc
316 | best_epoch = epoch
317 |
318 | print(f'Save best checkpoint at epoch {epoch:04d}. Valid loss = {valid_loss:.4f}')
319 | ckpt = {
320 | 'model_state': model.state_dict(),
321 | 'optimizer_state': optimizer.state_dict(),
322 | }
323 | torch.save(ckpt, os.path.join(args.checkpoint_dir, 'model_best.pt'))
324 |
325 | self.trainer_log['training_time'] = time.time() - start_time
326 |
327 | # Save models and node embeddings
328 | print('Saving final checkpoint')
329 | ckpt = {
330 | 'model_state': model.state_dict(),
331 | 'optimizer_state': optimizer.state_dict(),
332 | }
333 | torch.save(ckpt, os.path.join(args.checkpoint_dir, 'model_final.pt'))
334 |
335 | print(f'Training finished. Best checkpoint at epoch = {best_epoch:04d}, best valid loss = {best_metric:.4f}')
336 |
337 | self.trainer_log['best_epoch'] = best_epoch
338 | self.trainer_log['best_metric'] = best_metric
339 | self.trainer_log['training_time'] = np.mean([i['epoch_time'] for i in self.trainer_log['log'] if 'epoch_time' in i])
340 |
--------------------------------------------------------------------------------
/framework/trainer/graph_eraser.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import copy
4 | import math
5 | from tqdm import tqdm, trange
6 | import numpy as np
7 | import torch
8 | import torch.nn.functional as F
9 | from torch_geometric.utils import negative_sampling, subgraph
10 |
11 | from .base import Trainer
12 | from ..evaluation import *
13 | from ..utils import *
14 |
15 |
16 | class ConstrainedKmeans:
17 | '''This code is from https://github.com/MinChen00/Graph-Unlearning'''
18 |
19 | def __init__(self, args, data_feat, num_clusters, node_threshold, terminate_delta, max_iteration=20):
20 | self.args = args
21 | self.data_feat = data_feat
22 | self.num_clusters = num_clusters
23 | self.node_threshold = node_threshold
24 | self.terminate_delta = terminate_delta
25 | self.max_iteration = max_iteration
26 |
27 | def initialization(self):
28 | centroids = np.random.choice(np.arange(self.data_feat.shape[0]), self.num_clusters, replace=False)
29 | self.centroid = {}
30 | for i in range(self.num_clusters):
31 | self.centroid[i] = self.data_feat[centroids[i]]
32 |
33 | def clustering(self):
34 | centroid = copy.deepcopy(self.centroid)
35 | km_delta = []
36 |
37 | # pbar = tqdm(total=self.max_iteration)
38 | # pbar.set_description('Clustering')
39 |
40 | for i in trange(self.max_iteration, desc='Graph partition'):
41 | # self.logger.info('iteration %s' % (i,))
42 |
43 | self._node_reassignment()
44 | self._centroid_updating()
45 |
46 | # record the average change of centroids, if the change is smaller than a very small value, then terminate
47 | delta = self._centroid_delta(centroid, self.centroid)
48 | km_delta.append(delta)
49 | centroid = copy.deepcopy(self.centroid)
50 |
51 | if delta <= self.terminate_delta:
52 | break
53 | print("delta: %s" % delta)
54 | # pbar.close()
55 | return self.clusters, km_delta
56 |
57 | def _node_reassignment(self):
58 | self.clusters = {}
59 | for i in range(self.num_clusters):
60 | self.clusters[i] = np.zeros(0, dtype=np.uint64)
61 |
62 | distance = np.zeros([self.num_clusters, self.data_feat.shape[0]])
63 |
64 | for i in range(self.num_clusters):
65 | distance[i] = np.sum(np.power((self.data_feat - self.centroid[i]), 2), axis=1)
66 |
67 | sort_indices = np.unravel_index(np.argsort(distance, axis=None), distance.shape)
68 | clusters = sort_indices[0]
69 | users = sort_indices[1]
70 | selected_nodes = np.zeros(0, dtype=np.int64)
71 | counter = 0
72 |
73 | while len(selected_nodes) < self.data_feat.shape[0]:
74 | cluster = int(clusters[counter])
75 | user = users[counter]
76 | if self.clusters[cluster].size < self.node_threshold:
77 | self.clusters[cluster] = np.append(self.clusters[cluster], np.array(int(user)))
78 | selected_nodes = np.append(selected_nodes, np.array(int(user)))
79 |
80 | # delete all the following pairs for the selected user
81 | user_indices = np.where(users == user)[0]
82 | a = np.arange(users.size)
83 | b = user_indices[user_indices > counter]
84 | remain_indices = a[np.where(np.logical_not(np.isin(a, b)))[0]]
85 | clusters = clusters[remain_indices]
86 | users = users[remain_indices]
87 |
88 | counter += 1
89 |
90 | def _centroid_updating(self):
91 | for i in range(self.num_clusters):
92 | self.centroid[i] = np.mean(self.data_feat[self.clusters[i].astype(int)], axis=0)
93 |
94 | def _centroid_delta(self, centroid_pre, centroid_cur):
95 | delta = 0.0
96 | for i in range(len(centroid_cur)):
97 | delta += np.sum(np.abs(centroid_cur[i] - centroid_pre[i]))
98 |
99 | return delta
100 |
101 | def generate_shard_data(self, data):
102 | shard_data = {}
103 | for shard in trange(self.args['num_shards'], desc='Generate shard data'):
104 | train_shard_indices = list(self.community_to_node[shard])
105 | shard_indices = np.union1d(train_shard_indices, self.test_indices)
106 |
107 | x = data.x[shard_indices]
108 | y = data.y[shard_indices]
109 | edge_index = utils.filter_edge_index_1(data, shard_indices)
110 |
111 | data = Data(x=x, edge_index=torch.from_numpy(edge_index), y=y)
112 | data.train_mask = torch.from_numpy(np.isin(shard_indices, train_shard_indices))
113 | data.test_mask = torch.from_numpy(np.isin(shard_indices, self.test_indices))
114 |
115 | shard_data[shard] = data
116 |
117 | self.data_store.save_shard_data(self.shard_data)
118 |
119 | class OptimalAggregator:
120 | def __init__(self, run, target_model, data, args):
121 | self.args = args
122 |
123 | self.run = run
124 | self.target_model = target_model
125 | self.data = data
126 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
127 |
128 | self.num_shards = args.num_clusters
129 |
130 | def generate_train_data(self):
131 | data_store = DataStore(self.args)
132 | train_indices, _ = data_store.load_train_test_split()
133 |
134 | # sample a set of nodes from train_indices
135 | if self.args["num_opt_samples"] == 1000:
136 | train_indices = np.random.choice(train_indices, size=1000, replace=False)
137 | elif self.args["num_opt_samples"] == 10000:
138 | train_indices = np.random.choice(train_indices, size=int(train_indices.shape[0] * 0.1), replace=False)
139 | elif self.args["num_opt_samples"] == 1:
140 | train_indices = np.random.choice(train_indices, size=int(train_indices.shape[0]), replace=False)
141 |
142 | train_indices = np.sort(train_indices)
143 | self.logger.info("Using %s samples for optimization" % (int(train_indices.shape[0])))
144 |
145 | x = self.data.x[train_indices]
146 | y = self.data.y[train_indices]
147 | edge_index = utils.filter_edge_index(self.data.edge_index, train_indices)
148 |
149 | train_data = Data(x=x, edge_index=torch.from_numpy(edge_index), y=y)
150 | train_data.train_mask = torch.zeros(train_indices.shape[0], dtype=torch.bool)
151 | train_data.test_mask = torch.ones(train_indices.shape[0], dtype=torch.bool)
152 | self.true_labels = y
153 |
154 | self.posteriors = {}
155 | for shard in range(self.num_shards):
156 | self.target_model.data = train_data
157 | data_store.load_target_model(self.run, self.target_model, shard)
158 | self.posteriors[shard] = self.target_model.posterior().to(self.device)
159 |
160 | def optimization(self):
161 | weight_para = nn.Parameter(torch.full((self.num_shards,), fill_value=1.0 / self.num_shards), requires_grad=True)
162 | optimizer = optim.Adam([weight_para], lr=self.args['opt_lr'])
163 | scheduler = MultiStepLR(optimizer, milestones=[500, 1000], gamma=self.args['opt_lr'])
164 |
165 | train_dset = OptDataset(self.posteriors, self.true_labels)
166 | train_loader = DataLoader(train_dset, batch_size=32, shuffle=True, num_workers=0)
167 |
168 | min_loss = 1000.0
169 | for epoch in range(self.args.epochs):
170 | loss_all = 0.0
171 |
172 | for posteriors, labels in train_loader:
173 | labels = labels.to(self.device)
174 |
175 | optimizer.zero_grad()
176 | loss = self._loss_fn(posteriors, labels, weight_para)
177 | loss.backward()
178 | loss_all += loss
179 |
180 | optimizer.step()
181 | with torch.no_grad():
182 | weight_para[:] = torch.clamp(weight_para, min=0.0)
183 |
184 | scheduler.step()
185 |
186 | if loss_all < min_loss:
187 | ret_weight_para = copy.deepcopy(weight_para)
188 | min_loss = loss_all
189 |
190 | self.logger.info('epoch: %s, loss: %s' % (epoch, loss_all))
191 |
192 | return ret_weight_para / torch.sum(ret_weight_para)
193 |
194 | def _loss_fn(self, posteriors, labels, weight_para):
195 | aggregate_posteriors = torch.zeros_like(posteriors[0])
196 | for shard in range(self.num_shards):
197 | aggregate_posteriors += weight_para[shard] * posteriors[shard]
198 |
199 | aggregate_posteriors = F.softmax(aggregate_posteriors, dim=1)
200 | loss_1 = F.cross_entropy(aggregate_posteriors, labels)
201 | loss_2 = torch.sqrt(torch.sum(weight_para ** 2))
202 |
203 | return loss_1 + loss_2
204 |
205 | class Aggregator:
206 | def __init__(self, run, target_model, data, shard_data, args):
207 | self.args = args
208 |
209 | self.run = run
210 | self.target_model = target_model
211 | self.data = data
212 | self.shard_data = shard_data
213 |
214 | self.num_shards = args.num_clusters
215 |
216 | def generate_posterior(self, suffix=""):
217 | self.true_label = self.shard_data[0].y[self.shard_data[0]['test_mask']].detach().cpu().numpy()
218 | self.posteriors = {}
219 |
220 | for shard in range(self.args.num_clusters):
221 | self.target_model.data = self.shard_data[shard]
222 | self.data_store.load_target_model(self.run, self.target_model, shard, suffix)
223 | self.posteriors[shard] = self.target_model.posterior()
224 |
225 | def _optimal_aggregator(self):
226 | optimal = OptimalAggregator(self.run, self.target_model, self.data, self.args)
227 | optimal.generate_train_data()
228 | weight_para = optimal.optimization()
229 | self.data_store.save_optimal_weight(weight_para, run=self.run)
230 |
231 | posterior = self.posteriors[0] * weight_para[0]
232 | for shard in range(1, self.num_shards):
233 | posterior += self.posteriors[shard] * weight_para[shard]
234 |
235 | return f1_score(self.true_label, posterior.argmax(axis=1).cpu().numpy(), average="micro")
236 |
237 | class GraphEraserTrainer(Trainer):
238 |
239 | def train(self, model, data, optimizer, args, logits_ori=None, attack_model_all=None, attack_model_sub=None):
240 |
241 | with torch.no_grad():
242 | z = model(data.x, data.train_pos_edge_index[:, data.dr_mask])
243 |
244 | # Retrain the model
245 | for c in model.children():
246 | print('before', torch.norm(c.lin.weight), torch.norm(c.bias))
247 | for c in model.children():
248 | c.reset_parameters()
249 | for c in model.children():
250 | print('after', torch.norm(c.lin.weight), torch.norm(c.bias))
251 | model = model.cpu()
252 |
253 | num_nodes = data.num_nodes
254 | node_threshold = math.ceil(
255 | num_nodes / args.num_clusters + args.shard_size_delta * (num_nodes - num_nodes / args.num_clusters))
256 | print(f'Number of nodes: {num_nodes}. Shard threshold: {node_threshold}')
257 |
258 | cluster = ConstrainedKmeans(
259 | args,
260 | z.cpu().numpy(),
261 | args.num_clusters,
262 | node_threshold,
263 | args.terminate_delta,
264 | args.kmeans_max_iters)
265 | cluster.initialization()
266 |
267 | community, km_deltas = cluster.clustering()
268 | # with open(os.path.join(args.checkpoint_dir, 'kmeans_delta.pkl'), 'wb') as f:
269 | # pickle.dump(km_deltas, f)
270 |
271 | community_to_node = {}
272 | for i in range(args.num_clusters):
273 | community_to_node[i] = np.array(community[i].astype(int))
274 |
275 | models = {}
276 | test_result = []
277 | for shard_id in trange(args.num_clusters, desc='Sharded retraining'):
278 | model_shard_id = copy.deepcopy(model).to('cuda')
279 | optimizer = torch.optim.Adam(model_shard_id.parameters(), lr=args.lr)
280 |
281 | subset_train, _ = subgraph(
282 | torch.tensor(community[shard_id], dtype=torch.long, device=device),
283 | data.train_pos_edge_index,
284 | num_nodes=data.num_nodes)
285 |
286 | self.train_model(model_shard_id, data, subset_train, optimizer, args, shard_id)
287 |
288 | with torch.no_grad():
289 | z = model_shard_id(data.x, subset_train)
290 | logits = model_shard_id.decode(data.test_pos_edge_index, data.test_neg_edge_index)
291 |
292 | weight_para = nn.Parameter(torch.full((self.num_shards,), fill_value=1.0 / self.num_shards), requires_grad=True)
293 | optimizer = optim.Adam([weight_para], lr=self.args.lr)
294 |
295 |
296 | aggregator.generate_posterior()
297 | self.aggregate_f1_score = aggregator.aggregate()
298 | aggregate_time = time.time() - start_time
299 | self.logger.info("Partition cost %s seconds." % aggregate_time)
300 |
301 | self.logger.info("Final Test F1: %s" % (self.aggregate_f1_score,))
302 |
303 |
304 |
305 | def train_model(self, model, data, subset_train, optimizer, args, shard_id):
306 |
307 | best_loss = 100000
308 | for epoch in range(args.epochs):
309 | model.train()
310 |
311 | neg_edge_index = negative_sampling(
312 | edge_index=subset_train,
313 | num_nodes=data.num_nodes,
314 | num_neg_samples=subset_train.shape[1])
315 |
316 | z = model(data.x, subset_train)
317 | logits = model.decode(z, subset_train, neg_edge_index)
318 | label = self.get_link_labels(subset_train, neg_edge_index)
319 | loss = F.binary_cross_entropy_with_logits(logits, label)
320 |
321 | loss.backward()
322 | # torch.nn.utils.clip_grad_norm_(model.parameters(), 1)
323 | optimizer.step()
324 | optimizer.zero_grad()
325 |
326 | valid_loss, auc, aup, _, _, = self.eval_model(model, data, subset_train, 'val')
327 | log = {
328 | 'train_loss': loss.item(),
329 | 'valid_loss': valid_loss,
330 | 'valid_auc': auc,
331 | 'valid_aup': aup,
332 | }
333 | msg = [f'{i}: {j:>4d}' if isinstance(j, int) else f'{i}: {j:.4f}' for i, j in log.items()]
334 | tqdm.write(' | '.join(msg))
335 | self.trainer_log[f'shard_{shard_id}'] = log
336 |
337 | torch.save(model.state_dict(), os.path.join(args.checkpoint_dir, f'model_{shard_id}.pt'))
338 |
339 | @torch.no_grad()
340 | def eval_model(self, model, data, subset_train, stage='val', pred_all=False):
341 | model.eval()
342 | pos_edge_index = data[f'{stage}_pos_edge_index']
343 | neg_edge_index = data[f'{stage}_neg_edge_index']
344 |
345 | z = model(data.x, subset_train)
346 | logits = model.decode(z, pos_edge_index, neg_edge_index).sigmoid()
347 | label = self.get_link_labels(pos_edge_index, neg_edge_index)
348 |
349 | loss = F.binary_cross_entropy_with_logits(logits, label).cpu().item()
350 | auc = roc_auc_score(label.cpu(), logits.cpu())
351 | aup = average_precision_score(label.cpu(), logits.cpu())
352 |
353 | if self.args.unlearning_model in ['original', 'retrain']:
354 | df_logit = float('nan')
355 | else:
356 | # df_logit = float('nan')
357 | df_logit = model.decode(z, subset_train).sigmoid().detach().cpu().item()
358 |
359 | if pred_all:
360 | logit_all_pair = (z @ z.t()).cpu()
361 | else:
362 | logit_all_pair = None
363 |
364 | log = {
365 | f'{stage}_loss': loss,
366 | f'{stage}_auc': auc,
367 | f'{stage}_aup': aup,
368 | f'{stage}_df_logit': df_logit,
369 | }
370 | wandb.log(log)
371 | msg = [f'{i}: {j:.4f}' if isinstance(j, (np.floating, float)) else f'{i}: {j:>4d}' for i, j in log.items()]
372 | tqdm.write(' | '.join(msg))
373 |
374 | return loss, auc, aup, df_logit, logit_all_pair
375 |
--------------------------------------------------------------------------------
/framework/models/rgat.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from sklearn.metrics import roc_auc_score, average_precision_score
6 |
7 | from typing import Optional
8 |
9 | import torch
10 | import torch.nn.functional as F
11 | from torch import Tensor
12 | from torch.nn import Parameter, ReLU
13 | from torch_scatter import scatter_add
14 | from torch_sparse import SparseTensor
15 |
16 | from torch_geometric.nn.conv import MessagePassing
17 | from torch_geometric.nn.dense.linear import Linear
18 | from torch_geometric.nn.inits import glorot, ones, zeros
19 | from torch_geometric.typing import Adj, OptTensor, Size
20 | from torch_geometric.utils import softmax
21 |
22 |
23 | # Source: torch_geometric
24 | class RGATConv(MessagePassing):
25 | _alpha: OptTensor
26 |
27 | def __init__(
28 | self,
29 | in_channels: int,
30 | out_channels: int,
31 | num_relations: int,
32 | num_bases: Optional[int] = None,
33 | num_blocks: Optional[int] = None,
34 | mod: Optional[str] = None,
35 | attention_mechanism: str = "across-relation",
36 | attention_mode: str = "additive-self-attention",
37 | heads: int = 1,
38 | dim: int = 1,
39 | concat: bool = True,
40 | negative_slope: float = 0.2,
41 | dropout: float = 0.0,
42 | edge_dim: Optional[int] = None,
43 | bias: bool = True,
44 | **kwargs,
45 | ):
46 | kwargs.setdefault('aggr', 'add')
47 | super().__init__(node_dim=0, **kwargs)
48 |
49 | self.heads = heads
50 | self.negative_slope = negative_slope
51 | self.dropout = dropout
52 | self.mod = mod
53 | self.activation = ReLU()
54 | self.concat = concat
55 | self.attention_mode = attention_mode
56 | self.attention_mechanism = attention_mechanism
57 | self.dim = dim
58 | self.edge_dim = edge_dim
59 |
60 | self.in_channels = in_channels
61 | self.out_channels = out_channels
62 | self.num_relations = num_relations
63 | self.num_bases = num_bases
64 | self.num_blocks = num_blocks
65 |
66 | mod_types = ['additive', 'scaled', 'f-additive', 'f-scaled']
67 |
68 | if (self.attention_mechanism != "within-relation"
69 | and self.attention_mechanism != "across-relation"):
70 | raise ValueError('attention mechanism must either be '
71 | '"within-relation" or "across-relation"')
72 |
73 | if (self.attention_mode != "additive-self-attention"
74 | and self.attention_mode != "multiplicative-self-attention"):
75 | raise ValueError('attention mode must either be '
76 | '"additive-self-attention" or '
77 | '"multiplicative-self-attention"')
78 |
79 | if self.attention_mode == "additive-self-attention" and self.dim > 1:
80 | raise ValueError('"additive-self-attention" mode cannot be '
81 | 'applied when value of d is greater than 1. '
82 | 'Use "multiplicative-self-attention" instead.')
83 |
84 | if self.dropout > 0.0 and self.mod in mod_types:
85 | raise ValueError('mod must be None with dropout value greater '
86 | 'than 0 in order to sample attention '
87 | 'coefficients stochastically')
88 |
89 | if num_bases is not None and num_blocks is not None:
90 | raise ValueError('Can not apply both basis-decomposition and '
91 | 'block-diagonal-decomposition at the same time.')
92 |
93 | # The learnable parameters to compute both attention logits and
94 | # attention coefficients:
95 | self.q = Parameter(
96 | torch.Tensor(self.heads * self.out_channels,
97 | self.heads * self.dim))
98 | self.k = Parameter(
99 | torch.Tensor(self.heads * self.out_channels,
100 | self.heads * self.dim))
101 |
102 | if bias and concat:
103 | self.bias = Parameter(
104 | torch.Tensor(self.heads * self.dim * self.out_channels))
105 | elif bias and not concat:
106 | self.bias = Parameter(torch.Tensor(self.dim * self.out_channels))
107 | else:
108 | self.register_parameter('bias', None)
109 |
110 | if edge_dim is not None:
111 | self.lin_edge = Linear(self.edge_dim,
112 | self.heads * self.out_channels, bias=False,
113 | weight_initializer='glorot')
114 | self.e = Parameter(
115 | torch.Tensor(self.heads * self.out_channels,
116 | self.heads * self.dim))
117 | else:
118 | self.lin_edge = None
119 | self.register_parameter('e', None)
120 |
121 | if num_bases is not None:
122 | self.att = Parameter(
123 | torch.Tensor(self.num_relations, self.num_bases))
124 | self.basis = Parameter(
125 | torch.Tensor(self.num_bases, self.in_channels,
126 | self.heads * self.out_channels))
127 | elif num_blocks is not None:
128 | assert (
129 | self.in_channels % self.num_blocks == 0
130 | and (self.heads * self.out_channels) % self.num_blocks == 0), (
131 | "both 'in_channels' and 'heads * out_channels' must be "
132 | "multiple of 'num_blocks' used")
133 | self.weight = Parameter(
134 | torch.Tensor(self.num_relations, self.num_blocks,
135 | self.in_channels // self.num_blocks,
136 | (self.heads * self.out_channels) //
137 | self.num_blocks))
138 | else:
139 | self.weight = Parameter(
140 | torch.Tensor(self.num_relations, self.in_channels,
141 | self.heads * self.out_channels))
142 |
143 | self.w = Parameter(torch.ones(self.out_channels))
144 | self.l1 = Parameter(torch.Tensor(1, self.out_channels))
145 | self.b1 = Parameter(torch.Tensor(1, self.out_channels))
146 | self.l2 = Parameter(torch.Tensor(self.out_channels, self.out_channels))
147 | self.b2 = Parameter(torch.Tensor(1, self.out_channels))
148 |
149 | self._alpha = None
150 |
151 | self.reset_parameters()
152 |
153 | def reset_parameters(self):
154 | if self.num_bases is not None:
155 | glorot(self.basis)
156 | glorot(self.att)
157 | else:
158 | glorot(self.weight)
159 | glorot(self.q)
160 | glorot(self.k)
161 | zeros(self.bias)
162 | ones(self.l1)
163 | zeros(self.b1)
164 | torch.full(self.l2.size(), 1 / self.out_channels)
165 | zeros(self.b2)
166 | if self.lin_edge is not None:
167 | glorot(self.lin_edge)
168 | glorot(self.e)
169 |
170 | def forward(self, x: Tensor, edge_index: Adj, edge_type: OptTensor = None,
171 | edge_attr: OptTensor = None, size: Size = None,
172 | return_attention_weights=None):
173 | # propagate_type: (x: Tensor, edge_type: OptTensor, edge_attr: OptTensor) # noqa
174 | out = self.propagate(edge_index=edge_index, edge_type=edge_type, x=x,
175 | size=size, edge_attr=edge_attr)
176 |
177 | alpha = self._alpha
178 | assert alpha is not None
179 | self._alpha = None
180 |
181 | if isinstance(return_attention_weights, bool):
182 | if isinstance(edge_index, Tensor):
183 | return out, (edge_index, alpha)
184 | elif isinstance(edge_index, SparseTensor):
185 | return out, edge_index.set_value(alpha, layout='coo')
186 | else:
187 | return out
188 |
189 |
190 | def message(self, x_i: Tensor, x_j: Tensor, edge_type: Tensor,
191 | edge_attr: OptTensor, index: Tensor, ptr: OptTensor,
192 | size_i: Optional[int]) -> Tensor:
193 |
194 | if self.num_bases is not None: # Basis-decomposition =================
195 | w = torch.matmul(self.att, self.basis.view(self.num_bases, -1))
196 | w = w.view(self.num_relations, self.in_channels,
197 | self.heads * self.out_channels)
198 | if self.num_blocks is not None: # Block-diagonal-decomposition =======
199 | if (x_i.dtype == torch.long and x_j.dtype == torch.long
200 | and self.num_blocks is not None):
201 | raise ValueError('Block-diagonal decomposition not supported '
202 | 'for non-continuous input features.')
203 | w = self.weight
204 | x_i = x_i.view(-1, 1, w.size(1), w.size(2))
205 | x_j = x_j.view(-1, 1, w.size(1), w.size(2))
206 | w = torch.index_select(w, 0, edge_type)
207 | outi = torch.einsum('abcd,acde->ace', x_i, w)
208 | outi = outi.contiguous().view(-1, self.heads * self.out_channels)
209 | outj = torch.einsum('abcd,acde->ace', x_j, w)
210 | outj = outj.contiguous().view(-1, self.heads * self.out_channels)
211 | else: # No regularization/Basis-decomposition ========================
212 | if self.num_bases is None:
213 | w = self.weight
214 | w = torch.index_select(w, 0, edge_type)
215 | outi = torch.bmm(x_i.unsqueeze(1), w).squeeze(-2)
216 | outj = torch.bmm(x_j.unsqueeze(1), w).squeeze(-2)
217 |
218 | qi = torch.matmul(outi, self.q)
219 | kj = torch.matmul(outj, self.k)
220 |
221 | alpha_edge, alpha = 0, torch.tensor([0])
222 | if edge_attr is not None:
223 | if edge_attr.dim() == 1:
224 | edge_attr = edge_attr.view(-1, 1)
225 | assert self.lin_edge is not None, (
226 | "Please set 'edge_dim = edge_attr.size(-1)' while calling the "
227 | "RGATConv layer")
228 | edge_attributes = self.lin_edge(edge_attr).view(
229 | -1, self.heads * self.out_channels)
230 | if edge_attributes.size(0) != edge_attr.size(0):
231 | edge_attributes = torch.index_select(edge_attributes, 0,
232 | edge_type)
233 | alpha_edge = torch.matmul(edge_attributes, self.e)
234 |
235 | if self.attention_mode == "additive-self-attention":
236 | if edge_attr is not None:
237 | alpha = torch.add(qi, kj) + alpha_edge
238 | else:
239 | alpha = torch.add(qi, kj)
240 | alpha = F.leaky_relu(alpha, self.negative_slope)
241 | elif self.attention_mode == "multiplicative-self-attention":
242 | if edge_attr is not None:
243 | alpha = (qi * kj) * alpha_edge
244 | else:
245 | alpha = qi * kj
246 |
247 | if self.attention_mechanism == "within-relation":
248 | across_out = torch.zeros_like(alpha)
249 | for r in range(self.num_relations):
250 | mask = edge_type == r
251 | across_out[mask] = softmax(alpha[mask], index[mask])
252 | alpha = across_out
253 | elif self.attention_mechanism == "across-relation":
254 | alpha = softmax(alpha, index, ptr, size_i)
255 |
256 | self._alpha = alpha
257 |
258 | if self.mod == "additive":
259 | if self.attention_mode == "additive-self-attention":
260 | ones = torch.ones_like(alpha)
261 | h = (outj.view(-1, self.heads, self.out_channels) *
262 | ones.view(-1, self.heads, 1))
263 | h = torch.mul(self.w, h)
264 |
265 | return (outj.view(-1, self.heads, self.out_channels) *
266 | alpha.view(-1, self.heads, 1) + h)
267 | elif self.attention_mode == "multiplicative-self-attention":
268 | ones = torch.ones_like(alpha)
269 | h = (outj.view(-1, self.heads, 1, self.out_channels) *
270 | ones.view(-1, self.heads, self.dim, 1))
271 | h = torch.mul(self.w, h)
272 |
273 | return (outj.view(-1, self.heads, 1, self.out_channels) *
274 | alpha.view(-1, self.heads, self.dim, 1) + h)
275 |
276 | elif self.mod == "scaled":
277 | if self.attention_mode == "additive-self-attention":
278 | ones = alpha.new_ones(index.size())
279 | degree = scatter_add(ones, index,
280 | dim_size=size_i)[index].unsqueeze(-1)
281 | degree = torch.matmul(degree, self.l1) + self.b1
282 | degree = self.activation(degree)
283 | degree = torch.matmul(degree, self.l2) + self.b2
284 |
285 | return torch.mul(
286 | outj.view(-1, self.heads, self.out_channels) *
287 | alpha.view(-1, self.heads, 1),
288 | degree.view(-1, 1, self.out_channels))
289 | elif self.attention_mode == "multiplicative-self-attention":
290 | ones = alpha.new_ones(index.size())
291 | degree = scatter_add(ones, index,
292 | dim_size=size_i)[index].unsqueeze(-1)
293 | degree = torch.matmul(degree, self.l1) + self.b1
294 | degree = self.activation(degree)
295 | degree = torch.matmul(degree, self.l2) + self.b2
296 |
297 | return torch.mul(
298 | outj.view(-1, self.heads, 1, self.out_channels) *
299 | alpha.view(-1, self.heads, self.dim, 1),
300 | degree.view(-1, 1, 1, self.out_channels))
301 |
302 | elif self.mod == "f-additive":
303 | alpha = torch.where(alpha > 0, alpha + 1, alpha)
304 |
305 | elif self.mod == "f-scaled":
306 | ones = alpha.new_ones(index.size())
307 | degree = scatter_add(ones, index,
308 | dim_size=size_i)[index].unsqueeze(-1)
309 | alpha = alpha * degree
310 |
311 | elif self.training and self.dropout > 0:
312 | alpha = F.dropout(alpha, p=self.dropout, training=True)
313 |
314 | else:
315 | alpha = alpha # original
316 |
317 | if self.attention_mode == "additive-self-attention":
318 | return alpha.view(-1, self.heads, 1) * outj.view(
319 | -1, self.heads, self.out_channels)
320 | else:
321 | return (alpha.view(-1, self.heads, self.dim, 1) *
322 | outj.view(-1, self.heads, 1, self.out_channels))
323 |
324 | def update(self, aggr_out: Tensor) -> Tensor:
325 | if self.attention_mode == "additive-self-attention":
326 | if self.concat is True:
327 | aggr_out = aggr_out.view(-1, self.heads * self.out_channels)
328 | else:
329 | aggr_out = aggr_out.mean(dim=1)
330 |
331 | if self.bias is not None:
332 | aggr_out = aggr_out + self.bias
333 |
334 | return aggr_out
335 | else:
336 | if self.concat is True:
337 | aggr_out = aggr_out.view(
338 | -1, self.heads * self.dim * self.out_channels)
339 | else:
340 | aggr_out = aggr_out.mean(dim=1)
341 | aggr_out = aggr_out.view(-1, self.dim * self.out_channels)
342 |
343 | if self.bias is not None:
344 | aggr_out = aggr_out + self.bias
345 |
346 | return aggr_out
347 |
348 | def __repr__(self) -> str:
349 | return '{}({}, {}, heads={})'.format(self.__class__.__name__,
350 | self.in_channels,
351 | self.out_channels, self.heads)
352 |
353 | class RGAT(nn.Module):
354 | def __init__(self, args, num_nodes, num_edge_type, **kwargs):
355 | super().__init__()
356 | self.args = args
357 | self.num_edge_type = num_edge_type
358 |
359 | # Encoder: RGAT
360 | self.node_emb = nn.Embedding(num_nodes, args.in_dim)
361 | if num_edge_type > 20:
362 | self.conv1 = RGATConv(args.in_dim, args.hidden_dim, num_edge_type * 2, num_blocks=4)
363 | self.conv2 = RGATConv(args.hidden_dim, args.out_dim, num_edge_type * 2, num_blocks=4)
364 | else:
365 | self.conv1 = RGATConv(args.in_dim, args.hidden_dim, num_edge_type * 2)
366 | self.conv2 = RGATConv(args.hidden_dim, args.out_dim, num_edge_type * 2)
367 | self.relu = nn.ReLU()
368 |
369 | # Decoder: DistMult
370 | self.W = nn.Parameter(torch.Tensor(num_edge_type, args.out_dim))
371 | nn.init.xavier_uniform_(self.W, gain=nn.init.calculate_gain('relu'))
372 |
373 | def forward(self, x, edge, edge_type, return_all_emb=False):
374 | x = self.node_emb(x)
375 | x1 = self.conv1(x, edge, edge_type)
376 | x = self.relu(x1)
377 | x2 = self.conv2(x, edge, edge_type)
378 |
379 | if return_all_emb:
380 | return x1, x2
381 |
382 | return x2
383 |
384 | def decode(self, z, edge_index, edge_type):
385 | h = z[edge_index[0]]
386 | t = z[edge_index[1]]
387 | r = self.W[edge_type]
388 |
389 | logits = torch.sum(h * r * t, dim=1)
390 |
391 | return logits
392 |
--------------------------------------------------------------------------------
/prepare_dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import math
3 | import pickle
4 | import torch
5 | import pandas as pd
6 | import networkx as nx
7 | from tqdm import tqdm
8 | from torch_geometric.seed import seed_everything
9 | import torch_geometric.transforms as T
10 | from torch_geometric.data import Data
11 | from torch_geometric.datasets import CitationFull, Coauthor, Flickr, RelLinkPredDataset, WordNet18, WordNet18RR
12 | from torch_geometric.utils import train_test_split_edges, k_hop_subgraph, negative_sampling, to_undirected, is_undirected, to_networkx
13 | from ogb.linkproppred import PygLinkPropPredDataset
14 | from framework.utils import *
15 |
16 |
17 | data_dir = './data'
18 | df_size = [i / 100 for i in range(10)] + [i / 10 for i in range(10)] + [i for i in range(10)] # Df_size in percentage
19 | seeds = [42, 21, 13, 87, 100]
20 | graph_datasets = ['Cora', 'PubMed', 'DBLP', 'CS', 'ogbl-citation2', 'ogbl-collab'][4:]
21 | kg_datasets = ['FB15k-237', 'WordNet18', 'WordNet18RR', 'ogbl-biokg'][-1:]
22 | os.makedirs(data_dir, exist_ok=True)
23 |
24 |
25 | num_edge_type_mapping = {
26 | 'FB15k-237': 237,
27 | 'WordNet18': 18,
28 | 'WordNet18RR': 11
29 | }
30 |
31 | def train_test_split_edges_no_neg_adj_mask(data, val_ratio: float = 0.05, test_ratio: float = 0.1, two_hop_degree=None, kg=False):
32 | '''Avoid adding neg_adj_mask'''
33 |
34 | num_nodes = data.num_nodes
35 | row, col = data.edge_index
36 | edge_attr = data.edge_attr
37 | if kg:
38 | edge_type = data.edge_type
39 | data.edge_index = data.edge_attr = data.edge_weight = data.edge_year = data.edge_type = None
40 |
41 | if not kg:
42 | # Return upper triangular portion.
43 | mask = row < col
44 | row, col = row[mask], col[mask]
45 |
46 | if edge_attr is not None:
47 | edge_attr = edge_attr[mask]
48 |
49 | n_v = int(math.floor(val_ratio * row.size(0)))
50 | n_t = int(math.floor(test_ratio * row.size(0)))
51 |
52 | if two_hop_degree is not None: # Use low degree edges for test sets
53 | low_degree_mask = two_hop_degree < 50
54 |
55 | low = low_degree_mask.nonzero().squeeze()
56 | high = (~low_degree_mask).nonzero().squeeze()
57 |
58 | low = low[torch.randperm(low.size(0))]
59 | high = high[torch.randperm(high.size(0))]
60 |
61 | perm = torch.cat([low, high])
62 |
63 | else:
64 | perm = torch.randperm(row.size(0))
65 |
66 | row = row[perm]
67 | col = col[perm]
68 |
69 | # Train
70 | r, c = row[n_v + n_t:], col[n_v + n_t:]
71 |
72 | if kg:
73 |
74 | # data.edge_index and data.edge_type has reverse edges and edge types for message passing
75 | pos_edge_index = torch.stack([r, c], dim=0)
76 | # rev_pos_edge_index = torch.stack([r, c], dim=0)
77 | train_edge_type = edge_type[n_v + n_t:]
78 | # train_rev_edge_type = edge_type[n_v + n_t:] + edge_type.unique().shape[0]
79 |
80 | # data.edge_index = torch.cat((torch.stack([r, c], dim=0), torch.stack([r, c], dim=0)), dim=1)
81 | # data.edge_type = torch.cat([train_edge_type, train_rev_edge_type], dim=0)
82 |
83 | data.edge_index = pos_edge_index
84 | data.edge_type = train_edge_type
85 |
86 | # data.train_pos_edge_index and data.train_edge_type only has one direction edges and edge types for decoding
87 | data.train_pos_edge_index = torch.stack([r, c], dim=0)
88 | data.train_edge_type = train_edge_type
89 |
90 | else:
91 | data.train_pos_edge_index = torch.stack([r, c], dim=0)
92 | if edge_attr is not None:
93 | # out = to_undirected(data.train_pos_edge_index, edge_attr[n_v + n_t:])
94 | data.train_pos_edge_index, data.train_pos_edge_attr = out
95 | else:
96 | data.train_pos_edge_index = data.train_pos_edge_index
97 | # data.train_pos_edge_index = to_undirected(data.train_pos_edge_index)
98 |
99 | assert not is_undirected(data.train_pos_edge_index)
100 |
101 |
102 | # Test
103 | r, c = row[:n_t], col[:n_t]
104 | data.test_pos_edge_index = torch.stack([r, c], dim=0)
105 |
106 | if kg:
107 | data.test_edge_type = edge_type[:n_t]
108 | neg_edge_index = negative_sampling_kg(
109 | edge_index=data.test_pos_edge_index,
110 | edge_type=data.test_edge_type)
111 | else:
112 | neg_edge_index = negative_sampling(
113 | edge_index=data.test_pos_edge_index,
114 | num_nodes=data.num_nodes,
115 | num_neg_samples=data.test_pos_edge_index.shape[1])
116 |
117 | data.test_neg_edge_index = neg_edge_index
118 |
119 | # Valid
120 | r, c = row[n_t:n_t+n_v], col[n_t:n_t+n_v]
121 | data.val_pos_edge_index = torch.stack([r, c], dim=0)
122 |
123 | if kg:
124 | data.val_edge_type = edge_type[n_t:n_t+n_v]
125 | neg_edge_index = negative_sampling_kg(
126 | edge_index=data.val_pos_edge_index,
127 | edge_type=data.val_edge_type)
128 | else:
129 | neg_edge_index = negative_sampling(
130 | edge_index=data.val_pos_edge_index,
131 | num_nodes=data.num_nodes,
132 | num_neg_samples=data.val_pos_edge_index.shape[1])
133 |
134 | data.val_neg_edge_index = neg_edge_index
135 |
136 | return data
137 |
138 | def process_graph():
139 | for d in graph_datasets:
140 |
141 | if d in ['Cora', 'PubMed', 'DBLP']:
142 | dataset = CitationFull(os.path.join(data_dir, d), d, transform=T.NormalizeFeatures())
143 | elif d in ['CS', 'Physics']:
144 | dataset = Coauthor(os.path.join(data_dir, d), d, transform=T.NormalizeFeatures())
145 | elif d in ['Flickr']:
146 | dataset = Flickr(os.path.join(data_dir, d), transform=T.NormalizeFeatures())
147 | elif 'ogbl' in d:
148 | dataset = PygLinkPropPredDataset(root=os.path.join(data_dir, d), name=d)
149 | else:
150 | raise NotImplementedError
151 |
152 | print('Processing:', d)
153 | print(dataset)
154 | data = dataset[0]
155 | data.train_mask = data.val_mask = data.test_mask = None
156 | graph = to_networkx(data)
157 |
158 | # Get two hop degree for all nodes
159 | node_to_neighbors = {}
160 | for n in tqdm(graph.nodes(), desc='Two hop neighbors'):
161 | neighbor_1 = set(graph.neighbors(n))
162 | neighbor_2 = sum([list(graph.neighbors(i)) for i in neighbor_1], [])
163 | neighbor_2 = set(neighbor_2)
164 | neighbor = neighbor_1 | neighbor_2
165 |
166 | node_to_neighbors[n] = neighbor
167 |
168 | two_hop_degree = []
169 | row, col = data.edge_index
170 | mask = row < col
171 | row, col = row[mask], col[mask]
172 | for r, c in tqdm(zip(row, col), total=len(row)):
173 | neighbor_row = node_to_neighbors[r.item()]
174 | neighbor_col = node_to_neighbors[c.item()]
175 | neighbor = neighbor_row | neighbor_col
176 |
177 | num = len(neighbor)
178 |
179 | two_hop_degree.append(num)
180 |
181 | two_hop_degree = torch.tensor(two_hop_degree)
182 |
183 | for s in seeds:
184 | seed_everything(s)
185 |
186 | # D
187 | data = dataset[0]
188 | if 'ogbl' in d:
189 | data = train_test_split_edges_no_neg_adj_mask(data, test_ratio=0.05, two_hop_degree=two_hop_degree)
190 | else:
191 | data = train_test_split_edges_no_neg_adj_mask(data, test_ratio=0.05)
192 | print(s, data)
193 |
194 | with open(os.path.join(data_dir, d, f'd_{s}.pkl'), 'wb') as f:
195 | pickle.dump((dataset, data), f)
196 |
197 | # Two ways to sample Df from the training set
198 | ## 1. Df is within 2 hop local enclosing subgraph of Dtest
199 | ## 2. Df is outside of 2 hop local enclosing subgraph of Dtest
200 |
201 | # All the candidate edges (train edges)
202 | # graph = to_networkx(Data(edge_index=data.train_pos_edge_index, x=data.x))
203 |
204 | # Get the 2 hop local enclosing subgraph for all test edges
205 | _, local_edges, _, mask = k_hop_subgraph(
206 | data.test_pos_edge_index.flatten().unique(),
207 | 2,
208 | data.train_pos_edge_index,
209 | num_nodes=dataset[0].num_nodes)
210 | distant_edges = data.train_pos_edge_index[:, ~mask]
211 | print('Number of edges. Local: ', local_edges.shape[1], 'Distant:', distant_edges.shape[1])
212 |
213 | in_mask = mask
214 | out_mask = ~mask
215 |
216 | # df_in_mask = torch.zeros_like(mask)
217 | # df_out_mask = torch.zeros_like(mask)
218 |
219 | # df_in_all_idx = in_mask.nonzero().squeeze()
220 | # df_out_all_idx = out_mask.nonzero().squeeze()
221 | # df_in_selected_idx = df_in_all_idx[torch.randperm(df_in_all_idx.shape[0])[:df_size]]
222 | # df_out_selected_idx = df_out_all_idx[torch.randperm(df_out_all_idx.shape[0])[:df_size]]
223 |
224 | # df_in_mask[df_in_selected_idx] = True
225 | # df_out_mask[df_out_selected_idx] = True
226 |
227 | # assert (in_mask & out_mask).sum() == 0
228 | # assert (df_in_mask & df_out_mask).sum() == 0
229 |
230 |
231 | # local_edges = set()
232 | # for i in range(data.test_pos_edge_index.shape[1]):
233 | # edge = data.test_pos_edge_index[:, i].tolist()
234 | # subgraph = get_enclosing_subgraph(graph, edge)
235 | # local_edges = local_edges | set(subgraph[2])
236 |
237 | # distant_edges = graph.edges() - local_edges
238 |
239 | # print('aaaaaaa', len(local_edges), len(distant_edges))
240 | # local_edges = torch.tensor(sorted(list([i for i in local_edges if i[0] < i[1]])))
241 | # distant_edges = torch.tensor(sorted(list([i for i in distant_edges if i[0] < i[1]])))
242 |
243 |
244 | # df_in = torch.randperm(local_edges.shape[1])[:df_size]
245 | # df_out = torch.randperm(distant_edges.shape[1])[:df_size]
246 |
247 | # df_in = local_edges[:, df_in]
248 | # df_out = distant_edges[:, df_out]
249 |
250 | # df_in_mask = torch.zeros(data.train_pos_edge_index.shape[1], dtype=torch.bool)
251 | # df_out_mask = torch.zeros(data.train_pos_edge_index.shape[1], dtype=torch.bool)
252 |
253 | # for row in df_in:
254 | # i = (data.train_pos_edge_index.T == row).all(axis=1).nonzero()
255 | # df_in_mask[i] = True
256 |
257 | # for row in df_out:
258 | # i = (data.train_pos_edge_index.T == row).all(axis=1).nonzero()
259 | # df_out_mask[i] = True
260 |
261 | torch.save(
262 | {'out': out_mask, 'in': in_mask},
263 | os.path.join(data_dir, d, f'df_{s}.pt')
264 | )
265 |
266 | def process_kg():
267 | for d in kg_datasets:
268 |
269 | # Create the dataset to calculate node degrees
270 | if d in ['FB15k-237']:
271 | dataset = RelLinkPredDataset(os.path.join(data_dir, d), d, transform=T.NormalizeFeatures())
272 | data = dataset[0]
273 | data.x = torch.arange(data.num_nodes)
274 | edge_index = torch.cat([data.train_edge_index, data.valid_edge_index, data.test_edge_index], dim=1)
275 | edge_type = torch.cat([data.train_edge_type, data.valid_edge_type, data.test_edge_type])
276 | data = Data(edge_index=edge_index, edge_type=edge_type)
277 |
278 | elif d in ['WordNet18RR']:
279 | dataset = WordNet18RR(os.path.join(data_dir, d), transform=T.NormalizeFeatures())
280 | data = dataset[0]
281 | data.x = torch.arange(data.num_nodes)
282 | data.train_mask = data.val_mask = data.test_mask = None
283 |
284 | elif d in ['WordNet18']:
285 | dataset = WordNet18(os.path.join(data_dir, d), transform=T.NormalizeFeatures())
286 | data = dataset[0]
287 | data.x = torch.arange(data.num_nodes)
288 |
289 | # Use original split
290 | data.train_pos_edge_index = data.edge_index[:, data.train_mask]
291 | data.train_edge_type = data.edge_type[data.train_mask]
292 |
293 | data.val_pos_edge_index = data.edge_index[:, data.val_mask]
294 | data.val_edge_type = data.edge_type[data.val_mask]
295 | data.val_neg_edge_index = negative_sampling_kg(data.val_pos_edge_index, data.val_edge_type)
296 |
297 | data.test_pos_edge_index = data.edge_index[:, data.test_mask]
298 | data.test_edge_type = data.edge_type[data.test_mask]
299 | data.test_neg_edge_index = negative_sampling_kg(data.test_pos_edge_index, data.test_edge_type)
300 |
301 | elif 'ogbl' in d:
302 | dataset = PygLinkPropPredDataset(root=os.path.join(data_dir, d), name=d)
303 |
304 | split_edge = dataset.get_edge_split()
305 | train_edge, valid_edge, test_edge = split_edge["train"], split_edge["valid"], split_edge["test"]
306 | entity_dict = dict()
307 | cur_idx = 0
308 | for key in dataset[0]['num_nodes_dict']:
309 | entity_dict[key] = (cur_idx, cur_idx + dataset[0]['num_nodes_dict'][key])
310 | cur_idx += dataset[0]['num_nodes_dict'][key]
311 | nentity = sum(dataset[0]['num_nodes_dict'].values())
312 |
313 | valid_head_neg = valid_edge.pop('head_neg')
314 | valid_tail_neg = valid_edge.pop('tail_neg')
315 | test_head_neg = test_edge.pop('head_neg')
316 | test_tail_neg = test_edge.pop('tail_neg')
317 |
318 | train = pd.DataFrame(train_edge)
319 | valid = pd.DataFrame(valid_edge)
320 | test = pd.DataFrame(test_edge)
321 |
322 | # Convert to global index
323 | train['head'] = [idx + entity_dict[tp][0] for idx, tp in zip(train['head'], train['head_type'])]
324 | train['tail'] = [idx + entity_dict[tp][0] for idx, tp in zip(train['tail'], train['tail_type'])]
325 |
326 | valid['head'] = [idx + entity_dict[tp][0] for idx, tp in zip(valid['head'], valid['head_type'])]
327 | valid['tail'] = [idx + entity_dict[tp][0] for idx, tp in zip(valid['tail'], valid['tail_type'])]
328 |
329 | test['head'] = [idx + entity_dict[tp][0] for idx, tp in zip(test['head'], test['head_type'])]
330 | test['tail'] = [idx + entity_dict[tp][0] for idx, tp in zip(test['tail'], test['tail_type'])]
331 |
332 | valid_pos_edge_index = torch.tensor([valid['head'], valid['tail']])
333 | valid_edge_type = torch.tensor(valid.relation)
334 | valid_neg_edge_index = torch.stack([valid_pos_edge_index[0], valid_tail_neg[:, 0]])
335 |
336 | test_pos_edge_index = torch.tensor([test['head'], test['tail']])
337 | test_edge_type = torch.tensor(test.relation)
338 | test_neg_edge_index = torch.stack([test_pos_edge_index[0], test_tail_neg[:, 0]])
339 |
340 | train_directed = train[train.head_type != train.tail_type]
341 | train_undirected = train[train.head_type == train.tail_type]
342 | train_undirected_uni = train_undirected[train_undirected['head'] < train_undirected['tail']]
343 | train_uni = pd.concat([train_directed, train_undirected_uni], ignore_index=True)
344 |
345 | train_pos_edge_index = torch.tensor([train_uni['head'], train_uni['tail']])
346 | train_edge_type = torch.tensor(train_uni.relation)
347 |
348 | r, c = train_pos_edge_index
349 | rev_edge_index = torch.stack([c, r])
350 | rev_edge_type = train_edge_type + 51
351 |
352 | edge_index = torch.cat([train_pos_edge_index, rev_edge_index], dim=1)
353 | edge_type = torch.cat([train_edge_type, rev_edge_type], dim=0)
354 |
355 | data = Data(
356 | x=torch.arange(nentity), edge_index=edge_index, edge_type=edge_type,
357 | train_pos_edge_index=train_pos_edge_index, train_edge_type=train_edge_type,
358 | val_pos_edge_index=valid_pos_edge_index, val_edge_type=valid_edge_type, val_neg_edge_index=valid_neg_edge_index,
359 | test_pos_edge_index=test_pos_edge_index, test_edge_type=test_edge_type, test_neg_edge_index=test_neg_edge_index)
360 |
361 | else:
362 | raise NotImplementedError
363 |
364 | print('Processing:', d)
365 | print(dataset)
366 |
367 | for s in seeds:
368 | seed_everything(s)
369 |
370 | # D
371 | # data = train_test_split_edges_no_neg_adj_mask(data, test_ratio=0.05, two_hop_degree=two_hop_degree, kg=True)
372 | print(s, data)
373 |
374 | with open(os.path.join(data_dir, d, f'd_{s}.pkl'), 'wb') as f:
375 | pickle.dump((dataset, data), f)
376 |
377 | # Two ways to sample Df from the training set
378 | ## 1. Df is within 2 hop local enclosing subgraph of Dtest
379 | ## 2. Df is outside of 2 hop local enclosing subgraph of Dtest
380 |
381 | # All the candidate edges (train edges)
382 | # graph = to_networkx(Data(edge_index=data.train_pos_edge_index, x=data.x))
383 |
384 | # Get the 2 hop local enclosing subgraph for all test edges
385 | _, local_edges, _, mask = k_hop_subgraph(
386 | data.test_pos_edge_index.flatten().unique(),
387 | 2,
388 | data.train_pos_edge_index,
389 | num_nodes=dataset[0].num_nodes)
390 | distant_edges = data.train_pos_edge_index[:, ~mask]
391 | print('Number of edges. Local: ', local_edges.shape[1], 'Distant:', distant_edges.shape[1])
392 |
393 | in_mask = mask
394 | out_mask = ~mask
395 |
396 | torch.save(
397 | {'out': out_mask, 'in': in_mask},
398 | os.path.join(data_dir, d, f'df_{s}.pt')
399 | )
400 |
401 |
402 | def main():
403 | process_graph()
404 | # process_kg()
405 |
406 | if __name__ == "__main__":
407 | main()
408 |
--------------------------------------------------------------------------------