├── LICENSE ├── README.md ├── constants.py ├── data └── .DS_Store ├── datasets ├── .DS_Store ├── NCI1_dataset.py ├── Syn_dataset.py ├── __init__.py ├── ba3motif_dataset.py ├── mutag_dataset.py ├── sup_dataset.py └── web_dataset.py ├── evaluation ├── .DS_Store ├── __init__.py ├── in_distribution │ ├── .DS_Store │ ├── __init__.py │ ├── ood_stat.py │ └── orca │ │ ├── orca │ │ ├── orca.cpp │ │ ├── orca.exe │ │ ├── orca.h │ │ ├── test.txt │ │ └── tmp.txt ├── ood_evaluation.py └── robustness.py ├── explainers ├── .DS_Store ├── __init__.py ├── base.py ├── diff_explainer.py ├── diffusion │ ├── .DS_Store │ ├── __init__.py │ ├── graph_utils.py │ └── pgnn.py ├── gnnexplainer.py ├── meta_gnnexplainer.py └── visual.py ├── gnns ├── .DS_Store ├── __init__.py ├── ba3motif_gnn.py ├── bbbp_gnn.py ├── mutag_gnn.py ├── nci1_gnn.py ├── synthetic_gnn.py ├── tree_grids_gnn.py └── web_gnn.py ├── main.py ├── param ├── .DS_Store └── gnns │ └── .DS_Store ├── requirements.txt ├── results └── .DS_Store └── utils ├── __init__.py ├── dataset.py ├── dist_helper.py ├── helper.py └── train_utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Graph and Geometric Learning 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # D4Explainer: In-distribution Explanations of Graph Neural Network via Discrete Denoising Diffusion [NeurIPS 2023] 2 | This is the Pytorch implementation of " D4Explainer: In-distribution Explanations of Graph Neural Network via Discrete Denoising Diffusion" 3 | ## Requirements 4 | 5 | - `torch==1.10.1` 6 | - `torch-geometric==2.0.4` 7 | - `numpy==1.24.2` 8 | - `pandas==1.5.3` 9 | - `networkx==3.0` 10 | 11 | Refer to `requirements.txt` for more details. 12 | 13 | 14 | ## Dataset 15 | 16 | Download the datasets from [here](https://drive.google.com/drive/folders/1pwmeST3zBcSC34KbAL_Wvi-cFtufAOCE?usp=sharing) to `data/` 17 | 18 | **Datasets Included:** 19 | 20 | - Node classification: `BA_shapes`; `Tree_Cycle`; `Tree_Grids`; `cornell` 21 | - Graph classification: `mutag`; `ba3`; `bbbp`; `NCI1` 22 | 23 | ## Train Base GNNs 24 | ``` 25 | cd gnns 26 | python ba3motif_gnn.py 27 | python bbbp_gnn.py 28 | python mutag_gnn.py 29 | python nci1_gnn.py 30 | python synthetic_gnn.py --data_name Tree_Cycle 31 | python synthetic_gnn.py --data_name BA_shapes 32 | python tree_grids_gnn.py 33 | python web_gnn.py 34 | ``` 35 | 36 | 37 | ## Train and Evaluate D4Explainer 38 | For example, to train D4Explainer on Mutag, run: 39 | ``` 40 | python main.py --dataset mutag 41 | ``` 42 | 43 | 44 | ## Evaluation of Other Properties 45 | 46 | - In-distribution: `python -m evaluation.ood_evaluation` 47 | - Robustness: `python -m evaluation.robustness` 48 | 49 | 50 | -------------------------------------------------------------------------------- /constants.py: -------------------------------------------------------------------------------- 1 | feature_dict = { 2 | "BA_shapes": 10, 3 | "Tree_Cycle": 10, 4 | "Tree_Grids": 10, 5 | "cornell": 1703, 6 | "mutag": 14, 7 | "ba3": 4, 8 | "bbbp": 9, 9 | "NCI1": 37, 10 | } 11 | 12 | task_type = { 13 | "BA_shapes": "nc", 14 | "Tree_Cycle": "nc", 15 | "Tree_Grids": "nc", 16 | "cornell": "nc", 17 | "mutag": "gc", 18 | "ba3": "gc", 19 | "bbbp": "gc", 20 | "NCI1": "gc", 21 | } 22 | 23 | dataset_choices = list(task_type.keys()) 24 | -------------------------------------------------------------------------------- /data/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Graph-and-Geometric-Learning/D4Explainer/997b4c755fc19d0494c09bc3bd4925777ea53aca/data/.DS_Store -------------------------------------------------------------------------------- /datasets/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Graph-and-Geometric-Learning/D4Explainer/997b4c755fc19d0494c09bc3bd4925777ea53aca/datasets/.DS_Store -------------------------------------------------------------------------------- /datasets/NCI1_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import random 4 | 5 | import numpy as np 6 | import sklearn.preprocessing as preprocessing 7 | import torch 8 | from torch_geometric.data import Data, InMemoryDataset, download_url, extract_zip 9 | 10 | 11 | class NCI1(InMemoryDataset): 12 | url = ( 13 | "https://ls11-www.cs.tu-dortmund.de/people/morris/graphkerneldatasets/NCI1.zip" 14 | ) 15 | splits = ["training", "evaluation", "testing"] 16 | 17 | def __init__( 18 | self, root, mode="testing", transform=None, pre_transform=None, pre_filter=None 19 | ): 20 | assert mode in self.splits 21 | self.mode = mode 22 | super(NCI1, self).__init__(root, transform, pre_transform, pre_filter) 23 | 24 | idx = self.processed_file_names.index("{}.pt".format(mode)) 25 | self.data, self.slices = torch.load(self.processed_paths[idx]) 26 | self.url = { 27 | "https://ls11-www.cs.tu-dortmund.de/people/morris/graphkerneldatasets/NCI1.zip" 28 | } 29 | 30 | @property 31 | def raw_file_names(self): 32 | return [ 33 | "NCI1/" + i 34 | for i in [ 35 | "NCI1_A.txt", 36 | "NCI1_graph_indicator.txt", 37 | "NCI1_graph_labels.txt", 38 | "NCI1_node_labels.txt", 39 | ] 40 | ] 41 | 42 | @property 43 | def processed_file_names(self): 44 | return ["training.pt", "evaluation.pt", "testing.pt"] 45 | 46 | def download(self): 47 | if os.path.exists(osp.join(self.raw_dir, "NCI1")): 48 | print("Using existing data in folder NCI1") 49 | return 50 | 51 | path = download_url(self.url, self.raw_dir) 52 | extract_zip(path, self.raw_dir) 53 | os.unlink(path) 54 | 55 | def process(self): 56 | edge_index = np.loadtxt( 57 | osp.join(self.raw_dir, self.raw_file_names[0]), delimiter="," 58 | ).T 59 | edge_index = torch.from_numpy(edge_index - 1.0).to( 60 | torch.long 61 | ) # node idx from 0 62 | 63 | # edge_label = np.loadtxt(osp.join(self.raw_dir, self.raw_file_names[1])) 64 | # encoder = preprocessing.OneHotEncoder().fit(np.unique(edge_label).reshape(-1, 1)) 65 | # edge_attr = encoder.transform(edge_label.reshape(-1, 1)).toarray() 66 | # edge_attr = torch.Tensor(edge_attr) 67 | 68 | node_label = np.loadtxt(osp.join(self.raw_dir, self.raw_file_names[-1])) 69 | encoder = preprocessing.OneHotEncoder().fit( 70 | np.unique(node_label).reshape(-1, 1) 71 | ) 72 | x = encoder.transform(node_label.reshape(-1, 1)).toarray() 73 | x = torch.Tensor(x) 74 | 75 | z = np.loadtxt(osp.join(self.raw_dir, self.raw_file_names[1]), dtype=int) 76 | 77 | y = np.loadtxt(osp.join(self.raw_dir, self.raw_file_names[2])) 78 | y = torch.unsqueeze(torch.LongTensor(y), 1).long() 79 | num_graphs = len(y) 80 | total_edges = edge_index.size(1) 81 | begin = 0 82 | 83 | data_list = [] 84 | for i in range(num_graphs): 85 | perm = np.where(z == i + 1)[0] 86 | bound = max(perm) 87 | end = begin 88 | for end in range(begin, total_edges): 89 | if int(edge_index[0, end]) > bound: 90 | break 91 | 92 | data = Data( 93 | x=x[perm], 94 | y=y[i], 95 | z=node_label[perm], 96 | edge_index=edge_index[:, begin:end] - int(min(perm)), 97 | idx=i, 98 | ) 99 | 100 | if self.pre_filter is not None and not self.pre_filter(data): 101 | continue 102 | if self.pre_transform is not None: 103 | data = self.pre_transform(data) 104 | 105 | begin = end 106 | data_list.append(data) 107 | 108 | random.shuffle(data_list) 109 | torch.save(self.collate(data_list[1000:]), self.processed_paths[0]) 110 | torch.save(self.collate(data_list[500:1000]), self.processed_paths[1]) 111 | torch.save(self.collate(data_list[:500]), self.processed_paths[2]) 112 | -------------------------------------------------------------------------------- /datasets/Syn_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import pickle 4 | 5 | import numpy as np 6 | import torch 7 | from torch_geometric.data import Data, InMemoryDataset 8 | from torch_geometric.utils import dense_to_sparse, k_hop_subgraph, subgraph 9 | 10 | 11 | def get_neighbourhood( 12 | node_idx, edge_index, features, labels, edge_label_matrix, n_hops 13 | ): 14 | edge_subset = k_hop_subgraph(node_idx, n_hops, edge_index) # Get all nodes involved 15 | edge_index_unlabel = edge_subset[1] 16 | ground_truth = torch.zeros((edge_index_unlabel.size(1))) 17 | for t, (i, j) in enumerate(zip(edge_index_unlabel[0], edge_index_unlabel[1])): 18 | if edge_label_matrix[i, j] == 1: 19 | ground_truth[t] = 1 20 | ground_truth = ground_truth.bool() 21 | edge_subset_relabel = subgraph(edge_subset[0], edge_index, relabel_nodes=True) 22 | edge_index_sub = edge_subset_relabel[0] # [2, edge_num_sub] 23 | sub_feat = features[edge_subset[0], :] # [node_num_sub, feature_dim] 24 | sub_labels = labels[edge_subset[0]] 25 | self_label = labels[node_idx] 26 | node_dict = torch.tensor(edge_subset[0]).reshape(-1, 1) # Maps orig labels to new 27 | mapping = edge_subset[2] 28 | mapping_mask = torch.zeros((sub_feat.shape[0])) 29 | mapping_mask[mapping] = 1 30 | mapping_mask = mapping_mask.bool() 31 | return ( 32 | sub_feat, 33 | edge_index_sub, 34 | sub_labels, 35 | self_label, 36 | node_dict, 37 | ground_truth, 38 | mapping_mask, 39 | ) 40 | 41 | 42 | class SynGraphDataset(InMemoryDataset): 43 | def __init__(self, root, name, mode="testing", transform=None, pre_transform=None): 44 | self.name = name 45 | self.mode = mode 46 | super(SynGraphDataset, self).__init__(root, transform, pre_transform) 47 | idx = self.processed_file_names.index("{}_sub.pt".format(mode)) 48 | self.data, self.slices = torch.load(self.processed_paths[idx]) 49 | 50 | @property 51 | def raw_dir(self): 52 | return osp.join(self.root, self.name, "raw") 53 | 54 | @property 55 | def processed_dir(self): 56 | return osp.join(self.root, self.name, "processed") 57 | 58 | @property 59 | def raw_file_names(self): 60 | return [f"{self.name}.pkl"] 61 | 62 | @property 63 | def processed_file_names(self): 64 | return ["training_sub.pt", "evaluating_sub.pt", "testing_sub.pt"] 65 | 66 | def process(self): 67 | # Read data into huge `Data` list. 68 | with open( 69 | os.path.join(f"./data/{self.name}/raw", f"{self.name}.pkl"), "rb" 70 | ) as f: 71 | ( 72 | adj, 73 | features, 74 | y_train, 75 | y_val, 76 | y_test, 77 | train_mask, 78 | val_mask, 79 | test_mask, 80 | edge_label_matrix, 81 | ) = pickle.load(f) 82 | x = torch.from_numpy(features).float() 83 | y = ( 84 | train_mask.reshape(-1, 1) * y_train 85 | + val_mask.reshape(-1, 1) * y_val 86 | + test_mask.reshape(-1, 1) * y_test 87 | ) 88 | y = torch.from_numpy(np.where(y)[1]) 89 | 90 | edge_index = dense_to_sparse(torch.from_numpy(adj))[0] 91 | data_whole = Data(x=x, edge_index=edge_index, y=y) 92 | data_whole.train_mask = torch.from_numpy(train_mask) 93 | data_whole.val_mask = torch.from_numpy(val_mask) 94 | data_whole.test_mask = torch.from_numpy(test_mask) 95 | torch.save(data_whole, f"./data/{self.name}/processed/whole_graph.pt") 96 | 97 | data_list = [] 98 | for id in range(x.shape[0]): 99 | ( 100 | sub_feat, 101 | edge_index_sub, 102 | sub_labels, 103 | self_label, 104 | node_dict, 105 | ground_truth, 106 | mapping_mask, 107 | ) = get_neighbourhood( 108 | id, 109 | edge_index, 110 | features=x, 111 | labels=y, 112 | edge_label_matrix=edge_label_matrix, 113 | n_hops=4, 114 | ) 115 | data = Data( 116 | x=sub_feat, 117 | edge_index=edge_index_sub, 118 | y=sub_labels, 119 | self_y=self_label, 120 | node_dict=node_dict, 121 | ground_truth=ground_truth, 122 | mapping=mapping_mask, 123 | idx=id, 124 | ) 125 | print(data) 126 | if self.pre_filter is not None and not self.pre_filter(data): 127 | continue 128 | if self.pre_transform is not None: 129 | data = self.pre_transform(data) 130 | data_list.append(data) 131 | 132 | # data_list = np.array(data_list) 133 | train_mask = list(np.where(train_mask)[0]) 134 | val_mask = list(np.where(val_mask)[0]) 135 | test_mask = list(np.where(test_mask)[0]) 136 | torch.save( 137 | self.collate([data_list[i] for i in train_mask]), 138 | f"./data/{self.name}/processed/training_sub.pt", 139 | ) 140 | torch.save( 141 | self.collate([data_list[i] for i in val_mask]), 142 | f"./data/{self.name}/processed/evaluating_sub.pt", 143 | ) 144 | torch.save( 145 | self.collate([data_list[i] for i in test_mask]), 146 | f"./data/{self.name}/processed/testing_sub.pt", 147 | ) 148 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .ba3motif_dataset import BA3Motif 2 | from .mutag_dataset import Mutagenicity 3 | from .NCI1_dataset import NCI1 4 | from .sup_dataset import bbbp 5 | from .Syn_dataset import SynGraphDataset 6 | from .web_dataset import WebDataset 7 | 8 | __all__ = [ 9 | "BA3Motif", 10 | "Mutagenicity", 11 | "NCI1", 12 | "bbbp", 13 | "SynGraphDataset", 14 | "WebDataset", 15 | ] 16 | -------------------------------------------------------------------------------- /datasets/ba3motif_dataset.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import random 3 | import sys 4 | sys.path.append("..") 5 | import numpy as np 6 | import torch 7 | from torch_geometric.data import Data, InMemoryDataset 8 | 9 | 10 | class BA3Motif(InMemoryDataset): 11 | splits = ["training", "evaluation", "testing"] 12 | 13 | def __init__( 14 | self, root, mode="testing", transform=None, pre_transform=None, pre_filter=None 15 | ): 16 | assert mode in self.splits 17 | self.mode = mode 18 | super(BA3Motif, self).__init__(root, transform, pre_transform, pre_filter) 19 | 20 | idx = self.processed_file_names.index("{}.pt".format(mode)) 21 | self.data, self.slices = torch.load(self.processed_paths[idx]) 22 | 23 | @property 24 | def raw_file_names(self): 25 | return ["BA-3motif.npy"] 26 | 27 | @property 28 | def processed_file_names(self): 29 | return ["training.pt", "evaluation.pt", "testing.pt"] 30 | 31 | def download(self): 32 | if not osp.exists(osp.join(self.raw_dir, "raw", "BA-3motif.npy")): 33 | print( 34 | "raw data of `BA-3motif.npy` doesn't exist, please redownload from our github." 35 | ) 36 | raise FileNotFoundError 37 | 38 | def process(self): 39 | edge_index_list, label_list, ground_truth_list, role_id_list, pos = np.load( 40 | osp.join(self.raw_dir, self.raw_file_names[0]), allow_pickle=True 41 | ) 42 | 43 | data_list = [] 44 | alpha = 0.25 45 | for idx, (edge_index, y, ground_truth, z, p) in enumerate( 46 | zip(edge_index_list, label_list, ground_truth_list, role_id_list, pos) 47 | ): 48 | edge_index = torch.from_numpy(edge_index) 49 | edge_index = torch.tensor(edge_index, dtype=torch.long) 50 | node_idx = torch.unique(edge_index) 51 | assert node_idx.max() == node_idx.size(0) - 1 52 | x = torch.zeros(node_idx.size(0), 4) 53 | index = [i for i in range(node_idx.size(0))] 54 | x[index, z] = 1 55 | x = alpha * x + (1 - alpha) * torch.rand((node_idx.size(0), 4)) 56 | edge_attr = torch.ones(edge_index.size(1), 1) 57 | y = torch.tensor(y, dtype=torch.long).unsqueeze(dim=0) 58 | p = np.array(list(p.values())) 59 | 60 | data = Data( 61 | x=x, 62 | y=y, 63 | z=z, 64 | edge_index=edge_index, 65 | edge_attr=edge_attr, 66 | pos=p, 67 | ground_truth_mask=ground_truth, 68 | name=f"BA-3motif{idx}", 69 | idx=idx, 70 | ) 71 | 72 | if self.pre_filter is not None and not self.pre_filter(data): 73 | continue 74 | if self.pre_transform is not None: 75 | data = self.pre_transform(data) 76 | 77 | data_list.append(data) 78 | 79 | random.shuffle(data_list) 80 | torch.save(self.collate(data_list[800:]), self.processed_paths[0]) 81 | torch.save(self.collate(data_list[400:800]), self.processed_paths[1]) 82 | torch.save(self.collate(data_list[:400]), self.processed_paths[2]) 83 | -------------------------------------------------------------------------------- /datasets/mutag_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import random 4 | 5 | import numpy as np 6 | import sklearn.preprocessing as preprocessing 7 | import torch 8 | from torch_geometric.data import Data, InMemoryDataset, download_url, extract_zip 9 | 10 | 11 | class Mutagenicity(InMemoryDataset): 12 | url = "https://ls11-www.cs.tu-dortmund.de/people/morris/graphkerneldatasets/Mutagenicity.zip" 13 | 14 | splits = ["training", "evaluation", "testing"] 15 | 16 | def __init__( 17 | self, root, mode="testing", transform=None, pre_transform=None, pre_filter=None 18 | ): 19 | assert mode in self.splits 20 | self.mode = mode 21 | super(Mutagenicity, self).__init__(root, transform, pre_transform, pre_filter) 22 | 23 | idx = self.processed_file_names.index("{}.pt".format(mode)) 24 | self.data, self.slices = torch.load(self.processed_paths[idx]) 25 | 26 | @property 27 | def raw_file_names(self): 28 | return [ 29 | "Mutagenicity/" + i 30 | for i in [ 31 | "Mutagenicity_A.txt", 32 | "Mutagenicity_edge_labels.txt", 33 | "Mutagenicity_graph_indicator.txt", 34 | "Mutagenicity_graph_labels.txt", 35 | "Mutagenicity_node_labels.txt", 36 | ] 37 | ] 38 | 39 | @property 40 | def processed_file_names(self): 41 | return ["training.pt", "evaluation.pt", "testing.pt"] 42 | 43 | def download(self): 44 | if os.path.exists(osp.join(self.raw_dir, "Mutagenicity")): 45 | print("Using existing data in folder Mutagenicity") 46 | return 47 | 48 | path = download_url(self.url, self.raw_dir) 49 | extract_zip(path, self.raw_dir) 50 | os.unlink(path) 51 | 52 | def process(self): 53 | edge_index = np.loadtxt( 54 | osp.join(self.raw_dir, self.raw_file_names[0]), delimiter="," 55 | ).T 56 | edge_index = torch.from_numpy(edge_index - 1.0).to( 57 | torch.long 58 | ) # node idx from 0 59 | 60 | edge_label = np.loadtxt(osp.join(self.raw_dir, self.raw_file_names[1])) 61 | encoder = preprocessing.OneHotEncoder().fit( 62 | np.unique(edge_label).reshape(-1, 1) 63 | ) 64 | edge_attr = encoder.transform(edge_label.reshape(-1, 1)).toarray() 65 | edge_attr = torch.Tensor(edge_attr) 66 | 67 | node_label = np.loadtxt(osp.join(self.raw_dir, self.raw_file_names[-1])) 68 | encoder = preprocessing.OneHotEncoder().fit( 69 | np.unique(node_label).reshape(-1, 1) 70 | ) 71 | x = encoder.transform(node_label.reshape(-1, 1)).toarray() 72 | x = torch.Tensor(x) 73 | 74 | z = np.loadtxt(osp.join(self.raw_dir, self.raw_file_names[2]), dtype=int) 75 | 76 | y = np.loadtxt(osp.join(self.raw_dir, self.raw_file_names[3])) 77 | y = torch.unsqueeze(torch.LongTensor(y), 1).long() 78 | num_graphs = len(y) 79 | total_edges = edge_index.size(1) 80 | begin = 0 81 | 82 | data_list = [] 83 | for i in range(num_graphs): 84 | perm = np.where(z == i + 1)[0] 85 | bound = max(perm) 86 | end = begin 87 | for end in range(begin, total_edges): 88 | if int(edge_index[0, end]) > bound: 89 | break 90 | 91 | data = Data( 92 | x=x[perm], 93 | y=y[i], 94 | z=node_label[perm], 95 | edge_index=edge_index[:, begin:end] - int(min(perm)), 96 | edge_attr=edge_attr[begin:end], 97 | name="mutag_%d" % i, 98 | idx=i, 99 | ) 100 | 101 | if self.pre_filter is not None and not self.pre_filter(data): 102 | continue 103 | if self.pre_transform is not None: 104 | data = self.pre_transform(data) 105 | 106 | begin = end 107 | data_list.append(data) 108 | 109 | assert len(data_list) == 4337 110 | 111 | random.shuffle(data_list) 112 | torch.save(self.collate(data_list[1000:]), self.processed_paths[0]) 113 | torch.save(self.collate(data_list[500:1000]), self.processed_paths[1]) 114 | torch.save(self.collate(data_list[:500]), self.processed_paths[2]) 115 | -------------------------------------------------------------------------------- /datasets/sup_dataset.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | 3 | import torch 4 | from torch_geometric.data import InMemoryDataset 5 | 6 | 7 | class bbbp(InMemoryDataset): 8 | splits = ["training", "evaluation", "testing"] 9 | 10 | def __init__(self, root, transform=None, pre_transform=None, pre_filter=None): 11 | super(bbbp, self).__init__(root, transform, pre_transform, pre_filter) 12 | 13 | self.data, self.slices = torch.load(osp.join(self.root, "processed/data.pt")) 14 | self.data.x = self.data.x.float() 15 | edge_attr = self.data.edge_attr 16 | loc = torch.where(edge_attr != edge_attr) 17 | edge_attr[loc] = 0 18 | self.data.edge_attr = edge_attr.float() 19 | self.data.y = self.data.y.view(-1).to(torch.int64) 20 | -------------------------------------------------------------------------------- /datasets/web_dataset.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | 3 | import numpy as np 4 | import torch 5 | from torch_geometric.data import Data, InMemoryDataset 6 | from torch_geometric.datasets import WebKB 7 | from torch_geometric.utils import k_hop_subgraph, subgraph 8 | 9 | 10 | def get_neighbourhood(node_idx, edge_index, features, labels, n_hops): 11 | edge_subset = k_hop_subgraph(node_idx, n_hops, edge_index) # Get all nodes involved 12 | edge_subset_relabel = subgraph(edge_subset[0], edge_index, relabel_nodes=True) 13 | edge_index_sub = edge_subset_relabel[0] # [2, edge_num_sub] 14 | sub_feat = features[edge_subset[0], :] # [node_num_sub, feature_dim] 15 | sub_labels = labels[edge_subset[0]] 16 | self_label = labels[node_idx] 17 | node_dict = torch.tensor(edge_subset[0]).reshape(-1, 1) # Maps orig labels to new 18 | mapping = edge_subset[2] 19 | mapping_mask = torch.zeros((sub_feat.shape[0])) 20 | mapping_mask[mapping] = 1 21 | mapping_mask = mapping_mask.bool() 22 | return sub_feat, edge_index_sub, sub_labels, self_label, node_dict, mapping_mask 23 | 24 | 25 | class WebDataset(InMemoryDataset): 26 | def __init__(self, root, name, mode="testing", transform=None, pre_transform=None): 27 | self.name = name 28 | self.mode = mode 29 | super(WebDataset, self).__init__(root, transform, pre_transform) 30 | idx = self.processed_file_names.index("{}_sub.pt".format(mode)) 31 | self.data, self.slices = torch.load(self.processed_paths[idx]) 32 | 33 | @property 34 | def raw_dir(self): 35 | return osp.join(self.root, self.name, "raw") 36 | 37 | @property 38 | def processed_dir(self): 39 | return osp.join(self.root, self.name, "processed") 40 | 41 | @property 42 | def raw_file_names(self): 43 | return [f"{self.name}.pkl"] 44 | 45 | @property 46 | def processed_file_names(self): 47 | return ["training_sub.pt", "evaluating_sub.pt", "testing_sub.pt"] 48 | 49 | def process(self): 50 | # Read data into huge `Data` list. 51 | webdata = WebKB(root=self.root, name=self.name) 52 | data = webdata[0] 53 | train_mask_data = data.train_mask[:, 0] 54 | val_mask_data = data.val_mask[:, 0] 55 | test_mask_data = data.test_mask[:, 0] 56 | edge_index = data.edge_index 57 | x = data.x 58 | y = data.y 59 | data_whole = Data(x=x, edge_index=edge_index, y=y) 60 | data_whole.train_mask = train_mask_data 61 | data_whole.val_mask = val_mask_data 62 | data_whole.test_mask = test_mask_data 63 | torch.save(data_whole, f"./data/{self.name}/processed/whole_graph.pt") 64 | 65 | data_list = [] 66 | for id in range(x.shape[0]): 67 | ( 68 | sub_feat, 69 | edge_index_sub, 70 | sub_labels, 71 | self_label, 72 | node_dict, 73 | mapping_mask, 74 | ) = get_neighbourhood(id, edge_index, features=x, labels=y, n_hops=6) 75 | data = Data( 76 | x=sub_feat, 77 | edge_index=edge_index_sub, 78 | y=sub_labels, 79 | self_y=self_label, 80 | node_dict=node_dict, 81 | mapping=mapping_mask, 82 | idx=id, 83 | ) 84 | if self.pre_filter is not None and not self.pre_filter(data): 85 | continue 86 | if self.pre_transform is not None: 87 | data = self.pre_transform(data) 88 | data_list.append(data) 89 | 90 | train_mask = list(np.where(train_mask_data)[0]) 91 | val_mask = list(np.where(val_mask_data)[0]) 92 | test_mask = list(np.where(test_mask_data)[0]) 93 | torch.save( 94 | self.collate([data_list[i] for i in train_mask]), 95 | f"./data/{self.name}/processed/training_sub.pt", 96 | ) 97 | torch.save( 98 | self.collate([data_list[i] for i in val_mask]), 99 | f"./data/{self.name}/processed/evaluating_sub.pt", 100 | ) 101 | torch.save( 102 | self.collate([data_list[i] for i in test_mask]), 103 | f"./data/{self.name}/processed/testing_sub.pt", 104 | ) 105 | -------------------------------------------------------------------------------- /evaluation/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Graph-and-Geometric-Learning/D4Explainer/997b4c755fc19d0494c09bc3bd4925777ea53aca/evaluation/.DS_Store -------------------------------------------------------------------------------- /evaluation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Graph-and-Geometric-Learning/D4Explainer/997b4c755fc19d0494c09bc3bd4925777ea53aca/evaluation/__init__.py -------------------------------------------------------------------------------- /evaluation/in_distribution/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Graph-and-Geometric-Learning/D4Explainer/997b4c755fc19d0494c09bc3bd4925777ea53aca/evaluation/in_distribution/.DS_Store -------------------------------------------------------------------------------- /evaluation/in_distribution/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Graph-and-Geometric-Learning/D4Explainer/997b4c755fc19d0494c09bc3bd4925777ea53aca/evaluation/in_distribution/__init__.py -------------------------------------------------------------------------------- /evaluation/in_distribution/ood_stat.py: -------------------------------------------------------------------------------- 1 | import concurrent.futures 2 | import copy 3 | import os 4 | import subprocess as sp 5 | from datetime import datetime 6 | 7 | import networkx as nx 8 | import numpy as np 9 | from scipy.linalg import eigvalsh 10 | 11 | from utils import compute_mmd, gaussian, gaussian_emd, process_tensor 12 | 13 | PRINT_TIME = False 14 | ORCA_DIR = "orca" # the relative path to the orca dir 15 | 16 | 17 | def degree_worker(G): 18 | """ 19 | Compute the degree distribution of a graph. 20 | :param G: a networkx graph 21 | :return: a numpy array of the degree distribution 22 | """ 23 | return np.array(nx.degree_histogram(G)) 24 | 25 | 26 | def add_tensor(x, y): 27 | """ 28 | Add two tensors. If unequal shape, pads the smaller one with zeros. 29 | :param x: a tensor 30 | :param y: a tensor 31 | :return: x + y 32 | """ 33 | x, y = process_tensor(x, y) 34 | return x + y 35 | 36 | 37 | def degree_stats(graph_ref_list, graph_pred_list, is_parallel=True): 38 | """ 39 | Compute the distance between the degree distributions of two unordered sets of graphs. 40 | :param graph_ref_list: a list of networkx graphs 41 | :param graph_pred_list: a list of networkx graphs 42 | :param is_parallel: whether to use parallel computing 43 | :return: the distance between the two degree distributions 44 | """ 45 | sample_ref = [] 46 | sample_pred = [] 47 | # in case an empty graph is generated 48 | graph_pred_list_remove_empty = [G for G in graph_pred_list if not G.number_of_nodes() == 0] 49 | 50 | prev = datetime.now() 51 | if is_parallel: 52 | with concurrent.futures.ThreadPoolExecutor() as executor: 53 | for deg_hist in executor.map(degree_worker, graph_ref_list): 54 | sample_ref.append(deg_hist) 55 | with concurrent.futures.ThreadPoolExecutor() as executor: 56 | for deg_hist in executor.map(degree_worker, graph_pred_list_remove_empty): 57 | sample_pred.append(deg_hist) 58 | 59 | else: 60 | for i in range(len(graph_ref_list)): 61 | degree_temp = np.array(nx.degree_histogram(graph_ref_list[i])) 62 | sample_ref.append(degree_temp) 63 | for i in range(len(graph_pred_list_remove_empty)): 64 | degree_temp = np.array(nx.degree_histogram(graph_pred_list_remove_empty[i])) 65 | sample_pred.append(degree_temp) 66 | mmd_dist = compute_mmd(sample_ref, sample_pred, kernel=gaussian_emd) 67 | elapsed = datetime.now() - prev 68 | if PRINT_TIME: 69 | print("Time computing degree mmd: ", elapsed) 70 | return mmd_dist 71 | 72 | 73 | ############################################################################### 74 | 75 | 76 | def spectral_worker(G): 77 | """ 78 | Compute the spectral pmf of a graph. 79 | :param G: a networkx graph 80 | :return: a numpy array of the spectral pmf 81 | """ 82 | eigs = eigvalsh(nx.normalized_laplacian_matrix(G).todense()) 83 | spectral_pmf, _ = np.histogram(eigs, bins=200, range=(-1e-5, 2), density=False) 84 | spectral_pmf = spectral_pmf / spectral_pmf.sum() 85 | return spectral_pmf 86 | 87 | 88 | def spectral_stats(graph_ref_list, graph_pred_list, is_parallel=True): 89 | """ 90 | Compute the distance between the degree distributions of two unordered sets of graphs. 91 | :param graph_ref_list: a list of networkx graphs 92 | :param graph_pred_list: a list of networkx graphs 93 | :param is_parallel: whether to use parallel computing 94 | :return: the distance between the two degree distributions 95 | """ 96 | sample_ref = [] 97 | sample_pred = [] 98 | # in case an empty graph is generated 99 | graph_pred_list_remove_empty = [G for G in graph_pred_list if not G.number_of_nodes() == 0] 100 | 101 | prev = datetime.now() 102 | if is_parallel: 103 | with concurrent.futures.ThreadPoolExecutor() as executor: 104 | for spectral_density in executor.map(spectral_worker, graph_ref_list): 105 | sample_ref.append(spectral_density) 106 | with concurrent.futures.ThreadPoolExecutor() as executor: 107 | for spectral_density in executor.map(spectral_worker, graph_pred_list_remove_empty): 108 | sample_pred.append(spectral_density) 109 | else: 110 | for i in range(len(graph_ref_list)): 111 | spectral_temp = spectral_worker(graph_ref_list[i]) 112 | sample_ref.append(spectral_temp) 113 | for i in range(len(graph_pred_list_remove_empty)): 114 | spectral_temp = spectral_worker(graph_pred_list_remove_empty[i]) 115 | sample_pred.append(spectral_temp) 116 | 117 | mmd_dist = compute_mmd(sample_ref, sample_pred, kernel=gaussian_emd) 118 | 119 | elapsed = datetime.now() - prev 120 | if PRINT_TIME: 121 | print("Time computing degree mmd: ", elapsed) 122 | return mmd_dist 123 | 124 | 125 | ############################################################################### 126 | 127 | 128 | def clustering_worker(param): 129 | """ 130 | Compute the clustering coefficient distribution of a graph. 131 | :param param: a tuple of (graph, number of bins) 132 | :return: a numpy array of the clustering coefficient distribution 133 | """ 134 | G, bins = param 135 | clustering_coeffs_list = list(nx.clustering(G).values()) 136 | hist, _ = np.histogram(clustering_coeffs_list, bins=bins, range=(0.0, 1.0), density=False) 137 | return hist 138 | 139 | 140 | def clustering_stats(graph_ref_list, graph_pred_list, bins=100, is_parallel=True): 141 | """ 142 | Compute the distance between the clustering coefficient distributions of two unordered sets of graphs. 143 | :param graph_ref_list: a list of networkx graphs 144 | :param graph_pred_list: a list of networkx graphs 145 | :param bins: number of bins for the histogram 146 | :param is_parallel: whether to use parallel computing 147 | :return: the distance between the two clustering coefficient distributions 148 | """ 149 | sample_ref = [] 150 | sample_pred = [] 151 | graph_pred_list_remove_empty = [G for G in graph_pred_list if not G.number_of_nodes() == 0] 152 | 153 | prev = datetime.now() 154 | if is_parallel: 155 | with concurrent.futures.ThreadPoolExecutor() as executor: 156 | for clustering_hist in executor.map(clustering_worker, [(G, bins) for G in graph_ref_list]): 157 | sample_ref.append(clustering_hist) 158 | with concurrent.futures.ThreadPoolExecutor() as executor: 159 | for clustering_hist in executor.map(clustering_worker, [(G, bins) for G in graph_pred_list_remove_empty]): 160 | sample_pred.append(clustering_hist) 161 | else: 162 | for i in range(len(graph_ref_list)): 163 | clustering_coeffs_list = list(nx.clustering(graph_ref_list[i]).values()) 164 | hist, _ = np.histogram(clustering_coeffs_list, bins=bins, range=(0.0, 1.0), density=False) 165 | sample_ref.append(hist) 166 | 167 | for i in range(len(graph_pred_list_remove_empty)): 168 | clustering_coeffs_list = list(nx.clustering(graph_pred_list_remove_empty[i]).values()) 169 | hist, _ = np.histogram(clustering_coeffs_list, bins=bins, range=(0.0, 1.0), density=False) 170 | sample_pred.append(hist) 171 | 172 | mmd_dist = compute_mmd(sample_ref, sample_pred, kernel=gaussian_emd, sigma=1.0 / 10, distance_scaling=bins) 173 | elapsed = datetime.now() - prev 174 | if PRINT_TIME: 175 | print("Time computing clustering mmd: ", elapsed) 176 | return mmd_dist 177 | 178 | 179 | # maps motif/orbit name string to its corresponding list of indices from orca output 180 | motif_to_indices = {"3path": [1, 2], "4cycle": [8]} 181 | COUNT_START_STR = "orbit counts: \n" 182 | 183 | 184 | def edge_list_reindexed(G): 185 | """ 186 | Convert a graph to a list of edges, where the nodes are reindexed to be integers from 0 to n-1. 187 | :param G: a networkx graph 188 | :return: a list of edges, where each edge is a tuple of integers 189 | """ 190 | idx = 0 191 | id2idx = dict() 192 | for u in G.nodes(): 193 | id2idx[str(u)] = idx 194 | idx += 1 195 | 196 | edges = [] 197 | for u, v in G.edges(): 198 | edges.append((id2idx[str(u)], id2idx[str(v)])) 199 | return edges 200 | 201 | 202 | def orca(graph): 203 | """ 204 | Compute the orbit counts of a graph. 205 | :param graph: a networkx graph 206 | :return: a numpy array of shape (n, 2), where n is the number of nodes in the graph. The first column is the node index, and the second column is the orbit count. 207 | """ 208 | tmp_file_path = os.path.join(ORCA_DIR, "tmp.txt") 209 | f = open(tmp_file_path, "w") 210 | f.write(str(graph.number_of_nodes()) + " " + str(graph.number_of_edges()) + "\n") 211 | for u, v in edge_list_reindexed(graph): 212 | f.write(str(u) + " " + str(v) + "\n") 213 | f.close() 214 | 215 | output = sp.check_output([os.path.join(ORCA_DIR, "orca"), "node", "4", tmp_file_path, "std"]) 216 | output = output.decode("utf8").strip() 217 | idx = output.find(COUNT_START_STR) + len(COUNT_START_STR) 218 | output = output[idx:] 219 | node_orbit_counts = np.array( 220 | [list(map(int, node_cnts.strip().split(" "))) for node_cnts in output.strip("\n").split("\n")] 221 | ) 222 | 223 | return node_orbit_counts 224 | 225 | 226 | def orbit_stats_all(graph_ref_list, graph_pred_list): 227 | """ 228 | Compute the distance between the orbit counts of two unordered sets of graphs. 229 | :param graph_ref_list: a list of networkx graphs 230 | :param graph_pred_list: a list of networkx graphs 231 | :return: the distance between the two orbit counts 232 | """ 233 | total_counts_ref = [] 234 | total_counts_pred = [] 235 | 236 | for G in graph_ref_list: 237 | try: 238 | orbit_counts = orca(G) 239 | except Exception: 240 | continue 241 | 242 | orbit_counts_graph = np.sum(orbit_counts, axis=0) / G.number_of_nodes() 243 | total_counts_ref.append(orbit_counts_graph) 244 | 245 | for G in graph_pred_list: 246 | try: 247 | orbit_counts = orca(G) 248 | except Exception: 249 | continue 250 | orbit_counts_graph = np.sum(orbit_counts, axis=0) / G.number_of_nodes() 251 | total_counts_pred.append(orbit_counts_graph) 252 | 253 | total_counts_ref = np.array(total_counts_ref) 254 | total_counts_pred = np.array(total_counts_pred) 255 | mmd_dist = compute_mmd(total_counts_ref, total_counts_pred, kernel=gaussian, is_hist=False, sigma=30.0) 256 | 257 | return mmd_dist 258 | 259 | 260 | def adjs_to_graphs(adjs): 261 | """ 262 | Convert a list of adjacency matrices to a list of networkx graphs. 263 | :param adjs: a list of adjacency matrices 264 | :return: a list of networkx graphs 265 | """ 266 | graph_list = [] 267 | for adj in adjs: 268 | G = nx.from_numpy_matrix(adj) 269 | G.remove_edges_from(nx.selfloop_edges(G)) 270 | G.remove_nodes_from(list(nx.isolates(G))) 271 | if G.number_of_nodes() < 1: 272 | G.add_node(1) 273 | graph_list.append(G) 274 | return graph_list 275 | 276 | 277 | def is_lobster_graph(G): 278 | """ 279 | Check a given graph is a lobster graph or not (lobster -> caterpillar -> path) 280 | :param G: a networkx graph 281 | :return: True if the graph is a lobster graph, False otherwise 282 | """ 283 | # Check if G is a tree 284 | if nx.is_tree(G): 285 | leaves = [n for n, d in G.degree() if d == 1] 286 | G.remove_nodes_from(leaves) 287 | 288 | leaves = [n for n, d in G.degree() if d == 1] 289 | G.remove_nodes_from(leaves) 290 | 291 | num_nodes = len(G.nodes()) 292 | num_degree_one = [d for n, d in G.degree() if d == 1] 293 | num_degree_two = [d for n, d in G.degree() if d == 2] 294 | 295 | if sum(num_degree_one) == 2 and sum(num_degree_two) == 2 * (num_nodes - 2): 296 | return True 297 | elif sum(num_degree_one) == 0 and sum(num_degree_two) == 0: 298 | return True 299 | else: 300 | return False 301 | else: 302 | return False 303 | 304 | 305 | def eval_acc_lobster_graph(G_list): 306 | """ 307 | Compute the accuracy of a list of graphs being lobster graphs. 308 | :param G_list: a list of networkx graphs 309 | :return: the accuracy of the list of graphs being lobster graphs 310 | """ 311 | G_list = [copy.deepcopy(gg) for gg in G_list] 312 | 313 | count = 0 314 | for gg in G_list: 315 | if is_lobster_graph(gg): 316 | count += 1 317 | 318 | return count / float(len(G_list)) 319 | 320 | 321 | METHOD_NAME_TO_FUNC = { 322 | "degree": degree_stats, 323 | "cluster": clustering_stats, 324 | "orbit": orbit_stats_all, 325 | "spectral": spectral_stats, 326 | } 327 | 328 | 329 | def eval_graph_list(graph_ref_list, grad_pred_list, methods=None): 330 | """ 331 | Compute the evaluation metrics for a list of graphs. 332 | :param graph_ref_list: a list of networkx graphs 333 | :param grad_pred_list: a list of networkx graphs 334 | :param methods: a list of evaluation methods to be used 335 | :return: a dictionary of evaluation results 336 | """ 337 | if methods is None: 338 | methods = ["degree", "cluster", "spectral", "orbit"] 339 | results = {} 340 | for method in methods: 341 | results[method] = METHOD_NAME_TO_FUNC[method](graph_ref_list, grad_pred_list) 342 | if "orbit" not in methods: 343 | results["orbit"] = 0.0 344 | print(results) 345 | return results 346 | 347 | 348 | def eval_torch_batch(ref_batch, pred_batch, methods=None): 349 | """ 350 | Compute the evaluation metrics for a batch of graphs. 351 | :param ref_batch: a batch of adjacency matrices 352 | :param pred_batch: a batch of adjacency matrices 353 | :param methods: a list of evaluation methods to be used 354 | :return: a dictionary of evaluation results 355 | """ 356 | graph_ref_list = adjs_to_graphs(ref_batch.detach().cpu().numpy()) 357 | grad_pred_list = adjs_to_graphs(pred_batch.detach().cpu().numpy()) 358 | results = eval_graph_list(graph_ref_list, grad_pred_list, methods=methods) 359 | return results 360 | -------------------------------------------------------------------------------- /evaluation/in_distribution/orca/orca: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Graph-and-Geometric-Learning/D4Explainer/997b4c755fc19d0494c09bc3bd4925777ea53aca/evaluation/in_distribution/orca/orca -------------------------------------------------------------------------------- /evaluation/in_distribution/orca/orca.exe: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Graph-and-Geometric-Learning/D4Explainer/997b4c755fc19d0494c09bc3bd4925777ea53aca/evaluation/in_distribution/orca/orca.exe -------------------------------------------------------------------------------- /evaluation/in_distribution/orca/test.txt: -------------------------------------------------------------------------------- 1 | 4 4 2 | 0 1 3 | 1 2 4 | 2 3 5 | 3 0 6 | 7 | -------------------------------------------------------------------------------- /evaluation/in_distribution/orca/tmp.txt: -------------------------------------------------------------------------------- 1 | 4 3 2 | 0 1 3 | 1 2 4 | 1 3 5 | -------------------------------------------------------------------------------- /evaluation/ood_evaluation.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import sys 4 | sys.path.append("..") 5 | import torch 6 | from torch_geometric.data import DataLoader 7 | from torch_geometric.utils import to_networkx 8 | from tqdm import tqdm 9 | 10 | from constants import dataset_choices, feature_dict, task_type 11 | from evaluation.in_distribution.ood_stat import eval_graph_list 12 | from explainers import DiffExplainer 13 | from gnns import * 14 | from utils.dataset import get_datasets 15 | 16 | 17 | def parse_args(): 18 | parser = argparse.ArgumentParser(description="in-distribution evaluation") 19 | parser.add_argument("--cuda", type=int, default=0, help="GPU device.") 20 | parser.add_argument("--root", type=str, default="../results", help="Result directory.") 21 | parser.add_argument("--dataset", type=str, default="NCI1", choices=dataset_choices) 22 | parser.add_argument("--gnn_type", type=str, default="gcn") 23 | parser.add_argument("--task", type=str, default="nc") 24 | parser.add_argument("--num_test", type=int, default=50) 25 | parser.add_argument("--normalization", type=str, default="instance") 26 | parser.add_argument("--num_layers", type=int, default=6) 27 | parser.add_argument("--layers_per_conv", type=int, default=1) 28 | parser.add_argument("--n_hidden", type=int, default=64) 29 | parser.add_argument("--cat_output", type=bool, default=True) 30 | parser.add_argument("--residual", type=bool, default=False) 31 | parser.add_argument("--noise_mlp", type=bool, default=True) 32 | parser.add_argument("--simplified", type=bool, default=False) 33 | parser.add_argument("--dropout", type=float, default=0.001) 34 | parser.add_argument("--prob_low", type=float, default=0.0) 35 | parser.add_argument("--prob_high", type=float, default=0.4) 36 | parser.add_argument("--sigma_length", type=int, default=10) 37 | return parser.parse_args() 38 | 39 | 40 | args = parse_args() 41 | args.device = torch.device(f"cuda:{args.cuda}" if torch.cuda.is_available() else "cpu") 42 | mr = [0.2] 43 | args.noise_list = None 44 | args.feature_in = feature_dict[args.dataset] 45 | args.task = task_type[args.dataset] 46 | train_dataset, val_dataset, test_dataset = get_datasets(name=args.dataset, root="../data/") 47 | test_loader = DataLoader(test_dataset[: args.num_test], batch_size=1, shuffle=False, drop_last=False) 48 | gnn_path = f"../param/gnns/{args.dataset}_{args.gnn_type}.pt" 49 | Explainer = DiffExplainer(args.device, gnn_path) 50 | test_graph = [] 51 | pred_graph = [] 52 | for graph in test_loader: 53 | graph.to(args.device) 54 | exp_subgraph,_,_,_ = Explainer.explain_evaluation(args, graph) 55 | G_ori = to_networkx(graph, to_undirected=True) 56 | G_pred = to_networkx(exp_subgraph, to_undirected=True) 57 | test_graph.append(G_ori) 58 | pred_graph.append(G_pred) 59 | MMD = eval_graph_list(test_graph, pred_graph) 60 | print(MMD) 61 | -------------------------------------------------------------------------------- /evaluation/robustness.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import copy 3 | import math 4 | import os 5 | import sys 6 | sys.path.append("..") 7 | import numpy as np 8 | import torch 9 | from torch_geometric.data import DataLoader 10 | from tqdm import tqdm 11 | 12 | from constants import feature_dict, task_type 13 | from explainers import * 14 | from explainers.base import Explainer as BaseExplainer 15 | from explainers.diff_explainer import Powerful, sparsity 16 | from explainers.diffusion.graph_utils import ( 17 | gen_list_of_data_single, 18 | generate_mask, 19 | graph2tensor, 20 | ) 21 | from gnns import * 22 | from utils.dataset import get_datasets 23 | 24 | 25 | class DiffExplainer(BaseExplainer): 26 | def __init__(self, device, gnn_model_path, task, args): 27 | super(DiffExplainer, self).__init__(device, gnn_model_path, task) 28 | self.device = device 29 | self.model = Powerful(args).to(args.device) 30 | exp_dir = f"{args.root}/{args.dataset}/" 31 | self.model.load_state_dict(torch.load(os.path.join(exp_dir, "best_model.pth"), map_location="cuda:0")["model"]) 32 | self.model.eval() 33 | 34 | def explain_graph(self, model, graph, adj_b, x_b, node_flag_b, sigma_list, args): 35 | sigma_list = [sigma / 20 for sigma in sigma_list] 36 | _, _, _, test_noise_adj_b, _ = gen_list_of_data_single(x_b, adj_b, node_flag_b, sigma_list, args) 37 | test_noise_adj_b_chunked = test_noise_adj_b.chunk(len(sigma_list), dim=0) 38 | score = [] 39 | mask = generate_mask(node_flag_b) 40 | for i, sigma in enumerate(sigma_list): 41 | # [1, N, N, 1] 42 | score_batch = self.model( 43 | A=test_noise_adj_b_chunked[i].to(args.device), 44 | node_features=x_b, 45 | mask=mask, 46 | noiselevel=sigma, 47 | ).to(args.device) 48 | score.append(score_batch) 49 | score_tensor = torch.stack(score, dim=0) # [len_sigma_list, 1, N, N, 1] 50 | score_tensor = torch.mean(score_tensor, dim=0) # [1, N, N, 1] 51 | y_exp = None # output_prob_cont.argmax(dim=-1) 52 | 53 | modif_r = sparsity([score_tensor], adj_b, mask) 54 | 55 | score_tensor = score_tensor[0, :, :, 0] 56 | score_tensor[score_tensor < 0] = 0 57 | return score_tensor, modif_r, y_exp 58 | 59 | 60 | class RandomExplainer(BaseExplainer): 61 | def explain_graph(self, graph): 62 | return torch.randn(graph.edge_index.shape[1]) 63 | 64 | 65 | def parse_args(): 66 | parser = argparse.ArgumentParser(description="Robustness Experiment") 67 | parser.add_argument("--cuda", type=int, default=0, help="GPU device.") 68 | parser.add_argument("--root", type=str, default="../results", help="Result directory.") 69 | parser.add_argument("--dataset", type=str, default="NCI1") 70 | parser.add_argument("--explainer", type=str, default="DiffExplainer") 71 | parser.add_argument("--mod-ratio", type=float, default=0.2, help="Modification Ratio") 72 | parser.add_argument("--k", type=int, default=8, help="Top-K") 73 | # gflow explainer related parameters 74 | parser.add_argument("--gnn_type", type=str, default="gcn") 75 | parser.add_argument("--task", type=str, default="nc") 76 | parser.add_argument("--normalization", type=str, default="instance") 77 | parser.add_argument("--verbose", type=int, default=10) 78 | parser.add_argument("--num_layers", type=int, default=6) 79 | parser.add_argument("--layers_per_conv", type=int, default=1) 80 | parser.add_argument("--train_batchsize", type=int, default=32) 81 | parser.add_argument("--test_batchsize", type=int, default=32) 82 | parser.add_argument("--sigma_length", type=int, default=5) 83 | parser.add_argument("--epoch", type=int, default=3000) 84 | parser.add_argument("--feature_in", type=int) 85 | parser.add_argument("--n_hidden", type=int, default=64) 86 | parser.add_argument("--data_size", type=int, default=-1) 87 | 88 | parser.add_argument("--threshold", type=float, default=0.5) 89 | parser.add_argument("--alpha_cf", type=float, default=0.05) 90 | parser.add_argument("--dropout", type=float, default=0.001) 91 | parser.add_argument("--learning_rate", type=float, default=1e-3) 92 | parser.add_argument("--lr_decay", type=float, default=0.999) 93 | parser.add_argument("--weight_decay", type=float, default=0) 94 | parser.add_argument("--prob_low", type=float, default=0.0) 95 | parser.add_argument("--prob_high", type=float, default=0.4) 96 | parser.add_argument("--sparsity_level", type=float, default=2.5) 97 | 98 | parser.add_argument("--cat_output", type=bool, default=True) 99 | parser.add_argument("--residual", type=bool, default=False) 100 | parser.add_argument("--noise_mlp", type=bool, default=True) 101 | parser.add_argument("--simplified", type=bool, default=False) 102 | return parser.parse_args() 103 | 104 | 105 | args = parse_args() 106 | args.device = torch.device(f"cuda:{args.cuda}" if torch.cuda.is_available() else "cpu") 107 | args.feature_in = feature_dict[args.dataset] 108 | args.task = task_type[args.dataset] 109 | 110 | train_dataset, val_dataset, test_dataset = get_datasets(name=args.dataset, root="../data") 111 | test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, drop_last=False) 112 | gnn_path = f"../param/gnns/{args.dataset}_gcn.pt" 113 | model = torch.load(gnn_path, map_location=args.device).to(args.device) 114 | 115 | def get_graph_pred(graph): 116 | if args.task == "nc": 117 | output_prob, _ = model.get_node_pred_subgraph(x=graph.x, edge_index=graph.edge_index, mapping=graph.mapping) 118 | else: 119 | output_prob, _ = model.get_pred(x=graph.x, edge_index=graph.edge_index, batch=graph.batch) 120 | orig_pred = output_prob.argmax(dim=-1) 121 | return orig_pred.item() 122 | 123 | 124 | explainer = DiffExplainer(device=args.device, gnn_model_path=gnn_path, task=args.task, args=args) 125 | 126 | for sigma in range(0, 11): 127 | sigma /= 100 128 | acc_logger = [] 129 | orig_modif_r_arr = [] 130 | noisy_modif_r_arr = [] 131 | 132 | noisy_graph_same_class = [] 133 | for graph in tqdm(iter(test_loader), total=len(test_loader)): 134 | adj_b, x_b = graph2tensor(graph.to(args.device), device=args.device) 135 | x_b = x_b.to(args.device) 136 | node_flag_b = adj_b.sum(-1).gt(1e-5).to(dtype=torch.float32) 137 | _, _, _, noisy_adj_b, _ = gen_list_of_data_single(x_b, adj_b, node_flag_b, [sigma], args) 138 | 139 | noisy_graph = copy.deepcopy(graph) 140 | noisy_graph.edge_index = noisy_adj_b[0].nonzero().t() 141 | 142 | orig_pred = get_graph_pred(graph) 143 | noisy_pred = get_graph_pred(noisy_graph) 144 | 145 | # 2D arrays [N, N] of the importance of full adjacency matrix 146 | sigma_list = list(np.random.uniform(low=args.prob_low, high=args.prob_high, size=args.sigma_length)) 147 | orig_edge_imp, orig_modif_r, exp_pred = explainer.explain_graph( 148 | model, graph, adj_b, x_b, node_flag_b, sigma_list, args 149 | ) 150 | noisy_edge_imp, noisy_modif_r, noisy_exp_pred = explainer.explain_graph( 151 | model, noisy_graph, noisy_adj_b, x_b, node_flag_b, sigma_list, args 152 | ) 153 | 154 | orig_modif_r_arr.append(orig_modif_r.item()) 155 | noisy_modif_r_arr.append(noisy_modif_r.item()) 156 | 157 | n_nodes = orig_edge_imp.shape[0] 158 | 159 | # K edges with largest counterfactual importance 160 | try: 161 | _, indices = orig_edge_imp.flatten().topk(8) 162 | except Exception: 163 | t = int(n_nodes * n_nodes) 164 | _, indices = orig_edge_imp.flatten().topk(t) 165 | 166 | top_k_orig_exp_edges = torch.stack([indices // n_nodes, indices % n_nodes]) 167 | 168 | # all edges with positive counterfactual importance 169 | noisy_exp_edges = noisy_edge_imp.nonzero().T 170 | 171 | noisy_graph_same_class.append(1 if orig_pred == noisy_pred else 0) 172 | 173 | num_intersect = 0 174 | n_orig_edges = min(args.k, top_k_orig_exp_edges.shape[1]) 175 | for i in range(n_orig_edges): 176 | if (top_k_orig_exp_edges[:, i].unsqueeze(1) == noisy_exp_edges).all(0).any(): 177 | num_intersect += 1 178 | acc = num_intersect / n_orig_edges 179 | acc_logger.append(acc) 180 | 181 | print("Sigma", sigma) 182 | print("Top K Accuracy", round(np.array(acc_logger).mean(), 5)) 183 | print("Noisy Graph Same Class", np.array(noisy_graph_same_class).mean()) 184 | print( 185 | "Modification Ratio", 186 | np.array(orig_modif_r_arr).mean(), 187 | np.array(noisy_modif_r_arr).mean(), 188 | ) 189 | -------------------------------------------------------------------------------- /explainers/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Graph-and-Geometric-Learning/D4Explainer/997b4c755fc19d0494c09bc3bd4925777ea53aca/explainers/.DS_Store -------------------------------------------------------------------------------- /explainers/__init__.py: -------------------------------------------------------------------------------- 1 | from .diff_explainer import DiffExplainer 2 | 3 | __all__ = [ 4 | "DiffExplainer", 5 | ] 6 | -------------------------------------------------------------------------------- /explainers/base.py: -------------------------------------------------------------------------------- 1 | import io 2 | import math 3 | import os 4 | from pathlib import Path 5 | 6 | import matplotlib.pyplot as plt 7 | import networkx as nx 8 | import numpy as np 9 | import torch 10 | from matplotlib import cm 11 | from PIL import Image 12 | 13 | from .visual import graph_to_mol, vis_dict 14 | 15 | EPS = 1e-6 16 | 17 | 18 | class Explainer(object): 19 | def __init__(self, device, gnn_model_path, task="gc"): 20 | self.device = device 21 | self.model = torch.load(gnn_model_path, map_location=self.device).to(self.device) 22 | self.model.eval() 23 | self.model_name = self.model.__class__.__name__ 24 | self.name = self.__class__.__name__ 25 | 26 | self.path = gnn_model_path 27 | self.last_result = None 28 | self.vis_dict = None 29 | self.task = task 30 | 31 | def explain_graph(self, graph, **kwargs): 32 | """ 33 | Main part for different graph attribution methods 34 | :param graph: target graph instance to be explained 35 | :param kwargs: other parameters 36 | :return: edge_imp, i.e., attributions for edges, which are derived from the attribution methods. 37 | """ 38 | raise NotImplementedError 39 | 40 | @staticmethod 41 | def get_rank(lst, r=1): 42 | topk_idx = list(np.argsort(-lst)) 43 | top_pred = np.zeros_like(lst) 44 | n = len(lst) 45 | k = int(r * n) 46 | for i in range(k): 47 | top_pred[topk_idx[i]] = n - i 48 | return top_pred 49 | 50 | @staticmethod 51 | def norm_imp(imp): 52 | imp[imp < 0] = 0 53 | imp += 1e-16 54 | return imp / imp.sum() 55 | 56 | def __relabel__(self, g, edge_index): 57 | sub_nodes = torch.unique(edge_index) 58 | x = g.x[sub_nodes] 59 | batch = g.batch[sub_nodes] 60 | row, col = edge_index 61 | pos = None 62 | try: 63 | pos = g.pos[sub_nodes] 64 | except Exception: 65 | pass 66 | 67 | # remapping the nodes in the explanatory subgraph to new ids. 68 | node_idx = row.new_full((g.num_nodes,), -1) 69 | node_idx[sub_nodes] = torch.arange(sub_nodes.size(0), device=row.device) 70 | edge_index = node_idx[edge_index] 71 | return x, edge_index, batch, pos 72 | 73 | def __reparameterize__(self, log_alpha, beta=0.1, training=True): 74 | if training: 75 | random_noise = torch.rand(log_alpha.size()).to(self.device) 76 | gate_inputs = torch.log2(random_noise) - torch.log2(1.0 - random_noise) 77 | gate_inputs = (gate_inputs + log_alpha) / beta + EPS 78 | gate_inputs = gate_inputs.sigmoid() 79 | else: 80 | gate_inputs = log_alpha.sigmoid() 81 | 82 | return gate_inputs 83 | 84 | def pack_explanatory_subgraph(self, top_ratio=0.2, graph=None, imp=None, relabel=False, if_cf=False): 85 | """ 86 | Pack the explanatory subgraph from the original graph 87 | :param top_ratio: the ratio of edges to be selected 88 | :param graph: the original graph 89 | :param imp: the attribution scores for edges 90 | :param relabel: whether to relabel the nodes in the explanatory subgraph 91 | :param if_cf: whether to use the CF method 92 | :return: the explanatory subgraph 93 | """ 94 | if graph is None: 95 | graph, imp = self.last_result 96 | assert len(imp) == graph.num_edges, "length mismatch" 97 | 98 | top_idx = torch.LongTensor([]) 99 | graph_map = graph.batch[graph.edge_index[0, :]] 100 | exp_subgraph = graph.clone() 101 | exp_subgraph.y = graph.y if self.task == "gc" else graph.self_y 102 | for i in range(graph.num_graphs): 103 | edge_indicator = torch.where(graph_map == i)[0].detach().cpu() 104 | Gi_n_edge = len(edge_indicator) 105 | topk = min(max(math.ceil(top_ratio * Gi_n_edge), 1), Gi_n_edge) 106 | if not if_cf: 107 | Gi_pos_edge_idx = np.argsort(-imp[edge_indicator])[:topk] 108 | else: 109 | Gi_pos_edge_idx = np.argsort(-imp[edge_indicator])[topk:] 110 | top_idx = torch.cat([top_idx, edge_indicator[Gi_pos_edge_idx]]) 111 | try: 112 | exp_subgraph.edge_attr = graph.edge_attr[top_idx] 113 | except Exception: 114 | pass 115 | exp_subgraph.edge_index = graph.edge_index[:, top_idx] 116 | 117 | exp_subgraph.x = graph.x 118 | if relabel: 119 | (exp_subgraph.x, exp_subgraph.edge_index, exp_subgraph.batch, exp_subgraph.pos) = self.__relabel__( 120 | exp_subgraph, exp_subgraph.edge_index 121 | ) 122 | return exp_subgraph 123 | 124 | def evaluate_acc(self, top_ratio_list, graph=None, imp=None, if_cf=False): 125 | """ 126 | Evaluate the accuracy of the explanatory subgraph 127 | :param top_ratio_list: the ratio of edges to be selected 128 | :param graph: the original graph 129 | :param imp: the attribution scores for edges 130 | :param if_cf: whether to generate cf explanation 131 | :return: the accuracy of the explanatory subgraph 132 | """ 133 | if graph is None: 134 | assert self.last_result is not None 135 | graph, imp = self.last_result 136 | acc = np.array([[]]) 137 | fidelity = np.array([[]]) 138 | if self.task == "nc": 139 | output_prob, _ = self.model.get_node_pred_subgraph( 140 | x=graph.x, edge_index=graph.edge_index, mapping=graph.mapping 141 | ) 142 | else: 143 | output_prob, _ = self.model.get_pred(x=graph.x, edge_index=graph.edge_index, batch=graph.batch) 144 | y_pred = output_prob.argmax(dim=-1) 145 | for idx, top_ratio in enumerate(top_ratio_list): 146 | exp_subgraph = self.pack_explanatory_subgraph(top_ratio, graph=graph, imp=imp, if_cf=if_cf) 147 | if self.task == "nc": 148 | soft_pred, _ = self.model.get_node_pred_subgraph( 149 | x=exp_subgraph.x, edge_index=exp_subgraph.edge_index, mapping=exp_subgraph.mapping 150 | ) 151 | else: 152 | soft_pred, _ = self.model.get_pred( 153 | x=exp_subgraph.x, edge_index=exp_subgraph.edge_index, batch=exp_subgraph.batch 154 | ) 155 | # soft_pred: [bsz, num_class] 156 | res_acc = (y_pred == soft_pred.argmax(dim=-1)).detach().cpu().float().view(-1, 1).numpy() 157 | labels = torch.LongTensor([[i] for i in y_pred]).to(y_pred.device) 158 | if not if_cf: 159 | res_fid = soft_pred.gather(1, labels).detach().cpu().float().view(-1, 1).numpy() 160 | else: 161 | res_fid = (1 - soft_pred.gather(1, labels)).detach().cpu().float().view(-1, 1).numpy() 162 | acc = np.concatenate([acc, res_acc], axis=1) # [bsz, len_ratio_list] 163 | fidelity = np.concatenate([fidelity, res_fid], axis=1) 164 | return acc, fidelity 165 | 166 | def visualize( 167 | self, graph=None, edge_imp=None, counter_edge_index=None, vis_ratio=0.2, save=False, layout=False, name=None 168 | ): 169 | """ 170 | Visualize the attribution scores for edges (xx-Motif / Mutag) 171 | # TODO: visualization for BBBP / node classification 172 | :param graph: the original graph 173 | :param edge_imp: the attribution scores for edges 174 | :param counter_edge_index: the counterfactual edges 175 | :param vis_ratio: the ratio of edges to be visualized 176 | :param save: whether to save the visualization 177 | :param layout: whether to use the layout 178 | :param name: the name of the visualization 179 | :return: None 180 | """ 181 | if graph is None: 182 | assert self.last_result is not None 183 | graph, edge_imp = self.last_result 184 | 185 | topk = max(int(vis_ratio * graph.num_edges), 1) 186 | idx = np.argsort(-edge_imp)[:topk] 187 | G = nx.DiGraph() 188 | G.add_nodes_from(range(graph.num_nodes)) 189 | G.add_edges_from(list(graph.edge_index.cpu().numpy().T)) 190 | 191 | if counter_edge_index is not None: 192 | G.add_edges_from(list(counter_edge_index.cpu().numpy().T)) 193 | if self.vis_dict is None: 194 | self.vis_dict = vis_dict[self.model_name] if self.model_name in vis_dict.keys() else vis_dict["default"] 195 | 196 | folder = Path(r"image/%s" % (self.model_name)) 197 | if save and not os.path.exists(folder): 198 | os.makedirs(folder) 199 | 200 | edge_pos_mask = np.zeros(graph.num_edges, dtype=np.bool_) 201 | edge_pos_mask[idx] = True 202 | vmax = sum(edge_pos_mask) 203 | node_pos_mask = np.zeros(graph.num_nodes, dtype=np.bool_) 204 | node_neg_mask = np.zeros(graph.num_nodes, dtype=np.bool_) 205 | node_pos_idx = np.unique(graph.edge_index[:, edge_pos_mask].cpu().numpy()).tolist() 206 | node_neg_idx = list(set([i for i in range(graph.num_nodes)]) - set(node_pos_idx)) 207 | node_pos_mask[node_pos_idx] = True 208 | node_neg_mask[node_neg_idx] = True 209 | 210 | if "Motif" in self.model_name: 211 | plt.figure(figsize=(8, 6), dpi=100) 212 | pos = graph.pos[0] 213 | nx.draw_networkx_nodes( 214 | G, 215 | pos={i: pos[i] for i in node_pos_idx}, 216 | nodelist=node_pos_idx, 217 | node_size=self.vis_dict["node_size"], 218 | node_color=graph.z[0][node_pos_idx], 219 | alpha=1, 220 | cmap="winter", 221 | linewidths=self.vis_dict["linewidths"], 222 | edgecolors="red", 223 | vmin=-max(graph.z[0]), 224 | vmax=max(graph.z[0]), 225 | ) 226 | nx.draw_networkx_nodes( 227 | G, 228 | pos={i: pos[i] for i in node_neg_idx}, 229 | nodelist=node_neg_idx, 230 | node_size=self.vis_dict["node_size"], 231 | node_color=graph.z[0][node_neg_idx], 232 | alpha=0.2, 233 | cmap="winter", 234 | linewidths=self.vis_dict["linewidths"], 235 | edgecolors="whitesmoke", 236 | vmin=-max(graph.z[0]), 237 | vmax=max(graph.z[0]), 238 | ) 239 | nx.draw_networkx_edges( 240 | G, 241 | pos=pos, 242 | edgelist=list(graph.edge_index.cpu().numpy().T), 243 | edge_color="whitesmoke", 244 | width=self.vis_dict["width"], 245 | arrows=False, 246 | ) 247 | nx.draw_networkx_edges( 248 | G, 249 | pos=pos, 250 | edgelist=list(graph.edge_index[:, edge_pos_mask].cpu().numpy().T), 251 | edge_color=self.get_rank(edge_imp[edge_pos_mask]), 252 | # np.ones(len(edge_imp[edge_pos_mask])), 253 | width=self.vis_dict["width"], 254 | edge_cmap=cm.get_cmap("bwr"), 255 | edge_vmin=-vmax, 256 | edge_vmax=vmax, 257 | arrows=False, 258 | ) 259 | if counter_edge_index is not None: 260 | nx.draw_networkx_edges( 261 | G, 262 | pos=pos, 263 | edgelist=list(counter_edge_index.cpu().numpy().T), 264 | edge_color="mediumturquoise", 265 | width=self.vis_dict["width"] / 3.0, 266 | arrows=False, 267 | ) 268 | 269 | if "Mutag" in self.model_name: 270 | from rdkit.Chem.Draw import rdMolDraw2D 271 | 272 | idx = [int(i / 2) for i in idx] 273 | x = graph.x.detach().cpu().tolist() 274 | edge_index = graph.edge_index.T.detach().cpu().tolist() 275 | edge_attr = graph.edge_attr.detach().cpu().tolist() 276 | mol = graph_to_mol(x, edge_index, edge_attr) 277 | d = rdMolDraw2D.MolDraw2DCairo(500, 500) 278 | hit_at = np.unique(graph.edge_index[:, idx].detach().cpu().numpy()).tolist() 279 | 280 | def add_atom_index(mol): 281 | atoms = mol.GetNumAtoms() 282 | for i in range(atoms): 283 | mol.GetAtomWithIdx(i).SetProp("molAtomMapNumber", str(mol.GetAtomWithIdx(i).GetIdx())) 284 | return mol 285 | 286 | hit_bonds = [] 287 | for u, v in graph.edge_index.T[idx]: 288 | hit_bonds.append(mol.GetBondBetweenAtoms(int(u), int(v)).GetIdx()) 289 | rdMolDraw2D.PrepareAndDrawMolecule( 290 | d, 291 | mol, 292 | highlightAtoms=hit_at, 293 | highlightBonds=hit_bonds, 294 | highlightAtomColors={i: (0, 1, 0) for i in hit_at}, 295 | highlightBondColors={i: (0, 1, 0) for i in hit_bonds}, 296 | ) 297 | d.FinishDrawing() 298 | bindata = d.GetDrawingText() 299 | iobuf = io.BytesIO(bindata) 300 | image = Image.open(iobuf) 301 | image.show() 302 | if save: 303 | if name: 304 | d.WriteDrawingText("image/%s/%s-%d-%s.png" % (self.model_name, name, int(graph.y[0]), self.name)) 305 | else: 306 | d.WriteDrawingText( 307 | "image/%s/%s-%d-%s.png" % (self.model_name, str(graph.name[0]), int(graph.y[0]), self.name) 308 | ) 309 | -------------------------------------------------------------------------------- /explainers/diff_explainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import torch 5 | from torch_geometric.loader import DataLoader 6 | from torch_geometric.utils import to_undirected 7 | 8 | from explainers.base import Explainer 9 | from explainers.diffusion.graph_utils import ( 10 | gen_full, 11 | gen_list_of_data_single, 12 | generate_mask, 13 | graph2tensor, 14 | tensor2graph, 15 | ) 16 | from explainers.diffusion.pgnn import Powerful 17 | 18 | 19 | def model_save(args, model, mean_train_loss, best_sparsity, mean_test_acc): 20 | """ 21 | Save the model to disk 22 | :param args: arguments 23 | :param model: model 24 | :param mean_train_loss: mean training loss 25 | :param best_sparsity: best sparsity 26 | :param mean_test_acc: mean test accuracy 27 | """ 28 | to_save = { 29 | "model": model.state_dict(), 30 | "train_loss": mean_train_loss, 31 | "eval sparsity": best_sparsity, 32 | "eval acc": mean_test_acc, 33 | } 34 | exp_dir = f"{args.root}/{args.dataset}/" 35 | os.makedirs(exp_dir, exist_ok=True) 36 | torch.save(to_save, os.path.join(exp_dir, "best_model.pth")) 37 | print(f"save model to {exp_dir}/best_model.pth") 38 | 39 | 40 | def loss_func_bce(score_list, groundtruth, sigma_list, mask, device, sparsity_level): 41 | """ 42 | Loss function for binary cross entropy 43 | param score_list: [len(sigma_list)*bsz, N, N] 44 | param groundtruth: [len(sigma_list)*bsz, N, N] 45 | param sigma_list: list of sigma values 46 | param mask: [len(sigma_list)*bsz, N, N] 47 | param device: device 48 | param sparsity_level: sparsity level 49 | return: BCE loss 50 | """ 51 | bsz = int(score_list.size(0) / len(sigma_list)) 52 | num_node = score_list.size(-1) 53 | score_list = score_list * mask 54 | groundtruth = groundtruth * mask 55 | pos_weight = torch.full([num_node * num_node], sparsity_level).to(device) 56 | BCE = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight, reduction="none") 57 | score_list_ = torch.flatten(score_list, start_dim=1, end_dim=-1) 58 | groundtruth_ = torch.flatten(groundtruth, start_dim=1, end_dim=-1) 59 | loss_matrix = BCE(score_list_, groundtruth_) 60 | loss_matrix = loss_matrix.view(groundtruth.size(0), num_node, num_node) 61 | loss_matrix = loss_matrix * ( 62 | 1 63 | - 2 64 | * torch.tensor(sigma_list) 65 | .repeat(bsz) 66 | .unsqueeze(-1) 67 | .unsqueeze(-1) 68 | .expand(groundtruth.size(0), num_node, num_node) 69 | .to(device) 70 | + 1.0 / len(sigma_list) 71 | ) 72 | loss_matrix = loss_matrix * mask 73 | loss_matrix = (loss_matrix + torch.transpose(loss_matrix, -2, -1)) / 2 74 | loss = torch.mean(loss_matrix) 75 | return loss 76 | 77 | 78 | def sparsity(score, groundtruth, mask, threshold=0.5): 79 | """ 80 | Calculate the sparsity of the predicted adjacency matrix 81 | :param score: [bsz, N, N, 1] 82 | :param groundtruth: [bsz, N, N] 83 | :param mask: [bsz, N, N] 84 | :param threshold: threshold for the predicted adjacency matrix 85 | :return: sparsity 86 | """ 87 | score_tensor = torch.stack(score, dim=0).squeeze(-1) # [len_sigma_list, bsz, N, N] 88 | score_tensor = torch.mean(score_tensor, dim=0) # [bsz, N, N] 89 | pred_adj = torch.where(torch.sigmoid(score_tensor) > threshold, 1, 0).to(groundtruth.device) 90 | pred_adj = pred_adj * mask 91 | groundtruth_ = groundtruth * mask 92 | adj_diff = torch.abs(groundtruth_ - pred_adj) # [bsz, N, N] 93 | num_edge_b = groundtruth_.sum(dim=(1, 2)) 94 | adj_diff_ratio = adj_diff.sum(dim=(1, 2)) / num_edge_b 95 | ratio_average = torch.mean(adj_diff_ratio) 96 | return ratio_average 97 | 98 | 99 | def gnn_pred(graph_batch, graph_batch_sub, gnn_model, ds, task): 100 | """ 101 | Predict the labels of the graph 102 | :param graph_batch: graph batch 103 | :param graph_batch_sub: subgraph batch 104 | :param gnn_model: GNN model 105 | :param ds: dataset 106 | :param task: task 107 | :return: predicted labels (full graph and subgraph) 108 | """ 109 | gnn_model.eval() 110 | if task == "nc": 111 | output_prob, _ = gnn_model.get_node_pred_subgraph( 112 | x=graph_batch.x, 113 | edge_index=graph_batch.edge_index, 114 | mapping=graph_batch.mapping, 115 | ) 116 | output_prob_sub, _ = gnn_model.get_node_pred_subgraph( 117 | x=graph_batch_sub.x, 118 | edge_index=graph_batch_sub.edge_index, 119 | mapping=graph_batch_sub.mapping, 120 | ) 121 | else: 122 | output_prob, _ = gnn_model.get_pred( 123 | x=graph_batch.x, 124 | edge_index=graph_batch.edge_index, 125 | batch=graph_batch.batch, 126 | ) 127 | output_prob_sub, _ = gnn_model.get_pred( 128 | x=graph_batch_sub.x, 129 | edge_index=graph_batch_sub.edge_index, 130 | batch=graph_batch_sub.batch, 131 | ) 132 | 133 | y_pred = output_prob.argmax(dim=-1) 134 | y_exp = output_prob_sub.argmax(dim=-1) 135 | return y_pred, y_exp 136 | 137 | 138 | def loss_cf_exp(gnn_model, graph_batch, score, y_pred, y_exp, full_edge, mask, ds, task="nc"): 139 | """ 140 | Loss function for counterfactual explanation 141 | :param gnn_model: GNN model 142 | :param graph_batch: graph batch 143 | :param score: list of scores 144 | :param y_pred: predicted labels 145 | :param y_exp: predicted labels for subgraph 146 | :param full_edge: full edge index 147 | :param mask: mask 148 | :param ds: dataset 149 | :param task: task 150 | :return: loss 151 | """ 152 | score_tensor = torch.stack(score, dim=0).squeeze(-1) 153 | score_tensor = torch.mean(score_tensor, dim=0).view(-1, 1) 154 | mask_bool = mask.bool().view(-1, 1) 155 | edge_mask_full = score_tensor[mask_bool] 156 | assert edge_mask_full.size(0) == full_edge.size(1) 157 | criterion = torch.nn.NLLLoss() 158 | if task == "nc": 159 | output_prob_cont, output_repr_cont = gnn_model.get_pred_explain( 160 | x=graph_batch.x, 161 | edge_index=full_edge, 162 | edge_mask=edge_mask_full, 163 | mapping=graph_batch.mapping, 164 | ) 165 | else: 166 | output_prob_cont, output_repr_cont = gnn_model.get_pred_explain( 167 | x=graph_batch.x, 168 | edge_index=full_edge, 169 | edge_mask=edge_mask_full, 170 | batch=graph_batch.batch, 171 | ) 172 | n = output_repr_cont.size(-1) 173 | bsz = output_repr_cont.size(0) 174 | y_exp = output_prob_cont.argmax(dim=-1) 175 | inf_diag = torch.diag(-torch.ones((n)) / 0).unsqueeze(0).repeat(bsz, 1, 1).to(y_pred.device) 176 | neg_prop = (output_repr_cont.unsqueeze(1).expand(bsz, n, n) + inf_diag).logsumexp(-1) 177 | neg_prop = neg_prop - output_repr_cont.logsumexp(-1).unsqueeze(1).repeat(1, n) 178 | loss_cf = criterion(neg_prop, y_pred) 179 | labels = torch.LongTensor([[i] for i in y_pred]).to(y_pred.device) 180 | fid_drop = (1 - output_prob_cont.gather(1, labels).view(-1)).detach().cpu().numpy() 181 | fid_drop = np.mean(fid_drop) 182 | acc_cf = float(y_exp.eq(y_pred).sum().item() / y_pred.size(0)) # less, better 183 | return loss_cf, fid_drop, acc_cf 184 | 185 | 186 | class DiffExplainer(Explainer): 187 | def __init__(self, device, gnn_model_path): 188 | super(DiffExplainer, self).__init__(device, gnn_model_path) 189 | 190 | def explain_graph_task(self, args, train_dataset, test_dataset): 191 | """ 192 | Explain the graph for a specific dataset and task 193 | :param args: arguments 194 | :param train_dataset: training dataset 195 | :param test_dataset: test dataset 196 | """ 197 | gnn_model = self.model.to(args.device) 198 | model = Powerful(args).to(args.device) 199 | self.train(args, model, gnn_model, train_dataset, test_dataset) 200 | 201 | def train(self, args, model, gnn_model, train_dataset, test_dataset): 202 | """ 203 | Train the model 204 | :param args: arguments 205 | :param model: Powerful (explanation) model 206 | :param gnn_model: GNN model 207 | :param train_dataset: training dataset 208 | :param test_dataset: test dataset 209 | """ 210 | best_sparsity = np.inf 211 | optimizer = torch.optim.Adam( 212 | model.parameters(), lr=args.learning_rate, betas=(0.9, 0.999), eps=1e-8, weight_decay=args.weight_decay 213 | ) 214 | scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=args.lr_decay) 215 | noise_list = args.noise_list 216 | for epoch in range(args.epoch): 217 | print(f"start epoch {epoch}") 218 | train_losses = [] 219 | train_loss_dist = [] 220 | train_loss_cf = [] 221 | train_acc = [] 222 | train_fid = [] 223 | train_sparsity = [] 224 | train_remain = [] 225 | model.train() 226 | train_loader = DataLoader(train_dataset, batch_size=args.train_batchsize, shuffle=True) 227 | for i, graph in enumerate(train_loader): 228 | if graph.is_directed(): 229 | edge_index_temp = graph.edge_index 230 | graph.edge_index = to_undirected(edge_index=edge_index_temp) 231 | graph.to(args.device) 232 | train_adj_b, train_x_b = graph2tensor(graph, device=args.device) 233 | # train_adj_b: [bsz, N, N]; train_x_b: [bsz, N, C] 234 | sigma_list = ( 235 | list(np.random.uniform(low=args.prob_low, high=args.prob_high, size=args.sigma_length)) 236 | if noise_list is None 237 | else noise_list 238 | ) 239 | train_node_flag_b = train_adj_b.sum(-1).gt(1e-5).to(dtype=torch.float32) # [bsz, N] 240 | # all nodes that are not connected with others 241 | if isinstance(sigma_list, float): 242 | sigma_list = [sigma_list] 243 | (train_x_b, train_ori_adj_b, train_node_flag_sigma, train_noise_adj_b, _) = gen_list_of_data_single( 244 | train_x_b, train_adj_b, train_node_flag_b, sigma_list, args 245 | ) 246 | optimizer.zero_grad() 247 | train_noise_adj_b_chunked = train_noise_adj_b.chunk(len(sigma_list), dim=0) 248 | train_x_b_chunked = train_x_b.chunk(len(sigma_list), dim=0) 249 | train_node_flag_sigma = train_node_flag_sigma.chunk(len(sigma_list), dim=0) 250 | score = [] 251 | masks = [] 252 | for i, sigma in enumerate(sigma_list): 253 | mask = generate_mask(train_node_flag_sigma[i]) 254 | score_batch = model( 255 | A=train_noise_adj_b_chunked[i].to(args.device), 256 | node_features=train_x_b_chunked[i].to(args.device), 257 | mask=mask.to(args.device), 258 | noiselevel=sigma, 259 | ) # [bsz, N, N, 1] 260 | score.append(score_batch) 261 | masks.append(mask) 262 | graph_batch_sub = tensor2graph(graph, score, mask) 263 | y_pred, y_exp = gnn_pred(graph, graph_batch_sub, gnn_model, ds=args.dataset, task=args.task) 264 | full_edge_index = gen_full(graph.batch, mask) 265 | score_b = torch.cat(score, dim=0).squeeze(-1).to(args.device) # [len(sigma_list)*bsz, N, N] 266 | masktens = torch.cat(masks, dim=0).to(args.device) # [len(sigma_list)*bsz, N, N] 267 | modif_r = sparsity(score, train_adj_b, mask) 268 | remain_r = sparsity(score, train_adj_b, train_adj_b) 269 | loss_cf, fid_drop, acc_cf = loss_cf_exp( 270 | gnn_model, graph, score, y_pred, y_exp, full_edge_index, mask, ds=args.dataset, task=args.task 271 | ) 272 | loss_dist = loss_func_bce( 273 | score_b, 274 | train_ori_adj_b, 275 | sigma_list, 276 | masktens, 277 | device=args.device, 278 | sparsity_level=args.sparsity_level, 279 | ) 280 | loss = loss_dist + args.alpha_cf * loss_cf 281 | loss.backward() 282 | optimizer.step() 283 | train_losses.append(loss.item()) 284 | train_loss_dist.append(loss_dist.item()) 285 | train_loss_cf.append(loss_cf.item()) 286 | train_acc.append(acc_cf) 287 | train_fid.append(fid_drop) 288 | train_sparsity.append(modif_r.item()) 289 | train_remain.append(remain_r.item()) 290 | scheduler.step(epoch) 291 | mean_train_loss = np.mean(train_losses) 292 | mean_train_acc = 1- np.mean(train_acc) 293 | mean_train_fidelity = np.mean(train_fid) 294 | mean_train_sparsity = np.mean(train_sparsity) 295 | print( 296 | ( 297 | f"Training Epoch: {epoch} | " 298 | f"training loss: {mean_train_loss} | " 299 | f"training fidelity drop: {mean_train_fidelity} | " 300 | f"training cf acc: {mean_train_acc} | " 301 | f"training average modification: {mean_train_sparsity} | " 302 | ) 303 | ) 304 | # evaluation 305 | if (epoch + 1) % args.verbose == 0: 306 | test_losses = [] 307 | test_loss_dist = [] 308 | test_loss_cf = [] 309 | test_acc = [] 310 | test_fid = [] 311 | test_sparsity = [] 312 | test_remain = [] 313 | test_loader = DataLoader(dataset=test_dataset, batch_size=args.test_batchsize, shuffle=False) 314 | model.eval() 315 | for graph in test_loader: 316 | if graph.is_directed(): 317 | edge_index_temp = graph.edge_index 318 | graph.edge_index = to_undirected(edge_index=edge_index_temp) 319 | 320 | graph.to(args.device) 321 | test_adj_b, test_x_b = graph2tensor(graph, device=args.device) 322 | test_x_b = test_x_b.to(args.device) 323 | test_node_flag_b = test_adj_b.sum(-1).gt(1e-5).to(dtype=torch.float32) 324 | sigma_list = ( 325 | list(np.random.uniform(low=args.prob_low, high=args.prob_high, size=args.sigma_length)) 326 | if noise_list is None 327 | else noise_list 328 | ) 329 | if isinstance(sigma_list, float): 330 | sigma_list = [sigma_list] 331 | (test_x_b, test_ori_adj_b, test_node_flag_sigma, test_noise_adj_b, _) = gen_list_of_data_single( 332 | test_x_b, test_adj_b, test_node_flag_b, sigma_list, args 333 | ) 334 | with torch.no_grad(): 335 | test_noise_adj_b_chunked = test_noise_adj_b.chunk(len(sigma_list), dim=0) 336 | test_x_b_chunked = test_x_b.chunk(len(sigma_list), dim=0) 337 | test_node_flag_sigma = test_node_flag_sigma.chunk(len(sigma_list), dim=0) 338 | score = [] 339 | masks = [] 340 | for i, sigma in enumerate(sigma_list): 341 | mask = generate_mask(test_node_flag_sigma[i]) 342 | score_batch = model( 343 | A=test_noise_adj_b_chunked[i].to(args.device), 344 | node_features=test_x_b_chunked[i].to(args.device), 345 | mask=mask.to(args.device), 346 | noiselevel=sigma, 347 | ).to(args.device) 348 | masks.append(mask) 349 | score.append(score_batch) 350 | graph_batch_sub = tensor2graph(graph, score, mask) 351 | y_pred, y_exp = gnn_pred(graph, graph_batch_sub, gnn_model, ds=args.dataset, task=args.task) 352 | full_edge_index = gen_full(graph.batch, mask) 353 | score_b = torch.cat(score, dim=0).squeeze(-1).to(args.device) 354 | masktens = torch.cat(masks, dim=0).to(args.device) 355 | modif_r = sparsity(score, test_adj_b, mask) 356 | loss_cf, fid_drop, acc_cf = loss_cf_exp( 357 | gnn_model, 358 | graph, 359 | score, 360 | y_pred, 361 | y_exp, 362 | full_edge_index, 363 | mask, 364 | ds=args.dataset, 365 | task=args.task, 366 | ) 367 | loss_dist = loss_func_bce( 368 | score_b, 369 | test_ori_adj_b, 370 | sigma_list, 371 | masktens, 372 | device=args.device, 373 | sparsity_level=args.sparsity_level, 374 | ) 375 | loss = loss_dist + args.alpha_cf * loss_cf 376 | test_losses.append(loss.item()) 377 | test_loss_dist.append(loss_dist.item()) 378 | test_loss_cf.append(loss_cf.item()) 379 | test_acc.append(acc_cf) 380 | test_fid.append(fid_drop) 381 | test_sparsity.append(modif_r.item()) 382 | mean_test_loss = np.mean(test_losses) 383 | mean_test_acc = 1- np.mean(test_acc) 384 | mean_test_fid = np.mean(test_fid) 385 | mean_test_sparsity = np.mean(test_sparsity) 386 | print( 387 | ( 388 | f"Evaluation Epoch: {epoch} | " 389 | f"test loss: {mean_test_loss} | " 390 | f"test fidelity drop: {mean_test_fid} | " 391 | f"test cf acc: {mean_test_acc} | " 392 | f"test average modification: {mean_test_sparsity} | " 393 | ) 394 | ) 395 | if mean_test_sparsity < best_sparsity: 396 | best_sparsity = mean_test_sparsity 397 | model_save(args, model, mean_train_loss, best_sparsity, mean_test_acc) 398 | 399 | def explain_evaluation(self, args, graph): 400 | """ 401 | Explain the graph with the trained model 402 | :param args: arguments 403 | :param graph: graph to be explained 404 | :return: the explanation (edge_mask, original prediction, explanation prediction, modification rate) 405 | """ 406 | model = Powerful(args).to(args.device) 407 | exp_dir = f"{args.root}/{args.dataset}/" 408 | model.load_state_dict(torch.load(os.path.join(exp_dir, "best_model.pth"))["model"]) 409 | model.eval() 410 | graph.to(args.device) 411 | test_adj_b, test_x_b = graph2tensor(graph, device=args.device) # [bsz, N, N] 412 | test_x_b = test_x_b.to(args.device) 413 | test_node_flag_b = test_adj_b.sum(-1).gt(1e-5).to(dtype=torch.float32) 414 | sigma_list = ( 415 | list(np.random.uniform(low=args.prob_low, high=args.prob_high, size=args.sigma_length)) 416 | if args.noise_list is None 417 | else args.noise_list 418 | ) 419 | if isinstance(sigma_list, float): 420 | sigma_list = [sigma_list] 421 | (test_x_b, _, test_node_flag_sigma, test_noise_adj_b, _) = gen_list_of_data_single( 422 | test_x_b, test_adj_b, test_node_flag_b, sigma_list, args 423 | ) 424 | test_noise_adj_b_chunked = test_noise_adj_b.chunk(len(sigma_list), dim=0) 425 | test_x_b_chunked = test_x_b.chunk(len(sigma_list), dim=0) 426 | test_node_flag_sigma = test_node_flag_sigma.chunk(len(sigma_list), dim=0) 427 | score = [] 428 | masks = [] 429 | for i, sigma in enumerate(sigma_list): 430 | mask = generate_mask(test_node_flag_sigma[i]) 431 | score_batch = model( 432 | A=test_noise_adj_b_chunked[i].to(args.device), 433 | node_features=test_x_b_chunked[i].to(args.device), 434 | mask=mask.to(args.device), 435 | noiselevel=sigma, 436 | ).to(args.device) 437 | masks.append(mask) 438 | score.append(score_batch) 439 | graph_batch_sub = tensor2graph(graph, score, mask) 440 | full_edge_index = gen_full(graph.batch, mask) 441 | modif_r = sparsity(score, test_adj_b, mask) 442 | score_tensor = torch.stack(score, dim=0).squeeze(-1) # len_sigma_list, bsz, N, N] 443 | score_tensor = torch.mean(score_tensor, dim=0).view(-1, 1) # [bsz*N*N,1] 444 | mask_bool = mask.bool().view(-1, 1) 445 | edge_mask_full = score_tensor[mask_bool] 446 | if args.task == "nc": 447 | output_prob_cont, _ = self.model.get_pred_explain( 448 | x=graph.x, edge_index=full_edge_index, edge_mask=edge_mask_full, mapping=graph.mapping 449 | ) 450 | else: 451 | output_prob_cont, _ = self.model.get_pred_explain( 452 | x=graph.x, edge_index=full_edge_index, edge_mask=edge_mask_full, batch=graph.batch 453 | ) 454 | y_ori = graph.y if args.task == "gc" else graph.self_y 455 | y_exp = output_prob_cont.argmax(dim=-1) 456 | return graph_batch_sub, y_ori, y_exp, modif_r 457 | 458 | def one_step_model_level(self, args, random_adj, node_feature, sigma): 459 | """ 460 | One-step Model level explanation using the trained model 461 | Run multiple steps to get model-level explanation. 462 | :param args: arguments 463 | :param random_adj: a random adjacency matrix seed 464 | :param node_feature: node features of the dataset 465 | :param sigma: noise level 466 | :return: A predicted adjacency matrix 467 | """ 468 | random_adj = random_adj.unsqueeze(0) # batchsize=1 469 | node_feature = node_feature.unsqueeze(0) # batchsize=1 470 | mask = torch.ones_like(random_adj).to(args.device) 471 | model = Powerful(args).to(args.device) 472 | exp_dir = f"{args.root}/{args.dataset}/" 473 | model.load_state_dict(torch.load(os.path.join(exp_dir, "best_model.pth"))["model"]) 474 | model.eval() 475 | score = model(A=random_adj, node_features=node_feature, mask=mask, noiselevel=sigma).to(args.device) 476 | score = score.squeeze(0).squeeze(-1) 477 | pred_adj = torch.where(torch.sigmoid(score) > 0.5, 1, 0).to(score.device) 478 | return pred_adj # [N, N] 479 | 480 | -------------------------------------------------------------------------------- /explainers/diffusion/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Graph-and-Geometric-Learning/D4Explainer/997b4c755fc19d0494c09bc3bd4925777ea53aca/explainers/diffusion/.DS_Store -------------------------------------------------------------------------------- /explainers/diffusion/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Graph-and-Geometric-Learning/D4Explainer/997b4c755fc19d0494c09bc3bd4925777ea53aca/explainers/diffusion/__init__.py -------------------------------------------------------------------------------- /explainers/diffusion/graph_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_geometric.utils import degree, to_dense_adj 3 | 4 | do_check_adjs_symmetry = False 5 | 6 | 7 | def mask_adjs(adjs, node_flags): 8 | """ 9 | Mask the adjs with node_flags 10 | :param adjs: Adjacencies (B x N x N or B x C x N x N) 11 | :param node_flags: Node Flags (B x N) 12 | :return: Masked Adjacencies (B x N x N or B x C x N x N) 13 | """ 14 | if len(adjs.shape) == 4: 15 | node_flags = node_flags.unsqueeze(1) # B x 1 x N 16 | adjs = adjs * node_flags.unsqueeze(-1) 17 | adjs = adjs * node_flags.unsqueeze(-2) 18 | return adjs 19 | 20 | 21 | def discretenoise_single(train_adj_b, node_flags, sigma, device): 22 | """ 23 | Applies discrete noise to the adjacency matrix 24 | :param train_adj_b: [batch_size, N, N], batch of original adjacency matrices 25 | :param node_flags: [batch_size, N], the flags for the existence of nodes 26 | :param sigma: noise level 27 | :param device: device 28 | :returns: 29 | train_adj: [batch_size, N, N], batch of noisy adjacency matrices 30 | noisediff: [batch_size, N, N], the noise added to graph 31 | """ 32 | train_adj_b = train_adj_b.to(device) 33 | # if Aij=1 then chances for being 1 later is 1-sigma so chance of changing is sigma 34 | bernoulli_adj = torch.where( 35 | train_adj_b > 1 / 2, 36 | torch.full_like(train_adj_b, sigma).to(device), 37 | torch.full_like(train_adj_b, sigma).to(device), 38 | ) 39 | 40 | noise_upper = torch.bernoulli(bernoulli_adj).triu(diagonal=1).to(device) 41 | noise_lower = noise_upper.transpose(-1, -2) 42 | train_adj = torch.abs(-train_adj_b + noise_upper + noise_lower) 43 | noisediff = noise_upper + noise_lower 44 | train_adj = mask_adjs(train_adj, node_flags) 45 | noisediff = mask_adjs(noisediff, node_flags) 46 | return train_adj, noisediff 47 | 48 | 49 | def gen_list_of_data_single(train_x_b, train_adj_b, train_node_flag_b, sigma_list, args): 50 | """ 51 | Generate the list of data with different noise levels 52 | :param train_x_b: [batch_size, N, F_in], batch of feature vectors of nodes 53 | :param train_adj_b: [batch_size, N, N], batch of original adjacency matrices 54 | :param train_node_flag_b: [batch_size, N], the flags for the existence of nodes 55 | :param sigma_list: list of noise levels 56 | :returns: 57 | train_x_b: [len(sigma_list) * batch_size, N, F_in], batch of feature vectors of nodes 58 | train_ori_adj_b: [len(sigma_list) * batch_size, N, N], batch of original adjacency matrix (considered as the groundtruth) 59 | train_node_flag_b: [len(sigma_list) * batch_size, N], the flags for the existence of nodes 60 | train_noise_adj_b: [len(sigma_list) * batch_size, N, N], batch of noisy adjacency matrices 61 | noise_list: [len(sigma_list) * batch_size, N, N], the noise added to graph 62 | """ 63 | assert isinstance(sigma_list, list) 64 | train_noise_adj_b_list = [] 65 | noise_list = [] 66 | for i, sigma_i in enumerate(sigma_list): 67 | train_noise_adj_b, true_noise = discretenoise_single( 68 | train_adj_b, node_flags=train_node_flag_b, sigma=sigma_i, device=args.device 69 | ) 70 | 71 | train_noise_adj_b_list.append(train_noise_adj_b) 72 | noise_list.append(true_noise) 73 | 74 | train_noise_adj_b = torch.cat(train_noise_adj_b_list, dim=0).to(args.device) 75 | noise_list = torch.cat(noise_list, dim=0).to(args.device) 76 | train_x_b = train_x_b.repeat(len(sigma_list), 1, 1) 77 | train_ori_adj_b = train_adj_b.repeat(len(sigma_list), 1, 1) 78 | train_node_flag_sigma = train_node_flag_b.repeat(len(sigma_list), 1) 79 | return ( 80 | train_x_b, 81 | train_ori_adj_b, 82 | train_node_flag_sigma, 83 | train_noise_adj_b, 84 | noise_list, 85 | ) 86 | 87 | 88 | def generate_mask(node_flags): 89 | """ 90 | Generate the mask matrix for the existence of nodes 91 | :param node_flags: [bsz, N], the flags for the existence of nodes 92 | :return: groundtruth: [bsz, N, N] 93 | """ 94 | flag2 = node_flags.unsqueeze(1) # [bsz,1,N] 95 | flag1 = node_flags.unsqueeze(-1) # [bsz,N,1] 96 | mask_matrix = torch.bmm(flag1, flag2) # [bsz, N, N] 97 | groundtruth = torch.where(mask_matrix > 0.9, 1, 0).to(node_flags.device) 98 | return groundtruth 99 | 100 | 101 | def graph2tensor(graph, device): 102 | """ 103 | Convert graph batch to tensor batch 104 | :param graph: graph batch 105 | :param device: device 106 | :returns: 107 | adj: [bsz, N, N] 108 | x: [bsz, N, C] 109 | """ 110 | bsz = graph.num_graphs 111 | edge_index = graph.edge_index # [2, E_total] 112 | adj = to_dense_adj(edge_index, batch=graph.batch) # [bsz, max_num_node, max_num_node] 113 | max_num_node = adj.size(-1) 114 | node_features = graph.x # [N_total, C] 115 | feature_dim = node_features.size(-1) 116 | node_sizes = degree(graph.batch, dtype=torch.long).tolist() 117 | x_split = node_features.split(node_sizes, dim=0) # list of tensor 118 | x_tensor = torch.empty((bsz, max_num_node, feature_dim)).to(device) 119 | assert len(x_split) == bsz 120 | for i in range(bsz): 121 | Gi_x = x_split[i] 122 | num_node = Gi_x.size(0) 123 | zero_tensor = torch.zeros((max_num_node - num_node, feature_dim)).to(device) 124 | Gi_x = torch.cat((Gi_x, zero_tensor), dim=0) 125 | assert Gi_x.size(0) == max_num_node 126 | x_tensor[i] = Gi_x 127 | return adj, x_tensor 128 | 129 | 130 | def tensor2graph(graph_batch, score, mask_adj, threshold=0.5): 131 | """ 132 | Convert tensor batch to graph batch 133 | :param graph_batch: graph batch 134 | :param score: [bsz, N, N, 1] 135 | :param mask_adj: [bsz, N, N] 136 | :param threshold: threshold for the prediction 137 | :return: pred_adj: [bsz, N, N] 138 | """ 139 | score_tensor = torch.stack(score, dim=0).squeeze(-1) # len_sigma_list, bsz, N, N] 140 | score_tensor = torch.mean(score_tensor, dim=0) # [bsz, N, N] 141 | bsz = score_tensor.size(0) 142 | pred_adj = torch.where(torch.sigmoid(score_tensor) > threshold, 1, 0).to(score_tensor.device) 143 | pred_adj = pred_adj * mask_adj 144 | node_sizes = degree(graph_batch.batch, dtype=torch.long).detach().cpu().numpy() # list of node numbers 145 | sum_list = torch.tensor([node_sizes[:i].sum() for i in range(bsz)]).to(score_tensor.device) 146 | edge_indices = pred_adj.nonzero().t() 147 | batch = sum_list[edge_indices[0]] 148 | row = batch + edge_indices[1] 149 | col = batch + edge_indices[2] 150 | edge_index = torch.stack([row, col], dim=0) 151 | graph_batch_sub = graph_batch.clone() 152 | graph_batch_sub.edge_index = edge_index 153 | 154 | return graph_batch_sub 155 | 156 | 157 | def gen_full(batch, mask): 158 | """ 159 | Generate the full graph from the mask 160 | :param batch: graph.batch 161 | :param mask: [bsz, N, N] 162 | :return: edge_index: [2, E] 163 | """ 164 | bsz = mask.size(0) 165 | node_sizes = degree(batch, dtype=torch.long).detach().cpu().numpy() # list of node numbers 166 | sum_list = torch.tensor([node_sizes[:i].sum() for i in range(bsz)]).to(mask.device) 167 | edge_indices = mask.nonzero().t() 168 | batch = sum_list[edge_indices[0]] 169 | row = batch + edge_indices[1] 170 | col = batch + edge_indices[2] 171 | edge_index = torch.stack([row, col], dim=0) 172 | return edge_index 173 | -------------------------------------------------------------------------------- /explainers/diffusion/pgnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | SLOPE = 0.01 6 | 7 | 8 | def masked_instance_norm2D(x: torch.Tensor, mask: torch.Tensor, eps: float = 1e-5): 9 | """ 10 | Instance normalization for 2D feature maps with mask 11 | :param x: [batch_size (N), num_objects (L), num_objects (L), features(C)] 12 | :param mask: [batch_size (N), num_objects (L), num_objects (L), 1] 13 | return: [batch_size (N), num_objects (L), num_objects (L), features(C)] 14 | """ 15 | mask = mask.view(x.size(0), x.size(1), x.size(2), 1).expand_as(x) 16 | zero_indices = torch.where(torch.sum(mask, dim=[1, 2]) < 0.5)[0].squeeze(-1) # [N,] 17 | mean = torch.sum(x * mask, dim=[1, 2]) / (torch.sum(mask, dim=[1, 2])) # (N,C) 18 | var_term = ((x - mean.unsqueeze(1).unsqueeze(1).expand_as(x)) * mask) ** 2 # (N,L,L,C) 19 | var = torch.sum(var_term, dim=[1, 2]) / (torch.sum(mask, dim=[1, 2])) + 1e-5 # (N,C) 20 | mean = mean.unsqueeze(1).unsqueeze(1).expand_as(x) # (N, L, L, C) 21 | var = var.unsqueeze(1).unsqueeze(1).expand_as(x) # (N, L, L, C) 22 | instance_norm = (x - mean) / torch.sqrt(var + eps) # (N, L, L, C) 23 | instance_norm = instance_norm * mask 24 | instance_norm[zero_indices, :, :, :] = 0 25 | return instance_norm 26 | 27 | 28 | class PowerfulLayer(nn.Module): 29 | # in_feature will be given as hidden-dim and out as well hidden-dim 30 | def __init__(self, in_feat: int, out_feat: int, num_layers: int, spectral_norm=(lambda x: x)): 31 | super().__init__() 32 | self.in_feat = in_feat 33 | self.out_feat = out_feat 34 | activation = nn.LeakyReLU(negative_slope=SLOPE) 35 | # generate num_layers linear layers with the first one being of dim in_feat and all others out_feat, in between those layers we add an activation layer 36 | self.m1 = nn.Sequential( 37 | *[ 38 | spectral_norm(nn.Linear(in_feat if i == 0 else out_feat, out_feat)) if i % 2 == 0 else activation 39 | for i in range(num_layers * 2 - 1) 40 | ] 41 | ) 42 | self.m2 = nn.Sequential( 43 | *[ 44 | spectral_norm(nn.Linear(in_feat if i == 0 else out_feat, out_feat)) if i % 2 == 0 else activation 45 | for i in range(num_layers * 2 - 1) 46 | ] 47 | ) 48 | # a linear layer that has input dim in_feat + out_feat and outputdim out_feat and a possible bias 49 | # this is mentioned in paper as the layer after each concatenbation back to outputdim 50 | self.m4 = nn.Sequential(spectral_norm(nn.Linear(in_feat + out_feat, out_feat, bias=True))) 51 | 52 | # expects x as batch x N x N x in_feat and mask as batch x N x N x 1 53 | def forward(self, x, mask): 54 | """x: batch x N x N x in_feat""" 55 | 56 | # norm by taking uppermost col of mask gives nr of nodes active and then sqrt of that and put it in matching array dim 57 | # [bsz, N, N, 1] 58 | norm = mask.squeeze(-1).float().sum((-1, -2)).sqrt().sqrt().view(mask.size(0), 1, 1, 1) 59 | 60 | # here I will start to treat mask as in the edp paper, namely batch x N with then a boolean value 61 | # batch * N * N * 1 gets to batch * 1 * N * N 62 | mask = mask.unsqueeze(1).squeeze(-1) 63 | 64 | # run the two mlp on input and permute dimensions so that now matches dim of mask: batch * N * N * out_features 65 | out1 = self.m1(x).permute(0, 3, 1, 2) * mask # batch, out_feat, N, N 66 | out2 = self.m2(x).permute(0, 3, 1, 2) * mask # batch, out_feat, N, N 67 | 68 | # matrix multiply each matching layer of features as well as adjacencies 69 | out = out1 @ out2 70 | del out1, out2 71 | out = out / (norm + 1e-5) 72 | # permute back to correct dim and concat with the skip-mlp in last dim 73 | out_cat = torch.cat((out.permute(0, 2, 3, 1), x), dim=3) # batch, N, N, out_feat 74 | del x 75 | # run through last layer to go back to out_features dim 76 | out = self.m4(out_cat) 77 | 78 | return out 79 | 80 | 81 | # this is the class for the invariant layer that makes the whole thing invariant 82 | class FeatureExtractor(nn.Module): 83 | def __init__(self, in_features: int, out_features: int, spectral_norm=(lambda x: x)): 84 | super().__init__() 85 | self.lin1 = nn.Sequential(spectral_norm(nn.Linear(in_features, out_features, bias=True))) 86 | self.lin2 = nn.Sequential(spectral_norm(nn.Linear(in_features, out_features, bias=False))) 87 | self.lin3 = nn.Sequential(spectral_norm(nn.Linear(out_features, out_features, bias=False))) 88 | self.activation = nn.LeakyReLU(negative_slope=SLOPE) 89 | 90 | def forward(self, u, mask): 91 | """ 92 | Forward pass of the invariant layer. 93 | :param u: (batch_size, num_nodes, num_nodes, in_features) 94 | :param mask: (batch_size, num_nodes, num_nodes, 1) 95 | :return: (batch_size, out_features). 96 | """ 97 | u = u * mask 98 | # tensor of batch * 1 that represernts nr of active nodes 99 | n = mask[:, 0].sum(1) + 1e-5 100 | # tensor of batches * features * their diagonal elements (this retrieves the node elements that are stored on the diagonal) 101 | diag = u.diagonal(dim1=1, dim2=2) # batch_size, channels, num_nodes 102 | 103 | # tensor of batch * features with storing the sum of diagonals 104 | trace = torch.sum(diag, dim=2) 105 | del diag 106 | 107 | out1 = self.lin1.forward(trace / n) 108 | s = (torch.sum(u, dim=[1, 2]) - trace) / (n * (n - 1)) 109 | del trace 110 | 111 | out2 = self.lin2.forward(s) # bs, out_feat 112 | del s 113 | 114 | out = out1 + out2 115 | out = out + self.lin3.forward(self.activation(out)) 116 | return out 117 | 118 | 119 | class Powerful(nn.Module): 120 | def __init__( 121 | self, 122 | args, 123 | spectral_norm=(lambda x: x), 124 | project_first: bool = False, 125 | node_out: bool = False, 126 | ): 127 | super().__init__() 128 | self.cat_output = args.cat_output 129 | self.normalization = args.normalization 130 | self.layers_per_conv = args.layers_per_conv # was 1 originally, try 2? 131 | self.layer_after_conv = args.simplified 132 | self.dropout_p = args.dropout 133 | self.residual = args.residual 134 | # self.activation = nn.LeakyReLU(negative_slope=SLOPE) 135 | self.activation = nn.ReLU() 136 | self.project_first = project_first 137 | self.node_out = node_out 138 | self.output_features = 1 139 | self.node_output_features = 1 140 | self.noise_mlp = args.noise_mlp 141 | self.device = args.device 142 | self.num_layers = args.num_layers 143 | self.hidden = args.n_hidden 144 | 145 | self.time_mlp = nn.Sequential(nn.Linear(1, 4), nn.GELU(), nn.Linear(4, 1)) 146 | self.input_features = 2 * args.feature_in + 2 147 | 148 | self.in_lin = nn.Sequential(spectral_norm(nn.Linear(self.input_features, self.hidden))) 149 | 150 | if self.cat_output: 151 | if self.project_first: 152 | self.layer_cat_lin = nn.Sequential( 153 | spectral_norm(nn.Linear(self.hidden * (self.num_layers + 1), self.hidden)) 154 | ) 155 | else: 156 | self.layer_cat_lin = nn.Sequential( 157 | spectral_norm(nn.Linear(self.hidden * self.num_layers + self.input_features, self.hidden)) 158 | ) 159 | 160 | self.convs = nn.ModuleList([]) 161 | self.bns = nn.ModuleList([]) 162 | for _ in range(self.num_layers): 163 | self.convs.append( 164 | PowerfulLayer(self.hidden, self.hidden, self.layers_per_conv, spectral_norm=spectral_norm) 165 | ) 166 | 167 | self.feature_extractors = torch.nn.ModuleList([]) 168 | for _ in range(self.num_layers): 169 | if self.normalization == "batch": 170 | self.bns.append(nn.BatchNorm2d(self.hidden)) 171 | else: 172 | self.bns.append(None) 173 | self.feature_extractors.append(FeatureExtractor(self.hidden, self.hidden, spectral_norm=spectral_norm)) 174 | if self.layer_after_conv: 175 | self.after_conv = nn.Sequential(spectral_norm(nn.Linear(self.hidden, self.hidden))) 176 | self.final_lin = nn.Sequential(spectral_norm(nn.Linear(self.hidden, self.output_features))) 177 | 178 | if self.node_out: 179 | if self.cat_output: 180 | if self.project_first: 181 | self.layer_cat_lin_node = nn.Sequential( 182 | spectral_norm(nn.Linear(self.hidden * (self.num_layers + 1), self.hidden)) 183 | ) 184 | else: 185 | self.layer_cat_lin_node = nn.Sequential( 186 | spectral_norm(nn.Linear(self.hidden * self.num_layers + self.input_features, self.hidden)) 187 | ) 188 | 189 | if self.layer_after_conv: 190 | self.after_conv_node = nn.Sequential(spectral_norm(nn.Linear(self.hidden, self.hidden))) 191 | self.final_lin_node = nn.Sequential(spectral_norm(nn.Linear(self.hidden, self.node_output_features))) 192 | 193 | self.test_lin = nn.Sequential(spectral_norm(nn.Linear(self.input_features, self.output_features, bias=False))) 194 | 195 | def get_out_dim(self): 196 | """ 197 | returns the output dimension of the model 198 | :return: number of output features 199 | """ 200 | return self.output_features 201 | 202 | # expects the input as the adjacency tensor: batchsize x N x N 203 | # expects the node_features as tensor: batchsize x N x node_features 204 | # expects the mask as tensor: batchsize x N x N 205 | # expects noiselevel as the noislevel that was used as single float 206 | def forward(self, node_features, A, mask, noiselevel): 207 | """ 208 | forward pass of the model 209 | :param node_features: [batchsize, N, C] 210 | :param A: [batchsize, N, N] 211 | :param mask: [batchsize, N, N] 212 | :param noiselevel: single float 213 | :return: [batchsize, N, N, 1] 214 | """ 215 | if len(mask.shape) < 4: 216 | mask = mask[..., None] 217 | else: 218 | mask = mask 219 | if len(A.shape) < 4: 220 | u = A[..., None] # [batch, N, N, 1] 221 | else: 222 | u = A 223 | 224 | if self.noise_mlp: 225 | noiselevel = torch.tensor([float(noiselevel)]).to(self.device) 226 | noiselevel = self.time_mlp(noiselevel) 227 | noise_level_matrix = noiselevel.expand(u.size(0), u.size(1), u.size(3)).to(self.device) 228 | noise_level_matrix = torch.diag_embed(noise_level_matrix.transpose(-2, -1), dim1=1, dim2=2) 229 | else: 230 | noiselevel = torch.full([1], noiselevel).to(self.device) 231 | noise_level_matrix = noiselevel.expand(u.size(0), u.size(1), u.size(3)).to(self.device) # [bsz, N, 1] 232 | noise_level_matrix = torch.diag_embed(noise_level_matrix.transpose(-2, -1), dim1=1, dim2=2) 233 | 234 | node_feature1 = node_features.unsqueeze(1).repeat(1, node_features.size(1), 1, 1) 235 | node_feature2 = node_features.unsqueeze(2).repeat(1, 1, node_features.size(1), 1) 236 | u = torch.cat([u, node_feature1, node_feature2, noise_level_matrix], dim=-1).to(self.device) 237 | del node_features 238 | 239 | if self.project_first: 240 | u = self.in_lin(u) 241 | out = [u] 242 | else: 243 | out = [u] 244 | u1 = self.in_lin(u) 245 | for conv, bn in zip(self.convs, self.bns): 246 | u1 = conv(u1, mask) + (u1 if self.residual else 0) 247 | if self.normalization == "none": 248 | u1 = u1 249 | elif self.normalization == "instance": 250 | u1 = masked_instance_norm2D(u1, mask) 251 | elif self.normalization == "batch": 252 | u1 = bn(u1.permute(0, 3, 1, 2)).permute(0, 2, 3, 1) 253 | else: 254 | raise ValueError 255 | 256 | u1 = self.activation(u1) 257 | u2 = u1 * mask 258 | out.append(u2) 259 | 260 | out = torch.cat(out, dim=-1) 261 | if self.node_out: 262 | node_out = self.layer_cat_lin_node(out.diagonal(dim1=1, dim2=2).transpose(-2, -1)) 263 | if self.layer_after_conv: 264 | node_out = node_out + self.activation(self.after_conv_node(node_out)) 265 | node_out = F.dropout(node_out, p=self.dropout_p, training=self.training) 266 | node_out = self.final_lin_node(node_out) 267 | out = self.layer_cat_lin(out) 268 | out = masked_instance_norm2D(self.activation(out), mask) 269 | 270 | if self.layer_after_conv: 271 | out = out + self.activation(self.after_conv(out)) 272 | out = F.dropout(out, p=self.dropout_p, training=self.training) 273 | out = self.final_lin(out) 274 | out = out * mask 275 | if self.node_out: 276 | return out, node_out 277 | else: 278 | return out 279 | -------------------------------------------------------------------------------- /explainers/gnnexplainer.py: -------------------------------------------------------------------------------- 1 | from explainers.base import Explainer 2 | from explainers.meta_gnnexplainer import MetaGNNGExplainer 3 | 4 | 5 | class GNNExplainer(Explainer): 6 | def __init__(self, device, gnn_model_path, task): 7 | super(GNNExplainer, self).__init__(device, gnn_model_path, task) 8 | 9 | def explain_graph(self, graph, model=None, epochs=100, lr=1e-2, draw_graph=0, vis_ratio=0.2): 10 | """ 11 | Explain the graph using GNNExplainer 12 | :param graph: the graph to be explained. 13 | :param model: the model to be explained. 14 | :param epochs: the number of epochs to train the explainer. 15 | :param lr: the learning rate of the explainer. 16 | :param draw_graph: whether to draw the graph. 17 | :param vis_ratio: the ratio of edges to be visualized. 18 | :return: the explanation (edge_imp) 19 | """ 20 | if model is None: 21 | model = self.model 22 | 23 | explainer = MetaGNNGExplainer(model, epochs=epochs, lr=lr, task=self.task) 24 | edge_imp = explainer.explain_graph(graph) 25 | edge_imp = self.norm_imp(edge_imp.cpu().numpy()) 26 | 27 | if draw_graph: 28 | self.visualize(graph, edge_imp, self.name, vis_ratio=vis_ratio) 29 | self.last_result = (graph, edge_imp) 30 | 31 | return edge_imp 32 | -------------------------------------------------------------------------------- /explainers/meta_gnnexplainer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Modified based on torch_geometric.nn.models.GNNExplainer 3 | which generates explainations in node prediction tasks. 4 | 5 | Citation: 6 | Ying et al. GNNExplainer: Generating Explanations for Graph Neural Networks. 7 | """ 8 | 9 | from math import sqrt 10 | 11 | import torch 12 | import torch.nn.functional as F 13 | from torch_geometric.nn import MessagePassing 14 | 15 | EPS = 1e-15 16 | 17 | 18 | class MetaGNNGExplainer(torch.nn.Module): 19 | coeffs = { 20 | "edge_size": 0.05, 21 | "edge_ent": 0.5, 22 | } 23 | 24 | def __init__(self, model, epochs=100, lr=0.01, log=True, task="gc"): 25 | super(MetaGNNGExplainer, self).__init__() 26 | self.model = model 27 | self.epochs = epochs 28 | self.lr = lr 29 | self.log = log 30 | self.task = task 31 | 32 | def __set_masks__(self, x, edge_index, init="normal"): 33 | N = x.size(0) 34 | E = edge_index.size(1) 35 | 36 | std = torch.nn.init.calculate_gain("relu") * sqrt(2.0 / (2 * N)) 37 | self.edge_mask = torch.nn.Parameter(torch.randn(E) * std) 38 | 39 | for module in self.model.modules(): 40 | if isinstance(module, MessagePassing): 41 | module.__explain__ = True 42 | module.__edge_mask__ = self.edge_mask 43 | 44 | def __clear_masks__(self): 45 | for module in self.model.modules(): 46 | if isinstance(module, MessagePassing): 47 | module.__explain__ = False 48 | module.__edge_mask__ = None 49 | self.edge_mask = None 50 | 51 | def __loss__(self, log_logits, pred_label): 52 | criterion = torch.nn.NLLLoss() 53 | loss = criterion(log_logits, pred_label) 54 | m = self.edge_mask.sigmoid() 55 | loss = loss + self.coeffs["edge_size"] * m.sum() 56 | ent = -m * torch.log(m + EPS) - (1 - m) * torch.log(1 - m + EPS) 57 | loss = loss + self.coeffs["edge_ent"] * ent.mean() 58 | return loss 59 | 60 | def explain_graph(self, graph, **kwargs): 61 | """ 62 | Explain a graph using the MetaGNNGExplainer. 63 | :param graph: the graph to be explained. 64 | :param kwargs: additional arguments. 65 | :return: the explanation (edge_mask) 66 | """ 67 | self.__clear_masks__() 68 | 69 | # get the initial prediction. 70 | with torch.no_grad(): 71 | if self.task == "nc": 72 | soft_pred, _ = self.model.get_node_pred_subgraph( 73 | x=graph.x, edge_index=graph.edge_index, mapping=graph.mapping 74 | ) 75 | else: 76 | soft_pred, _ = self.model.get_pred(x=graph.x, edge_index=graph.edge_index, batch=graph.batch) 77 | pred_label = soft_pred.argmax(dim=-1) 78 | 79 | N = graph.x.size(0) 80 | E = graph.edge_index.size(1) 81 | 82 | std = torch.nn.init.calculate_gain("relu") * sqrt(2.0 / (2 * N)) 83 | self.edge_mask = torch.nn.Parameter(torch.randn(E) * std).to() 84 | self.to(graph.x.device) 85 | optimizer = torch.optim.Adam([self.edge_mask], lr=self.lr) 86 | 87 | for _ in range(self.epochs): 88 | optimizer.zero_grad() 89 | if self.task == "nc": 90 | _, output_repr = self.model.get_pred_explain( 91 | x=graph.x, edge_index=graph.edge_index, edge_mask=self.edge_mask, mapping=graph.mapping 92 | ) 93 | else: 94 | _, output_repr = self.model.get_pred_explain( 95 | x=graph.x, edge_index=graph.edge_index, edge_mask=self.edge_mask, batch=graph.batch 96 | ) 97 | log_logits = F.log_softmax(output_repr) 98 | loss = self.__loss__(log_logits, pred_label) 99 | loss.backward() 100 | optimizer.step() 101 | 102 | edge_mask = self.edge_mask.detach().sigmoid() 103 | 104 | return edge_mask 105 | 106 | def __repr__(self): 107 | return f"{self.__class__.__name__}()" 108 | -------------------------------------------------------------------------------- /explainers/visual.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | n_class_dict = {"MutagNet": 2, "BA3MotifNet": 3} 4 | 5 | vis_dict = { 6 | "MutagNet": {"node_size": 400, "linewidths": 1, "font_size": 10, "width": 3}, 7 | "BA3MotifNet": {"node_size": 300, "linewidths": 1, "font_size": 10, "width": 3}, 8 | } 9 | 10 | chem_graph_label_dict = { 11 | "MutagNet": { 12 | 0: "C", 13 | 1: "O", 14 | 2: "Cl", 15 | 3: "H", 16 | 4: "N", 17 | 5: "F", 18 | 6: "Br", 19 | 7: "S", 20 | 8: "P", 21 | 9: "I", 22 | 10: "Na", 23 | 11: "K", 24 | 12: "Li", 25 | 13: "Ca", 26 | }, 27 | } 28 | 29 | 30 | def e_map_mutag(bond_type, reverse=False): 31 | from rdkit import Chem 32 | 33 | if not reverse: 34 | if bond_type == Chem.BondType.SINGLE: 35 | return 0 36 | elif bond_type == Chem.BondType.DOUBLE: 37 | return 1 38 | elif bond_type == Chem.BondType.AROMATIC: 39 | return 2 40 | elif bond_type == Chem.BondType.TRIPLE: 41 | return 3 42 | else: 43 | raise Exception("No bond type found") 44 | 45 | if bond_type == 0: 46 | return Chem.BondType.SINGLE 47 | elif bond_type == 1: 48 | return Chem.BondType.DOUBLE 49 | elif bond_type == 2: 50 | return Chem.BondType.AROMATIC 51 | elif bond_type == 3: 52 | return Chem.BondType.TRIPLE 53 | else: 54 | raise Exception("No bond type found") 55 | 56 | 57 | class x_map_mutag(Enum): 58 | C = 0 59 | O = 1 60 | Cl = 2 61 | H = 3 62 | N = 4 63 | F = 5 64 | Br = 6 65 | S = 7 66 | P = 8 67 | I = 9 68 | Na = 10 69 | K = 11 70 | Li = 12 71 | Ca = 13 72 | 73 | 74 | def graph_to_mol(X, edge_index, edge_attr): 75 | from rdkit import Chem 76 | 77 | mol = Chem.RWMol() 78 | X = [Chem.Atom(x_map_mutag(x.index(1)).name) for x in X] 79 | 80 | E = edge_index 81 | for x in X: 82 | mol.AddAtom(x) 83 | for (u, v), attr in zip(E, edge_attr): 84 | attr = e_map_mutag(attr.index(1), reverse=True) 85 | 86 | if mol.GetBondBetweenAtoms(u, v): 87 | continue 88 | mol.AddBond(u, v, attr) 89 | return mol -------------------------------------------------------------------------------- /gnns/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Graph-and-Geometric-Learning/D4Explainer/997b4c755fc19d0494c09bc3bd4925777ea53aca/gnns/.DS_Store -------------------------------------------------------------------------------- /gnns/__init__.py: -------------------------------------------------------------------------------- 1 | from .ba3motif_gnn import BA3MotifNet 2 | from .bbbp_gnn import BBBP_GCN 3 | from .mutag_gnn import Mutag_GCN 4 | from .nci1_gnn import NCI1GCN 5 | from .synthetic_gnn import Syn_GCN 6 | from .tree_grids_gnn import Syn_GCN_TG 7 | from .web_gnn import EGNN 8 | 9 | __all__ = [ 10 | "BA3MotifNet", 11 | "BBBP_GCN", 12 | "Mutag_GCN", 13 | "NCI1GCN", 14 | "Syn_GCN", 15 | "Syn_GCN_TG", 16 | "EGNN", 17 | ] 18 | -------------------------------------------------------------------------------- /gnns/ba3motif_gnn.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import os.path as osp 4 | import random 5 | import time 6 | import sys 7 | sys.path.append("..") 8 | import torch 9 | import torch.nn.functional as F 10 | from torch.nn import CrossEntropyLoss, Linear, ModuleList, ReLU, Softmax 11 | from torch.optim.lr_scheduler import ReduceLROnPlateau 12 | from torch_geometric.data import DataLoader 13 | from torch_geometric.nn import GCNConv, global_mean_pool 14 | 15 | from datasets.ba3motif_dataset import BA3Motif 16 | from utils import Gtest, Gtrain, set_seed 17 | 18 | EPS = 1 19 | 20 | 21 | def parse_args(): 22 | parser = argparse.ArgumentParser(description="Train BA-3Motif Model") 23 | 24 | parser.add_argument( 25 | "--data_path", nargs="?", default=osp.join(osp.dirname(__file__), "..", "data", "BA3"), help="Input data path." 26 | ) 27 | parser.add_argument( 28 | "--model_path", 29 | nargs="?", 30 | default=osp.join(osp.dirname(__file__), "..", "param", "gnns"), 31 | help="path for saving trained model.", 32 | ) 33 | parser.add_argument("--cuda", type=int, default=0, help="GPU device.") 34 | parser.add_argument("--epoch", type=int, default=300, help="Number of epoch.") 35 | parser.add_argument("--lr", type=float, default=1e-3, help="Learning rate.") 36 | parser.add_argument("--batch_size", type=int, default=128, help="Batch size.") 37 | parser.add_argument("--verbose", type=int, default=10, help="Interval of evaluation.") 38 | parser.add_argument("--num_unit", type=int, default=3, help="number of Convolution layers(units)") 39 | parser.add_argument( 40 | "--random_label", type=bool, default=False, help="train a model under label randomization for sanity check" 41 | ) 42 | return parser.parse_args() 43 | 44 | 45 | class BA3MotifNet(torch.nn.Module): 46 | def __init__(self, num_unit): 47 | super().__init__() 48 | 49 | self.num_unit = num_unit 50 | 51 | self.node_emb = Linear(4, 64) 52 | 53 | self.convs = ModuleList() 54 | self.batch_norms = ModuleList() 55 | self.relus = ModuleList() 56 | 57 | for i in range(num_unit): 58 | conv = GCNConv(in_channels=64, out_channels=64) 59 | self.convs.append(conv) 60 | self.relus.append(ReLU()) 61 | 62 | self.lin1 = Linear(64, 64) 63 | self.relu = ReLU() 64 | self.lin2 = Linear(64, 3) 65 | self.softmax = Softmax(dim=1) 66 | 67 | def forward(self, x, edge_index, batch): 68 | edge_attr = torch.ones((edge_index.size(1),), device=edge_index.device) 69 | node_x = self.get_node_reps(x, edge_index, edge_attr, batch) 70 | graph_x = global_mean_pool(node_x, batch) 71 | pred = self.relu(self.lin1(graph_x)) 72 | pred = self.lin2(pred) 73 | return pred 74 | 75 | def get_node_reps(self, x, edge_index, edge_attr, batch): 76 | x = self.node_emb(x) 77 | x = F.dropout(x, p=0.4) 78 | for conv, relu in zip(self.convs, self.relus): 79 | x = conv(x=x, edge_index=edge_index, edge_weight=edge_attr) 80 | x = relu(x) 81 | x = F.dropout(x, p=0.4) 82 | node_x = x 83 | return node_x 84 | 85 | def get_graph_rep(self, x, edge_index, batch): 86 | edge_attr = torch.ones((edge_index.size(1),), device=edge_index.device) 87 | node_x = self.get_node_reps(x, edge_index, edge_attr, batch) 88 | graph_x = global_mean_pool(node_x, batch) 89 | return graph_x 90 | 91 | def get_pred(self, x, edge_index, batch): 92 | graph_x = self.get_graph_rep(x, edge_index, batch) 93 | pred = self.relu(self.lin1(graph_x)) 94 | pred = self.lin2(pred) 95 | self.readout = self.softmax(pred) 96 | return self.readout, pred 97 | 98 | def get_pred_explain(self, x, edge_index, edge_mask, batch): 99 | edge_mask = (edge_mask * EPS).sigmoid() 100 | x = self.node_emb(x) 101 | x = F.dropout(x, p=0.4) 102 | for conv, relu in zip(self.convs, self.relus): 103 | x = conv(x=x, edge_index=edge_index, edge_weight=edge_mask) 104 | x = relu(x) 105 | x = F.dropout(x, p=0.4) 106 | node_x = x 107 | graph_x = global_mean_pool(node_x, batch) 108 | pred = self.relu(self.lin1(graph_x)) 109 | pred = self.lin2(pred) 110 | self.readout = self.softmax(pred) 111 | return self.readout, pred 112 | 113 | def reset_parameters(self): 114 | with torch.no_grad(): 115 | for param in self.parameters(): 116 | param.uniform_(-1.0, 1.0) 117 | 118 | 119 | if __name__ == "__main__": 120 | set_seed(0) 121 | args = parse_args() 122 | device = torch.device(f"cuda:{args.cuda}" if torch.cuda.is_available() else "cpu") 123 | 124 | test_dataset = BA3Motif(args.data_path, mode="testing") 125 | val_dataset = BA3Motif(args.data_path, mode="evaluation") 126 | train_dataset = BA3Motif(args.data_path, mode="training") 127 | if args.random_label: 128 | for dataset in [test_dataset, val_dataset, train_dataset]: 129 | for g in dataset: 130 | g.y.fill_(random.choice([0, 1, 2])) 131 | 132 | test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False) 133 | val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False) 134 | train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True) 135 | model = BA3MotifNet(args.num_unit).to(device) 136 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) 137 | scheduler = ReduceLROnPlateau(optimizer, mode="min", factor=0.8, patience=10, min_lr=1e-4) 138 | min_error = None 139 | for epoch in range(1, args.epoch + 1): 140 | t1 = time.time() 141 | lr = scheduler.optimizer.param_groups[0]["lr"] 142 | loss = Gtrain(train_loader, model, optimizer, device=device, criterion=CrossEntropyLoss()) 143 | 144 | _, train_acc = Gtest(train_loader, model, device=device, criterion=CrossEntropyLoss()) 145 | 146 | val_error, val_acc = Gtest(val_loader, model, device=device, criterion=CrossEntropyLoss()) 147 | test_error, test_acc = Gtest(test_loader, model, device=device, criterion=CrossEntropyLoss()) 148 | scheduler.step(val_error) 149 | if min_error is None or val_error <= min_error: 150 | min_error = val_error 151 | 152 | t2 = time.time() 153 | 154 | if epoch % args.verbose == 0: 155 | test_error, test_acc = Gtest(test_loader, model, device=device, criterion=CrossEntropyLoss()) 156 | t3 = time.time() 157 | print( 158 | "Epoch{:4d}[{:.3f}s]: LR: {:.5f}, Loss: {:.5f}, Test Loss: {:.5f}, " 159 | "Test acc: {:.5f}".format(epoch, t3 - t1, lr, loss, test_error, test_acc) 160 | ) 161 | continue 162 | 163 | print( 164 | "Epoch{:4d}[{:.3f}s]: LR: {:.5f}, Loss: {:.5f}, Train acc: {:.5f}, Validation Loss: {:.5f}, " 165 | "Validation acc: {:5f}".format(epoch, t2 - t1, lr, loss, train_acc, val_error, val_acc) 166 | ) 167 | save_path = "ba3_gcn.pt" 168 | if not osp.exists(args.model_path): 169 | os.makedirs(args.model_path) 170 | torch.save(model.cpu(), osp.join(args.model_path, save_path)) 171 | -------------------------------------------------------------------------------- /gnns/bbbp_gnn.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import os.path as osp 4 | import time 5 | import sys 6 | sys.path.append("..") 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | from torch.nn import Linear as Lin, ModuleList, ReLU, Softmax 11 | from torch.optim.lr_scheduler import ReduceLROnPlateau 12 | from torch_geometric.data import DataLoader 13 | from torch_geometric.nn import BatchNorm, GCNConv, global_mean_pool 14 | 15 | from datasets import bbbp 16 | from utils import Gtest, Gtrain, set_seed 17 | 18 | 19 | 20 | def parse_args(): 21 | parser = argparse.ArgumentParser(description="Train bbbp Model") 22 | 23 | parser.add_argument( 24 | "--model_path", 25 | nargs="?", 26 | default=osp.join(osp.dirname(__file__), "param", "gnns"), 27 | help="path for saving trained model.", 28 | ) 29 | parser.add_argument("--cuda", type=int, default=0, help="GPU device.") 30 | parser.add_argument("--epoch", type=int, default=300, help="Number of epoch.") 31 | parser.add_argument("--lr", type=float, default=1e-3, help="Learning rate.") 32 | parser.add_argument("--batch_size", type=int, default=128, help="Batch size.") 33 | parser.add_argument("--verbose", type=int, default=10, help="Interval of evaluation.") 34 | parser.add_argument("--num_unit", type=int, default=4, help="number of Convolution layers(units)") 35 | parser.add_argument( 36 | "--random_label", type=bool, default=False, help="train a model under label randomization for sanity check" 37 | ) 38 | 39 | return parser.parse_args() 40 | 41 | 42 | class BBBP_GCN(torch.nn.Module): 43 | def __init__(self, conv_unit=3): 44 | super(BBBP_GCN, self).__init__() 45 | self.convs = ModuleList() 46 | self.batch_norms = ModuleList() 47 | self.relus = ModuleList() 48 | self.edge_emb = Lin(3, 1) 49 | self.convs.append(GCNConv(in_channels=9, out_channels=128)) 50 | for i in range(conv_unit - 2): 51 | self.convs.append(GCNConv(in_channels=128, out_channels=128)) 52 | self.convs.append(GCNConv(in_channels=128, out_channels=128)) 53 | 54 | self.batch_norms.extend([BatchNorm(128)] * conv_unit) 55 | self.relus.extend([ReLU()] * conv_unit) 56 | self.edge_emb = Lin(3, 128) 57 | # self.lin1 = Lin(128, 128) 58 | self.ffn = nn.Sequential(*([nn.Linear(128, 128)] + [ReLU(), nn.Dropout(), nn.Linear(128, 2)])) 59 | 60 | self.softmax = Softmax(dim=1) 61 | 62 | def forward(self, x, edge_index, batch): 63 | edge_weight = torch.ones((edge_index.size(1),), device=edge_index.device) 64 | # edge_weight = self.edge_emb(edge_attr).squeeze(-1) 65 | for conv, batch_norm, relu in zip(self.convs, self.batch_norms, self.relus): 66 | x = conv(x, edge_index, edge_weight=edge_weight) 67 | # x = relu(batch_norm(x)) 68 | x = relu(x) 69 | graph_x = global_mean_pool(x, batch) 70 | pred = self.ffn(graph_x) 71 | self.readout = self.softmax(pred) 72 | return pred 73 | 74 | def get_node_reps(self, x, edge_index): 75 | edge_weight = torch.ones((edge_index.size(1),), device=edge_index.device) 76 | for conv, batch_norm, relu in zip(self.convs, self.batch_norms, self.relus): 77 | x = conv(x, edge_index, edge_weight) 78 | # x = relu(batch_norm(x)) 79 | x = relu(x) 80 | node_x = x 81 | return node_x 82 | 83 | def get_graph_rep(self, x, edge_index, batch): 84 | node_x = self.get_node_reps(x, edge_index) 85 | graph_x = global_mean_pool(node_x, batch) 86 | return graph_x 87 | 88 | def get_pred(self, x, edge_index, batch): 89 | graph_x = self.get_graph_rep(x, edge_index, batch) 90 | pred = self.ffn(graph_x) 91 | self.readout = self.softmax(pred) 92 | return self.readout, pred 93 | 94 | def get_emb(self, x, edge_index, batch): 95 | graph_x = self.get_graph_rep(x, edge_index, batch) 96 | pred = self.ffn[0](graph_x) 97 | pred = F.relu(pred) 98 | return pred 99 | 100 | def get_pred_explain(self, x, edge_index, edge_mask, batch): 101 | edge_mask = edge_mask.sigmoid() 102 | # edge_weight = edge_mask.unsqueeze(-1).repeat(1, 128) 103 | for conv, batch_norm, relu in zip(self.convs, self.batch_norms, self.relus): 104 | x = conv(x, edge_index, edge_weight=edge_mask) 105 | x = relu(x) 106 | node_x = x 107 | graph_x = global_mean_pool(node_x, batch) 108 | pred = self.ffn(graph_x) 109 | self.readout = self.softmax(pred) 110 | return self.readout, pred 111 | 112 | def reset_parameters(self): 113 | with torch.no_grad(): 114 | for param in self.parameters(): 115 | param.uniform_(-1.0, 1.0) 116 | 117 | if __name__ == "__main__": 118 | set_seed(0) 119 | args = parse_args() 120 | device = torch.device(f"cuda:{args.cuda}" if torch.cuda.is_available() else "cpu") 121 | # folder = os.path.join("data", 'bbbp') 122 | folder = osp.join(osp.dirname(__file__), "..", "data", "bbbp") 123 | dataset = bbbp(folder) 124 | test_dataset = dataset[:200] 125 | val_dataset = dataset[200:400] 126 | train_dataset = dataset[400:] 127 | # train_dataset, val_dataset, test_dataset = get_datasets(name="bbbp", root="data/") 128 | 129 | test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False) 130 | val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False) 131 | train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True) 132 | model = BBBP_GCN(args.num_unit).to(device) 133 | 134 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) 135 | scheduler = ReduceLROnPlateau(optimizer, mode="min", factor=0.8, patience=10, min_lr=1e-5) 136 | min_error = None 137 | for epoch in range(1, args.epoch + 1): 138 | t1 = time.time() 139 | lr = scheduler.optimizer.param_groups[0]["lr"] 140 | 141 | loss = Gtrain(train_loader, model, optimizer, device=device, criterion=nn.CrossEntropyLoss()) 142 | 143 | _, train_acc = Gtest(train_loader, model, device=device, criterion=nn.CrossEntropyLoss()) 144 | 145 | val_error, val_acc = Gtest(val_loader, model, device=device, criterion=nn.CrossEntropyLoss()) 146 | scheduler.step(val_error) 147 | if min_error is None or val_error <= min_error: 148 | min_error = val_error 149 | 150 | t2 = time.time() 151 | 152 | if epoch % args.verbose == 0: 153 | test_error, test_acc = Gtest(test_loader, model, device=device, criterion=nn.CrossEntropyLoss()) 154 | t3 = time.time() 155 | print( 156 | "Epoch{:4d}[{:.3f}s]: LR: {:.5f}, Loss: {:.5f}, Test Loss: {:.5f}, " 157 | "Test acc: {:.5f}".format(epoch, t3 - t1, lr, loss, test_error, test_acc) 158 | ) 159 | continue 160 | 161 | print( 162 | "Epoch{:4d}[{:.3f}s]: LR: {:.5f}, Loss: {:.5f}, Train acc: {:.5f}, Validation Loss: {:.5f}, " 163 | "Validation acc: {:5f}".format(epoch, t2 - t1, lr, loss, train_acc, val_error, val_acc) 164 | ) 165 | 166 | save_path = "bbbp_gcn.pt" 167 | 168 | if not osp.exists(args.model_path): 169 | os.makedirs(args.model_path) 170 | torch.save(model.cpu(), osp.join(args.model_path, save_path)) 171 | -------------------------------------------------------------------------------- /gnns/mutag_gnn.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import os.path as osp 4 | import random 5 | import time 6 | import sys 7 | sys.path.append("..") 8 | import torch 9 | import torch.nn as nn 10 | from torch.nn import Linear as Lin, ModuleList, ReLU, Sequential as Seq, Softmax 11 | from torch.optim.lr_scheduler import ReduceLROnPlateau 12 | from torch_geometric.data import DataLoader 13 | from torch_geometric.nn import BatchNorm, GCNConv, GINEConv, global_mean_pool 14 | 15 | from datasets.mutag_dataset import Mutagenicity 16 | from utils import Gtest, Gtrain, set_seed 17 | 18 | EPS = 1 19 | 20 | 21 | def parse_args(): 22 | parser = argparse.ArgumentParser(description="Train Mutag Model") 23 | 24 | parser.add_argument( 25 | "--data_path", 26 | nargs="?", 27 | default=osp.join(osp.dirname(__file__), "..", "data", "MUTAG"), 28 | help="Input data path.", 29 | ) 30 | parser.add_argument( 31 | "--model_path", 32 | nargs="?", 33 | default=osp.join(osp.dirname(__file__), "..", "param", "gnns"), 34 | help="path for saving trained model.", 35 | ) 36 | parser.add_argument("--cuda", type=int, default=0, help="GPU device.") 37 | parser.add_argument("--epoch", type=int, default=300, help="Number of epoch.") 38 | parser.add_argument("--lr", type=float, default=1e-3, help="Learning rate.") 39 | parser.add_argument("--batch_size", type=int, default=128, help="Batch size.") 40 | parser.add_argument("--verbose", type=int, default=10, help="Interval of evaluation.") 41 | parser.add_argument("--num_unit", type=int, default=2, help="number of Convolution layers(units)") 42 | parser.add_argument( 43 | "--random_label", type=bool, default=False, help="train a model under label randomization for sanity check" 44 | ) 45 | 46 | return parser.parse_args() 47 | 48 | class Mutag_GCN(torch.nn.Module): 49 | def __init__(self, conv_unit=3): 50 | super(Mutag_GCN, self).__init__() 51 | self.convs = ModuleList() 52 | self.batch_norms = ModuleList() 53 | self.relus = ModuleList() 54 | self.edge_emb = Lin(3, 1) 55 | self.convs.append(GCNConv(in_channels=14, out_channels=128)) 56 | for i in range(conv_unit - 2): 57 | self.convs.append(GCNConv(in_channels=128, out_channels=128)) 58 | self.convs.append(GCNConv(in_channels=128, out_channels=128)) 59 | 60 | self.batch_norms.extend([BatchNorm(128)] * conv_unit) 61 | self.relus.extend([ReLU()] * conv_unit) 62 | 63 | # self.lin1 = Lin(128, 128) 64 | self.ffn = nn.Sequential(*([nn.Linear(128, 128)] + [ReLU(), nn.Dropout(), nn.Linear(128, 2)])) 65 | 66 | self.softmax = Softmax(dim=1) 67 | 68 | def forward(self, x, edge_index, batch): 69 | edge_weight = torch.ones((edge_index.size(1),), device=edge_index.device) 70 | for conv, batch_norm, relu in zip(self.convs, self.batch_norms, self.relus): 71 | x = conv(x, edge_index, edge_weight) 72 | # x = relu(batch_norm(x)) 73 | x = relu(x) 74 | graph_x = global_mean_pool(x, batch) 75 | pred = self.ffn(graph_x) 76 | self.readout = self.softmax(pred) 77 | return pred 78 | 79 | def get_node_reps(self, x, edge_index): 80 | edge_weight = torch.ones((edge_index.size(1),), device=edge_index.device) 81 | for conv, batch_norm, relu in zip(self.convs, self.batch_norms, self.relus): 82 | x = conv(x, edge_index, edge_weight) 83 | # x = relu(batch_norm(x)) 84 | x = relu(x) 85 | node_x = x 86 | return node_x 87 | 88 | def get_graph_rep(self, x, edge_index, batch): 89 | node_x = self.get_node_reps(x, edge_index) 90 | graph_x = global_mean_pool(node_x, batch) 91 | return graph_x 92 | 93 | def get_pred(self, x, edge_index, batch): 94 | graph_x = self.get_graph_rep(x, edge_index, batch) 95 | pred = self.ffn(graph_x) 96 | self.readout = self.softmax(pred) 97 | return self.readout, pred 98 | 99 | def get_pred_explain(self, x, edge_index, edge_mask, batch): 100 | edge_mask = (edge_mask * EPS).sigmoid() 101 | for conv, batch_norm, relu in zip(self.convs, self.batch_norms, self.relus): 102 | x = conv(x, edge_index, edge_weight=edge_mask) 103 | x = relu(x) 104 | node_x = x 105 | graph_x = global_mean_pool(node_x, batch) 106 | pred = self.ffn(graph_x) 107 | self.readout = self.softmax(pred) 108 | return self.readout, pred 109 | 110 | def reset_parameters(self): 111 | with torch.no_grad(): 112 | for param in self.parameters(): 113 | param.uniform_(-1.0, 1.0) 114 | 115 | 116 | if __name__ == "__main__": 117 | set_seed(0) 118 | args = parse_args() 119 | device = torch.device(f"cuda:{args.cuda}" if torch.cuda.is_available() else "cpu") 120 | 121 | test_dataset = Mutagenicity(args.data_path, mode="testing") 122 | val_dataset = Mutagenicity(args.data_path, mode="evaluation") 123 | train_dataset = Mutagenicity(args.data_path, mode="training") 124 | if args.random_label: 125 | for dataset in [test_dataset, val_dataset, train_dataset]: 126 | for g in dataset: 127 | g.y.fill_(random.choice([0, 1])) 128 | 129 | test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False) 130 | val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False) 131 | train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True) 132 | model = Mutag_GCN(args.num_unit).to(device) 133 | 134 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) 135 | scheduler = ReduceLROnPlateau(optimizer, mode="min", factor=0.8, patience=10, min_lr=1e-5) 136 | min_error = None 137 | for epoch in range(1, args.epoch + 1): 138 | t1 = time.time() 139 | lr = scheduler.optimizer.param_groups[0]["lr"] 140 | 141 | loss = Gtrain(train_loader, model, optimizer, device=device, criterion=nn.CrossEntropyLoss()) 142 | 143 | _, train_acc = Gtest(train_loader, model, device=device, criterion=nn.CrossEntropyLoss()) 144 | 145 | val_error, val_acc = Gtest(val_loader, model, device=device, criterion=nn.CrossEntropyLoss()) 146 | test_error, test_acc = Gtest(test_loader, model, device=device, criterion=nn.CrossEntropyLoss()) 147 | scheduler.step(val_error) 148 | if min_error is None or val_error <= min_error: 149 | min_error = val_error 150 | 151 | t2 = time.time() 152 | 153 | if epoch % args.verbose == 0: 154 | test_error, test_acc = Gtest(test_loader, model, device=device, criterion=nn.CrossEntropyLoss()) 155 | t3 = time.time() 156 | print( 157 | "Epoch{:4d}[{:.3f}s]: LR: {:.5f}, Loss: {:.5f}, Test Loss: {:.5f}, " 158 | "Test acc: {:.5f}".format(epoch, t3 - t1, lr, loss, test_error, test_acc) 159 | ) 160 | continue 161 | 162 | print( 163 | "Epoch{:4d}[{:.3f}s]: LR: {:.5f}, Loss: {:.5f}, Train acc: {:.5f}, Validation Loss: {:.5f}, " 164 | "Validation acc: {:5f}".format(epoch, t2 - t1, lr, loss, train_acc, val_error, val_acc) 165 | ) 166 | 167 | save_path = "mutag_gcn.pt" 168 | if not osp.exists(args.model_path): 169 | os.makedirs(args.model_path) 170 | torch.save(model.cpu(), osp.join(args.model_path, save_path)) 171 | -------------------------------------------------------------------------------- /gnns/nci1_gnn.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import os.path as osp 4 | import random 5 | import time 6 | import sys 7 | sys.path.append("..") 8 | import torch 9 | import torch.nn as nn 10 | from torch.nn import ModuleList, ReLU, Softmax 11 | from torch.optim.lr_scheduler import ReduceLROnPlateau 12 | from torch_geometric.data import DataLoader 13 | from torch_geometric.nn import BatchNorm, LEConv, global_mean_pool 14 | 15 | from datasets import NCI1 16 | from utils import Gtest, Gtrain, set_seed 17 | 18 | EPS = 1 19 | 20 | 21 | def parse_args(): 22 | parser = argparse.ArgumentParser(description="Train NCI1 Model") 23 | 24 | parser.add_argument( 25 | "--data_path", nargs="?", default=osp.join(osp.dirname(__file__), "..", "data", "NCI1"), help="Input data path." 26 | ) 27 | parser.add_argument( 28 | "--model_path", 29 | nargs="?", 30 | default=osp.join(osp.dirname(__file__), "..", "param", "gnns"), 31 | help="path for saving trained model.", 32 | ) 33 | parser.add_argument("--cuda", type=int, default=0, help="GPU device.") 34 | parser.add_argument("--epoch", type=int, default=300, help="Number of epoch.") 35 | parser.add_argument("--lr", type=float, default=1e-3, help="Learning rate.") 36 | parser.add_argument("--batch_size", type=int, default=128, help="Batch size.") 37 | parser.add_argument("--verbose", type=int, default=10, help="Interval of evaluation.") 38 | parser.add_argument("--num_unit", type=int, default=2, help="number of Convolution layers(units)") 39 | parser.add_argument( 40 | "--random_label", type=bool, default=False, help="train a model under label randomization for sanity check" 41 | ) 42 | return parser.parse_args() 43 | 44 | 45 | class NCI1GCN(torch.nn.Module): 46 | def __init__(self, conv_unit=3): 47 | super(NCI1GCN, self).__init__() 48 | self.convs = ModuleList() 49 | self.batch_norms = ModuleList() 50 | self.relus = ModuleList() 51 | # self.convs.append(GCNConv(in_channels=37, out_channels=128)) 52 | self.convs.append(LEConv(in_channels=37, out_channels=128)) 53 | for i in range(conv_unit - 2): 54 | # self.convs.append(GCNConv(in_channels=128, out_channels=128)) 55 | # self.convs.append(GCNConv(in_channels=128, out_channels=128)) 56 | self.convs.append(LEConv(in_channels=128, out_channels=128)) 57 | self.convs.append(LEConv(in_channels=128, out_channels=128)) 58 | self.batch_norms.extend([BatchNorm(128)] * conv_unit) 59 | self.relus.extend([ReLU()] * conv_unit) 60 | 61 | # self.lin1 = Lin(128, 128) 62 | self.ffn = nn.Sequential(*([nn.Linear(128, 128)] + [ReLU(), nn.Dropout(), nn.Linear(128, 2)])) 63 | 64 | self.softmax = Softmax(dim=1) 65 | 66 | def forward(self, x, edge_index, batch): 67 | edge_weight = torch.ones((edge_index.size(1),), device=edge_index.device) 68 | for conv, batch_norm, relu in zip(self.convs, self.batch_norms, self.relus): 69 | x = conv(x, edge_index, edge_weight) 70 | # x = relu(batch_norm(x)) 71 | x = relu(x) 72 | graph_x = global_mean_pool(x, batch) 73 | pred = self.ffn(graph_x) 74 | self.readout = self.softmax(pred) 75 | return pred 76 | 77 | def get_node_reps(self, x, edge_index): 78 | edge_weight = torch.ones((edge_index.size(1),), device=edge_index.device) 79 | for conv, batch_norm, relu in zip(self.convs, self.batch_norms, self.relus): 80 | x = conv(x, edge_index, edge_weight) 81 | # x = relu(batch_norm(x)) 82 | x = relu(x) 83 | node_x = x 84 | return node_x 85 | 86 | def get_graph_rep(self, x, edge_index, batch): 87 | node_x = self.get_node_reps(x, edge_index) 88 | graph_x = global_mean_pool(node_x, batch) 89 | return graph_x 90 | 91 | def get_pred(self, x, edge_index, batch): 92 | graph_x = self.get_graph_rep(x, edge_index, batch) 93 | pred = self.ffn(graph_x) 94 | self.readout = self.softmax(pred) 95 | return self.readout, pred 96 | 97 | def get_pred_explain(self, x, edge_index, edge_mask, batch): 98 | edge_mask = (edge_mask * EPS).sigmoid() 99 | for conv, batch_norm, relu in zip(self.convs, self.batch_norms, self.relus): 100 | x = conv(x, edge_index, edge_weight=edge_mask) 101 | x = relu(x) 102 | node_x = x 103 | graph_x = global_mean_pool(node_x, batch) 104 | pred = self.ffn(graph_x) 105 | self.readout = self.softmax(pred) 106 | return self.readout, pred 107 | 108 | def reset_parameters(self): 109 | with torch.no_grad(): 110 | for param in self.parameters(): 111 | param.uniform_(-1.0, 1.0) 112 | 113 | 114 | if __name__ == "__main__": 115 | set_seed(0) 116 | args = parse_args() 117 | device = torch.device(f"cuda:{args.cuda}" if torch.cuda.is_available() else "cpu") 118 | 119 | test_dataset = NCI1(args.data_path, mode="testing") 120 | val_dataset = NCI1(args.data_path, mode="evaluation") 121 | train_dataset = NCI1(args.data_path, mode="training") 122 | if args.random_label: 123 | for dataset in [test_dataset, val_dataset, train_dataset]: 124 | for g in dataset: 125 | g.y.fill_(random.choice([0, 1])) 126 | 127 | test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False) 128 | val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False) 129 | train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True) 130 | model = NCI1GCN(args.num_unit).to(device) 131 | 132 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) 133 | scheduler = ReduceLROnPlateau(optimizer, mode="min", factor=0.8, patience=10, min_lr=1e-5) 134 | min_error = None 135 | for epoch in range(1, args.epoch + 1): 136 | t1 = time.time() 137 | lr = scheduler.optimizer.param_groups[0]["lr"] 138 | 139 | loss = Gtrain(train_loader, model, optimizer, device=device, criterion=nn.CrossEntropyLoss()) 140 | 141 | _, train_acc = Gtest(train_loader, model, device=device, criterion=nn.CrossEntropyLoss()) 142 | 143 | val_error, val_acc = Gtest(val_loader, model, device=device, criterion=nn.CrossEntropyLoss()) 144 | test_error, test_acc = Gtest(test_loader, model, device=device, criterion=nn.CrossEntropyLoss()) 145 | scheduler.step(val_error) 146 | if min_error is None or val_error <= min_error: 147 | min_error = val_error 148 | 149 | t2 = time.time() 150 | 151 | if epoch % args.verbose == 0: 152 | test_error, test_acc = Gtest(test_loader, model, device=device, criterion=nn.CrossEntropyLoss()) 153 | t3 = time.time() 154 | print( 155 | "Epoch{:4d}[{:.3f}s]: LR: {:.5f}, Loss: {:.5f}, Test Loss: {:.5f}, " 156 | "Test acc: {:.5f}".format(epoch, t3 - t1, lr, loss, test_error, test_acc) 157 | ) 158 | continue 159 | 160 | print( 161 | "Epoch{:4d}[{:.3f}s]: LR: {:.5f}, Loss: {:.5f}, Train acc: {:.5f}, Validation Loss: {:.5f}, " 162 | "Validation acc: {:5f}".format(epoch, t2 - t1, lr, loss, train_acc, val_error, val_acc) 163 | ) 164 | 165 | save_path = "NCI1_gcn.pt" 166 | if not osp.exists(args.model_path): 167 | os.makedirs(args.model_path) 168 | torch.save(model.cpu(), osp.join(args.model_path, save_path)) 169 | -------------------------------------------------------------------------------- /gnns/synthetic_gnn.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import os.path as osp 4 | import time 5 | import sys 6 | sys.path.append("..") 7 | import numpy as np 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | from torch.nn import Linear as Lin, ModuleList, ReLU, Softmax 12 | from torch.optim.lr_scheduler import ReduceLROnPlateau 13 | from torch_geometric.nn import BatchNorm, GCNConv 14 | from torch_geometric.utils import accuracy 15 | 16 | from utils import set_seed 17 | 18 | EPS = 1 19 | 20 | 21 | def parse_args(): 22 | parser = argparse.ArgumentParser(description="Train Tree Cycle / BA shapes Model") 23 | 24 | parser.add_argument("--data_name", nargs="?", default="BA_shapes", help="Input data path.") 25 | parser.add_argument( 26 | "--model_path", 27 | nargs="?", 28 | default=osp.join(osp.dirname(__file__), "..", "param", "gnns"), 29 | help="path for saving trained model.", 30 | ) 31 | parser.add_argument("--cuda", type=int, default=0, help="GPU device.") 32 | parser.add_argument("--epoch", type=int, default=10000, help="Number of epoch.") 33 | parser.add_argument("--lr", type=float, default=1e-3, help="Learning rate.") 34 | parser.add_argument("--batch_size", type=int, default=128, help="Batch size.") 35 | parser.add_argument("--hidden", type=int, default=128, help="hiden size.") 36 | parser.add_argument("--verbose", type=int, default=10, help="Interval of evaluation.") 37 | parser.add_argument("--num_unit", type=int, default=3, help="number of Convolution layers(units)") 38 | parser.add_argument( 39 | "--random_label", type=bool, default=False, help="train a model under label randomization for sanity check" 40 | ) 41 | return parser.parse_args() 42 | 43 | 44 | class Syn_GCN(torch.nn.Module): 45 | def __init__(self, num_unit, n_input, n_out, n_hid): 46 | super(Syn_GCN, self).__init__() 47 | self.convs = ModuleList() 48 | self.batch_norms = ModuleList() 49 | self.relus = ModuleList() 50 | self.edge_emb = Lin(3, 1) 51 | self.relus.extend([ReLU()] * num_unit) 52 | self.convs.append(GCNConv(in_channels=n_input, out_channels=n_hid)) 53 | # self.convs.append(GINConv(nn=Seq(Lin(n_input, n_hid), self.relus[0], Lin(n_hid, n_hid)))) 54 | for i in range(num_unit - 2): 55 | # self.convs.append(GINConv(nn=Seq(Lin(n_hid, n_hid), self.relus[i+1], Lin(n_hid, n_hid)))) 56 | self.convs.append(GCNConv(in_channels=n_hid, out_channels=n_hid)) 57 | self.convs.append(GCNConv(in_channels=n_hid, out_channels=n_hid)) 58 | # self.convs.append(GINConv(nn=Seq(Lin(n_hid, n_hid), self.relus[-1], Lin(n_hid, n_hid)))) 59 | self.batch_norms.extend([BatchNorm(n_hid)] * num_unit) 60 | 61 | # self.lin1 = Lin(128, 128) 62 | self.ffn = nn.Sequential(*([nn.Linear(n_hid, n_hid)] + [ReLU(), nn.Dropout(), nn.Linear(n_hid, n_out)])) 63 | 64 | self.softmax = Softmax(dim=1) 65 | self.dropout = 0 66 | 67 | def forward(self, x, edge_index, edge_attr=None): 68 | for conv, batch_norm, relu in zip(self.convs, self.batch_norms, self.relus): 69 | x = conv(x, edge_index) 70 | x = F.dropout(x, self.dropout, training=self.training) 71 | # x = relu(batch_norm(x)) 72 | x = relu(x) 73 | # graph_x = global_mean_pool(x, batch) 74 | # x_res = torch.cat(xx, dim=1) 75 | pred = self.ffn(x) # [node_num, n_class] 76 | self.readout = self.softmax(pred) 77 | return pred 78 | 79 | def get_node_reps(self, x, edge_index, edge_attr=None): 80 | for conv, batch_norm, relu in zip(self.convs, self.batch_norms, self.relus): 81 | x = conv(x, edge_index) 82 | x = F.dropout(x, self.dropout, training=self.training) 83 | x = relu(x) 84 | node_x = self.ffn[0](x) 85 | node_x = F.relu(node_x) 86 | node_x = F.dropout(node_x) 87 | return node_x 88 | 89 | def get_node_pred_subgraph(self, x, edge_index, mapping=None): 90 | for conv, batch_norm, relu in zip(self.convs, self.batch_norms, self.relus): 91 | x = conv(x, edge_index) 92 | x = F.dropout(x, self.dropout, training=self.training) 93 | # x = relu(batch_norm(x)) 94 | x = relu(x) 95 | node_repr = self.ffn(x) # [node_num, n_class] 96 | node_prob = self.softmax(node_repr) 97 | output_prob = node_prob[mapping] # [bsz, n_classes] 98 | output_repr = node_repr[mapping] # [bsz, n_classes] 99 | return output_prob, output_repr 100 | 101 | def get_pred_explain(self, x, edge_index, edge_mask, mapping=None): 102 | edge_mask = (edge_mask * EPS).sigmoid() 103 | for conv, batch_norm, relu in zip(self.convs, self.batch_norms, self.relus): 104 | x = conv(x, edge_index, edge_weight=edge_mask) 105 | x = F.dropout(x, self.dropout, training=self.training) 106 | # x = ReLU(batch_norm(x)) 107 | x = relu(x) 108 | node_repr = self.ffn(x) # [node_num, n_class] 109 | node_prob = self.softmax(node_repr) 110 | output_prob = node_prob[mapping] # [bsz, n_classes] 111 | output_repr = node_repr[mapping] # [bsz, n_classes] 112 | return output_prob, output_repr 113 | 114 | def reset_parameters(self): 115 | with torch.no_grad(): 116 | for param in self.parameters(): 117 | param.uniform_(-1.0, 1.0) 118 | 119 | if __name__ == "__main__": 120 | set_seed(44) 121 | args = parse_args() 122 | device = torch.device(f"cuda:{args.cuda}" if torch.cuda.is_available() else "cpu") 123 | name = args.data_name 124 | file_dir = osp.join(osp.dirname(__file__), "..", "data", name, "processed/whole_graph.pt") 125 | data = torch.load(file_dir) 126 | data.to(device) 127 | # if args.random_label: 128 | # for dataset in [test_dataset, val_dataset, train_dataset]: 129 | # for g in dataset: 130 | # g.y.fill_(random.choice([0, 1])) 131 | n_input = data.x.size(1) 132 | n_labels = int(torch.unique(data.y).size(0)) 133 | model = Syn_GCN(args.num_unit, n_input=n_input, n_out=n_labels, n_hid=args.hidden).to(device) 134 | 135 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) 136 | scheduler = ReduceLROnPlateau(optimizer, mode="min", factor=0.8, patience=10, min_lr=1e-5) 137 | min_error = None 138 | criterion = nn.CrossEntropyLoss() 139 | 140 | def train(epoch): 141 | t = time.time() 142 | model.train() 143 | optimizer.zero_grad() 144 | output = model(x=data.x, edge_index=data.edge_index) 145 | loss_train = criterion(output[data.train_mask], data.y[data.train_mask]) 146 | y_pred = torch.argmax(output, dim=1) 147 | acc_train = accuracy(y_pred[data.train_mask], data.y[data.train_mask]) 148 | loss_train.backward() 149 | optimizer.step() 150 | print( 151 | "Epoch: {:04d}".format(epoch + 1), 152 | "loss_train: {:.4f}".format(loss_train.item()), 153 | "acc_train: {:.4f}".format(acc_train), 154 | "time: {:.4f}s".format(time.time() - t), 155 | ) 156 | 157 | def eval(): 158 | model.eval() 159 | output = model(x=data.x, edge_index=data.edge_index) 160 | loss_test = criterion(output[data.test_mask], data.y[data.test_mask]) 161 | y_pred = torch.argmax(output, dim=1) 162 | acc_test = accuracy(y_pred[data.test_mask], data.y[data.test_mask]) 163 | print("Test set results:", "loss= {:.4f}".format(loss_test.item()), "accuracy= {:.4f}".format(acc_test)) 164 | return loss_test, y_pred 165 | 166 | for epoch in range(1, args.epoch + 1): 167 | train(epoch) 168 | 169 | if epoch % args.verbose == 0: 170 | loss_test, y_pred = eval() 171 | scheduler.step(loss_test) 172 | 173 | save_path = f"{name}_gcn.pt" 174 | 175 | if not osp.exists(args.model_path): 176 | os.makedirs(args.model_path) 177 | torch.save(model.cpu(), osp.join(args.model_path, save_path)) 178 | labels = data.y[data.test_mask].cpu().numpy() 179 | pred = y_pred[data.test_mask].cpu().numpy() 180 | print("y_true counts: {}".format(np.unique(labels, return_counts=True))) 181 | print("y_pred_orig counts: {}".format(np.unique(pred, return_counts=True))) 182 | -------------------------------------------------------------------------------- /gnns/tree_grids_gnn.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import os.path as osp 4 | import time 5 | import sys 6 | sys.path.append("..") 7 | import numpy as np 8 | import torch 9 | import torch.nn as nn 10 | from torch.nn import Linear as Lin, ModuleList, ReLU, Softmax 11 | from torch.optim.lr_scheduler import ReduceLROnPlateau 12 | from torch_geometric.nn import BatchNorm, GCNConv 13 | from torch_geometric.utils import accuracy 14 | 15 | from utils import set_seed 16 | 17 | EPS = 1 18 | 19 | 20 | def parse_args(): 21 | parser = argparse.ArgumentParser(description="Train Tree Grids Model") 22 | 23 | parser.add_argument("--data_name", nargs="?", default="Tree_Grids", help="Input data path.") 24 | parser.add_argument( 25 | "--model_path", 26 | nargs="?", 27 | default=osp.join(osp.dirname(__file__), "..", "param", "gnns"), 28 | help="path for saving trained model.", 29 | ) 30 | parser.add_argument("--cuda", type=int, default=0, help="GPU device.") 31 | parser.add_argument("--epoch", type=int, default=10000, help="Number of epoch.") 32 | parser.add_argument("--lr", type=float, default=1e-4, help="Learning rate.") 33 | parser.add_argument("--batch_size", type=int, default=128, help="Batch size.") 34 | parser.add_argument("--hidden", type=int, default=128, help="hiden size.") 35 | parser.add_argument("--verbose", type=int, default=10, help="Interval of evaluation.") 36 | parser.add_argument("--num_unit", type=int, default=4, help="number of Convolution layers(units)") 37 | parser.add_argument( 38 | "--random_label", type=bool, default=False, help="train a model under label randomization for sanity check" 39 | ) 40 | return parser.parse_args() 41 | 42 | 43 | class Syn_GCN_TG(torch.nn.Module): 44 | def __init__(self, num_unit, n_input, n_out, n_hid): 45 | super(Syn_GCN_TG, self).__init__() 46 | self.convs = ModuleList() 47 | self.batch_norms = ModuleList() 48 | self.relus = ModuleList() 49 | self.edge_emb = Lin(3, 1) 50 | self.relus.extend([ReLU()] * num_unit) 51 | self.convs.append(GCNConv(in_channels=n_input, out_channels=n_hid)) 52 | for i in range(num_unit - 2): 53 | self.convs.append(GCNConv(in_channels=n_hid, out_channels=n_hid)) 54 | self.convs.append(GCNConv(in_channels=n_hid, out_channels=n_hid)) 55 | self.batch_norms.extend([BatchNorm(n_hid)] * num_unit) 56 | 57 | # self.lin1 = Lin(128, 128) 58 | self.ffn = nn.Sequential(*([nn.Linear(n_hid, n_hid)] + [ReLU(), nn.Dropout(), nn.Linear(n_hid, n_out)])) 59 | 60 | self.softmax = Softmax(dim=1) 61 | self.dropout = 0.6 62 | 63 | def forward(self, x, edge_index, edge_attr=None): 64 | for conv, batch_norm, relu in zip(self.convs, self.batch_norms, self.relus): 65 | x = conv(x, edge_index) 66 | x = relu(x) 67 | pred = self.ffn(x) # [node_num, n_class] 68 | self.readout = self.softmax(pred) 69 | return pred 70 | 71 | def get_node_reps(self, x, edge_index, edge_attr=None): 72 | edge_weight = torch.ones((edge_index.size(1),), device=edge_index.device) 73 | for conv, batch_norm, relu in zip(self.convs, self.batch_norms, self.relus): 74 | x = conv(x, edge_index, edge_weight) 75 | x = relu(x) 76 | node_x = x 77 | return node_x 78 | 79 | def get_node_pred_subgraph(self, x, edge_index, mapping=None): 80 | for conv, batch_norm, relu in zip(self.convs, self.batch_norms, self.relus): 81 | x = conv(x, edge_index) 82 | # x = F.dropout(x, self.dropout, training=self.training) 83 | # x = relu(batch_norm(x)) 84 | x = relu(x) 85 | node_repr = self.ffn(x) # [node_num, n_class] 86 | node_prob = self.softmax(node_repr) 87 | output_prob = node_prob[mapping] # [bsz, n_classes] 88 | output_repr = node_repr[mapping] # [bsz, n_classes] 89 | return output_prob, output_repr 90 | 91 | def get_pred_explain(self, x, edge_index, edge_mask, mapping=None): 92 | edge_mask = (edge_mask * EPS).sigmoid() 93 | for conv, batch_norm, relu in zip(self.convs, self.batch_norms, self.relus): 94 | x = conv(x, edge_index, edge_weight=edge_mask) 95 | # x = F.dropout(x, self.dropout, training=self.training) 96 | # x = relu(batch_norm(x)) 97 | x = relu(x) 98 | node_repr = self.ffn(x) # [node_num, n_class] 99 | node_prob = self.softmax(node_repr) 100 | output_prob = node_prob[mapping] # [bsz, n_classes] 101 | output_repr = node_repr[mapping] # [bsz, n_classes] 102 | return output_prob, output_repr 103 | 104 | def reset_parameters(self): 105 | with torch.no_grad(): 106 | for param in self.parameters(): 107 | param.uniform_(-1.0, 1.0) 108 | 109 | 110 | 111 | if __name__ == "__main__": 112 | set_seed(33) 113 | args = parse_args() 114 | device = torch.device(f"cuda:{args.cuda}" if torch.cuda.is_available() else "cpu") 115 | name = args.data_name 116 | file_dir = osp.join(osp.dirname(__file__), "..", "data", name, "processed/whole_graph.pt") 117 | data = torch.load(file_dir) 118 | data.to(device) 119 | n_input = data.x.size(1) 120 | n_labels = int(torch.unique(data.y).size(0)) 121 | model = Syn_GCN_TG(args.num_unit, n_input=n_input, n_out=n_labels, n_hid=args.hidden).to(device) 122 | 123 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) 124 | scheduler = ReduceLROnPlateau(optimizer, mode="min", factor=0.8, patience=10, min_lr=1e-5) 125 | min_error = None 126 | criterion = nn.CrossEntropyLoss() 127 | 128 | def train(epoch): 129 | t = time.time() 130 | model.train() 131 | optimizer.zero_grad() 132 | output = model(x=data.x, edge_index=data.edge_index) 133 | loss_train = criterion(output[data.train_mask], data.y[data.train_mask]) 134 | y_pred = torch.argmax(output, dim=1) 135 | acc_train = accuracy(y_pred[data.train_mask], data.y[data.train_mask]) 136 | loss_train.backward() 137 | optimizer.step() 138 | print( 139 | "Epoch: {:04d}".format(epoch + 1), 140 | "loss_train: {:.4f}".format(loss_train.item()), 141 | "acc_train: {:.4f}".format(acc_train), 142 | "time: {:.4f}s".format(time.time() - t), 143 | ) 144 | 145 | def eval(): 146 | model.eval() 147 | output = model(x=data.x, edge_index=data.edge_index) 148 | loss_test = criterion(output[data.test_mask], data.y[data.test_mask]) 149 | y_pred = torch.argmax(output, dim=1) 150 | acc_test = accuracy(y_pred[data.test_mask], data.y[data.test_mask]) 151 | print("Test set results:", "loss= {:.4f}".format(loss_test.item()), "accuracy= {:.4f}".format(acc_test)) 152 | return loss_test, y_pred 153 | 154 | for epoch in range(1, args.epoch + 1): 155 | train(epoch) 156 | 157 | if epoch % args.verbose == 0: 158 | loss_test, y_pred = eval() 159 | scheduler.step(loss_test) 160 | 161 | save_path = f"{name}_gcn.pt" 162 | 163 | if not osp.exists(args.model_path): 164 | os.makedirs(args.model_path) 165 | torch.save(model.cpu(), osp.join(args.model_path, save_path)) 166 | labels = data.y[data.test_mask].cpu().numpy() 167 | pred = y_pred[data.test_mask].cpu().numpy() 168 | print("y_true counts: {}".format(np.unique(labels, return_counts=True))) 169 | print("y_pred_orig counts: {}".format(np.unique(pred, return_counts=True))) 170 | -------------------------------------------------------------------------------- /gnns/web_gnn.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import math 3 | import os 4 | import os.path as osp 5 | import time 6 | import sys 7 | sys.path.append("..") 8 | import numpy as np 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | from torch.nn import Parameter 13 | from torch.optim.lr_scheduler import ReduceLROnPlateau 14 | from torch_geometric.nn import BatchNorm, MessagePassing 15 | from torch_geometric.nn.inits import zeros 16 | from torch_geometric.utils import accuracy, add_remaining_self_loops 17 | from torch_scatter import scatter_add 18 | 19 | from utils import set_seed 20 | 21 | EPS = 1 22 | 23 | 24 | def parse_args(): 25 | parser = argparse.ArgumentParser(description="Train Cornell Model") 26 | 27 | parser.add_argument("--data_name", nargs="?", default="cornell", help="Input data path.") 28 | parser.add_argument( 29 | "--model_path", 30 | nargs="?", 31 | default=osp.join(osp.dirname(__file__), "..", "param", "gnns"), 32 | help="path for saving trained model.", 33 | ) 34 | parser.add_argument("--cuda", type=int, default=0, help="GPU device.") 35 | parser.add_argument("--epoch", type=int, default=3000, help="Number of epoch.") 36 | parser.add_argument("--lr", type=float, default=1e-3, help="Learning rate.") 37 | parser.add_argument("--batch_size", type=int, default=4, help="Batch size.") 38 | parser.add_argument("--hidden", type=int, default=32, help="hidden size.") 39 | parser.add_argument("--verbose", type=int, default=10, help="Interval of evaluation.") 40 | parser.add_argument("--num_unit", type=int, default=6, help="number of Convolution layers(units)") 41 | parser.add_argument( 42 | "--random_label", type=bool, default=False, help="train a model under label randomization for sanity check" 43 | ) 44 | return parser.parse_args() 45 | 46 | 47 | class SReLU(nn.Module): 48 | """Shifted ReLU""" 49 | 50 | def __init__(self, nc, bias): 51 | super(SReLU, self).__init__() 52 | self.srelu_bias = nn.Parameter(torch.Tensor(nc)) 53 | self.srelu_relu = nn.ReLU(inplace=True) 54 | nn.init.constant_(self.srelu_bias, bias) 55 | 56 | def forward(self, x): 57 | return self.srelu_relu(x - self.srelu_bias) + self.srelu_bias 58 | 59 | 60 | class EGNNConv(MessagePassing): 61 | def __init__( 62 | self, 63 | in_channels, 64 | out_channels, 65 | c_max=1.0, 66 | improved=False, 67 | cached=False, 68 | bias=True, 69 | **kwargs, 70 | ): 71 | super(EGNNConv, self).__init__(aggr="add", **kwargs) 72 | 73 | self.in_channels = in_channels 74 | self.out_channels = out_channels 75 | self.improved = improved 76 | self.cached = False 77 | self.weight = Parameter(torch.eye(in_channels) * math.sqrt(c_max)) 78 | if bias: 79 | self.bias = Parameter(torch.Tensor(out_channels)) 80 | else: 81 | self.register_parameter("bias", None) 82 | self.reset_parameters() 83 | 84 | def reset_parameters(self): 85 | zeros(self.bias) 86 | self.cached_result = None 87 | self.cached_num_edges = None 88 | 89 | @staticmethod 90 | def norm(edge_index, num_nodes, edge_weight=None, improved=False, dtype=None): 91 | if edge_weight is None: 92 | edge_weight = torch.ones((edge_index.size(1),), dtype=dtype, device=edge_index.device) 93 | fill_value = 1 if not improved else 2 94 | edge_index, edge_weight = add_remaining_self_loops(edge_index, edge_weight, fill_value, num_nodes) 95 | row, col = edge_index 96 | deg = scatter_add(edge_weight, row, dim=0, dim_size=num_nodes) 97 | deg_inv_sqrt = deg.pow(-0.5) 98 | deg_inv_sqrt[deg_inv_sqrt == float("inf")] = 0 99 | 100 | return edge_index, deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col] 101 | 102 | def forward(self, x, edge_index, x_0=None, beta=0.0, residual_weight=0.0, edge_weight=None): 103 | """""" 104 | x_input = x 105 | if self.cached and self.cached_result is not None: 106 | if edge_index.size(1) != self.cached_num_edges: 107 | raise RuntimeError( 108 | "Cached {} number of edges, but found {}. Please " 109 | "disable the caching behavior of this layer by removing " 110 | "the `cached=True` argument in its constructor.".format(self.cached_num_edges, edge_index.size(1)) 111 | ) 112 | 113 | if not self.cached or self.cached_result is None: 114 | self.cached_num_edges = edge_index.size(1) 115 | edge_index, norm = self.norm(edge_index, x.size(0), edge_weight, self.improved, x.dtype) 116 | self.cached_result = edge_index, norm 117 | 118 | edge_index, norm = self.cached_result 119 | x = self.propagate(edge_index, x=x, norm=norm) 120 | x = (1 - residual_weight - beta) * x + residual_weight * x_input + beta * x_0 121 | x = torch.matmul(x, self.weight) 122 | return x 123 | 124 | def message(self, x_j, norm): 125 | return norm.view(-1, 1) * x_j 126 | 127 | def update(self, aggr_out): 128 | if self.bias is not None: 129 | aggr_out = aggr_out + self.bias 130 | return aggr_out 131 | 132 | def __repr__(self): 133 | return "{}({}, {})".format(self.__class__.__name__, self.in_channels, self.out_channels) 134 | 135 | 136 | class EGNN(nn.Module): 137 | def __init__(self, num_node_features, hidden_channels, num_classes, num_layers, dropout=0.6): 138 | super(EGNN, self).__init__() 139 | # self.dataset = args.dataset 140 | self.num_layers = num_layers 141 | self.num_feats = num_node_features 142 | self.num_classes = num_classes 143 | self.dim_hidden = hidden_channels 144 | 145 | self.cached = False 146 | self.layers_GCN = nn.ModuleList([]) 147 | self.layers_activation = nn.ModuleList([]) 148 | self.layers_bn = nn.ModuleList([]) 149 | self.layers_bn.extend([BatchNorm(self.dim_hidden)] * self.num_layers) 150 | self.c_min = 0.2 151 | self.c_max = 1 152 | self.beta = 0.1 153 | 154 | self.bias_SReLU = -10 155 | self.dropout = dropout 156 | self.output_dropout = 0.6 157 | 158 | self.reg_params = [] 159 | for i in range(self.num_layers): 160 | c_max = self.c_max if i == 0 else 1.0 161 | self.layers_GCN.append( 162 | EGNNConv(self.dim_hidden, self.dim_hidden, c_max=c_max, cached=self.cached, bias=False) 163 | ) 164 | self.layers_activation.append(SReLU(self.dim_hidden, self.bias_SReLU)) 165 | self.reg_params.append(self.layers_GCN[-1].weight) 166 | 167 | self.input_layer = torch.nn.Linear(self.num_feats, self.dim_hidden) 168 | self.output_layer = torch.nn.Linear(self.dim_hidden, self.num_classes) 169 | self.non_reg_params = list(self.input_layer.parameters()) + list(self.output_layer.parameters()) 170 | self.srelu_params = list(self.layers_activation[:-1].parameters()) 171 | 172 | def forward(self, x, edge_index): 173 | x = F.dropout(x, p=self.dropout, training=self.training) 174 | x = self.input_layer(x) 175 | x = F.relu(x) 176 | 177 | original_x = x 178 | for i in range(self.num_layers): 179 | x = F.dropout(x, p=self.dropout, training=self.training) 180 | residual_weight = self.c_min - self.beta 181 | 182 | x = self.layers_GCN[i](x, edge_index, original_x, beta=self.beta, residual_weight=residual_weight) 183 | # x = self.layers_bn[i](x) 184 | x = self.layers_activation[i](x) 185 | 186 | x = F.dropout(x, p=self.output_dropout, training=self.training) 187 | x = self.output_layer(x) 188 | return x 189 | 190 | def get_node_pred_subgraph(self, x, edge_index, mapping=None): 191 | x = F.dropout(x, p=self.dropout, training=self.training) 192 | x = self.input_layer(x) 193 | x = F.relu(x) 194 | 195 | original_x = x 196 | for i in range(self.num_layers): 197 | x = F.dropout(x, p=self.dropout, training=self.training) 198 | residual_weight = self.c_min - self.beta 199 | x = self.layers_GCN[i](x, edge_index, original_x, beta=self.beta, residual_weight=residual_weight) 200 | x = self.layers_bn[i](x) 201 | x = self.layers_activation[i](x) 202 | 203 | x = F.dropout(x, p=self.output_dropout, training=self.training) 204 | node_repr = self.output_layer(x) 205 | node_prob = F.softmax(node_repr, dim=-1) 206 | output_prob = node_prob[mapping] # [bsz, n_classes] 207 | output_repr = node_repr[mapping] # [bsz, n_classes] 208 | return output_prob, output_repr 209 | 210 | def get_pred_explain(self, x, edge_index, edge_mask, mapping=None): 211 | edge_mask = (edge_mask * EPS).sigmoid() 212 | x = F.dropout(x, p=self.dropout, training=self.training) 213 | x = self.input_layer(x) 214 | x = F.relu(x) 215 | 216 | original_x = x 217 | for i in range(self.num_layers): 218 | x = F.dropout(x, p=self.dropout, training=self.training) 219 | residual_weight = self.c_min - self.beta 220 | x = self.layers_GCN[i]( 221 | x, edge_index, original_x, beta=self.beta, residual_weight=residual_weight, edge_weight=edge_mask 222 | ) 223 | x = self.layers_activation[i](x) 224 | x = F.dropout(x, p=self.output_dropout, training=self.training) 225 | node_repr = self.output_layer(x) 226 | node_prob = F.softmax(node_repr, dim=-1) 227 | output_prob = node_prob[mapping] # [bsz, n_classes] 228 | output_repr = node_repr[mapping] # [bsz, n_classes] 229 | return output_prob, output_repr 230 | 231 | 232 | if __name__ == "__main__": 233 | set_seed(44) 234 | args = parse_args() 235 | device = torch.device(f"cuda:{args.cuda}" if torch.cuda.is_available() else "cpu") 236 | name = args.data_name 237 | file_dir = osp.join(osp.dirname(__file__), "..", "data", name, "processed/whole_graph.pt") 238 | data = torch.load(file_dir) 239 | data.to(device) 240 | n_input = data.x.size(1) 241 | n_labels = int(torch.unique(data.y).size(0)) 242 | model = EGNN(n_input, hidden_channels=args.hidden, num_classes=n_labels, num_layers=args.num_unit).to(device) 243 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) 244 | scheduler = ReduceLROnPlateau(optimizer, mode="min", factor=0.8, patience=10, min_lr=1e-5) 245 | min_error = None 246 | criterion = nn.CrossEntropyLoss() 247 | 248 | def train(epoch): 249 | t = time.time() 250 | model.train() 251 | optimizer.zero_grad() 252 | output = model(x=data.x, edge_index=data.edge_index) 253 | loss_train = criterion(output[data.train_mask], data.y[data.train_mask]) 254 | y_pred = torch.argmax(output, dim=1) 255 | acc_train = accuracy(y_pred[data.train_mask], data.y[data.train_mask]) 256 | loss_train.backward() 257 | optimizer.step() 258 | print( 259 | "Epoch: {:04d}".format(epoch + 1), 260 | "loss_train: {:.4f}".format(loss_train.item()), 261 | "acc_train: {:.4f}".format(acc_train), 262 | "time: {:.4f}s".format(time.time() - t), 263 | ) 264 | 265 | def eval(): 266 | model.eval() 267 | output = model(x=data.x, edge_index=data.edge_index) 268 | loss_test = criterion(output[data.test_mask], data.y[data.test_mask]) 269 | y_pred = torch.argmax(output, dim=1) 270 | acc_test = accuracy(y_pred[data.test_mask], data.y[data.test_mask]) 271 | print("Test set results:", "loss= {:.4f}".format(loss_test.item()), "accuracy= {:.4f}".format(acc_test)) 272 | return loss_test, y_pred 273 | 274 | for epoch in range(1, args.epoch + 1): 275 | train(epoch) 276 | 277 | if epoch % args.verbose == 0: 278 | loss_test, y_pred = eval() 279 | scheduler.step(loss_test) 280 | 281 | save_path = f"{name}_gcn.pt" 282 | 283 | if not osp.exists(args.model_path): 284 | os.makedirs(args.model_path) 285 | torch.save(model.cpu(), osp.join(args.model_path, save_path)) 286 | labels = data.y[data.test_mask].cpu().numpy() 287 | pred = y_pred[data.test_mask].cpu().numpy() 288 | print("y_true counts: {}".format(np.unique(labels, return_counts=True))) 289 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import torch 4 | from torch_geometric.loader import DataLoader 5 | from constants import feature_dict, task_type, dataset_choices 6 | from explainers import * 7 | from gnns import * 8 | from utils.dataset import get_datasets 9 | 10 | 11 | def parse_args(): 12 | parser = argparse.ArgumentParser(description="Train explainers") 13 | parser.add_argument("--cuda", type=int, default=0, help="GPU device.") 14 | parser.add_argument("--root", type=str, default="results/", help="Result directory.") 15 | parser.add_argument("--dataset", type=str, default="Tree_Cycle", choices=dataset_choices) 16 | parser.add_argument("--verbose", type=int, default=10) 17 | parser.add_argument("--gnn_type", type=str, default="gcn") 18 | parser.add_argument("--task", type=str, default="nc") 19 | 20 | parser.add_argument("--train_batchsize", type=int, default=32) 21 | parser.add_argument("--test_batchsize", type=int, default=32) 22 | parser.add_argument("--sigma_length", type=int, default=10) 23 | parser.add_argument("--epoch", type=int, default=800) 24 | parser.add_argument("--feature_in", type=int) 25 | parser.add_argument("--data_size", type=int, default=-1) 26 | 27 | parser.add_argument("--threshold", type=float, default=0.5) 28 | parser.add_argument("--alpha_cf", type=float, default=0.5) 29 | parser.add_argument("--dropout", type=float, default=0.001) 30 | parser.add_argument("--learning_rate", type=float, default=1e-3) 31 | parser.add_argument("--lr_decay", type=float, default=0.999) 32 | parser.add_argument("--weight_decay", type=float, default=0) 33 | parser.add_argument("--prob_low", type=float, default=0.0) 34 | parser.add_argument("--prob_high", type=float, default=0.4) 35 | parser.add_argument("--sparsity_level", type=float, default=2.5) 36 | 37 | parser.add_argument("--normalization", type=str, default="instance") 38 | parser.add_argument("--num_layers", type=int, default=6) 39 | parser.add_argument("--layers_per_conv", type=int, default=1) 40 | parser.add_argument("--n_hidden", type=int, default=64) 41 | parser.add_argument("--cat_output", type=bool, default=True) 42 | parser.add_argument("--residual", type=bool, default=False) 43 | parser.add_argument("--noise_mlp", type=bool, default=True) 44 | parser.add_argument("--simplified", type=bool, default=False) 45 | 46 | return parser.parse_args() 47 | 48 | 49 | args = parse_args() 50 | args.noise_list = None 51 | 52 | args.device = torch.device(f"cuda:{args.cuda}" if torch.cuda.is_available() else "cpu") 53 | args.feature_in = feature_dict[args.dataset] 54 | args.task = task_type[args.dataset] 55 | train_dataset, val_dataset, test_dataset = get_datasets(name=args.dataset) 56 | 57 | train_dataset = train_dataset[: args.data_size] 58 | gnn_path = f"param/gnns/{args.dataset}_{args.gnn_type}.pt" 59 | explainer = DiffExplainer(args.device, gnn_path) 60 | 61 | # Train D4Explainer over train_dataset and evaluate 62 | explainer.explain_graph_task(args, train_dataset, val_dataset) 63 | 64 | # Test D4Explainer on test_dataset 65 | test_loader = DataLoader(dataset=test_dataset, batch_size=1, shuffle=False) 66 | for graph in test_loader: 67 | explanation, y_ori, y_exp, modif_r = explainer.explain_evaluation(args, graph) -------------------------------------------------------------------------------- /param/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Graph-and-Geometric-Learning/D4Explainer/997b4c755fc19d0494c09bc3bd4925777ea53aca/param/.DS_Store -------------------------------------------------------------------------------- /param/gnns/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Graph-and-Geometric-Learning/D4Explainer/997b4c755fc19d0494c09bc3bd4925777ea53aca/param/gnns/.DS_Store -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | certifi==2022.12.7 2 | charset-normalizer==3.0.1 3 | contourpy==1.0.7 4 | cycler==0.11.0 5 | fonttools==4.38.0 6 | idna==3.4 7 | importlib-resources==5.12.0 8 | Jinja2==3.1.2 9 | joblib==1.2.0 10 | kiwisolver==1.4.4 11 | MarkupSafe==2.1.2 12 | matplotlib==3.7.0 13 | networkx==3.0 14 | numpy==1.24.2 15 | opt-einsum==3.3.0 16 | packaging==23.0 17 | pandas==1.5.3 18 | patsy==0.5.3 19 | pgmpy==0.1.21 20 | Pillow==9.4.0 21 | pyemd==1.0.0 22 | pyparsing==3.0.9 23 | python-dateutil==2.8.2 24 | pytz==2022.7.1 25 | requests==2.28.2 26 | scikit-learn==1.2.1 27 | scipy==1.10.1 28 | six==1.16.0 29 | statsmodels==0.13.5 30 | threadpoolctl==3.1.0 31 | torch==1.10.1+cu102 32 | torch-cluster==1.6.0 33 | torch-geometric==2.0.4 34 | torch-scatter==2.0.9 35 | torch-sparse==0.6.13 36 | torchaudio==0.10.1+cu102 37 | torchvision==0.11.2+cu102 38 | tqdm==4.64.1 39 | typing_extensions==4.5.0 40 | urllib3==1.26.14 41 | zipp==3.15.0 42 | -------------------------------------------------------------------------------- /results/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Graph-and-Geometric-Learning/D4Explainer/997b4c755fc19d0494c09bc3bd4925777ea53aca/results/.DS_Store -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .dataset import get_datasets 2 | from .dist_helper import compute_mmd, gaussian, gaussian_emd, process_tensor 3 | from .helper import set_seed 4 | from .train_utils import Gtest, Gtrain 5 | 6 | __all__ = ["get_datasets", "compute_mmd", "gaussian", "gaussian_emd", "process_tensor", "set_seed", "Gtest", "Gtrain"] 7 | -------------------------------------------------------------------------------- /utils/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from datasets import NCI1, BA3Motif, Mutagenicity, SynGraphDataset, WebDataset, bbbp 4 | 5 | 6 | def get_datasets(name, root="data/"): 7 | """ 8 | Get preloaded datasets by name 9 | :param name: name of the dataset 10 | :param root: root path of the dataset 11 | :return: train_dataset, test_dataset, val_dataset 12 | """ 13 | if name == "mutag": 14 | folder = os.path.join(root, "MUTAG") 15 | train_dataset = Mutagenicity(folder, mode="training") 16 | test_dataset = Mutagenicity(folder, mode="testing") 17 | val_dataset = Mutagenicity(folder, mode="evaluation") 18 | elif name == "NCI1": 19 | folder = os.path.join(root, "NCI1") 20 | train_dataset = NCI1(folder, mode="training") 21 | test_dataset = NCI1(folder, mode="testing") 22 | val_dataset = NCI1(folder, mode="evaluation") 23 | elif name == "ba3": 24 | folder = os.path.join(root, "BA3") 25 | train_dataset = BA3Motif(folder, mode="training") 26 | test_dataset = BA3Motif(folder, mode="testing") 27 | val_dataset = BA3Motif(folder, mode="evaluation") 28 | elif name == "BA_shapes": 29 | folder = os.path.join(root) 30 | test_dataset = SynGraphDataset(folder, mode="testing", name="BA_shapes") 31 | val_dataset = SynGraphDataset(folder, mode="evaluating", name="BA_shapes") 32 | train_dataset = SynGraphDataset(folder, mode="training", name="BA_shapes") 33 | elif name == "Tree_Cycle": 34 | folder = os.path.join(root) 35 | test_dataset = SynGraphDataset(folder, mode="testing", name="Tree_Cycle") 36 | val_dataset = SynGraphDataset(folder, mode="evaluating", name="Tree_Cycle") 37 | train_dataset = SynGraphDataset(folder, mode="training", name="Tree_Cycle") 38 | elif name == "Tree_Grids": 39 | folder = os.path.join(root) 40 | test_dataset = SynGraphDataset(folder, mode="testing", name="Tree_Grids") 41 | val_dataset = SynGraphDataset(folder, mode="evaluating", name="Tree_Grids") 42 | train_dataset = SynGraphDataset(folder, mode="training", name="Tree_Grids") 43 | elif name == "bbbp": 44 | folder = os.path.join(root, "bbbp") 45 | dataset = bbbp(folder) 46 | test_dataset = dataset[:200] 47 | val_dataset = dataset[200:400] 48 | train_dataset = dataset[400:] 49 | elif name == "cornell": 50 | folder = os.path.join(root) 51 | test_dataset = WebDataset(folder, mode="testing", name=name) 52 | val_dataset = WebDataset(folder, mode="evaluating", name=name) 53 | train_dataset = WebDataset(folder, mode="training", name=name) 54 | else: 55 | raise ValueError 56 | return train_dataset, val_dataset, test_dataset 57 | -------------------------------------------------------------------------------- /utils/dist_helper.py: -------------------------------------------------------------------------------- 1 | import concurrent.futures 2 | from functools import partial 3 | 4 | import numpy as np 5 | import pyemd 6 | from scipy.linalg import toeplitz 7 | 8 | # NOTES: 9 | # EMD stands for earth move distance, i.e. Wasserstein metric, 10 | # (\inf_{\gama \in \Gama(\mu, \nu) \int_{M*M} d(x,y)^p d\gama(x,y))^(1/p) 11 | 12 | 13 | def emd(x, y, distance_scaling=1.0): 14 | """ 15 | Earth Mover's Distance (EMD) between two 1D pmf 16 | :param x: 1D pmf 17 | :param y: 1D pmf 18 | :param distance_scaling: scaling factor for distance matrix 19 | :return: EMD distance 20 | """ 21 | x = x.astype(float) 22 | y = y.astype(float) 23 | support_size = max(len(x), len(y)) 24 | d_mat = toeplitz(range(support_size)).astype(float) # diagonal-constant matrix 25 | distance_mat = d_mat / distance_scaling 26 | x, y = process_tensor(x, y) 27 | 28 | emd_value = pyemd.emd(x, y, distance_mat) 29 | return np.abs(emd_value) 30 | 31 | 32 | def l2(x, y): 33 | """ 34 | L2 distance between two 1D pmf 35 | :param x: 1D pmf 36 | :param y: 1D pmf 37 | :return: L2 distance 38 | """ 39 | dist = np.linalg.norm(x - y, 2) 40 | return dist 41 | 42 | 43 | def gaussian_emd(x, y, sigma=1.0, distance_scaling=1.0): 44 | """ 45 | Gaussian kernel with squared distance in exponential term replaced by EMD 46 | :param x: 1D pmf 47 | :param y: 1D pmf 48 | :param sigma: standard deviation 49 | :param distance_scaling: scaling factor for distance matrix 50 | :return: Gaussian kernel with EMD 51 | """ 52 | emd_value = emd(x, y, distance_scaling) 53 | return np.exp(-emd_value * emd_value / (2 * sigma * sigma)) 54 | 55 | 56 | def gaussian(x, y, sigma=1.0): 57 | x = x.astype(float) 58 | y = y.astype(float) 59 | x, y = process_tensor(x, y) 60 | dist = np.linalg.norm(x - y, 2) 61 | return np.exp(-dist * dist / (2 * sigma * sigma)) 62 | 63 | 64 | def gaussian_tv(x, y, sigma=1.0): 65 | # convert histogram values x and y to float, and make them equal len 66 | x = x.astype(float) 67 | y = y.astype(float) 68 | x, y = process_tensor(x, y) 69 | 70 | dist = np.abs(x - y).sum() / 2.0 71 | return np.exp(-dist * dist / (2 * sigma * sigma)) 72 | 73 | 74 | def kernel_parallel_unpacked(x, samples2, kernel): 75 | d = 0 76 | for s2 in samples2: 77 | d += kernel(x, s2) 78 | return d 79 | 80 | 81 | def kernel_parallel_worker(t): 82 | return kernel_parallel_unpacked(*t) 83 | 84 | 85 | def disc(samples1, samples2, kernel, is_parallel=True, *args, **kwargs): 86 | """ 87 | Discrepancy between 2 samples 88 | :param samples1: list of samples 89 | :param samples2: list of samples 90 | :param kernel: kernel function 91 | :param is_parallel: whether to use parallel computation 92 | :param args: args for kernel 93 | :param kwargs: kwargs for kernel 94 | """ 95 | d = 0 96 | if not is_parallel: 97 | for s1 in samples1: 98 | for s2 in samples2: 99 | d += kernel(s1, s2, *args, **kwargs) 100 | else: 101 | with concurrent.futures.ProcessPoolExecutor() as executor: 102 | for dist in executor.map( 103 | kernel_parallel_worker, 104 | [(s1, samples2, partial(kernel, *args, **kwargs)) for s1 in samples1], 105 | ): 106 | d += dist 107 | if len(samples1) * len(samples2) > 0: 108 | d /= len(samples1) * len(samples2) 109 | else: 110 | d = 1e6 111 | return d 112 | 113 | 114 | def compute_mmd(samples1, samples2, kernel, is_hist=True, *args, **kwargs): 115 | """ 116 | MMD between two samples 117 | :param samples1: list of samples 118 | :param samples2: list of samples 119 | :param kernel: kernel function 120 | :param is_hist: whether the samples are histograms or pmf 121 | :param args: args for kernel 122 | :param kwargs: kwargs for kernel 123 | """ 124 | # normalize histograms into pmf 125 | if is_hist: 126 | samples1 = [s1 / np.sum(s1) for s1 in samples1] 127 | samples2 = [s2 / np.sum(s2) for s2 in samples2] 128 | return ( 129 | disc(samples1, samples1, kernel, *args, **kwargs) 130 | + disc(samples2, samples2, kernel, *args, **kwargs) 131 | - 2 * disc(samples1, samples2, kernel, *args, **kwargs) 132 | ) 133 | 134 | 135 | def compute_emd(samples1, samples2, kernel, is_hist=True, *args, **kwargs): 136 | """ 137 | EMD between average of two samples 138 | :param samples1: list of samples 139 | :param samples2: list of samples 140 | :param kernel: kernel function 141 | :param is_hist: whether the samples are histograms or pmf 142 | :param args: args for kernel 143 | :param kwargs: kwargs for kernel 144 | """ 145 | # normalize histograms into pmf 146 | if is_hist: 147 | samples1 = [np.mean(samples1)] 148 | samples2 = [np.mean(samples2)] 149 | return disc(samples1, samples2, kernel, *args, **kwargs), [samples1[0], samples2[0]] 150 | 151 | 152 | def test(): 153 | s1 = np.array([0.2, 0.8]) 154 | s2 = np.array([0.3, 0.7]) 155 | samples1 = [s1, s2] 156 | 157 | s3 = np.array([0.25, 0.75]) 158 | s4 = np.array([0.35, 0.65]) 159 | samples2 = [s3, s4] 160 | 161 | s5 = np.array([0.8, 0.2]) 162 | s6 = np.array([0.7, 0.3]) 163 | samples3 = [s5, s6] 164 | 165 | # print( 166 | # "between samples1 and samples2: ", 167 | # compute_emd(samples1, samples2, kernel=gaussian_emd, is_parallel=False, sigma=1.0), 168 | # ) 169 | # print( 170 | # "between samples1 and samples3: ", 171 | # compute_emd(samples1, samples3, kernel=gaussian_emd, is_parallel=False, sigma=1.0), 172 | # ) 173 | print( 174 | "between samples1 and samples2: ", 175 | compute_mmd(samples1, samples2, kernel=gaussian, is_parallel=True, sigma=1.0), 176 | ) 177 | print( 178 | "between samples1 and samples3: ", 179 | compute_mmd(samples1, samples3, kernel=gaussian, is_parallel=True, sigma=1.0), 180 | ) 181 | print( 182 | "between samples1 and samples2: ", 183 | compute_mmd(samples1, samples2, kernel=gaussian, is_parallel=True, sigma=1.0), 184 | ) 185 | print( 186 | "between samples1 and samples3: ", 187 | compute_mmd(samples1, samples3, kernel=gaussian, is_parallel=True, sigma=1.0), 188 | ) 189 | 190 | 191 | def process_tensor(x, y): 192 | """ 193 | Helper function to pad tensors to the same size 194 | :param x: tensor 195 | :param y: tensor 196 | :return: padded tensors 197 | """ 198 | support_size = max(len(x), len(y)) 199 | if len(x) < len(y): 200 | x = np.hstack((x, [0.0] * (support_size - len(x)))) 201 | elif len(y) < len(x): 202 | y = np.hstack((y, [0.0] * (support_size - len(y)))) 203 | return x, y 204 | -------------------------------------------------------------------------------- /utils/helper.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import numpy as np 4 | import torch 5 | 6 | 7 | def set_seed(seed): 8 | """ 9 | Set seed for reproducibility 10 | :param seed: seed 11 | """ 12 | random.seed(seed) 13 | np.random.seed(seed) 14 | torch.manual_seed(seed) 15 | torch.cuda.manual_seed(seed) 16 | torch.backends.cudnn.benchmark = False 17 | torch.backends.cudnn.deterministic = True 18 | -------------------------------------------------------------------------------- /utils/train_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | def Gtrain(train_loader, model, optimizer, device, criterion=nn.MSELoss()): 6 | """ 7 | General training function for graph classification 8 | :param train_loader: DataLoader 9 | :param model: model 10 | :param optimizer: optimizer 11 | :param device: device 12 | :param criterion: loss function (default: MSELoss) 13 | """ 14 | model.train() 15 | loss_all = 0 16 | criterion = criterion 17 | 18 | for data in train_loader: 19 | data.to(device) 20 | optimizer.zero_grad() 21 | out = model(data.x, data.edge_index, data.batch) 22 | loss = criterion(out, data.y) 23 | loss.backward() 24 | loss_all += loss.item() * data.num_graphs 25 | optimizer.step() 26 | 27 | return loss_all / len(train_loader.dataset) 28 | 29 | 30 | def Gtest(test_loader, model, device, criterion=nn.L1Loss(reduction="mean")): 31 | """ 32 | General test function for graph classification 33 | :param test_loader: DataLoader 34 | :param model: model 35 | :param device: device 36 | :param criterion: loss function (default: L1Loss) 37 | :return: error, accuracy 38 | """ 39 | model.eval() 40 | error = 0 41 | correct = 0 42 | 43 | with torch.no_grad(): 44 | for data in test_loader: 45 | data = data.to(device) 46 | output = model( 47 | data.x, 48 | data.edge_index, 49 | data.batch, 50 | ) 51 | 52 | error += criterion(output, data.y) * data.num_graphs 53 | correct += float(output.argmax(dim=1).eq(data.y).sum().item()) 54 | 55 | return error / len(test_loader.dataset), correct / len(test_loader.dataset) 56 | --------------------------------------------------------------------------------