├── gnnexplainer ├── __init__.py └── explain.py ├── gnnexplainer_utils ├── __init__.py ├── math_utils.py ├── train_utils.py ├── parser_utils.py ├── featgen.py ├── graph_utils.py ├── synthetic_structsim.py └── io_utils.py ├── qualitative ├── REDDITBINARY.model ├── dataloader.py └── gin.py ├── requirements.txt ├── evaluate.py ├── README.md ├── gnnexplainer_configs.py ├── .gitignore ├── gnnexplainer_gengraph.py ├── gen_dataset.py ├── explain.py ├── gnnexplainer_main.py └── models.py /gnnexplainer/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /gnnexplainer_utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /qualitative/REDDITBINARY.model: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lukasjf/contrastive-gnn-explanation/HEAD/qualitative/REDDITBINARY.model -------------------------------------------------------------------------------- /gnnexplainer_utils/math_utils.py: -------------------------------------------------------------------------------- 1 | """ math_utils.py 2 | 3 | Math utilities. 4 | """ 5 | 6 | import torch 7 | 8 | def exp_moving_avg(x, decay=0.9): 9 | '''Exponentially decaying moving average. 10 | ''' 11 | shadow = x[0] 12 | a = [shadow] 13 | for v in x[1:]: 14 | shadow -= (1-decay) * (shadow-v) 15 | a.append(shadow) 16 | return a 17 | 18 | def tv_norm(input, tv_beta): 19 | '''Total variation norm 20 | ''' 21 | img = input[0, 0, :] 22 | row_grad = torch.mean(torch.abs((img[:-1, :] - img[1:, :])).pow(tv_beta)) 23 | col_grad = torch.mean(torch.abs((img[:, :-1] - img[:, 1:])).pow(tv_beta)) 24 | return row_grad + col_grad -------------------------------------------------------------------------------- /gnnexplainer_utils/train_utils.py: -------------------------------------------------------------------------------- 1 | '''train_utils.py 2 | 3 | Some training utilities. 4 | ''' 5 | import torch.optim as optim 6 | 7 | def build_optimizer(args, params, weight_decay=0.0): 8 | filter_fn = filter(lambda p : p.requires_grad, params) 9 | if args.opt == 'adam': 10 | optimizer = optim.Adam(filter_fn, lr=args.lr, weight_decay=weight_decay) 11 | elif args.opt == 'sgd': 12 | optimizer = optim.SGD(filter_fn, lr=args.lr, momentum=0.95, weight_decay=weight_decay) 13 | elif args.opt == 'rmsprop': 14 | optimizer = optim.RMSprop(filter_fn, lr=args.lr, weight_decay=weight_decay) 15 | elif args.opt == 'adagrad': 16 | optimizer = optim.Adagrad(filter_fn, lr=args.lr, weight_decay=weight_decay) 17 | if args.opt_scheduler == 'none': 18 | return None, optimizer 19 | elif args.opt_scheduler == 'step': 20 | scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=args.opt_decay_step, gamma=args.opt_decay_rate) 21 | elif args.opt_scheduler == 'cos': 22 | scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.opt_restart) 23 | return scheduler, optimizer 24 | -------------------------------------------------------------------------------- /gnnexplainer_utils/parser_utils.py: -------------------------------------------------------------------------------- 1 | """ parser_utils.py 2 | 3 | Parsing utilities. 4 | """ 5 | import argparse 6 | 7 | def parse_optimizer(parser): 8 | '''Set optimizer parameters''' 9 | opt_parser = parser.add_argument_group() 10 | opt_parser.add_argument('--opt', dest='opt', type=str, 11 | help='Type of optimizer') 12 | opt_parser.add_argument('--opt-scheduler', dest='opt_scheduler', type=str, 13 | help='Type of optimizer scheduler. By default none') 14 | opt_parser.add_argument('--opt-restart', dest='opt_restart', type=int, 15 | help='Number of epochs before restart (by default set to 0 which means no restart)') 16 | opt_parser.add_argument('--opt-decay-step', dest='opt_decay_step', type=int, 17 | help='Number of epochs before decay') 18 | opt_parser.add_argument('--opt-decay-rate', dest='opt_decay_rate', type=float, 19 | help='Learning rate decay ratio') 20 | opt_parser.add_argument('--lr', dest='lr', type=float, 21 | help='Learning rate.') 22 | opt_parser.add_argument('--clip', dest='clip', type=float, 23 | help='Gradient clipping.') 24 | 25 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | appnope==0.1.0 2 | attrs==19.3.0 3 | backcall==0.2.0 4 | bleach==3.1.5 5 | click==7.1.2 6 | cycler==0.10.0 7 | decorator==4.4.2 8 | defusedxml==0.6.0 9 | entrypoints==0.3 10 | future==0.18.2 11 | geomloss==0.2.3 12 | ipykernel==5.3.0 13 | ipython==7.16.1 14 | ipython-genutils==0.2.0 15 | ipywidgets==7.5.1 16 | jedi==0.17.1 17 | Jinja2==2.11.2 18 | joblib==0.16.0 19 | jsonschema==3.2.0 20 | jupyter==1.0.0 21 | jupyter-client==6.1.5 22 | jupyter-console==6.1.0 23 | jupyter-core==4.6.3 24 | kiwisolver==1.2.0 25 | MarkupSafe==1.1.1 26 | matplotlib==3.2.2 27 | mistune==0.8.4 28 | nbconvert==5.6.1 29 | nbformat==5.0.7 30 | networkx==2.4 31 | notebook==6.0.3 32 | numpy==1.19.0 33 | opencv-python==4.2.0.34 34 | packaging==20.4 35 | pandas==1.0.5 36 | pandocfilters==1.4.2 37 | parso==0.7.0 38 | pexpect==4.8.0 39 | pickleshare==0.7.5 40 | prometheus-client==0.8.0 41 | prompt-toolkit==3.0.5 42 | protobuf==3.12.2 43 | ptyprocess==0.6.0 44 | Pygments==2.6.1 45 | pyparsing==2.4.7 46 | pyrsistent==0.16.0 47 | python-dateutil==2.8.1 48 | pytz==2020.1 49 | pyzmq==19.0.1 50 | qtconsole==4.7.5 51 | QtPy==1.9.0 52 | scikit-learn==0.23.1 53 | scipy==1.5.0 54 | seaborn==0.10.1 55 | Send2Trash==1.5.0 56 | six==1.15.0 57 | sklearn==0.0 58 | tensorboardX==2.0 59 | terminado==0.8.3 60 | testpath==0.4.4 61 | threadpoolctl==2.1.0 62 | torch==1.5.1 63 | tornado==6.0.4 64 | tqdm==4.46.1 65 | traitlets==4.3.3 66 | typer==0.3.0 67 | wcwidth==0.2.5 68 | webencodings==0.5.1 69 | widgetsnbextension==3.5.1 70 | -------------------------------------------------------------------------------- /gnnexplainer_utils/featgen.py: -------------------------------------------------------------------------------- 1 | """ featgen.py 2 | 3 | Node feature generators. 4 | 5 | """ 6 | import networkx as nx 7 | import numpy as np 8 | import random 9 | 10 | import abc 11 | 12 | 13 | class FeatureGen(metaclass=abc.ABCMeta): 14 | """Feature Generator base class.""" 15 | @abc.abstractmethod 16 | def gen_node_features(self, G): 17 | pass 18 | 19 | 20 | class ConstFeatureGen(FeatureGen): 21 | """Constant Feature class.""" 22 | def __init__(self, val): 23 | self.val = val 24 | 25 | def gen_node_features(self, G): 26 | feat_dict = {i:{'feat': np.array(self.val, dtype=np.float32)} for i in G.nodes()} 27 | print ('feat_dict[0]["feat"]:', feat_dict[0]['feat'].dtype) 28 | nx.set_node_attributes(G, feat_dict) 29 | print ('G.nodes[0]["feat"]:', G.nodes[0]['feat'].dtype) 30 | 31 | 32 | class GaussianFeatureGen(FeatureGen): 33 | """Gaussian Feature class.""" 34 | def __init__(self, mu, sigma): 35 | self.mu = mu 36 | if sigma.ndim < 2: 37 | self.sigma = np.diag(sigma) 38 | else: 39 | self.sigma = sigma 40 | 41 | def gen_node_features(self, G): 42 | feat = np.random.multivariate_normal(self.mu, self.sigma, G.number_of_nodes()) 43 | feat_dict = { 44 | i: {"feat": feat[i]} for i in range(feat.shape[0]) 45 | } 46 | nx.set_node_attributes(G, feat_dict) 47 | 48 | 49 | class GridFeatureGen(FeatureGen): 50 | """Grid Feature class.""" 51 | def __init__(self, mu, sigma, com_choices): 52 | self.mu = mu # Mean 53 | self.sigma = sigma # Variance 54 | self.com_choices = com_choices # List of possible community labels 55 | 56 | def gen_node_features(self, G): 57 | # Generate community assignment 58 | community_dict = { 59 | n: self.com_choices[0] if G.degree(n) < 4 else self.com_choices[1] 60 | for n in G.nodes() 61 | } 62 | 63 | # Generate random variable 64 | s = np.random.normal(self.mu, self.sigma, G.number_of_nodes()) 65 | 66 | # Generate features 67 | feat_dict = { 68 | n: {"feat": np.asarray([community_dict[n], s[i]])} 69 | for i, n in enumerate(G.nodes()) 70 | } 71 | 72 | nx.set_node_attributes(G, feat_dict) 73 | return community_dict 74 | 75 | -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | import os 2 | from collections import defaultdict 3 | from pathlib import Path 4 | 5 | import numpy as np 6 | import typer 7 | 8 | from explain import read_graphs 9 | 10 | 11 | def main(dataset_path: Path, explain_path: Path): 12 | nx_graphs, labels = read_graphs(dataset_path) 13 | 14 | graph_masked_adjs = {} 15 | 16 | name_map = {} 17 | for name in os.listdir(str(explain_path)): 18 | if not name.endswith('npy'): 19 | continue 20 | name_map[name.split('.')[0]] = name 21 | 22 | for i in nx_graphs: 23 | if str(i) not in name_map: 24 | continue 25 | masked_adj = np.load(str(explain_path / name_map[str(i)])) 26 | last_idx = len(nx_graphs[i].nodes) 27 | masked_adj = masked_adj[:last_idx, :last_idx] 28 | graph_masked_adjs[i] = masked_adj 29 | 30 | def explain(graph_num): 31 | edge_importance = {} 32 | for u, v in nx_graphs[graph_num].edges(): 33 | u = int(u) 34 | v = int(v) 35 | edge_importance[(u, v)] = graph_masked_adjs[graph_num][u, v] 36 | return edge_importance 37 | 38 | def get_correct_edges(g): 39 | nodes_by_label = defaultdict(list) 40 | for u, data in g.nodes(data=True): 41 | nodes_by_label[data['label']].append(u) 42 | edges = [] 43 | for label, nodes in nodes_by_label.items(): 44 | if label == '0' or label == 0: 45 | continue 46 | edges.extend([(int(u), int(v)) for u, v in g.subgraph(nodes).edges()]) 47 | return edges 48 | 49 | def get_accuracy(correct_edges, edge_importance): 50 | 51 | # Extract as many as correct edges 52 | predicted_edges = sorted(edge_importance.keys(), key=lambda e: -edge_importance[e])[:len(correct_edges)] 53 | correct = 0 54 | for u, v in predicted_edges: 55 | if (u, v) in correct_edges or (v, u) in correct_edges: 56 | correct += 1 57 | return correct / len(correct_edges) 58 | 59 | accs = [] 60 | accs_by_label = defaultdict(list) 61 | for idx in nx_graphs: 62 | if str(idx) not in name_map: 63 | continue 64 | correct_edges = get_correct_edges(nx_graphs[idx]) 65 | if len(correct_edges) == 0: 66 | continue 67 | edge_importance = explain(idx) 68 | acc = get_accuracy(correct_edges, edge_importance) 69 | accs.append(acc) 70 | accs_by_label[labels[idx]].append(acc) 71 | print('Total accuracy:') 72 | print('Sample count:', len(accs), 'Mean accuracy:', np.mean(accs), 'standard deviation:', np.std(accs)) 73 | print('Accuracy by label:') 74 | for k, v in sorted(accs_by_label.items()): 75 | print('Accuracy for label', k) 76 | print('Sample count:', len(v), 'Mean accuracy:', np.mean(v), 'standard deviation:', np.std(v)) 77 | print('-' * 40) 78 | 79 | 80 | if __name__ == '__main__': 81 | app = typer.Typer(add_completion=False) 82 | app.command()(main) 83 | app() 84 | -------------------------------------------------------------------------------- /qualitative/dataloader.py: -------------------------------------------------------------------------------- 1 | """ 2 | PyTorch compatible dataloader 3 | """ 4 | 5 | 6 | import math 7 | import numpy as np 8 | import torch 9 | from torch.utils.data import DataLoader 10 | from torch.utils.data.sampler import SubsetRandomSampler 11 | from sklearn.model_selection import StratifiedKFold 12 | import dgl 13 | 14 | 15 | # default collate function 16 | def collate(samples): 17 | # The input `samples` is a list of pairs (graph, label). 18 | graphs, labels = map(list, zip(*samples)) 19 | for g in graphs: 20 | # deal with node feats 21 | for key in g.node_attr_schemes().keys(): 22 | g.ndata[key] = g.ndata[key].float() 23 | # no edge feats 24 | batched_graph = dgl.batch(graphs) 25 | labels = torch.tensor(labels) 26 | return batched_graph, labels 27 | 28 | 29 | class GraphDataLoader(): 30 | def __init__(self, 31 | dataset, 32 | batch_size, 33 | device, 34 | collate_fn=collate, 35 | seed=0, 36 | shuffle=True, 37 | split_name='fold10', 38 | fold_idx=0, 39 | split_ratio=0.7): 40 | 41 | self.shuffle = shuffle 42 | self.seed = seed 43 | self.kwargs = {'pin_memory': True} if 'cuda' in device.type else {} 44 | 45 | labels = [l for _, l in dataset] 46 | 47 | if split_name == 'fold10': 48 | train_idx, valid_idx = self._split_fold10( 49 | labels, fold_idx, seed, shuffle) 50 | elif split_name == 'rand': 51 | train_idx, valid_idx = self._split_rand( 52 | labels, split_ratio, seed, shuffle) 53 | else: 54 | raise NotImplementedError() 55 | 56 | train_sampler = SubsetRandomSampler(train_idx) 57 | valid_sampler = SubsetRandomSampler(valid_idx) 58 | 59 | self.train_loader = DataLoader( 60 | dataset, sampler=train_sampler, 61 | batch_size=batch_size, collate_fn=collate_fn, **self.kwargs) 62 | self.valid_loader = DataLoader( 63 | dataset, sampler=valid_sampler, 64 | batch_size=batch_size, collate_fn=collate_fn, **self.kwargs) 65 | 66 | def train_valid_loader(self): 67 | return self.train_loader, self.valid_loader 68 | 69 | def _split_fold10(self, labels, fold_idx=0, seed=0, shuffle=True): 70 | ''' 10 flod ''' 71 | assert 0 <= fold_idx and fold_idx < 10, print( 72 | "fold_idx must be from 0 to 9.") 73 | 74 | skf = StratifiedKFold(n_splits=10, shuffle=shuffle, random_state=seed) 75 | idx_list = [] 76 | for idx in skf.split(np.zeros(len(labels)), labels): # split(x, y) 77 | idx_list.append(idx) 78 | train_idx, valid_idx = idx_list[fold_idx] 79 | 80 | print( 81 | "train_set : test_set = %d : %d", 82 | len(train_idx), len(valid_idx)) 83 | 84 | return train_idx, valid_idx 85 | 86 | def _split_rand(self, labels, split_ratio=0.7, seed=0, shuffle=True): 87 | num_entries = len(labels) 88 | indices = list(range(num_entries)) 89 | np.random.seed(seed) 90 | np.random.shuffle(indices) 91 | split = int(math.floor(split_ratio * num_entries)) 92 | train_idx, valid_idx = indices[:split], indices[split:] 93 | 94 | print( 95 | "train_set : test_set = %d : %d", 96 | len(train_idx), len(valid_idx)) 97 | 98 | return train_idx, valid_idx 99 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Contrastive Graph Neural Network Explanation 2 | 3 | This is the source code for the paper _Contrastive Graph Neural Network Explanation_ 4 | 5 | ## Required libraries 6 | You can install the required libraries by running: 7 | ```shell 8 | pip install -r requirements.txt 9 | ``` 10 | 11 | ## Recreating the benchmark dataset 12 | Run the following command to recreate the datasets 13 | ```shell 14 | python gen_dataset.py DATASET_NAME 15 | ``` 16 | `DATASET_NAME` can be one of the followings: 17 | * `CYCLIQ`: Same as the dataset mentioned in the paper 18 | * `CYCLIQ-MULTI`: 2 additional classes compared to `CYCLIQ`. One class is base trees without any attachment and the other is base trees with both cycles and cliques. 19 | * `TRISQ`: 1000 base trees with attached triangles and 1000 base trees with attached squares (cycles of length 4) 20 | * `HOUSE_CLIQ`: 1000 base trees with attached cliques of size 5 and 1000 base trees with attached [house graphs](https://mathworld.wolfram.com/HouseGraph.html) 21 | 22 | You can add a `--sample-size` option to change the number of samples for each label (default is 1000). 23 | 24 | 25 | ## Training a GNN model 26 | We use GNNExplainer source code for training the model, making node embeddings, and generating the explanation. 27 | Any file or folder with the prefix of `gnnexplainer` is ported from [GNNExplainer repo](https://github.com/RexYing/gnn-model-explainer) with small modifications to make it compatible with our framework. 28 | To train a model, you can run the following command: 29 | ```sh 30 | python gnnexplainer_train.py --bmname=DATASET_NAME --epochs=20 31 | ``` 32 | The following command will output a model in `ckpt` folder which you will use in subsequent commands. 33 | 34 | ## Generating node embeddings and GNNExplainer output 35 | Run the following command for generating node embeddings for all graphs as well as GNNExplainer explanation 36 | ```sh 37 | python gnnexplainer_main.py --bmname=DATASET_NAME --graph-mode --explain-all 38 | ``` 39 | This will create a new folder with name `embeddings-DATASET_NAME` and explanations in `explanations/gnnexplainer` folder. 40 | 41 | ## Running other explanation methods 42 | You can run all the other explanation method via `explain.py` script 43 | ```sh 44 | python explain.py contrast | sensitivity | occlusion | random 45 | ``` 46 | Use `--help` option to see all the available options for each command. For example 47 | ```sh 48 | python explain.py contrast --help 49 | ``` 50 | 51 | ## Evaluation 52 | Run the following command to see the accuracy of explanations 53 | ```sh 54 | python evaluate.py DATASET_PATH EXPLAIN_PATH 55 | ``` 56 | 57 | ## Complete example 58 | Here is the complete list of commands needed to reproduce the paper results: 59 | ```sh 60 | python gen_dataset.py CYCLIQ 61 | 62 | python gnnexplainer_train.py --bmname=CYCLIQ --epochs=100 63 | python gnnexplainer_main.py --bmname=CYCLIQ --graph-mode --explain-all 64 | 65 | python explain.py random data/CYCLIQ/ explanations/random 66 | python explain.py sensitivity data/CYCLIQ/ ckpt/CYCLIQ_base_h20_o20.pth.tar explanations/sensitivity 67 | python explain.py occlusion data/CYCLIQ/ ckpt/CYCLIQ_base_h20_o20.pth.tar explanations/occlusion 68 | python explain.py contrast data/CYCLIQ/ embeddings-CYCLIQ explanations/contrast 69 | 70 | python evaluate.py data/CYCLIQ/ explanations/gnnexplainer/ 71 | python evaluate.py data/CYCLIQ/ explanations/random/ 72 | python evaluate.py data/CYCLIQ/ explanations/sensitivity/ 73 | python evaluate.py data/CYCLIQ/ explanations/occlusion/ 74 | python evaluate.py data/CYCLIQ/ explanations/contrast/ 75 | ``` 76 | 77 | ## Visualizing Explanations 78 | You can run the `Visualize.ipynb` notebook for visualizing each method explanation 79 | 80 | ## Qualitative analysis 81 | For reproducing the qualitative experiments look into the notebooks in the `qualitative` folder. 82 | -------------------------------------------------------------------------------- /gnnexplainer_configs.py: -------------------------------------------------------------------------------- 1 | # This file is copied from GNNExplainer repo with small modifications 2 | # https://github.com/RexYing/gnn-model-explainer 3 | 4 | import argparse 5 | import gnnexplainer_utils.parser_utils as parser_utils 6 | 7 | def arg_parse(): 8 | parser = argparse.ArgumentParser(description='GraphPool arguments.') 9 | io_parser = parser.add_mutually_exclusive_group(required=False) 10 | io_parser.add_argument('--dataset', dest='dataset', 11 | help='Input dataset.') 12 | benchmark_parser = io_parser.add_argument_group() 13 | benchmark_parser.add_argument('--bmname', dest='bmname', 14 | help='Name of the benchmark dataset') 15 | io_parser.add_argument('--pkl', dest='pkl_fname', 16 | help='Name of the pkl data file') 17 | 18 | softpool_parser = parser.add_argument_group() 19 | softpool_parser.add_argument('--assign-ratio', dest='assign_ratio', type=float, 20 | help='ratio of number of nodes in consecutive layers') 21 | softpool_parser.add_argument('--num-pool', dest='num_pool', type=int, 22 | help='number of pooling layers') 23 | parser.add_argument('--linkpred', dest='linkpred', action='store_const', 24 | const=True, default=False, 25 | help='Whether link prediction side objective is used') 26 | 27 | parser_utils.parse_optimizer(parser) 28 | 29 | parser.add_argument('--datadir', dest='datadir', 30 | help='Directory where benchmark is located') 31 | parser.add_argument('--logdir', dest='logdir', 32 | help='Tensorboard log directory') 33 | parser.add_argument('--ckptdir', dest='ckptdir', 34 | help='Model checkpoint directory') 35 | parser.add_argument('--cuda', dest='cuda', 36 | help='CUDA.') 37 | parser.add_argument('--gpu', dest='gpu', action='store_const', 38 | const=True, default=False, 39 | help='whether to use GPU.') 40 | parser.add_argument('--max_nodes', dest='max_nodes', type=int, 41 | help='Maximum number of nodes (ignore graghs with nodes exceeding the number.') 42 | parser.add_argument('--batch_size', dest='batch_size', type=int, 43 | help='Batch size.') 44 | parser.add_argument('--epochs', dest='num_epochs', type=int, 45 | help='Number of epochs to train.') 46 | parser.add_argument('--train_ratio', dest='train_ratio', type=float, 47 | help='Ratio of number of graphs training set to all graphs.') 48 | parser.add_argument('--num_workers', dest='num_workers', type=int, 49 | help='Number of workers to load data.') 50 | parser.add_argument('--feature', dest='feature_type', 51 | help='Feature used for encoder. Can be: id, deg') 52 | parser.add_argument('--input_dim', dest='input_dim', type=int, 53 | help='Input feature dimension') 54 | parser.add_argument('--hidden_dim', dest='hidden_dim', type=int, 55 | help='Hidden dimension') 56 | parser.add_argument('--output_dim', dest='output_dim', type=int, 57 | help='Output dimension') 58 | parser.add_argument('--num_classes', dest='num_classes', type=int, 59 | help='Number of label classes') 60 | parser.add_argument('--num_gc_layers', dest='num_gc_layers', type=int, 61 | help='Number of graph convolution layers before each pooling') 62 | parser.add_argument('--bn', dest='bn', action='store_const', 63 | const=True, default=False, 64 | help='Whether batch normalization is used') 65 | parser.add_argument('--dropout', dest='dropout', type=float, 66 | help='Dropout rate.') 67 | parser.add_argument('--nobias', dest='bias', action='store_const', 68 | const=False, default=True, 69 | help='Whether to add bias. Default to True.') 70 | parser.add_argument('--weight_decay', dest='weight_decay', type=float, 71 | help='Weight decay regularization constant.') 72 | 73 | parser.add_argument('--method', dest='method', 74 | help='Method. Possible values: base, ') 75 | parser.add_argument('--name-suffix', dest='name_suffix', 76 | help='suffix added to the output filename') 77 | 78 | parser.set_defaults(datadir='data', # io_parser 79 | logdir='log', 80 | ckptdir='ckpt', 81 | dataset='syn1', 82 | opt='adam', # opt_parser 83 | opt_scheduler='none', 84 | max_nodes=100, 85 | cuda='1', 86 | feature_type='default', 87 | lr=0.001, 88 | clip=2.0, 89 | batch_size=20, 90 | num_epochs=1000, 91 | train_ratio=0.8, 92 | test_ratio=0.1, 93 | num_workers=1, 94 | input_dim=10, 95 | hidden_dim=20, 96 | output_dim=20, 97 | num_classes=2, 98 | num_gc_layers=3, 99 | dropout=0.0, 100 | weight_decay=0.005, 101 | method='base', 102 | name_suffix='', 103 | assign_ratio=0.1, 104 | ) 105 | return parser.parse_args() 106 | 107 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | # Created by https://www.toptal.com/developers/gitignore/api/python,jupyternotebooks,pycharm,macos 3 | # Edit at https://www.toptal.com/developers/gitignore?templates=python,jupyternotebooks,pycharm,macos 4 | 5 | ### JupyterNotebooks ### 6 | # gitignore template for Jupyter Notebooks 7 | # website: http://jupyter.org/ 8 | 9 | .ipynb_checkpoints 10 | */.ipynb_checkpoints/* 11 | 12 | # IPython 13 | profile_default/ 14 | ipython_config.py 15 | 16 | # Remove previous ipynb_checkpoints 17 | # git rm -r .ipynb_checkpoints/ 18 | 19 | ### macOS ### 20 | # General 21 | .DS_Store 22 | .AppleDouble 23 | .LSOverride 24 | 25 | # Icon must end with two \r 26 | Icon 27 | 28 | # Thumbnails 29 | ._* 30 | 31 | # Files that might appear in the root of a volume 32 | .DocumentRevisions-V100 33 | .fseventsd 34 | .Spotlight-V100 35 | .TemporaryItems 36 | .Trashes 37 | .VolumeIcon.icns 38 | .com.apple.timemachine.donotpresent 39 | 40 | # Directories potentially created on remote AFP share 41 | .AppleDB 42 | .AppleDesktop 43 | Network Trash Folder 44 | Temporary Items 45 | .apdisk 46 | 47 | ### PyCharm ### 48 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider 49 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 50 | 51 | # User-specific stuff 52 | .idea/**/workspace.xml 53 | .idea/**/tasks.xml 54 | .idea/**/usage.statistics.xml 55 | .idea/**/dictionaries 56 | .idea/**/shelf 57 | 58 | # Generated files 59 | .idea/**/contentModel.xml 60 | 61 | # Sensitive or high-churn files 62 | .idea/**/dataSources/ 63 | .idea/**/dataSources.ids 64 | .idea/**/dataSources.local.xml 65 | .idea/**/sqlDataSources.xml 66 | .idea/**/dynamic.xml 67 | .idea/**/uiDesigner.xml 68 | .idea/**/dbnavigator.xml 69 | 70 | # Gradle 71 | .idea/**/gradle.xml 72 | .idea/**/libraries 73 | 74 | # Gradle and Maven with auto-import 75 | # When using Gradle or Maven with auto-import, you should exclude module files, 76 | # since they will be recreated, and may cause churn. Uncomment if using 77 | # auto-import. 78 | # .idea/artifacts 79 | # .idea/compiler.xml 80 | # .idea/jarRepositories.xml 81 | # .idea/modules.xml 82 | # .idea/*.iml 83 | # .idea/modules 84 | # *.iml 85 | # *.ipr 86 | 87 | # CMake 88 | cmake-build-*/ 89 | 90 | # Mongo Explorer plugin 91 | .idea/**/mongoSettings.xml 92 | 93 | # File-based project format 94 | *.iws 95 | 96 | # IntelliJ 97 | out/ 98 | 99 | # mpeltonen/sbt-idea plugin 100 | .idea_modules/ 101 | 102 | # JIRA plugin 103 | atlassian-ide-plugin.xml 104 | 105 | # Cursive Clojure plugin 106 | .idea/replstate.xml 107 | 108 | # Crashlytics plugin (for Android Studio and IntelliJ) 109 | com_crashlytics_export_strings.xml 110 | crashlytics.properties 111 | crashlytics-build.properties 112 | fabric.properties 113 | 114 | # Editor-based Rest Client 115 | .idea/httpRequests 116 | 117 | # Android studio 3.1+ serialized cache file 118 | .idea/caches/build_file_checksums.ser 119 | 120 | ### PyCharm Patch ### 121 | # Comment Reason: https://github.com/joeblau/gitignore.io/issues/186#issuecomment-215987721 122 | 123 | # *.iml 124 | # modules.xml 125 | # .idea/misc.xml 126 | # *.ipr 127 | 128 | # Sonarlint plugin 129 | .idea/**/sonarlint/ 130 | 131 | # SonarQube Plugin 132 | .idea/**/sonarIssues.xml 133 | 134 | # Markdown Navigator plugin 135 | .idea/**/markdown-navigator.xml 136 | .idea/**/markdown-navigator-enh.xml 137 | .idea/**/markdown-navigator/ 138 | 139 | # Cache file creation bug 140 | # See https://youtrack.jetbrains.com/issue/JBR-2257 141 | .idea/$CACHE_FILE$ 142 | 143 | ### Python ### 144 | # Byte-compiled / optimized / DLL files 145 | __pycache__/ 146 | *.py[cod] 147 | *$py.class 148 | 149 | # C extensions 150 | *.so 151 | 152 | # Distribution / packaging 153 | .Python 154 | build/ 155 | develop-eggs/ 156 | dist/ 157 | downloads/ 158 | eggs/ 159 | .eggs/ 160 | lib/ 161 | lib64/ 162 | parts/ 163 | sdist/ 164 | var/ 165 | wheels/ 166 | pip-wheel-metadata/ 167 | share/python-wheels/ 168 | *.egg-info/ 169 | .installed.cfg 170 | *.egg 171 | MANIFEST 172 | 173 | # PyInstaller 174 | # Usually these files are written by a python script from a template 175 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 176 | *.manifest 177 | *.spec 178 | 179 | # Installer logs 180 | pip-log.txt 181 | pip-delete-this-directory.txt 182 | 183 | # Unit test / coverage reports 184 | htmlcov/ 185 | .tox/ 186 | .nox/ 187 | .coverage 188 | .coverage.* 189 | .cache 190 | nosetests.xml 191 | coverage.xml 192 | *.cover 193 | *.py,cover 194 | .hypothesis/ 195 | .pytest_cache/ 196 | 197 | # Translations 198 | *.mo 199 | *.pot 200 | 201 | # Django stuff: 202 | *.log 203 | local_settings.py 204 | db.sqlite3 205 | db.sqlite3-journal 206 | 207 | # Flask stuff: 208 | instance/ 209 | .webassets-cache 210 | 211 | # Scrapy stuff: 212 | .scrapy 213 | 214 | # Sphinx documentation 215 | docs/_build/ 216 | 217 | # PyBuilder 218 | target/ 219 | 220 | # Jupyter Notebook 221 | 222 | # IPython 223 | 224 | # pyenv 225 | .python-version 226 | 227 | # pipenv 228 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 229 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 230 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 231 | # install all needed dependencies. 232 | #Pipfile.lock 233 | 234 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 235 | __pypackages__/ 236 | 237 | # Celery stuff 238 | celerybeat-schedule 239 | celerybeat.pid 240 | 241 | # SageMath parsed files 242 | *.sage.py 243 | 244 | # Environments 245 | .env 246 | .venv 247 | env/ 248 | venv/ 249 | ENV/ 250 | env.bak/ 251 | venv.bak/ 252 | 253 | # Spyder project settings 254 | .spyderproject 255 | .spyproject 256 | 257 | # Rope project settings 258 | .ropeproject 259 | 260 | # mkdocs documentation 261 | /site 262 | 263 | # mypy 264 | .mypy_cache/ 265 | .dmypy.json 266 | dmypy.json 267 | 268 | # Pyre type checker 269 | .pyre/ 270 | 271 | # pytype static type analyzer 272 | .pytype/ 273 | 274 | # End of https://www.toptal.com/developers/gitignore/api/python,jupyternotebooks,pycharm,macos 275 | -------------------------------------------------------------------------------- /qualitative/gin.py: -------------------------------------------------------------------------------- 1 | """ 2 | How Powerful are Graph Neural Networks 3 | https://arxiv.org/abs/1810.00826 4 | https://openreview.net/forum?id=ryGs6iA5Km 5 | Author's implementation: https://github.com/weihua916/powerful-gnns 6 | """ 7 | 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | from dgl.nn.pytorch.conv import GINConv 13 | from dgl.nn.pytorch.glob import SumPooling, AvgPooling, MaxPooling 14 | 15 | 16 | class ApplyNodeFunc(nn.Module): 17 | """Update the node feature hv with MLP, BN and ReLU.""" 18 | def __init__(self, mlp): 19 | super(ApplyNodeFunc, self).__init__() 20 | self.mlp = mlp 21 | self.bn = nn.BatchNorm1d(self.mlp.output_dim) 22 | 23 | def forward(self, h): 24 | h = self.mlp(h) 25 | h = self.bn(h) 26 | h = F.relu(h) 27 | return h 28 | 29 | 30 | class MLP(nn.Module): 31 | """MLP with linear output""" 32 | def __init__(self, num_layers, input_dim, hidden_dim, output_dim): 33 | """MLP layers construction 34 | Paramters 35 | --------- 36 | num_layers: int 37 | The number of linear layers 38 | input_dim: int 39 | The dimensionality of input features 40 | hidden_dim: int 41 | The dimensionality of hidden units at ALL layers 42 | output_dim: int 43 | The number of classes for prediction 44 | """ 45 | super(MLP, self).__init__() 46 | self.linear_or_not = True # default is linear model 47 | self.num_layers = num_layers 48 | self.output_dim = output_dim 49 | 50 | if num_layers < 1: 51 | raise ValueError("number of layers should be positive!") 52 | elif num_layers == 1: 53 | # Linear model 54 | self.linear = nn.Linear(input_dim, output_dim) 55 | else: 56 | # Multi-layer model 57 | self.linear_or_not = False 58 | self.linears = torch.nn.ModuleList() 59 | self.batch_norms = torch.nn.ModuleList() 60 | 61 | self.linears.append(nn.Linear(input_dim, hidden_dim)) 62 | for layer in range(num_layers - 2): 63 | self.linears.append(nn.Linear(hidden_dim, hidden_dim)) 64 | self.linears.append(nn.Linear(hidden_dim, output_dim)) 65 | 66 | for layer in range(num_layers - 1): 67 | self.batch_norms.append(nn.BatchNorm1d((hidden_dim))) 68 | 69 | def forward(self, x): 70 | if self.linear_or_not: 71 | # If linear model 72 | return self.linear(x) 73 | else: 74 | # If MLP 75 | h = x 76 | for i in range(self.num_layers - 1): 77 | h = F.relu(self.batch_norms[i](self.linears[i](h))) 78 | return self.linears[-1](h) 79 | 80 | 81 | class GIN(nn.Module): 82 | """GIN model""" 83 | def __init__(self, num_layers, num_mlp_layers, input_dim, hidden_dim, 84 | output_dim, final_dropout, learn_eps, graph_pooling_type, 85 | neighbor_pooling_type): 86 | """model parameters setting 87 | Paramters 88 | --------- 89 | num_layers: int 90 | The number of linear layers in the neural network 91 | num_mlp_layers: int 92 | The number of linear layers in mlps 93 | input_dim: int 94 | The dimensionality of input features 95 | hidden_dim: int 96 | The dimensionality of hidden units at ALL layers 97 | output_dim: int 98 | The number of classes for prediction 99 | final_dropout: float 100 | dropout ratio on the final linear layer 101 | learn_eps: boolean 102 | If True, learn epsilon to distinguish center nodes from neighbors 103 | If False, aggregate neighbors and center nodes altogether. 104 | neighbor_pooling_type: str 105 | how to aggregate neighbors (sum, mean, or max) 106 | graph_pooling_type: str 107 | how to aggregate entire nodes in a graph (sum, mean or max) 108 | """ 109 | super(GIN, self).__init__() 110 | self.num_layers = num_layers 111 | self.learn_eps = learn_eps 112 | 113 | # List of MLPs 114 | self.ginlayers = torch.nn.ModuleList() 115 | self.batch_norms = torch.nn.ModuleList() 116 | 117 | for layer in range(self.num_layers - 1): 118 | if layer == 0: 119 | mlp = MLP(num_mlp_layers, input_dim, hidden_dim, hidden_dim) 120 | else: 121 | mlp = MLP(num_mlp_layers, hidden_dim, hidden_dim, hidden_dim) 122 | 123 | self.ginlayers.append( 124 | GINConv(ApplyNodeFunc(mlp), neighbor_pooling_type, 0, self.learn_eps)) 125 | self.batch_norms.append(nn.BatchNorm1d(hidden_dim)) 126 | 127 | # Linear function for graph poolings of output of each layer 128 | # which maps the output of different layers into a prediction score 129 | self.linears_prediction = torch.nn.ModuleList() 130 | 131 | for layer in range(num_layers): 132 | if layer == 0: 133 | self.linears_prediction.append( 134 | nn.Linear(input_dim, output_dim)) 135 | else: 136 | self.linears_prediction.append( 137 | nn.Linear(hidden_dim, output_dim)) 138 | 139 | self.drop = nn.Dropout(final_dropout) 140 | 141 | if graph_pooling_type == 'sum': 142 | self.pool = SumPooling() 143 | elif graph_pooling_type == 'mean': 144 | self.pool = AvgPooling() 145 | elif graph_pooling_type == 'max': 146 | self.pool = MaxPooling() 147 | else: 148 | raise NotImplementedError 149 | 150 | def forward(self, g, h): 151 | # list of hidden representation at each layer (including input) 152 | hidden_rep = [h] 153 | 154 | for i in range(self.num_layers - 1): 155 | h = self.ginlayers[i](g, h) 156 | h = self.batch_norms[i](h) 157 | h = F.relu(h) 158 | hidden_rep.append(h) 159 | 160 | score_over_layer = 0 161 | 162 | # perform pooling over all nodes in each graph in every layer 163 | for i, h in enumerate(hidden_rep): 164 | pooled_h = self.pool(g, h) 165 | score_over_layer += self.drop(self.linears_prediction[i](pooled_h)) 166 | 167 | return score_over_layer 168 | -------------------------------------------------------------------------------- /gnnexplainer_utils/graph_utils.py: -------------------------------------------------------------------------------- 1 | """graph_utils.py 2 | 3 | Utility for sampling graphs from a dataset. 4 | """ 5 | import networkx as nx 6 | import numpy as np 7 | import torch 8 | import torch.utils.data 9 | 10 | 11 | class GraphSampler(torch.utils.data.Dataset): 12 | """ Sample graphs and nodes in graph 13 | """ 14 | 15 | def __init__( 16 | self, 17 | G_list, 18 | features="default", 19 | normalize=True, 20 | assign_feat="default", 21 | max_num_nodes=0, 22 | ): 23 | self.adj_all = [] 24 | self.len_all = [] 25 | self.feature_all = [] 26 | self.label_all = [] 27 | self.idx_all = [] 28 | 29 | self.assign_feat_all = [] 30 | 31 | if max_num_nodes == 0: 32 | self.max_num_nodes = max([G.number_of_nodes() for G in G_list]) 33 | else: 34 | self.max_num_nodes = max_num_nodes 35 | 36 | existing_node = list(G_list[0].nodes())[-1] 37 | self.feat_dim = G_list[0].nodes[existing_node]["feat"].shape[0] 38 | 39 | for G in G_list: 40 | adj = np.array(nx.to_numpy_matrix(G, nodelist=sorted(G.nodes))) 41 | if normalize: 42 | sqrt_deg = np.diag( 43 | 1.0 / np.sqrt(np.sum(adj, axis=0, dtype=float).squeeze()) 44 | ) 45 | adj = np.matmul(np.matmul(sqrt_deg, adj), sqrt_deg) 46 | self.adj_all.append(adj) 47 | self.len_all.append(G.number_of_nodes()) 48 | self.label_all.append(G.graph["label"]) 49 | self.idx_all.append(G.graph["idx"]) 50 | # feat matrix: max_num_nodes x feat_dim 51 | if features == "default": 52 | f = np.zeros((self.max_num_nodes, self.feat_dim), dtype=float) 53 | for i, u in enumerate(G.nodes()): 54 | f[i, :] = G.nodes[u]["feat"] 55 | self.feature_all.append(f) 56 | elif features == "id": 57 | self.feature_all.append(np.identity(self.max_num_nodes)) 58 | elif features == "deg-num": 59 | degs = np.sum(np.array(adj), 1) 60 | degs = np.expand_dims( 61 | np.pad(degs, [0, self.max_num_nodes - G.number_of_nodes()], 0), 62 | axis=1, 63 | ) 64 | self.feature_all.append(degs) 65 | elif features == "deg": 66 | self.max_deg = 10 67 | degs = np.sum(np.array(adj), 1).astype(int) 68 | degs[degs > self.max_deg] = self.max_deg 69 | feat = np.zeros((len(degs), self.max_deg + 1)) 70 | feat[np.arange(len(degs)), degs] = 1 71 | feat = np.pad( 72 | feat, 73 | ((0, self.max_num_nodes - G.number_of_nodes()), (0, 0)), 74 | "constant", 75 | constant_values=0, 76 | ) 77 | 78 | f = np.zeros((self.max_num_nodes, self.feat_dim), dtype=float) 79 | for i, u in enumerate(G.nodes()): 80 | f[i, :] = G.nodes[u]["feat"] 81 | 82 | feat = np.concatenate((feat, f), axis=1) 83 | 84 | self.feature_all.append(feat) 85 | elif features == "struct": 86 | self.max_deg = 10 87 | degs = np.sum(np.array(adj), 1).astype(int) 88 | degs[degs > 10] = 10 89 | feat = np.zeros((len(degs), self.max_deg + 1)) 90 | feat[np.arange(len(degs)), degs] = 1 91 | degs = np.pad( 92 | feat, 93 | ((0, self.max_num_nodes - G.number_of_nodes()), (0, 0)), 94 | "constant", 95 | constant_values=0, 96 | ) 97 | 98 | clusterings = np.array(list(nx.clustering(G).values())) 99 | clusterings = np.expand_dims( 100 | np.pad( 101 | clusterings, 102 | [0, self.max_num_nodes - G.number_of_nodes()], 103 | "constant", 104 | ), 105 | axis=1, 106 | ) 107 | g_feat = np.hstack([degs, clusterings]) 108 | if "feat" in G.nodes[0]: 109 | node_feats = np.array( 110 | [G.nodes[i]["feat"] for i in range(G.number_of_nodes())] 111 | ) 112 | node_feats = np.pad( 113 | node_feats, 114 | ((0, self.max_num_nodes - G.number_of_nodes()), (0, 0)), 115 | "constant", 116 | ) 117 | g_feat = np.hstack([g_feat, node_feats]) 118 | 119 | self.feature_all.append(g_feat) 120 | 121 | if assign_feat == "id": 122 | self.assign_feat_all.append( 123 | np.hstack((np.identity(self.max_num_nodes), self.feature_all[-1])) 124 | ) 125 | else: 126 | self.assign_feat_all.append(self.feature_all[-1]) 127 | 128 | self.feat_dim = self.feature_all[0].shape[1] 129 | self.assign_feat_dim = self.assign_feat_all[0].shape[1] 130 | 131 | def __len__(self): 132 | return len(self.adj_all) 133 | 134 | def __getitem__(self, idx): 135 | adj = self.adj_all[idx] 136 | num_nodes = adj.shape[0] 137 | adj_padded = np.zeros((self.max_num_nodes, self.max_num_nodes)) 138 | adj_padded[:num_nodes, :num_nodes] = adj 139 | 140 | # use all nodes for aggregation (baseline) 141 | return { 142 | "adj": adj_padded, 143 | "feats": self.feature_all[idx].copy(), 144 | "label": self.label_all[idx], 145 | "num_nodes": num_nodes, 146 | "assign_feats": self.assign_feat_all[idx].copy(), 147 | "idx": self.idx_all[idx] 148 | } 149 | 150 | def neighborhoods(adj, n_hops, use_cuda): 151 | """Returns the n_hops degree adjacency matrix adj.""" 152 | adj = torch.tensor(adj, dtype=torch.float) 153 | if use_cuda: 154 | adj = adj.cuda() 155 | hop_adj = power_adj = adj 156 | for i in range(n_hops - 1): 157 | power_adj = power_adj @ adj 158 | prev_hop_adj = hop_adj 159 | hop_adj = hop_adj + power_adj 160 | hop_adj = (hop_adj > 0).float() 161 | return hop_adj.cpu().numpy().astype(int) -------------------------------------------------------------------------------- /gnnexplainer_gengraph.py: -------------------------------------------------------------------------------- 1 | """gengraph.py 2 | 3 | Generating and manipulaton the synthetic graphs needed for the paper's experiments. 4 | """ 5 | 6 | import os 7 | 8 | from matplotlib import pyplot as plt 9 | from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas 10 | from matplotlib.figure import Figure 11 | import matplotlib.colors as colors 12 | 13 | # Set matplotlib backend to file writing 14 | plt.switch_backend("agg") 15 | 16 | import networkx as nx 17 | 18 | import numpy as np 19 | 20 | from tensorboardX import SummaryWriter 21 | 22 | from gnnexplainer_utils import synthetic_structsim 23 | from gnnexplainer_utils import featgen 24 | import gnnexplainer_utils.io_utils as io_utils 25 | 26 | 27 | #################################### 28 | # 29 | # Experiment utilities 30 | # 31 | #################################### 32 | def perturb(graph_list, p): 33 | """ Perturb the list of (sparse) graphs by adding/removing edges. 34 | Args: 35 | p: proportion of added edges based on current number of edges. 36 | Returns: 37 | A list of graphs that are perturbed from the original graphs. 38 | """ 39 | perturbed_graph_list = [] 40 | for G_original in graph_list: 41 | G = G_original.copy() 42 | edge_count = int(G.number_of_edges() * p) 43 | # randomly add the edges between a pair of nodes without an edge. 44 | for _ in range(edge_count): 45 | while True: 46 | u = np.random.randint(0, G.number_of_nodes()) 47 | v = np.random.randint(0, G.number_of_nodes()) 48 | if (not G.has_edge(u, v)) and (u != v): 49 | break 50 | G.add_edge(u, v) 51 | perturbed_graph_list.append(G) 52 | return perturbed_graph_list 53 | 54 | 55 | def join_graph(G1, G2, n_pert_edges): 56 | """ Join two graphs along matching nodes, then perturb the resulting graph. 57 | Args: 58 | G1, G2: Networkx graphs to be joined. 59 | n_pert_edges: number of perturbed edges. 60 | Returns: 61 | A new graph, result of merging and perturbing G1 and G2. 62 | """ 63 | assert n_pert_edges > 0 64 | F = nx.compose(G1, G2) 65 | edge_cnt = 0 66 | while edge_cnt < n_pert_edges: 67 | node_1 = np.random.choice(G1.nodes()) 68 | node_2 = np.random.choice(G2.nodes()) 69 | F.add_edge(node_1, node_2) 70 | edge_cnt += 1 71 | return F 72 | 73 | 74 | def preprocess_input_graph(G, labels, normalize_adj=False): 75 | """ Load an existing graph to be converted for the experiments. 76 | Args: 77 | G: Networkx graph to be loaded. 78 | labels: Associated node labels. 79 | normalize_adj: Should the method return a normalized adjacency matrix. 80 | Returns: 81 | A dictionary containing adjacency, node features and labels 82 | """ 83 | adj = np.array(nx.to_numpy_matrix(G)) 84 | if normalize_adj: 85 | sqrt_deg = np.diag(1.0 / np.sqrt(np.sum(adj, axis=0, dtype=float).squeeze())) 86 | adj = np.matmul(np.matmul(sqrt_deg, adj), sqrt_deg) 87 | 88 | existing_node = list(G.nodes)[-1] 89 | feat_dim = G.nodes[existing_node]["feat"].shape[0] 90 | f = np.zeros((G.number_of_nodes(), feat_dim), dtype=float) 91 | for i, u in enumerate(G.nodes()): 92 | f[i, :] = G.nodes[u]["feat"] 93 | 94 | # add batch dim 95 | adj = np.expand_dims(adj, axis=0) 96 | f = np.expand_dims(f, axis=0) 97 | labels = np.expand_dims(labels, axis=0) 98 | return {"adj": adj, "feat": f, "labels": labels} 99 | 100 | 101 | #################################### 102 | # 103 | # Generating synthetic graphs 104 | # 105 | ################################### 106 | def gen_syn1(nb_shapes=80, width_basis=300, feature_generator=None, m=5): 107 | """ Synthetic Graph #1: 108 | 109 | Start with Barabasi-Albert graph and attach house-shaped subgraphs. 110 | 111 | Args: 112 | nb_shapes : The number of shapes (here 'houses') that should be added to the base graph. 113 | width_basis : The width of the basis graph (here 'Barabasi-Albert' random graph). 114 | feature_generator : A `FeatureGenerator` for node features. If `None`, add constant features to nodes. 115 | m : number of edges to attach to existing node (for BA graph) 116 | 117 | Returns: 118 | G : A networkx graph 119 | role_id : A list with length equal to number of nodes in the entire graph (basis 120 | : + shapes). role_id[i] is the ID of the role of node i. It is the label. 121 | name : A graph identifier 122 | """ 123 | basis_type = "ba" 124 | list_shapes = [["house"]] * nb_shapes 125 | 126 | plt.figure(figsize=(8, 6), dpi=300) 127 | 128 | G, role_id, _ = synthetic_structsim.build_graph( 129 | width_basis, basis_type, list_shapes, start=0, m=5 130 | ) 131 | G = perturb([G], 0.01)[0] 132 | 133 | if feature_generator is None: 134 | feature_generator = featgen.ConstFeatureGen(1) 135 | feature_generator.gen_node_features(G) 136 | 137 | name = basis_type + "_" + str(width_basis) + "_" + str(nb_shapes) 138 | return G, role_id, name 139 | 140 | 141 | def gen_syn2(nb_shapes=100, width_basis=350): 142 | """ Synthetic Graph #2: 143 | 144 | Start with Barabasi-Albert graph and add node features indicative of a community label. 145 | 146 | Args: 147 | nb_shapes : The number of shapes (here 'houses') that should be added to the base graph. 148 | width_basis : The width of the basis graph (here 'Barabasi-Albert' random graph). 149 | 150 | Returns: 151 | G : A networkx graph 152 | label : Label of the nodes (determined by role_id and community) 153 | name : A graph identifier 154 | """ 155 | basis_type = "ba" 156 | 157 | random_mu = [0.0] * 8 158 | random_sigma = [1.0] * 8 159 | 160 | # Create two grids 161 | mu_1, sigma_1 = np.array([-1.0] * 2 + random_mu), np.array([0.5] * 2 + random_sigma) 162 | mu_2, sigma_2 = np.array([1.0] * 2 + random_mu), np.array([0.5] * 2 + random_sigma) 163 | feat_gen_G1 = featgen.GaussianFeatureGen(mu=mu_1, sigma=sigma_1) 164 | feat_gen_G2 = featgen.GaussianFeatureGen(mu=mu_2, sigma=sigma_2) 165 | G1, role_id1, name = gen_syn1(feature_generator=feat_gen_G1, m=4) 166 | G2, role_id2, name = gen_syn1(feature_generator=feat_gen_G2, m=4) 167 | G1_size = G1.number_of_nodes() 168 | num_roles = max(role_id1) + 1 169 | role_id2 = [r + num_roles for r in role_id2] 170 | label = role_id1 + role_id2 171 | 172 | # Edit node ids to avoid collisions on join 173 | g1_map = {n: i for i, n in enumerate(G1.nodes())} 174 | G1 = nx.relabel_nodes(G1, g1_map) 175 | g2_map = {n: i + G1_size for i, n in enumerate(G2.nodes())} 176 | G2 = nx.relabel_nodes(G2, g2_map) 177 | 178 | # Join 179 | n_pert_edges = width_basis 180 | G = join_graph(G1, G2, n_pert_edges) 181 | 182 | name = basis_type + "_" + str(width_basis) + "_" + str(nb_shapes) + "_2comm" 183 | 184 | return G, label, name 185 | 186 | 187 | def gen_syn3(nb_shapes=80, width_basis=300, feature_generator=None, m=5): 188 | """ Synthetic Graph #3: 189 | 190 | Start with Barabasi-Albert graph and attach grid-shaped subgraphs. 191 | 192 | Args: 193 | nb_shapes : The number of shapes (here 'grid') that should be added to the base graph. 194 | width_basis : The width of the basis graph (here 'Barabasi-Albert' random graph). 195 | feature_generator : A `FeatureGenerator` for node features. If `None`, add constant features to nodes. 196 | m : number of edges to attach to existing node (for BA graph) 197 | 198 | Returns: 199 | G : A networkx graph 200 | role_id : Role ID for each node in synthetic graph. 201 | name : A graph identifier 202 | """ 203 | basis_type = "ba" 204 | list_shapes = [["grid", 3]] * nb_shapes 205 | 206 | plt.figure(figsize=(8, 6), dpi=300) 207 | 208 | G, role_id, _ = synthetic_structsim.build_graph( 209 | width_basis, basis_type, list_shapes, start=0, m=5 210 | ) 211 | G = perturb([G], 0.01)[0] 212 | 213 | if feature_generator is None: 214 | feature_generator = featgen.ConstFeatureGen(1) 215 | feature_generator.gen_node_features(G) 216 | 217 | name = basis_type + "_" + str(width_basis) + "_" + str(nb_shapes) 218 | return G, role_id, name 219 | 220 | 221 | def gen_syn4(nb_shapes=60, width_basis=8, feature_generator=None, m=4): 222 | """ Synthetic Graph #4: 223 | 224 | Start with a tree and attach cycle-shaped subgraphs. 225 | 226 | Args: 227 | nb_shapes : The number of shapes (here 'houses') that should be added to the base graph. 228 | width_basis : The width of the basis graph (here a random 'Tree'). 229 | feature_generator : A `FeatureGenerator` for node features. If `None`, add constant features to nodes. 230 | m : The tree depth. 231 | 232 | Returns: 233 | G : A networkx graph 234 | role_id : Role ID for each node in synthetic graph 235 | name : A graph identifier 236 | """ 237 | basis_type = "tree" 238 | list_shapes = [["cycle", 6]] * nb_shapes 239 | 240 | fig = plt.figure(figsize=(8, 6), dpi=300) 241 | 242 | G, role_id, plugins = synthetic_structsim.build_graph( 243 | width_basis, basis_type, list_shapes, start=0 244 | ) 245 | G = perturb([G], 0.01)[0] 246 | 247 | if feature_generator is None: 248 | feature_generator = featgen.ConstFeatureGen(1) 249 | feature_generator.gen_node_features(G) 250 | 251 | name = basis_type + "_" + str(width_basis) + "_" + str(nb_shapes) 252 | 253 | path = os.path.join("log/syn4_base_h20_o20") 254 | writer = SummaryWriter(path) 255 | io_utils.log_graph(writer, G, "graph/full") 256 | 257 | return G, role_id, name 258 | 259 | 260 | def gen_syn5(nb_shapes=80, width_basis=8, feature_generator=None, m=3): 261 | """ Synthetic Graph #5: 262 | 263 | Start with a tree and attach grid-shaped subgraphs. 264 | 265 | Args: 266 | nb_shapes : The number of shapes (here 'houses') that should be added to the base graph. 267 | width_basis : The width of the basis graph (here a random 'grid'). 268 | feature_generator : A `FeatureGenerator` for node features. If `None`, add constant features to nodes. 269 | m : The tree depth. 270 | 271 | Returns: 272 | G : A networkx graph 273 | role_id : Role ID for each node in synthetic graph 274 | name : A graph identifier 275 | """ 276 | basis_type = "tree" 277 | list_shapes = [["grid", m]] * nb_shapes 278 | 279 | plt.figure(figsize=(8, 6), dpi=300) 280 | 281 | G, role_id, _ = synthetic_structsim.build_graph( 282 | width_basis, basis_type, list_shapes, start=0 283 | ) 284 | G = perturb([G], 0.1)[0] 285 | 286 | if feature_generator is None: 287 | feature_generator = featgen.ConstFeatureGen(1) 288 | feature_generator.gen_node_features(G) 289 | 290 | name = basis_type + "_" + str(width_basis) + "_" + str(nb_shapes) 291 | 292 | path = os.path.join("log/syn5_base_h20_o20") 293 | writer = SummaryWriter(path) 294 | 295 | return G, role_id, name 296 | -------------------------------------------------------------------------------- /gen_dataset.py: -------------------------------------------------------------------------------- 1 | import random 2 | from enum import Enum 3 | from pathlib import Path 4 | 5 | import networkx as nx 6 | import typer 7 | from tqdm import tqdm as tq 8 | 9 | 10 | class Dataset(str, Enum): 11 | CYCLIQ = 'CYCLIQ' 12 | CYCLIQ_MULTI = 'CYCLIQ-MULTI' 13 | TRISQ = 'TRISQ' 14 | HOUSE_CLIQ = 'HOUSE_CLIQ' 15 | GRID_CLIQ = 'GRID_CLIQ' 16 | HOUSE_GRID = "HOUSE_GRID" 17 | 18 | 19 | def random_tree(n): 20 | g = nx.generators.trees.random_tree(n) 21 | for i in range(n): 22 | g.nodes[i]['label'] = 0 23 | return g 24 | 25 | 26 | def attach_cycle(g, cycle_len, label, is_clique): 27 | N = len(g.nodes()) 28 | host_cands = [k for k, v in g.nodes(data=True) if v['label'] == 0] 29 | host_node = random.choice(host_cands) 30 | neighbors = list(g.neighbors(host_node)) 31 | for u in neighbors: 32 | g.remove_edge(u, host_node) 33 | 34 | # add the cycle 35 | cycle_nodes = [host_node] 36 | for i in range(cycle_len - 1): 37 | g.add_edge(cycle_nodes[-1], N + i) 38 | cycle_nodes.append(N + i) 39 | g.add_edge(host_node, cycle_nodes[-1]) 40 | 41 | if is_clique: 42 | for u in cycle_nodes: 43 | for v in cycle_nodes: 44 | if u != v: 45 | g.add_edge(u, v) 46 | 47 | for u in cycle_nodes: 48 | g.nodes[u]['label'] = label 49 | 50 | # restore host_node edges 51 | for u in neighbors: 52 | v = random.choice(cycle_nodes) 53 | g.add_edge(u, v) 54 | return g 55 | 56 | 57 | def attach_grid(g, label): 58 | N = len(g.nodes()) 59 | host_cands = [k for k, v in g.nodes(data=True) if v['label'] == 0] 60 | host_node = random.choice(host_cands) 61 | neighbors = list(g.neighbors(host_node)) 62 | for u in neighbors: 63 | g.remove_edge(u, host_node) 64 | # 0 - 1 - 2 65 | # | | | 66 | # 3 - 4 - 5 67 | # | | | 68 | # 6 - 7 - 8 69 | grid_nodes = [N + i for i in range(8)] 70 | # assign which type of node the host node would be 71 | grid_nodes.insert(random.randint(0, len(grid_nodes)), host_node) 72 | 73 | g.add_edge(grid_nodes[0], grid_nodes[1]) 74 | g.add_edge(grid_nodes[0], grid_nodes[3]) 75 | g.add_edge(grid_nodes[1], grid_nodes[2]) 76 | g.add_edge(grid_nodes[1], grid_nodes[4]) 77 | g.add_edge(grid_nodes[2], grid_nodes[5]) 78 | g.add_edge(grid_nodes[3], grid_nodes[4]) 79 | g.add_edge(grid_nodes[3], grid_nodes[6]) 80 | g.add_edge(grid_nodes[4], grid_nodes[5]) 81 | g.add_edge(grid_nodes[4], grid_nodes[7]) 82 | g.add_edge(grid_nodes[5], grid_nodes[8]) 83 | g.add_edge(grid_nodes[6], grid_nodes[7]) 84 | g.add_edge(grid_nodes[7], grid_nodes[8]) 85 | 86 | for u in grid_nodes: 87 | g.nodes[u]['label'] = label 88 | 89 | # restore host_node edges 90 | for u in neighbors: 91 | v = random.choice(grid_nodes) 92 | g.add_edge(u, v) 93 | return g 94 | 95 | def attach_house(g, label): 96 | N = len(g.nodes()) 97 | host_cands = [k for k, v in g.nodes(data=True) if v['label'] == 0] 98 | host_node = random.choice(host_cands) 99 | neighbors = list(g.neighbors(host_node)) 100 | for u in neighbors: 101 | g.remove_edge(u, host_node) 102 | # 4 103 | # / \ 104 | # 2---3 105 | # | | 106 | # 0---1 107 | house_nodes = [N + 0, N + 1, N + 2, N + 3] 108 | # assign which type of node the host node would be 109 | house_nodes.insert(random.randint(0, 3), host_node) 110 | 111 | g.add_edge(house_nodes[0], house_nodes[1]) 112 | g.add_edge(house_nodes[0], house_nodes[2]) 113 | g.add_edge(house_nodes[1], house_nodes[3]) 114 | g.add_edge(house_nodes[2], house_nodes[3]) 115 | g.add_edge(house_nodes[2], house_nodes[4]) 116 | g.add_edge(house_nodes[3], house_nodes[4]) 117 | 118 | for u in house_nodes: 119 | g.nodes[u]['label'] = label 120 | 121 | # restore host_node edges 122 | for u in neighbors: 123 | v = random.choice(house_nodes) 124 | g.add_edge(u, v) 125 | return g 126 | 127 | 128 | def attach_cycles(g, cycle_len, count, is_clique=False): 129 | for i in range(count): 130 | attach_cycle(g, cycle_len, '%d-%d-%d' % (cycle_len, is_clique, i), is_clique) 131 | return g 132 | 133 | 134 | def add_to_list(graph_list, g, label): 135 | graph_num = len(graph_list) + 1 136 | for u in g.nodes(): 137 | g.nodes()[u]['graph_num'] = graph_num 138 | g.graph['graph_num'] = graph_num 139 | graph_list.append((g, label)) 140 | 141 | 142 | def house_cliq(sample_size): 143 | all_graphs = [] 144 | label = 0 145 | random.seed(1) 146 | for i in range(sample_size): 147 | g = random_tree(random.randint(8, 15)) 148 | count = random.randint(1, 2) 149 | for i in range(count): 150 | attach_house(g, 'h-%d' % i) 151 | add_to_list(all_graphs, g, label) 152 | label += 1 153 | random.seed(2) 154 | for i in range(sample_size): 155 | g = random_tree(random.randint(8, 15)) 156 | count = random.randint(1, 2) 157 | attach_cycles(g, cycle_len=5, count=count, is_clique=True) 158 | add_to_list(all_graphs, g, label) 159 | return all_graphs 160 | 161 | 162 | def grid_cliq(sample_size): 163 | all_graphs = [] 164 | label = 0 165 | random.seed(1) 166 | for i in range(sample_size): 167 | g = random_tree(random.randint(8, 15)) 168 | count = random.randint(1, 2) 169 | for i in range(count): 170 | attach_grid(g, 'g-%d' % i) 171 | add_to_list(all_graphs, g, label) 172 | label += 1 173 | random.seed(2) 174 | for i in range(sample_size): 175 | g = random_tree(random.randint(8, 15)) 176 | count = random.randint(1, 2) 177 | attach_cycles(g, cycle_len=5, count=count, is_clique=True) 178 | add_to_list(all_graphs, g, label) 179 | return all_graphs 180 | 181 | 182 | def house_grid(sample_size): 183 | all_graphs = [] 184 | label = 0 185 | random.seed(1) 186 | for i in range(sample_size): 187 | g = random_tree(random.randint(8, 15)) 188 | count = random.randint(1, 2) 189 | for i in range(count): 190 | attach_grid(g, 'g-%d' % i) 191 | add_to_list(all_graphs, g, label) 192 | label += 1 193 | random.seed(2) 194 | for i in range(sample_size): 195 | g = random_tree(random.randint(8, 15)) 196 | count = random.randint(1, 2) 197 | for i in range(count): 198 | attach_house(g, 'h-%d' % i) 199 | add_to_list(all_graphs, g, label) 200 | return all_graphs 201 | 202 | 203 | def trisq(sample_size): 204 | all_graphs = [] 205 | random.seed(0) 206 | for i in range(sample_size): 207 | g = random_tree(random.randint(8, 15)) 208 | add_to_list(all_graphs, g, 0) 209 | random.seed(1) 210 | for i in range(sample_size): 211 | g = random_tree(random.randint(8, 15)) 212 | count = random.randint(1, 4) 213 | attach_cycles(g, cycle_len=3, count=count) 214 | add_to_list(all_graphs, g, 1) 215 | random.seed(2) 216 | for i in range(sample_size): 217 | g = random_tree(random.randint(8, 15)) 218 | count = random.randint(1, 4) 219 | attach_cycles(g, cycle_len=4, count=count) 220 | add_to_list(all_graphs, g, 2) 221 | random.seed(3) 222 | for i in range(sample_size): 223 | g = random_tree(random.randint(8, 15)) 224 | count_tri = random.randint(1, 4) 225 | count_sq = random.randint(1, 4) 226 | attach_cycles(g, cycle_len=3, count=count_tri) 227 | attach_cycles(g, cycle_len=4, count=count_sq) 228 | add_to_list(all_graphs, g, 3) 229 | return all_graphs 230 | 231 | 232 | def cycliq(sample_size, is_multi): 233 | all_graphs = [] 234 | label = 0 235 | if is_multi: 236 | random.seed(0) 237 | for i in range(sample_size): 238 | g = random_tree(random.randint(8, 15)) 239 | add_to_list(all_graphs, g, label) 240 | label += 1 241 | random.seed(1) 242 | for i in range(sample_size): 243 | g = random_tree(random.randint(8, 15)) 244 | count = random.randint(1, 2) 245 | attach_cycles(g, cycle_len=5, count=count) 246 | add_to_list(all_graphs, g, label) 247 | label += 1 248 | random.seed(2) 249 | for i in range(sample_size): 250 | g = random_tree(random.randint(8, 15)) 251 | count = random.randint(1, 2) 252 | attach_cycles(g, cycle_len=5, count=count, is_clique=True) 253 | add_to_list(all_graphs, g, label) 254 | label += 1 255 | if is_multi: 256 | random.seed(3) 257 | for i in range(sample_size): 258 | g = random_tree(random.randint(8, 15)) 259 | count = random.randint(1, 2) 260 | attach_cycles(g, cycle_len=5, count=count, is_clique=True) 261 | count = random.randint(1, 2) 262 | attach_cycles(g, cycle_len=5, count=count) 263 | add_to_list(all_graphs, g, label) 264 | return all_graphs 265 | 266 | 267 | def write_gexf(output_path: Path, graphs): 268 | print('Created .gexf files in %s' % output_path) 269 | for g, label in graphs: 270 | nx.write_gexf(g, output_path / ('%d.%d.gexf' % (g.graph['graph_num'], label))) 271 | 272 | 273 | def write_adjacency(output_path: Path, dataset: Dataset, graphs): 274 | relabled_gs = [] 275 | first_label = 1 276 | graph_indicator = [] 277 | for g, label in tq(graphs): 278 | relabled_gs.append(nx.convert_node_labels_to_integers(g, first_label=first_label)) 279 | N = len(g.nodes()) 280 | first_label += N 281 | graph_indicator.extend([g.graph['graph_num']] * N) 282 | with open(output_path / ('%s_A.txt' % dataset.value), 'w') as f: 283 | for g in relabled_gs: 284 | for u, v in g.edges(): 285 | f.write(f'{u}, {v}\n{v}, {u}\n') 286 | with open(output_path / ('%s_graph_indicator.txt' % dataset.value), 'w') as f: 287 | f.write('\n'.join(map(str, graph_indicator))) 288 | with open(output_path / ('%s_graph_labels.txt' % dataset.value), 'w') as f: 289 | f.write('\n'.join([str(label) for g, label in graphs])) 290 | 291 | 292 | def main(dataset: Dataset, output_path: Path = typer.Argument('data', help='Output path for dataset'), 293 | sample_size: int = typer.Option(1000, help='Number of samples for each label to generate')): 294 | print('Generating %s dataset' % dataset.value) 295 | if dataset == Dataset.CYCLIQ: 296 | graphs = cycliq(sample_size, is_multi=False) 297 | elif dataset == Dataset.CYCLIQ_MULTI: 298 | graphs = cycliq(sample_size, is_multi=True) 299 | elif dataset == Dataset.TRISQ: 300 | graphs = trisq(sample_size) 301 | elif dataset == Dataset.HOUSE_CLIQ: 302 | graphs = house_cliq(sample_size) 303 | elif dataset == Dataset.GRID_CLIQ: 304 | graphs = grid_cliq(sample_size) 305 | elif dataset == Dataset.HOUSE_GRID: 306 | graphs = house_grid(sample_size) 307 | 308 | if not output_path.exists(): 309 | typer.confirm("Output path %s does not exist, do you want to create it?" % output_path, abort=True) 310 | output_path.mkdir() 311 | 312 | output_path = output_path / dataset.value 313 | output_path.mkdir(exist_ok=True) 314 | 315 | write_gexf(output_path, graphs) 316 | write_adjacency(output_path, dataset, graphs) 317 | 318 | 319 | if __name__ == '__main__': 320 | app = typer.Typer(add_completion=False) 321 | app.command()(main) 322 | app() 323 | -------------------------------------------------------------------------------- /explain.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | from enum import Enum 4 | from pathlib import Path 5 | 6 | import networkx as nx 7 | import numpy as np 8 | import torch 9 | import typer 10 | from geomloss import SamplesLoss 11 | from tqdm import tqdm as tq 12 | 13 | from models import GcnEncoderGraph 14 | 15 | app = typer.Typer(add_completion=False) 16 | 17 | 18 | class ExplainMethod(str, Enum): 19 | contrastive = 'contrastive' 20 | sa = 'sensitivity' 21 | occlusion = 'occlusion' 22 | random = 'random' 23 | 24 | 25 | def load_model(model_path: Path): 26 | ckpt = torch.load(model_path) 27 | cg_dict = ckpt["cg"] # get computation graph 28 | input_dim = cg_dict["feat"].shape[2] 29 | num_classes = cg_dict["pred"].shape[2] 30 | model = GcnEncoderGraph( 31 | input_dim=input_dim, 32 | hidden_dim=20, 33 | embedding_dim=20, 34 | label_dim=num_classes, 35 | num_layers=3, 36 | bn=False, 37 | args=None, 38 | ) 39 | model.load_state_dict(ckpt["model_state"]) 40 | model.eval() 41 | return model 42 | 43 | 44 | def check_path(output_path: Path): 45 | if not output_path.exists(): 46 | typer.confirm("Output path does not exist, do you want to create it?", abort=True) 47 | output_path.mkdir(parents=True) 48 | 49 | 50 | def read_graphs(dataset_path: Path): 51 | labels = {} 52 | nx_graphs = {} 53 | for name in os.listdir(str(dataset_path)): 54 | if not name.endswith('gexf'): 55 | continue 56 | idx, label = name.split('.')[-3:-1] 57 | idx, label = int(idx), int(label) 58 | nx_graphs[idx] = nx.read_gexf(dataset_path / name) 59 | labels[idx] = label 60 | print('Found %d samples' % len(nx_graphs)) 61 | return nx_graphs, labels 62 | 63 | 64 | @app.command(name='sensitivity', help='Run sensitivity analysis explanation') 65 | def sa(dataset_path: Path, model_path: Path, output_path: Path): 66 | check_path(output_path) 67 | nx_graphs, labels = read_graphs(dataset_path) 68 | model = load_model(model_path) 69 | 70 | def explain(graph_num): 71 | g = nx_graphs[graph_num] 72 | node_count = len(g.nodes) 73 | 74 | adj = np.zeros((1, 100, 100)) 75 | adj[0, :node_count, :node_count] = nx.to_numpy_matrix(g) 76 | adj = torch.tensor(adj, dtype=torch.float) 77 | x = torch.ones((1, 100, 10), requires_grad=True, dtype=torch.float) 78 | 79 | ypred, _ = model(x, adj) 80 | 81 | loss = model.loss(ypred, torch.LongTensor([labels[graph_num]])) 82 | loss.backward() 83 | node_importance = x.grad.detach().numpy()[0][:node_count] 84 | node_importance = (node_importance ** 2).sum(axis=1) 85 | N = nx_graphs[graph_num].number_of_nodes() 86 | masked_adj = np.zeros((N, N)) 87 | for u, v in nx_graphs[graph_num].edges(): 88 | u = int(u) 89 | v = int(v) 90 | masked_adj[u, v] = masked_adj[v, u] = node_importance[u] + node_importance[v] 91 | return masked_adj 92 | 93 | for gid in tq(nx_graphs): 94 | masked_adj = explain(gid) 95 | np.save(output_path / ('%s.npy' % gid), masked_adj) 96 | 97 | 98 | @app.command(help='Run occlusion explanation') 99 | def occlusion(dataset_path: Path, model_path: Path, output_path: Path): 100 | check_path(output_path) 101 | nx_graphs, labels = read_graphs(dataset_path) 102 | model = load_model(model_path) 103 | 104 | def prepare_input(g): 105 | node_count = len(g.nodes) 106 | adj = np.zeros((1, 100, 100)) 107 | adj[0, :node_count, :node_count] = nx.to_numpy_matrix(g) 108 | adj = torch.tensor(adj, dtype=torch.float) 109 | x = torch.ones((1, 100, 10), requires_grad=False, dtype=torch.float) 110 | return x, adj 111 | 112 | def explain(graph_num): 113 | model.eval() 114 | g = nx_graphs[graph_num] 115 | x, adj = prepare_input(g) 116 | 117 | ypred, _ = model(x, adj) 118 | true_label = labels[graph_num] 119 | before_occlusion = ypred[0].softmax(0) 120 | node_importance = {} 121 | 122 | for removed_node in g.nodes(): 123 | g2 = g.copy() 124 | g2.remove_node(removed_node) 125 | x, adj = prepare_input(g2) 126 | ypred, _ = model(x, adj) 127 | after_occlusion = ypred[0].softmax(0) 128 | importance = abs(after_occlusion[true_label] - before_occlusion[true_label]) 129 | node_importance[int(removed_node)] = importance.item() 130 | 131 | N = nx_graphs[graph_num].number_of_nodes() 132 | masked_adj = np.zeros((N, N)) 133 | for u, v in nx_graphs[graph_num].edges(): 134 | u = int(u) 135 | v = int(v) 136 | masked_adj[u, v] = masked_adj[v, u] = node_importance[u] + node_importance[v] 137 | return masked_adj 138 | 139 | for gid in tq(nx_graphs): 140 | masked_adj = explain(gid) 141 | np.save(output_path / ('%s.npy' % gid), masked_adj) 142 | 143 | 144 | @app.command(name='random', help='Run random explanation') 145 | def random_explain(dataset_path: Path, output_path: Path): 146 | check_path(output_path) 147 | nx_graphs, labels = read_graphs(dataset_path) 148 | 149 | def explain(graph_num): 150 | g = nx_graphs[graph_num] 151 | random_importance = list(range(len(g.edges()))) 152 | random.shuffle(random_importance) 153 | 154 | N = g.number_of_nodes() 155 | masked_adj = np.zeros((N, N)) 156 | for (u, v), importance in zip(g.edges(), random_importance): 157 | u = int(u) 158 | v = int(v) 159 | masked_adj[u, v] = masked_adj[v, u] = importance 160 | return masked_adj 161 | 162 | for gid in tq(nx_graphs): 163 | masked_adj = explain(gid) 164 | np.save(output_path / ('%s.npy' % gid), masked_adj) 165 | 166 | 167 | @app.command(name='contrast', help='Run contrastive explanation') 168 | def contrast(dataset_path: Path, 169 | embedding_path: Path = typer.Argument(..., help='path containing the graph embeddings'), 170 | output_path: Path = typer.Argument(..., help='output path for explanations'), 171 | loss_str: str = typer.Option('-+s', '--loss', 172 | help="add each of '-', '+' and 's' for different parts of loss. Order does not matter"), 173 | similar_size: int = typer.Option(10, 174 | help="number of similar graphs to use for positive and negative set"), 175 | distance_str: str = typer.Option('ot', '--distance', 176 | help="distance measure to use can be one of ['ot,'avg']") 177 | ): 178 | check_path(output_path) 179 | nx_graphs, labels = read_graphs(dataset_path) 180 | torch.set_num_threads(1) 181 | graph_embs = {} 182 | for name in os.listdir(str(embedding_path)): 183 | if not name.endswith('npy'): 184 | continue 185 | graph_num = int(name.split('.')[0]) 186 | embs = np.load(str(embedding_path / name)) 187 | last_idx = len(nx_graphs[graph_num].nodes) 188 | embs = embs[:last_idx, :] 189 | graph_embs[graph_num] = embs 190 | 191 | def closest(graph_num, dist, size=1, neg_label=None): 192 | cur_label = labels[graph_num] 193 | pos_dists = [] 194 | neg_dists = [] 195 | for i in graph_embs: 196 | if i == graph_num: 197 | continue 198 | # if pred_labels[i] != dataset[i][1]: # ignore those not predicted correct 199 | # continue 200 | d = dist(graph_num, i) 201 | if labels[i] != cur_label: 202 | if neg_label is None or labels[i] == neg_label: 203 | neg_dists.append((d, i)) 204 | else: 205 | pos_dists.append((d, i)) 206 | pos_dists = sorted(pos_dists) 207 | neg_dists = sorted(neg_dists) 208 | pos_indices = [i for d, i in pos_dists] 209 | neg_indices = [i for d, i in neg_dists] 210 | 211 | return pos_indices[:size], neg_indices[:size] 212 | 213 | def loss_verbose(loss_str): 214 | res = '' 215 | if '-' in loss_str: 216 | res = res + '+ loss_neg ' 217 | if '+' in loss_str: 218 | res = res + '- loss_pos ' 219 | if 's' in loss_str: 220 | res = res + '+ loss_self ' 221 | return res 222 | 223 | print('Using %s for loss function' % loss_verbose(loss_str)) 224 | 225 | if distance_str == 'ot': 226 | distance = SamplesLoss("sinkhorn", p=1, blur=.01) 227 | elif distance_str == 'avg': 228 | distance = lambda x, y: torch.dist(x.mean(axis=0), y.mean(axis=0)) 229 | 230 | def graph_distance(g1_num, g2_num): 231 | k = (min(g1_num, g2_num), max(g1_num, g2_num)) 232 | g1_embs = graph_embs[g1_num] 233 | g2_embs = graph_embs[g2_num] 234 | return distance(torch.Tensor(g1_embs), torch.Tensor(g2_embs)).item() 235 | 236 | def explain(graph_num): 237 | cur_embs = torch.Tensor(graph_embs[graph_num]) 238 | 239 | distance = SamplesLoss("sinkhorn", p=1, blur=.01) 240 | 241 | positive_ids, negative_ids = closest(graph_num, graph_distance, size=similar_size) 242 | 243 | positive_embs = [torch.Tensor(graph_embs[i]) for i in positive_ids] 244 | negative_embs = [torch.Tensor(graph_embs[i]) for i in negative_ids] 245 | 246 | mask = torch.nn.Parameter(torch.zeros(len(cur_embs))) 247 | 248 | learning_rate = 1e-1 249 | optimizer = torch.optim.Adam([mask], lr=learning_rate) 250 | 251 | if distance_str == 'ot': 252 | def mydist(mask, embs): 253 | return distance(mask.softmax(0), cur_embs, 254 | distance.generate_weights(embs), embs) 255 | else: 256 | def mydist(mask, embs): 257 | return torch.dist((cur_embs * mask.softmax(0).reshape(-1, 1)).sum(axis=0), embs.mean(axis=0)) 258 | # tq = tqdm(range(50)) 259 | history = [] 260 | for t in range(50): 261 | loss_pos = torch.mean(torch.stack([mydist(mask, x) for x in positive_embs])) 262 | loss_neg = torch.mean(torch.stack([mydist(mask, x) for x in negative_embs])) 263 | loss_self = mydist(mask, cur_embs) 264 | 265 | loss = 0 266 | if '-' in loss_str: 267 | loss = loss + loss_neg 268 | if '+' in loss_str: 269 | loss = loss - loss_pos 270 | if 's' in loss_str: 271 | loss = loss + loss_self 272 | 273 | hist_item = dict(loss_neg=loss_neg.item(), loss_self=loss_self.item(), loss_pos=loss_pos.item(), 274 | loss=loss.item()) 275 | history.append(hist_item) 276 | # tq.set_postfix(**hist_item) 277 | optimizer.zero_grad() 278 | loss.backward() 279 | optimizer.step() 280 | node_importance = list(1 - mask.softmax(0).detach().numpy().ravel()) 281 | N = nx_graphs[graph_num].number_of_nodes() 282 | masked_adj = np.zeros((N, N)) 283 | for u, v in nx_graphs[graph_num].edges(): 284 | u = int(u) 285 | v = int(v) 286 | masked_adj[u, v] = masked_adj[v, u] = node_importance[u] + node_importance[v] 287 | return masked_adj 288 | 289 | for gid in tq(graph_embs): 290 | masked_adj = explain(gid) 291 | np.save(output_path / ('%s.npy' % gid), masked_adj) 292 | 293 | 294 | if __name__ == "__main__": 295 | app() 296 | -------------------------------------------------------------------------------- /gnnexplainer_main.py: -------------------------------------------------------------------------------- 1 | # This file is copied from GNNExplainer repo with small modifications 2 | # https://github.com/RexYing/gnn-model-explainer 3 | 4 | """ explainer_main.py 5 | 6 | Main user interface for the explainer module. 7 | """ 8 | import argparse 9 | import os 10 | 11 | import sklearn.metrics as metrics 12 | 13 | from tensorboardX import SummaryWriter 14 | 15 | import pickle 16 | import shutil 17 | import torch 18 | 19 | import models 20 | import gnnexplainer_utils.io_utils as io_utils 21 | import gnnexplainer_utils.parser_utils as parser_utils 22 | from gnnexplainer import explain 23 | 24 | from pathlib import Path 25 | 26 | 27 | def arg_parse(): 28 | parser = argparse.ArgumentParser(description="GNN Explainer arguments.") 29 | io_parser = parser.add_mutually_exclusive_group(required=False) 30 | io_parser.add_argument("--dataset", dest="dataset", help="Input dataset.") 31 | benchmark_parser = io_parser.add_argument_group() 32 | benchmark_parser.add_argument( 33 | "--bmname", dest="bmname", help="Name of the benchmark dataset" 34 | ) 35 | io_parser.add_argument("--pkl", dest="pkl_fname", help="Name of the pkl data file") 36 | 37 | parser_utils.parse_optimizer(parser) 38 | 39 | parser.add_argument("--clean-log", action="store_true", help="If true, cleans the specified log directory before running.") 40 | parser.add_argument("--logdir", dest="logdir", help="Tensorboard log directory") 41 | parser.add_argument("--ckptdir", dest="ckptdir", help="Model checkpoint directory") 42 | parser.add_argument("--cuda", dest="cuda", help="CUDA.") 43 | parser.add_argument( 44 | "--gpu", 45 | dest="gpu", 46 | action="store_const", 47 | const=True, 48 | default=False, 49 | help="whether to use GPU.", 50 | ) 51 | parser.add_argument( 52 | "--epochs", dest="num_epochs", type=int, help="Number of epochs to train." 53 | ) 54 | parser.add_argument( 55 | "--hidden-dim", dest="hidden_dim", type=int, help="Hidden dimension" 56 | ) 57 | parser.add_argument( 58 | "--output-dim", dest="output_dim", type=int, help="Output dimension" 59 | ) 60 | parser.add_argument( 61 | "--num-gc-layers", 62 | dest="num_gc_layers", 63 | type=int, 64 | help="Number of graph convolution layers before each pooling", 65 | ) 66 | parser.add_argument( 67 | "--bn", 68 | dest="bn", 69 | action="store_const", 70 | const=True, 71 | default=False, 72 | help="Whether batch normalization is used", 73 | ) 74 | parser.add_argument("--dropout", dest="dropout", type=float, help="Dropout rate.") 75 | parser.add_argument( 76 | "--nobias", 77 | dest="bias", 78 | action="store_const", 79 | const=False, 80 | default=True, 81 | help="Whether to add bias. Default to True.", 82 | ) 83 | parser.add_argument( 84 | "--no-writer", 85 | dest="writer", 86 | action="store_const", 87 | const=False, 88 | default=True, 89 | help="Whether to add bias. Default to True.", 90 | ) 91 | # Explainer 92 | parser.add_argument("--mask-act", dest="mask_act", type=str, help="sigmoid, ReLU.") 93 | parser.add_argument( 94 | "--mask-bias", 95 | dest="mask_bias", 96 | action="store_const", 97 | const=True, 98 | default=False, 99 | help="Whether to add bias. Default to True.", 100 | ) 101 | parser.add_argument( 102 | "--explain-node", dest="explain_node", type=int, help="Node to explain." 103 | ) 104 | parser.add_argument( 105 | "--graph-idx", dest="graph_idx", type=int, help="Graph to explain." 106 | ) 107 | parser.add_argument( 108 | "--graph-mode", 109 | dest="graph_mode", 110 | action="store_const", 111 | const=True, 112 | default=False, 113 | help="whether to run Explainer on Graph Classification task.", 114 | ) 115 | parser.add_argument( 116 | "--multigraph-class", 117 | dest="multigraph_class", 118 | type=int, 119 | help="whether to run Explainer on multiple Graphs from the Classification task for examples in the same class.", 120 | ) 121 | parser.add_argument( 122 | "--multinode-class", 123 | dest="multinode_class", 124 | type=int, 125 | help="whether to run Explainer on multiple nodes from the Classification task for examples in the same class.", 126 | ) 127 | parser.add_argument( 128 | "--align-steps", 129 | dest="align_steps", 130 | type=int, 131 | help="Number of iterations to find P, the alignment matrix.", 132 | ) 133 | 134 | parser.add_argument( 135 | "--method", dest="method", type=str, help="Method. Possible values: base, att." 136 | ) 137 | parser.add_argument( 138 | "--name-suffix", dest="name_suffix", help="suffix added to the output filename" 139 | ) 140 | parser.add_argument( 141 | "--explainer-suffix", 142 | dest="explainer_suffix", 143 | help="suffix added to the explainer log", 144 | ) 145 | 146 | parser.add_argument( 147 | "--explain-all", 148 | dest="explain_all", 149 | action="store_true", 150 | help="explain all graphs", 151 | ) 152 | 153 | # TODO: Check argument usage 154 | parser.set_defaults( 155 | logdir="log", 156 | ckptdir="ckpt", 157 | dataset="syn1", 158 | opt="adam", 159 | opt_scheduler="none", 160 | cuda="0", 161 | lr=0.1, 162 | clip=2.0, 163 | batch_size=20, 164 | num_epochs=100, 165 | hidden_dim=20, 166 | output_dim=20, 167 | num_gc_layers=3, 168 | dropout=0.0, 169 | method="base", 170 | name_suffix="", 171 | explainer_suffix="", 172 | align_steps=1000, 173 | explain_node=None, 174 | graph_idx=-1, 175 | mask_act="sigmoid", 176 | multigraph_class=-1, 177 | multinode_class=-1, 178 | ) 179 | return parser.parse_args() 180 | 181 | 182 | def main(): 183 | # Load a configuration 184 | prog_args = arg_parse() 185 | 186 | if prog_args.gpu: 187 | os.environ["CUDA_VISIBLE_DEVICES"] = prog_args.cuda 188 | print("CUDA", prog_args.cuda) 189 | else: 190 | print("Using CPU") 191 | 192 | # Configure the logging directory 193 | if prog_args.writer: 194 | path = os.path.join(prog_args.logdir, io_utils.gen_explainer_prefix(prog_args)) 195 | if os.path.isdir(path) and prog_args.clean_log: 196 | print('Removing existing log dir: ', path) 197 | if not input("Are you sure you want to remove this directory? (y/n): ").lower().strip()[:1] == "y": sys.exit(1) 198 | shutil.rmtree(path) 199 | writer = SummaryWriter(path) 200 | else: 201 | writer = None 202 | 203 | # Load a model checkpoint 204 | ckpt = io_utils.load_ckpt(prog_args) 205 | cg_dict = ckpt["cg"] # get computation graph 206 | input_dim = cg_dict["feat"].shape[2] 207 | num_classes = cg_dict["pred"].shape[2] 208 | print("Loaded model from {}".format(prog_args.ckptdir)) 209 | print("input dim: ", input_dim, "; num classes: ", num_classes) 210 | 211 | # Determine explainer mode 212 | graph_mode = ( 213 | prog_args.graph_mode 214 | or prog_args.multigraph_class >= 0 215 | or prog_args.graph_idx >= 0 216 | ) 217 | 218 | # build model 219 | print("Method: ", prog_args.method) 220 | if graph_mode: 221 | # Explain Graph prediction 222 | model = models.GcnEncoderGraph( 223 | input_dim=input_dim, 224 | hidden_dim=prog_args.hidden_dim, 225 | embedding_dim=prog_args.output_dim, 226 | label_dim=num_classes, 227 | num_layers=prog_args.num_gc_layers, 228 | bn=prog_args.bn, 229 | args=prog_args, 230 | ) 231 | else: 232 | if prog_args.dataset == "ppi_essential": 233 | # class weight in CE loss for handling imbalanced label classes 234 | prog_args.loss_weight = torch.tensor([1.0, 5.0], dtype=torch.float).cuda() 235 | # Explain Node prediction 236 | model = models.GcnEncoderNode( 237 | input_dim=input_dim, 238 | hidden_dim=prog_args.hidden_dim, 239 | embedding_dim=prog_args.output_dim, 240 | label_dim=num_classes, 241 | num_layers=prog_args.num_gc_layers, 242 | bn=prog_args.bn, 243 | args=prog_args, 244 | ) 245 | if prog_args.gpu: 246 | model = model.cuda() 247 | # load state_dict (obtained by model.state_dict() when saving checkpoint) 248 | model.load_state_dict(ckpt["model_state"]) 249 | 250 | # Create explainer 251 | explainer = explain.Explainer( 252 | model=model, 253 | adj=cg_dict["adj"], 254 | feat=cg_dict["feat"], 255 | label=cg_dict["label"], 256 | pred=cg_dict["pred"], 257 | train_idx=cg_dict["train_idx"], 258 | args=prog_args, 259 | writer=writer, 260 | print_training=True, 261 | graph_mode=graph_mode, 262 | graph_idx=prog_args.graph_idx, 263 | ) 264 | 265 | # TODO: API should definitely be cleaner 266 | # Let's define exactly which modes we support 267 | # We could even move each mode to a different method (even file) 268 | if prog_args.explain_node is not None: 269 | explainer.explain(prog_args.explain_node, unconstrained=False) 270 | elif graph_mode: 271 | if prog_args.explain_all: 272 | explain_path = Path('explanations/gnnexplainer/') 273 | explain_path.mkdir(exist_ok=True, parents=True) 274 | embeddings_path = Path('embeddings-%s/' % prog_args.bmname) 275 | embeddings_path.mkdir(exist_ok=True, parents=True) 276 | for i in range(len(cg_dict['all_idx'])): 277 | print('Explaining %s' % cg_dict['all_idx'][i]) 278 | explainer.explain( 279 | node_idx=0, 280 | graph_idx=i, 281 | graph_mode=True, 282 | unconstrained=False, 283 | original_idx=cg_dict['all_idx'][i] 284 | ) 285 | elif prog_args.multigraph_class >= 0: 286 | print(cg_dict["label"]) 287 | # only run for graphs with label specified by multigraph_class 288 | labels = cg_dict["label"].numpy() 289 | graph_indices = [] 290 | for i, l in enumerate(labels): 291 | if l == prog_args.multigraph_class: 292 | graph_indices.append(i) 293 | if len(graph_indices) > 30: 294 | break 295 | print( 296 | "Graph indices for label ", 297 | prog_args.multigraph_class, 298 | " : ", 299 | graph_indices, 300 | ) 301 | explainer.explain_graphs(graph_indices=graph_indices) 302 | 303 | elif prog_args.graph_idx == -1: 304 | # just run for a customized set of indices 305 | explainer.explain_graphs(graph_indices=[1, 2, 3, 4]) 306 | else: 307 | explainer.explain( 308 | node_idx=0, 309 | graph_idx=prog_args.graph_idx, 310 | graph_mode=True, 311 | unconstrained=False, 312 | original_idx=cg_dict['all_idx'][prog_args.graph_idx] 313 | ) 314 | # io_utils.plot_cmap_tb(writer, "tab20", 20, "tab20_cmap") 315 | else: 316 | if prog_args.multinode_class >= 0: 317 | print(cg_dict["label"]) 318 | # only run for nodes with label specified by multinode_class 319 | labels = cg_dict["label"][0] # already numpy matrix 320 | 321 | node_indices = [] 322 | for i, l in enumerate(labels): 323 | if len(node_indices) > 4: 324 | break 325 | if l == prog_args.multinode_class: 326 | node_indices.append(i) 327 | print( 328 | "Node indices for label ", 329 | prog_args.multinode_class, 330 | " : ", 331 | node_indices, 332 | ) 333 | explainer.explain_nodes(node_indices, prog_args) 334 | 335 | else: 336 | # explain a set of nodes 337 | masked_adj = explainer.explain_nodes_gnn_stats( 338 | range(400, 700, 5), prog_args 339 | ) 340 | 341 | if __name__ == "__main__": 342 | main() 343 | 344 | -------------------------------------------------------------------------------- /gnnexplainer_utils/synthetic_structsim.py: -------------------------------------------------------------------------------- 1 | """synthetic_structsim.py 2 | 3 | Utilities for generating certain graph shapes. 4 | """ 5 | import math 6 | 7 | import networkx as nx 8 | import numpy as np 9 | 10 | # Following GraphWave's representation of structural similarity 11 | 12 | 13 | def clique(start, nb_nodes, nb_to_remove=0, role_start=0): 14 | """ Defines a clique (complete graph on nb_nodes nodes, 15 | with nb_to_remove edges that will have to be removed), 16 | index of nodes starting at start 17 | and role_ids at role_start 18 | INPUT: 19 | ------------- 20 | start : starting index for the shape 21 | nb_nodes : int correspondingraph to the nb of nodes in the clique 22 | role_start : starting index for the roles 23 | nb_to_remove: int-- numb of edges to remove (unif at RDM) 24 | OUTPUT: 25 | ------------- 26 | graph : a house shape graph, with ids beginning at start 27 | roles : list of the roles of the nodes (indexed starting at 28 | role_start) 29 | """ 30 | a = np.ones((nb_nodes, nb_nodes)) 31 | np.fill_diagonal(a, 0) 32 | graph = nx.from_numpy_matrix(a) 33 | edge_list = graph.edges().keys() 34 | roles = [role_start] * nb_nodes 35 | if nb_to_remove > 0: 36 | lst = np.random.choice(len(edge_list), nb_to_remove, replace=False) 37 | print(edge_list, lst) 38 | to_delete = [edge_list[e] for e in lst] 39 | graph.remove_edges_from(to_delete) 40 | for e in lst: 41 | print(edge_list[e][0]) 42 | print(len(roles)) 43 | roles[edge_list[e][0]] += 1 44 | roles[edge_list[e][1]] += 1 45 | mapping_graph = {k: (k + start) for k in range(nb_nodes)} 46 | graph = nx.relabel_nodes(graph, mapping_graph) 47 | return graph, roles 48 | 49 | 50 | def cycle(start, len_cycle, role_start=0): 51 | """Builds a cycle graph, with index of nodes starting at start 52 | and role_ids at role_start 53 | INPUT: 54 | ------------- 55 | start : starting index for the shape 56 | role_start : starting index for the roles 57 | OUTPUT: 58 | ------------- 59 | graph : a house shape graph, with ids beginning at start 60 | roles : list of the roles of the nodes (indexed starting at 61 | role_start) 62 | """ 63 | graph = nx.Graph() 64 | graph.add_nodes_from(range(start, start + len_cycle)) 65 | for i in range(len_cycle - 1): 66 | graph.add_edges_from([(start + i, start + i + 1)]) 67 | graph.add_edges_from([(start + len_cycle - 1, start)]) 68 | roles = [role_start] * len_cycle 69 | return graph, roles 70 | 71 | 72 | def diamond(start, role_start=0): 73 | """Builds a diamond graph, with index of nodes starting at start 74 | and role_ids at role_start 75 | INPUT: 76 | ------------- 77 | start : starting index for the shape 78 | role_start : starting index for the roles 79 | OUTPUT: 80 | ------------- 81 | graph : a house shape graph, with ids beginning at start 82 | roles : list of the roles of the nodes (indexed starting at 83 | role_start) 84 | """ 85 | graph = nx.Graph() 86 | graph.add_nodes_from(range(start, start + 6)) 87 | graph.add_edges_from( 88 | [ 89 | (start, start + 1), 90 | (start + 1, start + 2), 91 | (start + 2, start + 3), 92 | (start + 3, start), 93 | ] 94 | ) 95 | graph.add_edges_from( 96 | [ 97 | (start + 4, start), 98 | (start + 4, start + 1), 99 | (start + 4, start + 2), 100 | (start + 4, start + 3), 101 | ] 102 | ) 103 | graph.add_edges_from( 104 | [ 105 | (start + 5, start), 106 | (start + 5, start + 1), 107 | (start + 5, start + 2), 108 | (start + 5, start + 3), 109 | ] 110 | ) 111 | roles = [role_start] * 6 112 | return graph, roles 113 | 114 | 115 | def tree(start, height, r=2, role_start=0): 116 | """Builds a balanced r-tree of height h 117 | INPUT: 118 | ------------- 119 | start : starting index for the shape 120 | height : int height of the tree 121 | r : int number of branches per node 122 | role_start : starting index for the roles 123 | OUTPUT: 124 | ------------- 125 | graph : a tree shape graph, with ids beginning at start 126 | roles : list of the roles of the nodes (indexed starting at role_start) 127 | """ 128 | graph = nx.balanced_tree(r, height) 129 | roles = [0] * graph.number_of_nodes() 130 | return graph, roles 131 | 132 | 133 | def fan(start, nb_branches, role_start=0): 134 | """Builds a fan-like graph, with index of nodes starting at start 135 | and role_ids at role_start 136 | INPUT: 137 | ------------- 138 | nb_branches : int correspondingraph to the nb of fan branches 139 | start : starting index for the shape 140 | role_start : starting index for the roles 141 | OUTPUT: 142 | ------------- 143 | graph : a house shape graph, with ids beginning at start 144 | roles : list of the roles of the nodes (indexed starting at 145 | role_start) 146 | """ 147 | graph, roles = star(start, nb_branches, role_start=role_start) 148 | for k in range(1, nb_branches - 1): 149 | roles[k] += 1 150 | roles[k + 1] += 1 151 | graph.add_edges_from([(start + k, start + k + 1)]) 152 | return graph, roles 153 | 154 | 155 | def ba(start, width, role_start=0, m=5): 156 | """Builds a BA preferential attachment graph, with index of nodes starting at start 157 | and role_ids at role_start 158 | INPUT: 159 | ------------- 160 | start : starting index for the shape 161 | width : int size of the graph 162 | role_start : starting index for the roles 163 | OUTPUT: 164 | ------------- 165 | graph : a house shape graph, with ids beginning at start 166 | roles : list of the roles of the nodes (indexed starting at 167 | role_start) 168 | """ 169 | graph = nx.barabasi_albert_graph(width, m) 170 | graph.add_nodes_from(range(start, start + width)) 171 | nids = sorted(graph) 172 | mapping = {nid: start + i for i, nid in enumerate(nids)} 173 | graph = nx.relabel_nodes(graph, mapping) 174 | roles = [role_start for i in range(width)] 175 | return graph, roles 176 | 177 | 178 | def house(start, role_start=0): 179 | """Builds a house-like graph, with index of nodes starting at start 180 | and role_ids at role_start 181 | INPUT: 182 | ------------- 183 | start : starting index for the shape 184 | role_start : starting index for the roles 185 | OUTPUT: 186 | ------------- 187 | graph : a house shape graph, with ids beginning at start 188 | roles : list of the roles of the nodes (indexed starting at 189 | role_start) 190 | """ 191 | graph = nx.Graph() 192 | graph.add_nodes_from(range(start, start + 5)) 193 | graph.add_edges_from( 194 | [ 195 | (start, start + 1), 196 | (start + 1, start + 2), 197 | (start + 2, start + 3), 198 | (start + 3, start), 199 | ] 200 | ) 201 | # graph.add_edges_from([(start, start + 2), (start + 1, start + 3)]) 202 | graph.add_edges_from([(start + 4, start), (start + 4, start + 1)]) 203 | roles = [role_start, role_start, role_start + 1, role_start + 1, role_start + 2] 204 | return graph, roles 205 | 206 | 207 | def grid(start, dim=2, role_start=0): 208 | """ Builds a 2by2 grid 209 | """ 210 | grid_G = nx.grid_graph([dim, dim]) 211 | grid_G = nx.convert_node_labels_to_integers(grid_G, first_label=start) 212 | roles = [role_start for i in grid_G.nodes()] 213 | return grid_G, roles 214 | 215 | 216 | def star(start, nb_branches, role_start=0): 217 | """Builds a star graph, with index of nodes starting at start 218 | and role_ids at role_start 219 | INPUT: 220 | ------------- 221 | nb_branches : int correspondingraph to the nb of star branches 222 | start : starting index for the shape 223 | role_start : starting index for the roles 224 | OUTPUT: 225 | ------------- 226 | graph : a house shape graph, with ids beginning at start 227 | roles : list of the roles of the nodes (indexed starting at 228 | role_start) 229 | """ 230 | graph = nx.Graph() 231 | graph.add_nodes_from(range(start, start + nb_branches + 1)) 232 | for k in range(1, nb_branches + 1): 233 | graph.add_edges_from([(start, start + k)]) 234 | roles = [role_start + 1] * (nb_branches + 1) 235 | roles[0] = role_start 236 | return graph, roles 237 | 238 | 239 | def path(start, width, role_start=0): 240 | """Builds a path graph, with index of nodes starting at start 241 | and role_ids at role_start 242 | INPUT: 243 | ------------- 244 | start : starting index for the shape 245 | width : int length of the path 246 | role_start : starting index for the roles 247 | OUTPUT: 248 | ------------- 249 | graph : a house shape graph, with ids beginning at start 250 | roles : list of the roles of the nodes (indexed starting at 251 | role_start) 252 | """ 253 | graph = nx.Graph() 254 | graph.add_nodes_from(range(start, start + width)) 255 | for i in range(width - 1): 256 | graph.add_edges_from([(start + i, start + i + 1)]) 257 | roles = [role_start] * width 258 | roles[0] = role_start + 1 259 | roles[-1] = role_start + 1 260 | return graph, roles 261 | 262 | 263 | def build_graph( 264 | width_basis, 265 | basis_type, 266 | list_shapes, 267 | start=0, 268 | rdm_basis_plugins=False, 269 | add_random_edges=0, 270 | m=5, 271 | ): 272 | """This function creates a basis (scale-free, path, or cycle) 273 | and attaches elements of the type in the list randomly along the basis. 274 | Possibility to add random edges afterwards. 275 | INPUT: 276 | -------------------------------------------------------------------------------------- 277 | width_basis : width (in terms of number of nodes) of the basis 278 | basis_type : (torus, string, or cycle) 279 | shapes : list of shape list (1st arg: type of shape, 280 | next args:args for building the shape, 281 | except for the start) 282 | start : initial nb for the first node 283 | rdm_basis_plugins: boolean. Should the shapes be randomly placed 284 | along the basis (True) or regularly (False)? 285 | add_random_edges : nb of edges to randomly add on the structure 286 | m : number of edges to attach to existing node (for BA graph) 287 | OUTPUT: 288 | -------------------------------------------------------------------------------------- 289 | basis : a nx graph with the particular shape 290 | role_ids : labels for each role 291 | plugins : node ids with the attached shapes 292 | """ 293 | if basis_type == "ba": 294 | basis, role_id = eval(basis_type)(start, width_basis, m=m) 295 | else: 296 | basis, role_id = eval(basis_type)(start, width_basis) 297 | 298 | n_basis, n_shapes = nx.number_of_nodes(basis), len(list_shapes) 299 | start += n_basis # indicator of the id of the next node 300 | 301 | # Sample (with replacement) where to attach the new motifs 302 | if rdm_basis_plugins is True: 303 | plugins = np.random.choice(n_basis, n_shapes, replace=False) 304 | else: 305 | spacing = math.floor(n_basis / n_shapes) 306 | plugins = [int(k * spacing) for k in range(n_shapes)] 307 | seen_shapes = {"basis": [0, n_basis]} 308 | 309 | for shape_id, shape in enumerate(list_shapes): 310 | shape_type = shape[0] 311 | args = [start] 312 | if len(shape) > 1: 313 | args += shape[1:] 314 | args += [0] 315 | graph_s, roles_graph_s = eval(shape_type)(*args) 316 | n_s = nx.number_of_nodes(graph_s) 317 | try: 318 | col_start = seen_shapes[shape_type][0] 319 | except: 320 | col_start = np.max(role_id) + 1 321 | seen_shapes[shape_type] = [col_start, n_s] 322 | # Attach the shape to the basis 323 | basis.add_nodes_from(graph_s.nodes()) 324 | basis.add_edges_from(graph_s.edges()) 325 | basis.add_edges_from([(start, plugins[shape_id])]) 326 | if shape_type == "cycle": 327 | if np.random.random() > 0.5: 328 | a = np.random.randint(1, 4) 329 | b = np.random.randint(1, 4) 330 | basis.add_edges_from([(a + start, b + plugins[shape_id])]) 331 | temp_labels = [r + col_start for r in roles_graph_s] 332 | # temp_labels[0] += 100 * seen_shapes[shape_type][0] 333 | role_id += temp_labels 334 | start += n_s 335 | 336 | if add_random_edges > 0: 337 | # add random edges between nodes: 338 | for p in range(add_random_edges): 339 | src, dest = np.random.choice(nx.number_of_nodes(basis), 2, replace=False) 340 | print(src, dest) 341 | basis.add_edges_from([(src, dest)]) 342 | 343 | return basis, role_id, plugins 344 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | # This file is copied from GNNExplainer repo with small modifications 2 | # https://github.com/RexYing/gnn-model-explainer 3 | 4 | import torch 5 | import torch.nn as nn 6 | from torch.nn import init 7 | import torch.nn.functional as F 8 | 9 | import numpy as np 10 | 11 | 12 | # GCN basic operation 13 | class GraphConv(nn.Module): 14 | def __init__( 15 | self, 16 | input_dim, 17 | output_dim, 18 | add_self=False, 19 | normalize_embedding=False, 20 | dropout=0.0, 21 | bias=True, 22 | gpu=True, 23 | att=False, 24 | ): 25 | super(GraphConv, self).__init__() 26 | self.att = att 27 | self.add_self = add_self 28 | self.dropout = dropout 29 | if dropout > 0.001: 30 | self.dropout_layer = nn.Dropout(p=dropout) 31 | self.normalize_embedding = normalize_embedding 32 | self.input_dim = input_dim 33 | self.output_dim = output_dim 34 | if not gpu: 35 | self.weight = nn.Parameter(torch.FloatTensor(input_dim, output_dim)) 36 | if add_self: 37 | self.self_weight = nn.Parameter( 38 | torch.FloatTensor(input_dim, output_dim) 39 | ) 40 | if att: 41 | self.att_weight = nn.Parameter(torch.FloatTensor(input_dim, input_dim)) 42 | else: 43 | self.weight = nn.Parameter(torch.FloatTensor(input_dim, output_dim)) 44 | if add_self: 45 | self.self_weight = nn.Parameter( 46 | torch.FloatTensor(input_dim, output_dim) 47 | ) 48 | if att: 49 | self.att_weight = nn.Parameter( 50 | torch.FloatTensor(input_dim, input_dim) 51 | ) 52 | if bias: 53 | if not gpu: 54 | self.bias = nn.Parameter(torch.FloatTensor(output_dim)) 55 | else: 56 | self.bias = nn.Parameter(torch.FloatTensor(output_dim)) 57 | else: 58 | self.bias = None 59 | 60 | # self.softmax = nn.Softmax(dim=-1) 61 | 62 | def forward(self, x, adj): 63 | if self.dropout > 0.001: 64 | x = self.dropout_layer(x) 65 | # deg = torch.sum(adj, -1, keepdim=True) 66 | if self.att: 67 | x_att = torch.matmul(x, self.att_weight) 68 | # import pdb 69 | # pdb.set_trace() 70 | att = x_att @ x_att.permute(0, 2, 1) 71 | # att = self.softmax(att) 72 | adj = adj * att 73 | 74 | y = torch.matmul(adj, x) 75 | y = torch.matmul(y, self.weight) 76 | if self.add_self: 77 | self_emb = torch.matmul(x, self.self_weight) 78 | y += self_emb 79 | if self.bias is not None: 80 | y = y + self.bias 81 | if self.normalize_embedding: 82 | y = F.normalize(y, p=2, dim=2) 83 | # print(y[0][0]) 84 | return y, adj 85 | 86 | 87 | class GcnEncoderGraph(nn.Module): 88 | def __init__( 89 | self, 90 | input_dim, 91 | hidden_dim, 92 | embedding_dim, 93 | label_dim, 94 | num_layers, 95 | pred_hidden_dims=[], 96 | concat=True, 97 | bn=True, 98 | dropout=0.0, 99 | add_self=False, 100 | args=None, 101 | ): 102 | super(GcnEncoderGraph, self).__init__() 103 | self.concat = concat 104 | add_self = add_self 105 | self.bn = bn 106 | self.num_layers = num_layers 107 | self.num_aggs = 1 108 | 109 | self.bias = True 110 | self.gpu = False 111 | self.att = False 112 | 113 | if args is not None: 114 | if args.method == "att": 115 | self.att = True 116 | else: 117 | self.att = False 118 | self.gpu = args.gpu 119 | self.bias = args.bias 120 | 121 | self.conv_first, self.conv_block, self.conv_last = self.build_conv_layers( 122 | input_dim, 123 | hidden_dim, 124 | embedding_dim, 125 | num_layers, 126 | add_self, 127 | normalize=True, 128 | dropout=dropout, 129 | ) 130 | self.act = nn.ReLU() 131 | self.label_dim = label_dim 132 | 133 | if concat: 134 | self.pred_input_dim = hidden_dim * (num_layers - 1) + embedding_dim 135 | else: 136 | self.pred_input_dim = embedding_dim 137 | self.pred_model = self.build_pred_layers( 138 | self.pred_input_dim, pred_hidden_dims, label_dim, num_aggs=self.num_aggs 139 | ) 140 | 141 | for m in self.modules(): 142 | if isinstance(m, GraphConv): 143 | init.xavier_uniform_(m.weight.data, gain=nn.init.calculate_gain("relu")) 144 | if m.att: 145 | init.xavier_uniform_( 146 | m.att_weight.data, gain=nn.init.calculate_gain("relu") 147 | ) 148 | if m.add_self: 149 | init.xavier_uniform_( 150 | m.self_weight.data, gain=nn.init.calculate_gain("relu") 151 | ) 152 | if m.bias is not None: 153 | init.constant_(m.bias.data, 0.0) 154 | 155 | def build_conv_layers( 156 | self, 157 | input_dim, 158 | hidden_dim, 159 | embedding_dim, 160 | num_layers, 161 | add_self, 162 | normalize=False, 163 | dropout=0.0, 164 | ): 165 | conv_first = GraphConv( 166 | input_dim=input_dim, 167 | output_dim=hidden_dim, 168 | add_self=add_self, 169 | normalize_embedding=normalize, 170 | bias=self.bias, 171 | gpu=self.gpu, 172 | att=self.att, 173 | ) 174 | conv_block = nn.ModuleList( 175 | [ 176 | GraphConv( 177 | input_dim=hidden_dim, 178 | output_dim=hidden_dim, 179 | add_self=add_self, 180 | normalize_embedding=normalize, 181 | dropout=dropout, 182 | bias=self.bias, 183 | gpu=self.gpu, 184 | att=self.att, 185 | ) 186 | for i in range(num_layers - 2) 187 | ] 188 | ) 189 | conv_last = GraphConv( 190 | input_dim=hidden_dim, 191 | output_dim=embedding_dim, 192 | add_self=add_self, 193 | normalize_embedding=normalize, 194 | bias=self.bias, 195 | gpu=self.gpu, 196 | att=self.att, 197 | ) 198 | return conv_first, conv_block, conv_last 199 | 200 | def build_pred_layers( 201 | self, pred_input_dim, pred_hidden_dims, label_dim, num_aggs=1 202 | ): 203 | pred_input_dim = pred_input_dim * num_aggs 204 | if len(pred_hidden_dims) == 0: 205 | pred_model = nn.Linear(pred_input_dim, label_dim) 206 | else: 207 | pred_layers = [] 208 | for pred_dim in pred_hidden_dims: 209 | pred_layers.append(nn.Linear(pred_input_dim, pred_dim)) 210 | pred_layers.append(self.act) 211 | pred_input_dim = pred_dim 212 | pred_layers.append(nn.Linear(pred_dim, label_dim)) 213 | pred_model = nn.Sequential(*pred_layers) 214 | return pred_model 215 | 216 | def construct_mask(self, max_nodes, batch_num_nodes): 217 | """ For each num_nodes in batch_num_nodes, the first num_nodes entries of the 218 | corresponding column are 1's, and the rest are 0's (to be masked out). 219 | Dimension of mask: [batch_size x max_nodes x 1] 220 | """ 221 | # masks 222 | packed_masks = [torch.ones(int(num)) for num in batch_num_nodes] 223 | batch_size = len(batch_num_nodes) 224 | out_tensor = torch.zeros(batch_size, max_nodes) 225 | for i, mask in enumerate(packed_masks): 226 | out_tensor[i, : batch_num_nodes[i]] = mask 227 | return out_tensor.unsqueeze(2) 228 | 229 | def apply_bn(self, x): 230 | """ Batch normalization of 3D tensor x 231 | """ 232 | bn_module = nn.BatchNorm1d(x.size()[1]) 233 | if self.gpu: 234 | bn_module = bn_module 235 | return bn_module(x) 236 | 237 | def gcn_forward( 238 | self, x, adj, conv_first, conv_block, conv_last, embedding_mask=None 239 | ): 240 | 241 | """ Perform forward prop with graph convolution. 242 | Returns: 243 | Embedding matrix with dimension [batch_size x num_nodes x embedding] 244 | The embedding dim is self.pred_input_dim 245 | """ 246 | 247 | x, adj_att = conv_first(x, adj) 248 | x = self.act(x) 249 | if self.bn: 250 | x = self.apply_bn(x) 251 | x_all = [x] 252 | adj_att_all = [adj_att] 253 | # out_all = [] 254 | # out, _ = torch.max(x, dim=1) 255 | # out_all.append(out) 256 | for i in range(len(conv_block)): 257 | x, _ = conv_block[i](x, adj) 258 | x = self.act(x) 259 | if self.bn: 260 | x = self.apply_bn(x) 261 | x_all.append(x) 262 | adj_att_all.append(adj_att) 263 | x, adj_att = conv_last(x, adj) 264 | x_all.append(x) 265 | adj_att_all.append(adj_att) 266 | # x_tensor: [batch_size x num_nodes x embedding] 267 | x_tensor = torch.cat(x_all, dim=2) 268 | if embedding_mask is not None: 269 | x_tensor = x_tensor * embedding_mask 270 | self.embedding_tensor = x_tensor 271 | 272 | # adj_att_tensor: [batch_size x num_nodes x num_nodes x num_gc_layers] 273 | adj_att_tensor = torch.stack(adj_att_all, dim=3) 274 | return x_tensor, adj_att_tensor 275 | 276 | def forward(self, x, adj, batch_num_nodes=None, **kwargs): 277 | # mask 278 | max_num_nodes = adj.size()[1] 279 | if batch_num_nodes is not None: 280 | self.embedding_mask = self.construct_mask(max_num_nodes, batch_num_nodes) 281 | else: 282 | self.embedding_mask = None 283 | 284 | # conv 285 | x, adj_att = self.conv_first(x, adj) 286 | x = self.act(x) 287 | if self.bn: 288 | x = self.apply_bn(x) 289 | out_all = [] 290 | out, _ = torch.max(x, dim=1) 291 | out_all.append(out) 292 | adj_att_all = [adj_att] 293 | for i in range(self.num_layers - 2): 294 | x, adj_att = self.conv_block[i](x, adj) 295 | x = self.act(x) 296 | if self.bn: 297 | x = self.apply_bn(x) 298 | out, _ = torch.max(x, dim=1) 299 | out_all.append(out) 300 | if self.num_aggs == 2: 301 | out = torch.sum(x, dim=1) 302 | out_all.append(out) 303 | adj_att_all.append(adj_att) 304 | x, adj_att = self.conv_last(x, adj) 305 | adj_att_all.append(adj_att) 306 | # x = self.act(x) 307 | out, _ = torch.max(x, dim=1) 308 | out_all.append(out) 309 | if self.num_aggs == 2: 310 | out = torch.sum(x, dim=1) 311 | out_all.append(out) 312 | if self.concat: 313 | output = torch.cat(out_all, dim=1) 314 | else: 315 | output = out 316 | 317 | # adj_att_tensor: [batch_size x num_nodes x num_nodes x num_gc_layers] 318 | adj_att_tensor = torch.stack(adj_att_all, dim=3) 319 | 320 | self.embedding_tensor = output 321 | ypred = self.pred_model(output) 322 | # print(output.size()) 323 | return ypred, adj_att_tensor 324 | 325 | def final_node_embeddings(self, x, adj, batch_num_nodes=None, **kwargs): 326 | # mask 327 | max_num_nodes = adj.size()[1] 328 | if batch_num_nodes is not None: 329 | self.embedding_mask = self.construct_mask(max_num_nodes, batch_num_nodes) 330 | else: 331 | self.embedding_mask = None 332 | 333 | # conv 334 | x, adj_att = self.conv_first(x, adj) 335 | x = self.act(x) 336 | if self.bn: 337 | x = self.apply_bn(x) 338 | out_all = [] 339 | out, _ = torch.max(x, dim=1) 340 | out_all.append(out) 341 | adj_att_all = [adj_att] 342 | for i in range(self.num_layers - 2): 343 | x, adj_att = self.conv_block[i](x, adj) 344 | x = self.act(x) 345 | if self.bn: 346 | x = self.apply_bn(x) 347 | out, _ = torch.max(x, dim=1) 348 | out_all.append(out) 349 | if self.num_aggs == 2: 350 | out = torch.sum(x, dim=1) 351 | out_all.append(out) 352 | adj_att_all.append(adj_att) 353 | x, adj_att = self.conv_last(x, adj) 354 | return x 355 | 356 | def loss(self, pred, label, type="softmax"): 357 | # softmax + CE 358 | if type == "softmax": 359 | return F.cross_entropy(pred, label, size_average=True) 360 | elif type == "margin": 361 | batch_size = pred.size()[0] 362 | label_onehot = torch.zeros(batch_size, self.label_dim).long() 363 | label_onehot.scatter_(1, label.view(-1, 1), 1) 364 | return torch.nn.MultiLabelMarginLoss()(pred, label_onehot) 365 | 366 | # return F.binary_cross_entropy(F.sigmoid(pred[:,0]), label.float()) 367 | 368 | 369 | class GcnEncoderNode(GcnEncoderGraph): 370 | def __init__( 371 | self, 372 | input_dim, 373 | hidden_dim, 374 | embedding_dim, 375 | label_dim, 376 | num_layers, 377 | pred_hidden_dims=[], 378 | concat=True, 379 | bn=True, 380 | dropout=0.0, 381 | args=None, 382 | ): 383 | super(GcnEncoderNode, self).__init__( 384 | input_dim, 385 | hidden_dim, 386 | embedding_dim, 387 | label_dim, 388 | num_layers, 389 | pred_hidden_dims, 390 | concat, 391 | bn, 392 | dropout, 393 | args=args, 394 | ) 395 | if hasattr(args, "loss_weight"): 396 | print("Loss weight: ", args.loss_weight) 397 | self.celoss = nn.CrossEntropyLoss(weight=args.loss_weight) 398 | else: 399 | self.celoss = nn.CrossEntropyLoss() 400 | 401 | def forward(self, x, adj, batch_num_nodes=None, **kwargs): 402 | # mask 403 | max_num_nodes = adj.size()[1] 404 | if batch_num_nodes is not None: 405 | embedding_mask = self.construct_mask(max_num_nodes, batch_num_nodes) 406 | else: 407 | embedding_mask = None 408 | 409 | self.adj_atts = [] 410 | self.embedding_tensor, adj_att = self.gcn_forward( 411 | x, adj, self.conv_first, self.conv_block, self.conv_last, embedding_mask 412 | ) 413 | pred = self.pred_model(self.embedding_tensor) 414 | return pred, adj_att 415 | 416 | def loss(self, pred, label): 417 | pred = torch.transpose(pred, 1, 2) 418 | return self.celoss(pred, label) 419 | 420 | 421 | class SoftPoolingGcnEncoder(GcnEncoderGraph): 422 | def __init__( 423 | self, 424 | max_num_nodes, 425 | input_dim, 426 | hidden_dim, 427 | embedding_dim, 428 | label_dim, 429 | num_layers, 430 | assign_hidden_dim, 431 | assign_ratio=0.25, 432 | assign_num_layers=-1, 433 | num_pooling=1, 434 | pred_hidden_dims=[50], 435 | concat=True, 436 | bn=True, 437 | dropout=0.0, 438 | linkpred=True, 439 | assign_input_dim=-1, 440 | args=None, 441 | ): 442 | """ 443 | Args: 444 | num_layers: number of gc layers before each pooling 445 | num_nodes: number of nodes for each graph in batch 446 | linkpred: flag to turn on link prediction side objective 447 | """ 448 | 449 | super(SoftPoolingGcnEncoder, self).__init__( 450 | input_dim, 451 | hidden_dim, 452 | embedding_dim, 453 | label_dim, 454 | num_layers, 455 | pred_hidden_dims=pred_hidden_dims, 456 | concat=concat, 457 | args=args, 458 | ) 459 | add_self = not concat 460 | self.num_pooling = num_pooling 461 | self.linkpred = linkpred 462 | self.assign_ent = True 463 | 464 | # GC 465 | self.conv_first_after_pool = [] 466 | self.conv_block_after_pool = [] 467 | self.conv_last_after_pool = [] 468 | for i in range(num_pooling): 469 | # use self to register the modules in self.modules() 470 | self.conv_first2, self.conv_block2, self.conv_last2 = self.build_conv_layers( 471 | self.pred_input_dim, 472 | hidden_dim, 473 | embedding_dim, 474 | num_layers, 475 | add_self, 476 | normalize=True, 477 | dropout=dropout, 478 | ) 479 | self.conv_first_after_pool.append(self.conv_first2) 480 | self.conv_block_after_pool.append(self.conv_block2) 481 | self.conv_last_after_pool.append(self.conv_last2) 482 | 483 | # assignment 484 | assign_dims = [] 485 | if assign_num_layers == -1: 486 | assign_num_layers = num_layers 487 | if assign_input_dim == -1: 488 | assign_input_dim = input_dim 489 | 490 | self.assign_conv_first_modules = [] 491 | self.assign_conv_block_modules = [] 492 | self.assign_conv_last_modules = [] 493 | self.assign_pred_modules = [] 494 | assign_dim = int(max_num_nodes * assign_ratio) 495 | for i in range(num_pooling): 496 | assign_dims.append(assign_dim) 497 | self.assign_conv_first, self.assign_conv_block, self.assign_conv_last = self.build_conv_layers( 498 | assign_input_dim, 499 | assign_hidden_dim, 500 | assign_dim, 501 | assign_num_layers, 502 | add_self, 503 | normalize=True, 504 | ) 505 | assign_pred_input_dim = ( 506 | assign_hidden_dim * (num_layers - 1) + assign_dim 507 | if concat 508 | else assign_dim 509 | ) 510 | self.assign_pred = self.build_pred_layers( 511 | assign_pred_input_dim, [], assign_dim, num_aggs=1 512 | ) 513 | 514 | # next pooling layer 515 | assign_input_dim = embedding_dim 516 | assign_dim = int(assign_dim * assign_ratio) 517 | 518 | self.assign_conv_first_modules.append(self.assign_conv_first) 519 | self.assign_conv_block_modules.append(self.assign_conv_block) 520 | self.assign_conv_last_modules.append(self.assign_conv_last) 521 | self.assign_pred_modules.append(self.assign_pred) 522 | 523 | self.pred_model = self.build_pred_layers( 524 | self.pred_input_dim * (num_pooling + 1), 525 | pred_hidden_dims, 526 | label_dim, 527 | num_aggs=self.num_aggs, 528 | ) 529 | 530 | for m in self.modules(): 531 | if isinstance(m, GraphConv): 532 | m.weight.data = init.xavier_uniform( 533 | m.weight.data, gain=nn.init.calculate_gain("relu") 534 | ) 535 | if m.bias is not None: 536 | m.bias.data = init.constant(m.bias.data, 0.0) 537 | 538 | def forward(self, x, adj, batch_num_nodes, **kwargs): 539 | if "assign_x" in kwargs: 540 | x_a = kwargs["assign_x"] 541 | else: 542 | x_a = x 543 | 544 | # mask 545 | max_num_nodes = adj.size()[1] 546 | if batch_num_nodes is not None: 547 | embedding_mask = self.construct_mask(max_num_nodes, batch_num_nodes) 548 | else: 549 | embedding_mask = None 550 | 551 | out_all = [] 552 | 553 | # self.assign_tensor = self.gcn_forward(x_a, adj, 554 | # self.assign_conv_first_modules[0], self.assign_conv_block_modules[0], self.assign_conv_last_modules[0], 555 | # embedding_mask) 556 | ## [batch_size x num_nodes x next_lvl_num_nodes] 557 | # self.assign_tensor = nn.Softmax(dim=-1)(self.assign_pred(self.assign_tensor)) 558 | # if embedding_mask is not None: 559 | # self.assign_tensor = self.assign_tensor * embedding_mask 560 | # [batch_size x num_nodes x embedding_dim] 561 | embedding_tensor = self.gcn_forward( 562 | x, adj, self.conv_first, self.conv_block, self.conv_last, embedding_mask 563 | ) 564 | 565 | out, _ = torch.max(embedding_tensor, dim=1) 566 | out_all.append(out) 567 | if self.num_aggs == 2: 568 | out = torch.sum(embedding_tensor, dim=1) 569 | out_all.append(out) 570 | 571 | for i in range(self.num_pooling): 572 | if batch_num_nodes is not None and i == 0: 573 | embedding_mask = self.construct_mask(max_num_nodes, batch_num_nodes) 574 | else: 575 | embedding_mask = None 576 | 577 | self.assign_tensor = self.gcn_forward( 578 | x_a, 579 | adj, 580 | self.assign_conv_first_modules[i], 581 | self.assign_conv_block_modules[i], 582 | self.assign_conv_last_modules[i], 583 | embedding_mask, 584 | ) 585 | # [batch_size x num_nodes x next_lvl_num_nodes] 586 | self.assign_tensor = nn.Softmax(dim=-1)( 587 | self.assign_pred(self.assign_tensor) 588 | ) 589 | if embedding_mask is not None: 590 | self.assign_tensor = self.assign_tensor * embedding_mask 591 | 592 | # update pooled features and adj matrix 593 | x = torch.matmul( 594 | torch.transpose(self.assign_tensor, 1, 2), embedding_tensor 595 | ) 596 | adj = torch.transpose(self.assign_tensor, 1, 2) @ adj @ self.assign_tensor 597 | x_a = x 598 | 599 | embedding_tensor = self.gcn_forward( 600 | x, 601 | adj, 602 | self.conv_first_after_pool[i], 603 | self.conv_block_after_pool[i], 604 | self.conv_last_after_pool[i], 605 | ) 606 | 607 | out, _ = torch.max(embedding_tensor, dim=1) 608 | out_all.append(out) 609 | if self.num_aggs == 2: 610 | # out = torch.mean(embedding_tensor, dim=1) 611 | out = torch.sum(embedding_tensor, dim=1) 612 | out_all.append(out) 613 | 614 | if self.concat: 615 | output = torch.cat(out_all, dim=1) 616 | else: 617 | output = out 618 | ypred = self.pred_model(output) 619 | return ypred 620 | 621 | def loss(self, pred, label, adj=None, batch_num_nodes=None, adj_hop=1): 622 | """ 623 | Args: 624 | batch_num_nodes: numpy array of number of nodes in each graph in the minibatch. 625 | """ 626 | eps = 1e-7 627 | loss = super(SoftPoolingGcnEncoder, self).loss(pred, label) 628 | if self.linkpred: 629 | max_num_nodes = adj.size()[1] 630 | pred_adj0 = self.assign_tensor @ torch.transpose(self.assign_tensor, 1, 2) 631 | tmp = pred_adj0 632 | pred_adj = pred_adj0 633 | for adj_pow in range(adj_hop - 1): 634 | tmp = tmp @ pred_adj0 635 | pred_adj = pred_adj + tmp 636 | pred_adj = torch.min(pred_adj, torch.Tensor(1)) 637 | # print('adj1', torch.sum(pred_adj0) / torch.numel(pred_adj0)) 638 | # print('adj2', torch.sum(pred_adj) / torch.numel(pred_adj)) 639 | # self.link_loss = F.nll_loss(torch.log(pred_adj), adj) 640 | self.link_loss = -adj * torch.log(pred_adj + eps) - (1 - adj) * torch.log( 641 | 1 - pred_adj + eps 642 | ) 643 | if batch_num_nodes is None: 644 | num_entries = max_num_nodes * max_num_nodes * adj.size()[0] 645 | print("Warning: calculating link pred loss without masking") 646 | else: 647 | num_entries = np.sum(batch_num_nodes * batch_num_nodes) 648 | embedding_mask = self.construct_mask(max_num_nodes, batch_num_nodes) 649 | adj_mask = embedding_mask @ torch.transpose(embedding_mask, 1, 2) 650 | self.link_loss[1 - adj_mask.byte()] = 0.0 651 | 652 | self.link_loss = torch.sum(self.link_loss) / float(num_entries) 653 | # print('linkloss: ', self.link_loss) 654 | return loss + self.link_loss 655 | return loss 656 | -------------------------------------------------------------------------------- /gnnexplainer_utils/io_utils.py: -------------------------------------------------------------------------------- 1 | """ io_utils.py 2 | 3 | Utilities for reading and writing logs. 4 | """ 5 | import os 6 | import statistics 7 | import re 8 | import csv 9 | 10 | import numpy as np 11 | import pandas as pd 12 | import scipy as sc 13 | 14 | 15 | import matplotlib 16 | import matplotlib.pyplot as plt 17 | 18 | import numpy as np 19 | import torch 20 | import networkx as nx 21 | import tensorboardX 22 | 23 | import cv2 24 | 25 | import torch 26 | import torch.nn as nn 27 | from torch.autograd import Variable 28 | 29 | # Only necessary to rebuild the Chemistry example 30 | # from rdkit import Chem 31 | 32 | import gnnexplainer_utils.featgen as featgen 33 | 34 | use_cuda = torch.cuda.is_available() 35 | 36 | 37 | def gen_prefix(args): 38 | '''Generate label prefix for a graph model. 39 | ''' 40 | if args.bmname is not None: 41 | name = args.bmname 42 | else: 43 | name = args.dataset 44 | name += "_" + args.method 45 | 46 | name += "_h" + str(args.hidden_dim) + "_o" + str(args.output_dim) 47 | if not args.bias: 48 | name += "_nobias" 49 | if len(args.name_suffix) > 0: 50 | name += "_" + args.name_suffix 51 | return name 52 | 53 | 54 | def gen_explainer_prefix(args): 55 | '''Generate label prefix for a graph explainer model. 56 | ''' 57 | name = gen_prefix(args) + "_explain" 58 | if len(args.explainer_suffix) > 0: 59 | name += "_" + args.explainer_suffix 60 | return name 61 | 62 | 63 | def create_filename(save_dir, args, isbest=False, num_epochs=-1): 64 | """ 65 | Args: 66 | args : the arguments parsed in the parser 67 | isbest : whether the saved model is the best-performing one 68 | num_epochs : epoch number of the model (when isbest=False) 69 | """ 70 | filename = os.path.join(save_dir, gen_prefix(args)) 71 | os.makedirs(filename, exist_ok=True) 72 | 73 | if isbest: 74 | filename = os.path.join(filename, "best") 75 | elif num_epochs > 0: 76 | filename = os.path.join(filename, str(num_epochs)) 77 | 78 | return filename + ".pth.tar" 79 | 80 | 81 | def save_checkpoint(model, optimizer, args, num_epochs=-1, isbest=False, cg_dict=None): 82 | """Save pytorch model checkpoint. 83 | 84 | Args: 85 | - model : The PyTorch model to save. 86 | - optimizer : The optimizer used to train the model. 87 | - args : A dict of meta-data about the model. 88 | - num_epochs : Number of training epochs. 89 | - isbest : True if the model has the highest accuracy so far. 90 | - cg_dict : A dictionary of the sampled computation graphs. 91 | """ 92 | filename = create_filename(args.ckptdir, args, isbest, num_epochs=num_epochs) 93 | torch.save( 94 | { 95 | "epoch": num_epochs, 96 | "model_type": args.method, 97 | "optimizer": optimizer, 98 | "model_state": model.state_dict(), 99 | "optimizer_state": optimizer.state_dict(), 100 | "cg": cg_dict, 101 | }, 102 | filename, 103 | ) 104 | 105 | 106 | def load_ckpt(args, isbest=False): 107 | '''Load a pre-trained pytorch model from checkpoint. 108 | ''' 109 | print("loading model") 110 | filename = create_filename(args.ckptdir, args, isbest) 111 | print(filename) 112 | if os.path.isfile(filename): 113 | print("=> loading checkpoint '{}'".format(filename)) 114 | ckpt = torch.load(filename) 115 | else: 116 | print("Checkpoint does not exist!") 117 | print("Checked path -- {}".format(filename)) 118 | print("Make sure you have provided the correct path!") 119 | print("You may have forgotten to train a model for this dataset.") 120 | print() 121 | print("To train one of the paper's models, run the following") 122 | print(">> python train.py --dataset=DATASET_NAME") 123 | print() 124 | raise Exception("File not found.") 125 | return ckpt 126 | 127 | def preprocess_cg(cg): 128 | """Pre-process computation graph.""" 129 | if use_cuda: 130 | preprocessed_cg_tensor = torch.from_numpy(cg).cuda() 131 | else: 132 | preprocessed_cg_tensor = torch.from_numpy(cg) 133 | 134 | preprocessed_cg_tensor.unsqueeze_(0) 135 | return Variable(preprocessed_cg_tensor, requires_grad=False) 136 | 137 | def load_model(path): 138 | """Load a pytorch model.""" 139 | model = torch.load(path) 140 | model.eval() 141 | if use_cuda: 142 | model.cuda() 143 | 144 | for p in model.features.parameters(): 145 | p.requires_grad = False 146 | for p in model.classifier.parameters(): 147 | p.requires_grad = False 148 | 149 | return model 150 | 151 | 152 | def load_cg(path): 153 | """Load a computation graph.""" 154 | cg = pickle.load(open(path)) 155 | return cg 156 | 157 | 158 | def save(mask_cg): 159 | """Save a rendering of the computation graph mask.""" 160 | mask = mask_cg.cpu().data.numpy()[0] 161 | mask = np.transpose(mask, (1, 2, 0)) 162 | 163 | mask = (mask - np.min(mask)) / np.max(mask) 164 | mask = 1 - mask 165 | 166 | cv2.imwrite("mask.png", np.uint8(255 * mask)) 167 | 168 | def log_matrix(writer, mat, name, epoch, fig_size=(8, 6), dpi=200): 169 | """Save an image of a matrix to disk. 170 | 171 | Args: 172 | - writer : A file writer. 173 | - mat : The matrix to write. 174 | - name : Name of the file to save. 175 | - epoch : Epoch number. 176 | - fig_size : Size to of the figure to save. 177 | - dpi : Resolution. 178 | """ 179 | plt.switch_backend("agg") 180 | fig = plt.figure(figsize=fig_size, dpi=dpi) 181 | mat = mat.cpu().detach().numpy() 182 | if mat.ndim == 1: 183 | mat = mat[:, np.newaxis] 184 | plt.imshow(mat, cmap=plt.get_cmap("BuPu")) 185 | cbar = plt.colorbar() 186 | cbar.solids.set_edgecolor("face") 187 | 188 | plt.tight_layout() 189 | fig.canvas.draw() 190 | writer.add_image(name, tensorboardX.utils.figure_to_image(fig), epoch) 191 | 192 | 193 | def denoise_graph(adj, node_idx, feat=None, label=None, threshold=None, threshold_num=None, max_component=True): 194 | """Cleaning a graph by thresholding its node values. 195 | 196 | Args: 197 | - adj : Adjacency matrix. 198 | - node_idx : Index of node to highlight (TODO ?) 199 | - feat : An array of node features. 200 | - label : A list of node labels. 201 | - threshold : The weight threshold. 202 | - theshold_num : The maximum number of nodes to threshold. 203 | - max_component : TODO 204 | """ 205 | num_nodes = adj.shape[-1] 206 | G = nx.Graph() 207 | G.add_nodes_from(range(num_nodes)) 208 | G.nodes[node_idx]["self"] = 1 209 | if feat is not None: 210 | for node in G.nodes(): 211 | G.nodes[node]["feat"] = feat[node] 212 | if label is not None: 213 | for node in G.nodes(): 214 | G.nodes[node]["label"] = label[node] 215 | 216 | if threshold_num is not None: 217 | # this is for symmetric graphs: edges are repeated twice in adj 218 | adj_threshold_num = threshold_num * 2 219 | #adj += np.random.rand(adj.shape[0], adj.shape[1]) * 1e-4 220 | neigh_size = len(adj[adj > 0]) 221 | threshold_num = min(neigh_size, adj_threshold_num) 222 | threshold = np.sort(adj[adj > 0])[-threshold_num] 223 | 224 | if threshold is not None: 225 | weighted_edge_list = [ 226 | (i, j, adj[i, j]) 227 | for i in range(num_nodes) 228 | for j in range(num_nodes) 229 | if adj[i, j] >= threshold 230 | ] 231 | else: 232 | weighted_edge_list = [ 233 | (i, j, adj[i, j]) 234 | for i in range(num_nodes) 235 | for j in range(num_nodes) 236 | if adj[i, j] > 1e-6 237 | ] 238 | G.add_weighted_edges_from(weighted_edge_list) 239 | if max_component: 240 | largest_cc = max(nx.connected_components(G), key=len) 241 | G = G.subgraph(largest_cc).copy() 242 | else: 243 | # remove zero degree nodes 244 | G.remove_nodes_from(list(nx.isolates(G))) 245 | return G 246 | 247 | # TODO: unify log_graph and log_graph2 248 | def log_graph( 249 | writer, 250 | Gc, 251 | name, 252 | identify_self=True, 253 | nodecolor="label", 254 | epoch=0, 255 | fig_size=(4, 3), 256 | dpi=300, 257 | label_node_feat=False, 258 | edge_vmax=None, 259 | args=None, 260 | ): 261 | """ 262 | Args: 263 | nodecolor: the color of node, can be determined by 'label', or 'feat'. For feat, it needs to 264 | be one-hot' 265 | """ 266 | cmap = plt.get_cmap("Set1") 267 | plt.switch_backend("agg") 268 | fig = plt.figure(figsize=fig_size, dpi=dpi) 269 | 270 | node_colors = [] 271 | # edge_colors = [min(max(w, 0.0), 1.0) for (u,v,w) in Gc.edges.data('weight', default=1)] 272 | edge_colors = [w for (u, v, w) in Gc.edges.data("weight", default=1)] 273 | 274 | # maximum value for node color 275 | vmax = 8 276 | for i in Gc.nodes(): 277 | if nodecolor == "feat" and "feat" in Gc.nodes[i]: 278 | num_classes = Gc.nodes[i]["feat"].size()[0] 279 | if num_classes >= 10: 280 | cmap = plt.get_cmap("tab20") 281 | vmax = 19 282 | elif num_classes >= 8: 283 | cmap = plt.get_cmap("tab10") 284 | vmax = 9 285 | break 286 | 287 | feat_labels = {} 288 | for i in Gc.nodes(): 289 | if identify_self and "self" in Gc.nodes[i]: 290 | node_colors.append(0) 291 | elif nodecolor == "label" and "label" in Gc.nodes[i]: 292 | node_colors.append(Gc.nodes[i]["label"] + 1) 293 | elif nodecolor == "feat" and "feat" in Gc.nodes[i]: 294 | # print(Gc.nodes[i]['feat']) 295 | feat = Gc.nodes[i]["feat"].detach().numpy() 296 | # idx with pos val in 1D array 297 | feat_class = 0 298 | for j in range(len(feat)): 299 | if feat[j] == 1: 300 | feat_class = j 301 | break 302 | node_colors.append(feat_class) 303 | feat_labels[i] = feat_class 304 | else: 305 | node_colors.append(1) 306 | if not label_node_feat: 307 | feat_labels = None 308 | 309 | plt.switch_backend("agg") 310 | fig = plt.figure(figsize=fig_size, dpi=dpi) 311 | 312 | if Gc.number_of_nodes() == 0: 313 | raise Exception("empty graph") 314 | if Gc.number_of_edges() == 0: 315 | raise Exception("empty edge") 316 | # remove_nodes = [] 317 | # for u in Gc.nodes(): 318 | # if Gc 319 | pos_layout = nx.kamada_kawai_layout(Gc, weight=None) 320 | # pos_layout = nx.spring_layout(Gc, weight=None) 321 | 322 | weights = [d for (u, v, d) in Gc.edges(data="weight", default=1)] 323 | if edge_vmax is None: 324 | edge_vmax = statistics.median_high( 325 | [d for (u, v, d) in Gc.edges(data="weight", default=1)] 326 | ) 327 | min_color = min([d for (u, v, d) in Gc.edges(data="weight", default=1)]) 328 | # color range: gray to black 329 | edge_vmin = 2 * min_color - edge_vmax 330 | nx.draw( 331 | Gc, 332 | pos=pos_layout, 333 | with_labels=False, 334 | font_size=4, 335 | labels=feat_labels, 336 | node_color=node_colors, 337 | vmin=0, 338 | vmax=vmax, 339 | cmap=cmap, 340 | edge_color=edge_colors, 341 | edge_cmap=plt.get_cmap("Greys"), 342 | edge_vmin=edge_vmin, 343 | edge_vmax=edge_vmax, 344 | width=1.0, 345 | node_size=50, 346 | alpha=0.8, 347 | ) 348 | fig.axes[0].xaxis.set_visible(False) 349 | fig.canvas.draw() 350 | 351 | if args is None: 352 | save_path = os.path.join("log/", name + ".pdf") 353 | else: 354 | save_path = os.path.join( 355 | "log", name + gen_explainer_prefix(args) + "_" + str(epoch) + ".pdf" 356 | ) 357 | print("log/" + name + gen_explainer_prefix(args) + "_" + str(epoch) + ".pdf") 358 | os.makedirs(os.path.dirname(save_path), exist_ok=True) 359 | plt.savefig(save_path, format="pdf") 360 | 361 | img = tensorboardX.utils.figure_to_image(fig) 362 | writer.add_image(name, img, epoch) 363 | 364 | 365 | def plot_cmap(cmap, ncolor): 366 | """ 367 | A convenient function to plot colors of a matplotlib cmap 368 | Credit goes to http://gvallver.perso.univ-pau.fr/?p=712 369 | 370 | Args: 371 | ncolor (int): number of color to show 372 | cmap: a cmap object or a matplotlib color name 373 | """ 374 | 375 | if isinstance(cmap, str): 376 | name = cmap 377 | try: 378 | cm = plt.get_cmap(cmap) 379 | except ValueError: 380 | print("WARNINGS :", cmap, " is not a known colormap") 381 | cm = plt.cm.gray 382 | else: 383 | cm = cmap 384 | name = cm.name 385 | 386 | with matplotlib.rc_context(matplotlib.rcParamsDefault): 387 | fig = plt.figure(figsize=(12, 1), frameon=False) 388 | ax = fig.add_subplot(111) 389 | ax.pcolor(np.linspace(1, ncolor, ncolor).reshape(1, ncolor), cmap=cm) 390 | ax.set_title(name) 391 | xt = ax.set_xticks([]) 392 | yt = ax.set_yticks([]) 393 | return fig 394 | 395 | 396 | def plot_cmap_tb(writer, cmap, ncolor, name): 397 | """Plot the color map used for plot.""" 398 | fig = plot_cmap(cmap, ncolor) 399 | img = tensorboardX.utils.figure_to_image(fig) 400 | writer.add_image(name, img, 0) 401 | 402 | 403 | def sparse_mx_to_torch_sparse_tensor(sparse_mx): 404 | """Convert a scipy sparse matrix to a torch sparse tensor.""" 405 | sparse_mx = sparse_mx.tocoo().astype(np.float32) 406 | indices = torch.from_numpy( 407 | np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64) 408 | ) 409 | values = torch.from_numpy(sparse_mx.data) 410 | shape = torch.Size(sparse_mx.shape) 411 | return torch.sparse.FloatTensor(indices, values, shape) 412 | 413 | def numpy_to_torch(img, requires_grad=True): 414 | if len(img.shape) < 3: 415 | output = np.float32([img]) 416 | else: 417 | output = np.transpose(img, (2, 0, 1)) 418 | 419 | output = torch.from_numpy(output) 420 | if use_cuda: 421 | output = output.cuda() 422 | 423 | output.unsqueeze_(0) 424 | v = Variable(output, requires_grad=requires_grad) 425 | return v 426 | 427 | 428 | def read_graphfile(datadir, dataname, max_nodes=None, edge_labels=False): 429 | """ Read data from https://ls11-www.cs.tu-dortmund.de/staff/morris/graphkerneldatasets 430 | graph index starts with 1 in file 431 | 432 | Returns: 433 | List of networkx objects with graph and node labels 434 | """ 435 | prefix = os.path.join(datadir, dataname, dataname) 436 | filename_graph_indic = prefix + "_graph_indicator.txt" 437 | # index of graphs that a given node belongs to 438 | graph_indic = {} 439 | with open(filename_graph_indic) as f: 440 | i = 1 441 | for line in f: 442 | line = line.strip("\n") 443 | graph_indic[i] = int(line) 444 | i += 1 445 | 446 | filename_nodes = prefix + "_node_labels.txt" 447 | node_labels = [] 448 | min_label_val = None 449 | try: 450 | with open(filename_nodes) as f: 451 | for line in f: 452 | line = line.strip("\n") 453 | l = int(line) 454 | node_labels += [l] 455 | if min_label_val is None or min_label_val > l: 456 | min_label_val = l 457 | # assume that node labels are consecutive 458 | num_unique_node_labels = max(node_labels) - min_label_val + 1 459 | node_labels = [l - min_label_val for l in node_labels] 460 | except IOError: 461 | print("No node labels") 462 | 463 | filename_node_attrs = prefix + "_node_attributes.txt" 464 | node_attrs = [] 465 | try: 466 | with open(filename_node_attrs) as f: 467 | for line in f: 468 | line = line.strip("\s\n") 469 | attrs = [ 470 | float(attr) for attr in re.split("[,\s]+", line) if not attr == "" 471 | ] 472 | node_attrs.append(np.array(attrs)) 473 | except IOError: 474 | print("No node attributes") 475 | 476 | label_has_zero = False 477 | filename_graphs = prefix + "_graph_labels.txt" 478 | graph_labels = [] 479 | 480 | label_vals = [] 481 | with open(filename_graphs) as f: 482 | for line in f: 483 | line = line.strip("\n") 484 | val = int(line) 485 | if val not in label_vals: 486 | label_vals.append(val) 487 | graph_labels.append(val) 488 | 489 | label_map_to_int = {val: i for i, val in enumerate(label_vals)} 490 | graph_labels = np.array([label_map_to_int[l] for l in graph_labels]) 491 | 492 | if edge_labels: 493 | # For Tox21_AHR we want to know edge labels 494 | filename_edges = prefix + "_edge_labels.txt" 495 | edge_labels = [] 496 | 497 | edge_label_vals = [] 498 | with open(filename_edges) as f: 499 | for line in f: 500 | line = line.strip("\n") 501 | val = int(line) 502 | if val not in edge_label_vals: 503 | edge_label_vals.append(val) 504 | edge_labels.append(val) 505 | 506 | edge_label_map_to_int = {val: i for i, val in enumerate(edge_label_vals)} 507 | 508 | filename_adj = prefix + "_A.txt" 509 | adj_list = {i: [] for i in range(1, len(graph_labels) + 1)} 510 | # edge_label_list={i:[] for i in range(1,len(graph_labels)+1)} 511 | index_graph = {i: [] for i in range(1, len(graph_labels) + 1)} 512 | num_edges = 0 513 | with open(filename_adj) as f: 514 | for line in f: 515 | line = line.strip("\n").split(",") 516 | e0, e1 = (int(line[0].strip(" ")), int(line[1].strip(" "))) 517 | adj_list[graph_indic[e0]].append((e0, e1)) 518 | index_graph[graph_indic[e0]] += [e0, e1] 519 | # edge_label_list[graph_indic[e0]].append(edge_labels[num_edges]) 520 | num_edges += 1 521 | for k in index_graph.keys(): 522 | index_graph[k] = [u - 1 for u in set(index_graph[k])] 523 | 524 | graphs = [] 525 | for i in range(1, 1 + len(adj_list)): 526 | # indexed from 1 here 527 | G = nx.from_edgelist(adj_list[i]) 528 | G_idx = graph_indic[adj_list[i][0][0]] 529 | 530 | if max_nodes is not None and G.number_of_nodes() > max_nodes: 531 | continue 532 | 533 | # add features and labels 534 | G.graph["label"] = graph_labels[i - 1] 535 | G.graph["idx"] = G_idx 536 | 537 | # Special label for aromaticity experiment 538 | # aromatic_edge = 2 539 | # G.graph['aromatic'] = aromatic_edge in edge_label_list[i] 540 | 541 | for u in G.nodes(): 542 | if len(node_labels) > 0: 543 | node_label_one_hot = [0] * num_unique_node_labels 544 | node_label = node_labels[u - 1] 545 | node_label_one_hot[node_label] = 1 546 | G.nodes[u]["label"] = node_label_one_hot 547 | if len(node_attrs) > 0: 548 | G.nodes[u]["feat"] = node_attrs[u - 1] 549 | if len(node_attrs) > 0: 550 | G.graph["feat_dim"] = node_attrs[0].shape[0] 551 | 552 | # relabeling 553 | mapping = {} 554 | min_id = min(G.nodes()) 555 | # if len(graphs)==2933: 556 | # print('z') 557 | # if float(nx.__version__) < 2.0: 558 | # for n in G.nodes(): 559 | # mapping[n] = it 560 | # it += 1 561 | # else: 562 | # for n in G.nodes: 563 | # mapping[n] = it 564 | # it += 1 565 | 566 | # indexed from 0 567 | graphs.append(nx.relabel_nodes(G, lambda x: x - min_id)) 568 | return graphs 569 | 570 | 571 | def read_biosnap(datadir, edgelist_file, label_file, feat_file=None, concat=True): 572 | """ Read data from BioSnap 573 | 574 | Returns: 575 | List of networkx objects with graph and node labels 576 | """ 577 | G = nx.Graph() 578 | delimiter = "\t" if "tsv" in edgelist_file else "," 579 | print(delimiter) 580 | df = pd.read_csv( 581 | os.path.join(datadir, edgelist_file), delimiter=delimiter, header=None 582 | ) 583 | data = list(map(tuple, df.values.tolist())) 584 | G.add_edges_from(data) 585 | print("Total nodes: ", G.number_of_nodes()) 586 | 587 | G = max(nx.connected_component_subgraphs(G), key=len) 588 | print("Total nodes in largest connected component: ", G.number_of_nodes()) 589 | 590 | df = pd.read_csv(os.path.join(datadir, label_file), delimiter="\t", usecols=[0, 1]) 591 | data = list(map(tuple, df.values.tolist())) 592 | 593 | missing_node = 0 594 | for line in data: 595 | if int(line[0]) not in G: 596 | missing_node += 1 597 | else: 598 | G.nodes[int(line[0])]["label"] = int(line[1] == "Essential") 599 | 600 | print("missing node: ", missing_node) 601 | 602 | missing_label = 0 603 | remove_nodes = [] 604 | for u in G.nodes(): 605 | if "label" not in G.nodes[u]: 606 | missing_label += 1 607 | remove_nodes.append(u) 608 | G.remove_nodes_from(remove_nodes) 609 | print("missing_label: ", missing_label) 610 | 611 | if feat_file is None: 612 | feature_generator = featgen.ConstFeatureGen(np.ones(10, dtype=float)) 613 | feature_generator.gen_node_features(G) 614 | else: 615 | df = pd.read_csv(os.path.join(datadir, feat_file), delimiter=",") 616 | data = np.array(df.values) 617 | print("Feat shape: ", data.shape) 618 | 619 | for row in data: 620 | if int(row[0]) in G: 621 | if concat: 622 | node = int(row[0]) 623 | onehot = np.zeros(10) 624 | onehot[min(G.degree[node], 10) - 1] = 1.0 625 | G.nodes[node]["feat"] = np.hstack( 626 | (np.log(row[1:] + 0.1), [1.0], onehot) 627 | ) 628 | else: 629 | G.nodes[int(row[0])]["feat"] = np.log(row[1:] + 0.1) 630 | 631 | missing_feat = 0 632 | remove_nodes = [] 633 | for u in G.nodes(): 634 | if "feat" not in G.nodes[u]: 635 | missing_feat += 1 636 | remove_nodes.append(u) 637 | G.remove_nodes_from(remove_nodes) 638 | print("missing feat: ", missing_feat) 639 | 640 | return G 641 | 642 | 643 | def build_aromaticity_dataset(): 644 | filename = "data/tox21_10k_data_all.sdf" 645 | basename = filename.split(".")[0] 646 | collector = [] 647 | sdprovider = Chem.SDMolSupplier(filename) 648 | for i,mol in enumerate(sdprovider): 649 | try: 650 | moldict = {} 651 | moldict['smiles'] = Chem.MolToSmiles(mol) 652 | #Parse Data 653 | for propname in mol.GetPropNames(): 654 | moldict[propname] = mol.GetProp(propname) 655 | nb_bonds = len(mol.GetBonds()) 656 | is_aromatic = False; aromatic_bonds = [] 657 | for j in range(nb_bonds): 658 | if mol.GetBondWithIdx(j).GetIsAromatic(): 659 | aromatic_bonds.append(j) 660 | is_aromatic = True 661 | moldict['aromaticity'] = is_aromatic 662 | moldict['aromatic_bonds'] = aromatic_bonds 663 | collector.append(moldict) 664 | except: 665 | print("Molecule %s failed"%i) 666 | data = pd.DataFrame(collector) 667 | data.to_csv(basename + '_pandas.csv') 668 | 669 | 670 | def gen_train_plt_name(args): 671 | return "results/" + gen_prefix(args) + ".png" 672 | 673 | 674 | def log_assignment(assign_tensor, writer, epoch, batch_idx): 675 | plt.switch_backend("agg") 676 | fig = plt.figure(figsize=(8, 6), dpi=300) 677 | 678 | # has to be smaller than args.batch_size 679 | for i in range(len(batch_idx)): 680 | plt.subplot(2, 2, i + 1) 681 | plt.imshow( 682 | assign_tensor.cpu().data.numpy()[batch_idx[i]], cmap=plt.get_cmap("BuPu") 683 | ) 684 | cbar = plt.colorbar() 685 | cbar.solids.set_edgecolor("face") 686 | plt.tight_layout() 687 | fig.canvas.draw() 688 | 689 | data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="") 690 | data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) 691 | writer.add_image("assignment", data, epoch) 692 | 693 | # TODO: unify log_graph and log_graph2 694 | def log_graph2(adj, batch_num_nodes, writer, epoch, batch_idx, assign_tensor=None): 695 | plt.switch_backend("agg") 696 | fig = plt.figure(figsize=(8, 6), dpi=300) 697 | 698 | for i in range(len(batch_idx)): 699 | ax = plt.subplot(2, 2, i + 1) 700 | num_nodes = batch_num_nodes[batch_idx[i]] 701 | adj_matrix = adj[batch_idx[i], :num_nodes, :num_nodes].cpu().data.numpy() 702 | G = nx.from_numpy_matrix(adj_matrix) 703 | nx.draw( 704 | G, 705 | pos=nx.spring_layout(G), 706 | with_labels=True, 707 | node_color="#336699", 708 | edge_color="grey", 709 | width=0.5, 710 | node_size=300, 711 | alpha=0.7, 712 | ) 713 | ax.xaxis.set_visible(False) 714 | 715 | plt.tight_layout() 716 | fig.canvas.draw() 717 | 718 | data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="") 719 | data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) 720 | writer.add_image("graphs", data, epoch) 721 | 722 | # log a label-less version 723 | # fig = plt.figure(figsize=(8,6), dpi=300) 724 | # for i in range(len(batch_idx)): 725 | # ax = plt.subplot(2, 2, i+1) 726 | # num_nodes = batch_num_nodes[batch_idx[i]] 727 | # adj_matrix = adj[batch_idx[i], :num_nodes, :num_nodes].cpu().data.numpy() 728 | # G = nx.from_numpy_matrix(adj_matrix) 729 | # nx.draw(G, pos=nx.spring_layout(G), with_labels=False, node_color='#336699', 730 | # edge_color='grey', width=0.5, node_size=25, 731 | # alpha=0.8) 732 | 733 | # plt.tight_layout() 734 | # fig.canvas.draw() 735 | 736 | # data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='') 737 | # data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) 738 | # writer.add_image('graphs_no_label', data, epoch) 739 | 740 | # colored according to assignment 741 | assignment = assign_tensor.cpu().data.numpy() 742 | fig = plt.figure(figsize=(8, 6), dpi=300) 743 | 744 | num_clusters = assignment.shape[2] 745 | all_colors = np.array(range(num_clusters)) 746 | 747 | for i in range(len(batch_idx)): 748 | ax = plt.subplot(2, 2, i + 1) 749 | num_nodes = batch_num_nodes[batch_idx[i]] 750 | adj_matrix = adj[batch_idx[i], :num_nodes, :num_nodes].cpu().data.numpy() 751 | 752 | label = np.argmax(assignment[batch_idx[i]], axis=1).astype(int) 753 | label = label[: batch_num_nodes[batch_idx[i]]] 754 | node_colors = all_colors[label] 755 | 756 | G = nx.from_numpy_matrix(adj_matrix) 757 | nx.draw( 758 | G, 759 | pos=nx.spring_layout(G), 760 | with_labels=False, 761 | node_color=node_colors, 762 | edge_color="grey", 763 | width=0.4, 764 | node_size=50, 765 | cmap=plt.get_cmap("Set1"), 766 | vmin=0, 767 | vmax=num_clusters - 1, 768 | alpha=0.8, 769 | ) 770 | 771 | plt.tight_layout() 772 | fig.canvas.draw() 773 | 774 | data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="") 775 | data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) 776 | writer.add_image("graphs_colored", data, epoch) 777 | -------------------------------------------------------------------------------- /gnnexplainer/explain.py: -------------------------------------------------------------------------------- 1 | """ explain.py 2 | 3 | Implementation of the explainer. 4 | """ 5 | 6 | import math 7 | import time 8 | import os 9 | 10 | import matplotlib 11 | import matplotlib.colors as colors 12 | import matplotlib.pyplot as plt 13 | from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas 14 | from matplotlib.figure import Figure 15 | 16 | import networkx as nx 17 | import numpy as np 18 | import pandas as pd 19 | import seaborn as sns 20 | import tensorboardX.utils 21 | 22 | import torch 23 | import torch.nn as nn 24 | from torch.autograd import Variable 25 | 26 | import sklearn.metrics as metrics 27 | from sklearn.metrics import roc_auc_score, recall_score, precision_score, roc_auc_score, precision_recall_curve 28 | from sklearn.cluster import DBSCAN 29 | 30 | import pdb 31 | 32 | import gnnexplainer_utils.io_utils as io_utils 33 | import gnnexplainer_utils.train_utils as train_utils 34 | import gnnexplainer_utils.graph_utils as graph_utils 35 | 36 | 37 | use_cuda = torch.cuda.is_available() 38 | FloatTensor = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor 39 | LongTensor = torch.cuda.LongTensor if use_cuda else torch.LongTensor 40 | Tensor = FloatTensor 41 | 42 | class Explainer: 43 | def __init__( 44 | self, 45 | model, 46 | adj, 47 | feat, 48 | label, 49 | pred, 50 | train_idx, 51 | args, 52 | writer=None, 53 | print_training=True, 54 | graph_mode=False, 55 | graph_idx=False, 56 | ): 57 | self.model = model 58 | self.model.eval() 59 | self.adj = adj 60 | self.feat = feat 61 | self.label = label 62 | self.pred = pred 63 | self.train_idx = train_idx 64 | self.n_hops = args.num_gc_layers 65 | self.graph_mode = graph_mode 66 | self.graph_idx = graph_idx 67 | self.neighborhoods = None if self.graph_mode else graph_utils.neighborhoods(adj=self.adj, n_hops=self.n_hops, use_cuda=use_cuda) 68 | self.args = args 69 | self.writer = writer 70 | self.print_training = print_training 71 | 72 | 73 | # Main method 74 | def explain( 75 | self, node_idx, graph_idx=0, graph_mode=False, unconstrained=False, model="exp",original_idx = None 76 | ): 77 | """Explain a single node prediction 78 | """ 79 | # index of the query node in the new adj 80 | if graph_mode: 81 | node_idx_new = node_idx 82 | sub_adj = self.adj[graph_idx] 83 | sub_feat = self.feat[graph_idx, :] 84 | sub_label = self.label[graph_idx] 85 | neighbors = np.asarray(range(self.adj.shape[0])) 86 | else: 87 | print("node label: ", self.label[graph_idx][node_idx]) 88 | node_idx_new, sub_adj, sub_feat, sub_label, neighbors = self.extract_neighborhood( 89 | node_idx, graph_idx 90 | ) 91 | print("neigh graph idx: ", node_idx, node_idx_new) 92 | sub_label = np.expand_dims(sub_label, axis=0) 93 | 94 | sub_adj = np.expand_dims(sub_adj, axis=0) 95 | sub_feat = np.expand_dims(sub_feat, axis=0) 96 | 97 | adj = torch.tensor(sub_adj, dtype=torch.float) 98 | x = torch.tensor(sub_feat, requires_grad=True, dtype=torch.float) 99 | label = torch.tensor(sub_label, dtype=torch.long) 100 | 101 | if self.graph_mode: 102 | pred_label = np.argmax(self.pred[0][graph_idx], axis=0) 103 | print("Graph predicted label: ", pred_label) 104 | else: 105 | pred_label = np.argmax(self.pred[graph_idx][neighbors], axis=1) 106 | print("Node predicted label: ", pred_label[node_idx_new]) 107 | 108 | # with open('sidelog/%d.%d.adj.npy' % (original_idx, pred_label), 'wb') as outfile: 109 | # np.save(outfile, np.asarray(sub_adj.copy())) 110 | final_node_embeddings = self.model.final_node_embeddings(x, adj).detach().numpy()[0] 111 | with open('embeddings-%s/%d.%d.embs.npy' % (self.args.bmname, original_idx, pred_label), 'wb') as outfile: 112 | np.save(outfile, np.asarray(final_node_embeddings.copy())) 113 | 114 | 115 | explainer = ExplainModule( 116 | adj=adj, 117 | x=x, 118 | model=self.model, 119 | label=label, 120 | args=self.args, 121 | writer=self.writer, 122 | graph_idx=self.graph_idx, 123 | graph_mode=self.graph_mode, 124 | ) 125 | if self.args.gpu: 126 | explainer = explainer.cuda() 127 | 128 | self.model.eval() 129 | 130 | 131 | # gradient baseline 132 | if model == "grad": 133 | explainer.zero_grad() 134 | # pdb.set_trace() 135 | adj_grad = torch.abs( 136 | explainer.adj_feat_grad(node_idx_new, pred_label[node_idx_new])[0] 137 | )[graph_idx] 138 | masked_adj = adj_grad + adj_grad.t() 139 | masked_adj = nn.functional.sigmoid(masked_adj) 140 | masked_adj = masked_adj.cpu().detach().numpy() * sub_adj.squeeze() 141 | else: 142 | explainer.train() 143 | begin_time = time.time() 144 | for epoch in range(self.args.num_epochs+1): 145 | explainer.zero_grad() 146 | explainer.optimizer.zero_grad() 147 | ypred, adj_atts = explainer(node_idx_new, unconstrained=unconstrained) 148 | loss = explainer.loss(ypred, pred_label, node_idx_new, epoch) 149 | loss.backward() 150 | 151 | explainer.optimizer.step() 152 | if explainer.scheduler is not None: 153 | explainer.scheduler.step() 154 | 155 | mask_density = explainer.mask_density() 156 | if self.print_training: 157 | print( 158 | "epoch: ", 159 | epoch, 160 | "; loss: ", 161 | loss.item(), 162 | "; mask density: ", 163 | mask_density.item(), 164 | "; pred: ", 165 | ypred, 166 | ) 167 | single_subgraph_label = sub_label.squeeze() 168 | 169 | if self.writer is not None: 170 | self.writer.add_scalar("mask/density", mask_density, epoch) 171 | self.writer.add_scalar( 172 | "optimization/lr", 173 | explainer.optimizer.param_groups[0]["lr"], 174 | epoch, 175 | ) 176 | # comment the following lines when mass explaining 177 | # if epoch % 25 == 0: 178 | # explainer.log_mask(epoch) 179 | # explainer.log_masked_adj( 180 | # node_idx_new, epoch, label=single_subgraph_label 181 | # ) 182 | # explainer.log_adj_grad( 183 | # node_idx_new, pred_label, epoch, label=single_subgraph_label 184 | # ) 185 | 186 | if epoch == 0: 187 | if self.model.att: 188 | # explain node 189 | print("adj att size: ", adj_atts.size()) 190 | adj_att = torch.sum(adj_atts[0], dim=2) 191 | # adj_att = adj_att[neighbors][:, neighbors] 192 | node_adj_att = adj_att * adj.float().cuda() 193 | io_utils.log_matrix( 194 | self.writer, node_adj_att[0], "att/matrix", epoch 195 | ) 196 | node_adj_att = node_adj_att[0].cpu().detach().numpy() 197 | G = io_utils.denoise_graph( 198 | node_adj_att, 199 | node_idx_new, 200 | threshold=3.8, # threshold_num=20, 201 | max_component=True, 202 | ) 203 | io_utils.log_graph( 204 | self.writer, 205 | G, 206 | name="att/graph", 207 | identify_self=not self.graph_mode, 208 | nodecolor="label", 209 | edge_vmax=None, 210 | args=self.args, 211 | ) 212 | if model != "exp": 213 | break 214 | 215 | print("finished training in ", time.time() - begin_time) 216 | if model == "exp": 217 | masked_adj = ( 218 | explainer.masked_adj[0].cpu().detach().numpy() * sub_adj.squeeze() 219 | ) 220 | else: 221 | adj_atts = nn.functional.sigmoid(adj_atts).squeeze() 222 | masked_adj = adj_atts.cpu().detach().numpy() * sub_adj.squeeze() 223 | 224 | fname = 'masked_adj_' + io_utils.gen_explainer_prefix(self.args) + ( 225 | 'node_idx_'+str(node_idx)+'graph_idx_'+str(self.graph_idx)+'.npy') 226 | # with open(os.path.join(self.args.logdir, fname), 'wb') as outfile: 227 | with open('explanations/gnnexplainer/%d.%d.masked_adj.npy' % (original_idx, pred_label), 'wb') as outfile: 228 | np.save(outfile, np.asarray(masked_adj.copy())) 229 | print("Saved adjacency matrix to ", fname) 230 | return masked_adj 231 | 232 | 233 | # NODE EXPLAINER 234 | def explain_nodes(self, node_indices, args, graph_idx=0): 235 | """ 236 | Explain nodes 237 | 238 | Args: 239 | - node_indices : Indices of the nodes to be explained 240 | - args : Program arguments (mainly for logging paths) 241 | - graph_idx : Index of the graph to explain the nodes from (if multiple). 242 | """ 243 | masked_adjs = [ 244 | self.explain(node_idx, graph_idx=graph_idx) for node_idx in node_indices 245 | ] 246 | ref_idx = node_indices[0] 247 | ref_adj = masked_adjs[0] 248 | curr_idx = node_indices[1] 249 | curr_adj = masked_adjs[1] 250 | new_ref_idx, _, ref_feat, _, _ = self.extract_neighborhood(ref_idx) 251 | new_curr_idx, _, curr_feat, _, _ = self.extract_neighborhood(curr_idx) 252 | 253 | G_ref = io_utils.denoise_graph(ref_adj, new_ref_idx, ref_feat, threshold=0.1) 254 | denoised_ref_feat = np.array( 255 | [G_ref.nodes[node]["feat"] for node in G_ref.nodes()] 256 | ) 257 | denoised_ref_adj = nx.to_numpy_matrix(G_ref) 258 | # ref center node 259 | ref_node_idx = list(G_ref.nodes()).index(new_ref_idx) 260 | 261 | G_curr = io_utils.denoise_graph( 262 | curr_adj, new_curr_idx, curr_feat, threshold=0.1 263 | ) 264 | denoised_curr_feat = np.array( 265 | [G_curr.nodes[node]["feat"] for node in G_curr.nodes()] 266 | ) 267 | denoised_curr_adj = nx.to_numpy_matrix(G_curr) 268 | # curr center node 269 | curr_node_idx = list(G_curr.nodes()).index(new_curr_idx) 270 | 271 | P, aligned_adj, aligned_feat = self.align( 272 | denoised_ref_feat, 273 | denoised_ref_adj, 274 | ref_node_idx, 275 | denoised_curr_feat, 276 | denoised_curr_adj, 277 | curr_node_idx, 278 | args=args, 279 | ) 280 | io_utils.log_matrix(self.writer, P, "align/P", 0) 281 | 282 | G_ref = nx.convert_node_labels_to_integers(G_ref) 283 | io_utils.log_graph(self.writer, G_ref, "align/ref") 284 | G_curr = nx.convert_node_labels_to_integers(G_curr) 285 | io_utils.log_graph(self.writer, G_curr, "align/before") 286 | 287 | P = P.cpu().detach().numpy() 288 | aligned_adj = aligned_adj.cpu().detach().numpy() 289 | aligned_feat = aligned_feat.cpu().detach().numpy() 290 | 291 | aligned_idx = np.argmax(P[:, curr_node_idx]) 292 | print("aligned self: ", aligned_idx) 293 | G_aligned = io_utils.denoise_graph( 294 | aligned_adj, aligned_idx, aligned_feat, threshold=0.5 295 | ) 296 | io_utils.log_graph(self.writer, G_aligned, "mask/aligned") 297 | 298 | # io_utils.log_graph(self.writer, aligned_adj.cpu().detach().numpy(), new_curr_idx, 299 | # 'align/aligned', epoch=1) 300 | 301 | return masked_adjs 302 | 303 | 304 | def explain_nodes_gnn_stats(self, node_indices, args, graph_idx=0, model="exp"): 305 | masked_adjs = [ 306 | self.explain(node_idx, graph_idx=graph_idx, model=model) 307 | for node_idx in node_indices 308 | ] 309 | # pdb.set_trace() 310 | graphs = [] 311 | feats = [] 312 | adjs = [] 313 | pred_all = [] 314 | real_all = [] 315 | for i, idx in enumerate(node_indices): 316 | new_idx, _, feat, _, _ = self.extract_neighborhood(idx) 317 | G = io_utils.denoise_graph(masked_adjs[i], new_idx, feat, threshold_num=20) 318 | pred, real = self.make_pred_real(masked_adjs[i], new_idx) 319 | pred_all.append(pred) 320 | real_all.append(real) 321 | denoised_feat = np.array([G.nodes[node]["feat"] for node in G.nodes()]) 322 | denoised_adj = nx.to_numpy_matrix(G) 323 | graphs.append(G) 324 | feats.append(denoised_feat) 325 | adjs.append(denoised_adj) 326 | io_utils.log_graph( 327 | self.writer, 328 | G, 329 | "graph/{}_{}_{}".format(self.args.dataset, model, i), 330 | identify_self=True, 331 | ) 332 | 333 | pred_all = np.concatenate((pred_all), axis=0) 334 | real_all = np.concatenate((real_all), axis=0) 335 | 336 | auc_all = roc_auc_score(real_all, pred_all) 337 | precision, recall, thresholds = precision_recall_curve(real_all, pred_all) 338 | 339 | plt.switch_backend("agg") 340 | plt.plot(recall, precision) 341 | plt.savefig("log/pr/pr_" + self.args.dataset + "_" + model + ".png") 342 | 343 | plt.close() 344 | 345 | auc_all = roc_auc_score(real_all, pred_all) 346 | precision, recall, thresholds = precision_recall_curve(real_all, pred_all) 347 | 348 | plt.switch_backend("agg") 349 | plt.plot(recall, precision) 350 | plt.savefig("log/pr/pr_" + self.args.dataset + "_" + model + ".png") 351 | 352 | plt.close() 353 | 354 | with open("log/pr/auc_" + self.args.dataset + "_" + model + ".txt", "w") as f: 355 | f.write( 356 | "dataset: {}, model: {}, auc: {}\n".format( 357 | self.args.dataset, "exp", str(auc_all) 358 | ) 359 | ) 360 | 361 | return masked_adjs 362 | 363 | # GRAPH EXPLAINER 364 | def explain_graphs(self, graph_indices): 365 | """ 366 | Explain graphs. 367 | """ 368 | masked_adjs = [] 369 | 370 | for graph_idx in graph_indices: 371 | masked_adj = self.explain(node_idx=0, graph_idx=graph_idx, graph_mode=True) 372 | G_denoised = io_utils.denoise_graph( 373 | masked_adj, 374 | 0, 375 | threshold_num=20, 376 | feat=self.feat[graph_idx], 377 | max_component=False, 378 | ) 379 | label = self.label[graph_idx] 380 | io_utils.log_graph( 381 | self.writer, 382 | G_denoised, 383 | "graph/graphidx_{}_label={}".format(graph_idx, label), 384 | identify_self=False, 385 | nodecolor="feat", 386 | ) 387 | masked_adjs.append(masked_adj) 388 | 389 | G_orig = io_utils.denoise_graph( 390 | self.adj[graph_idx], 391 | 0, 392 | feat=self.feat[graph_idx], 393 | threshold=None, 394 | max_component=False, 395 | ) 396 | 397 | io_utils.log_graph( 398 | self.writer, 399 | G_orig, 400 | "graph/graphidx_{}".format(graph_idx), 401 | identify_self=False, 402 | nodecolor="feat", 403 | ) 404 | 405 | # plot cmap for graphs' node features 406 | io_utils.plot_cmap_tb(self.writer, "tab20", 20, "tab20_cmap") 407 | 408 | return masked_adjs 409 | 410 | def log_representer(self, rep_val, sim_val, alpha, graph_idx=0): 411 | """ visualize output of representer instances. """ 412 | rep_val = rep_val.cpu().detach().numpy() 413 | sim_val = sim_val.cpu().detach().numpy() 414 | alpha = alpha.cpu().detach().numpy() 415 | sorted_rep = sorted(range(len(rep_val)), key=lambda k: rep_val[k]) 416 | print(sorted_rep) 417 | topk = 5 418 | most_neg_idx = [sorted_rep[i] for i in range(topk)] 419 | most_pos_idx = [sorted_rep[-i - 1] for i in range(topk)] 420 | rep_idx = [most_pos_idx, most_neg_idx] 421 | 422 | if self.graph_mode: 423 | pred = np.argmax(self.pred[0][graph_idx], axis=0) 424 | else: 425 | pred = np.argmax(self.pred[graph_idx][self.train_idx], axis=1) 426 | print(metrics.confusion_matrix(self.label[graph_idx][self.train_idx], pred)) 427 | plt.switch_backend("agg") 428 | fig = plt.figure(figsize=(5, 3), dpi=600) 429 | for i in range(2): 430 | for j in range(topk): 431 | idx = self.train_idx[rep_idx[i][j]] 432 | print( 433 | "node idx: ", 434 | idx, 435 | "; node label: ", 436 | self.label[graph_idx][idx], 437 | "; pred: ", 438 | pred, 439 | ) 440 | 441 | idx_new, sub_adj, sub_feat, sub_label, neighbors = self.extract_neighborhood( 442 | idx, graph_idx 443 | ) 444 | G = nx.from_numpy_matrix(sub_adj) 445 | node_colors = [1 for i in range(G.number_of_nodes())] 446 | node_colors[idx_new] = 0 447 | # node_color='#336699', 448 | 449 | ax = plt.subplot(2, topk, i * topk + j + 1) 450 | nx.draw( 451 | G, 452 | pos=nx.spring_layout(G), 453 | with_labels=True, 454 | font_size=4, 455 | node_color=node_colors, 456 | cmap=plt.get_cmap("Set1"), 457 | vmin=0, 458 | vmax=8, 459 | edge_vmin=0.0, 460 | edge_vmax=1.0, 461 | width=0.5, 462 | node_size=25, 463 | alpha=0.7, 464 | ) 465 | ax.xaxis.set_visible(False) 466 | fig.canvas.draw() 467 | self.writer.add_image( 468 | "local/representer_neigh", tensorboardX.utils.figure_to_image(fig), 0 469 | ) 470 | 471 | def representer(self): 472 | """ 473 | experiment using representer theorem for finding supporting instances. 474 | https://papers.nips.cc/paper/8141-representer-point-selection-for-explaining-deep-neural-networks.pdf 475 | """ 476 | self.model.train() 477 | self.model.zero_grad() 478 | adj = torch.tensor(self.adj, dtype=torch.float) 479 | x = torch.tensor(self.feat, requires_grad=True, dtype=torch.float) 480 | label = torch.tensor(self.label, dtype=torch.long) 481 | if self.args.gpu: 482 | adj, x, label = adj.cuda(), x.cuda(), label.cuda() 483 | 484 | preds, _ = self.model(x, adj) 485 | preds.retain_grad() 486 | self.embedding = self.model.embedding_tensor 487 | loss = self.model.loss(preds, label) 488 | loss.backward() 489 | self.preds_grad = preds.grad 490 | pred_idx = np.expand_dims(np.argmax(self.pred, axis=2), axis=2) 491 | pred_idx = torch.LongTensor(pred_idx) 492 | if self.args.gpu: 493 | pred_idx = pred_idx.cuda() 494 | self.alpha = self.preds_grad 495 | 496 | 497 | # Utilities 498 | def extract_neighborhood(self, node_idx, graph_idx=0): 499 | """Returns the neighborhood of a given ndoe.""" 500 | neighbors_adj_row = self.neighborhoods[graph_idx][node_idx, :] 501 | # index of the query node in the new adj 502 | node_idx_new = sum(neighbors_adj_row[:node_idx]) 503 | neighbors = np.nonzero(neighbors_adj_row)[0] 504 | sub_adj = self.adj[graph_idx][neighbors][:, neighbors] 505 | sub_feat = self.feat[graph_idx, neighbors] 506 | sub_label = self.label[graph_idx][neighbors] 507 | return node_idx_new, sub_adj, sub_feat, sub_label, neighbors 508 | 509 | def align( 510 | self, ref_feat, ref_adj, ref_node_idx, curr_feat, curr_adj, curr_node_idx, args 511 | ): 512 | """ Tries to find an alignment between two graphs. 513 | """ 514 | ref_adj = torch.FloatTensor(ref_adj) 515 | curr_adj = torch.FloatTensor(curr_adj) 516 | 517 | ref_feat = torch.FloatTensor(ref_feat) 518 | curr_feat = torch.FloatTensor(curr_feat) 519 | 520 | P = nn.Parameter(torch.FloatTensor(ref_adj.shape[0], curr_adj.shape[0])) 521 | with torch.no_grad(): 522 | nn.init.constant_(P, 1.0 / ref_adj.shape[0]) 523 | P[ref_node_idx, :] = 0.0 524 | P[:, curr_node_idx] = 0.0 525 | P[ref_node_idx, curr_node_idx] = 1.0 526 | opt = torch.optim.Adam([P], lr=0.01, betas=(0.5, 0.999)) 527 | for i in range(args.align_steps): 528 | opt.zero_grad() 529 | feat_loss = torch.norm(P @ curr_feat - ref_feat) 530 | 531 | aligned_adj = P @ curr_adj @ torch.transpose(P, 0, 1) 532 | align_loss = torch.norm(aligned_adj - ref_adj) 533 | loss = feat_loss + align_loss 534 | loss.backward() # Calculate gradients 535 | self.writer.add_scalar("optimization/align_loss", loss, i) 536 | print("iter: ", i, "; loss: ", loss) 537 | opt.step() 538 | 539 | return P, aligned_adj, P @ curr_feat 540 | 541 | def make_pred_real(self, adj, start): 542 | # house graph 543 | if self.args.dataset == "syn1" or self.args.dataset == "syn2": 544 | # num_pred = max(G.number_of_edges(), 6) 545 | pred = adj[np.triu(adj) > 0] 546 | real = adj.copy() 547 | 548 | if real[start][start + 1] > 0: 549 | real[start][start + 1] = 10 550 | if real[start + 1][start + 2] > 0: 551 | real[start + 1][start + 2] = 10 552 | if real[start + 2][start + 3] > 0: 553 | real[start + 2][start + 3] = 10 554 | if real[start][start + 3] > 0: 555 | real[start][start + 3] = 10 556 | if real[start][start + 4] > 0: 557 | real[start][start + 4] = 10 558 | if real[start + 1][start + 4]: 559 | real[start + 1][start + 4] = 10 560 | real = real[np.triu(real) > 0] 561 | real[real != 10] = 0 562 | real[real == 10] = 1 563 | 564 | # cycle graph 565 | elif self.args.dataset == "syn4": 566 | pred = adj[np.triu(adj) > 0] 567 | real = adj.copy() 568 | # pdb.set_trace() 569 | if real[start][start + 1] > 0: 570 | real[start][start + 1] = 10 571 | if real[start + 1][start + 2] > 0: 572 | real[start + 1][start + 2] = 10 573 | if real[start + 2][start + 3] > 0: 574 | real[start + 2][start + 3] = 10 575 | if real[start + 3][start + 4] > 0: 576 | real[start + 3][start + 4] = 10 577 | if real[start + 4][start + 5] > 0: 578 | real[start + 4][start + 5] = 10 579 | if real[start][start + 5]: 580 | real[start][start + 5] = 10 581 | real = real[np.triu(real) > 0] 582 | real[real != 10] = 0 583 | real[real == 10] = 1 584 | 585 | return pred, real 586 | 587 | 588 | class ExplainModule(nn.Module): 589 | def __init__( 590 | self, 591 | adj, 592 | x, 593 | model, 594 | label, 595 | args, 596 | graph_idx=0, 597 | writer=None, 598 | use_sigmoid=True, 599 | graph_mode=False, 600 | ): 601 | super(ExplainModule, self).__init__() 602 | self.adj = adj 603 | self.x = x 604 | self.model = model 605 | self.label = label 606 | self.graph_idx = graph_idx 607 | self.args = args 608 | self.writer = writer 609 | self.mask_act = args.mask_act 610 | self.use_sigmoid = use_sigmoid 611 | self.graph_mode = graph_mode 612 | 613 | init_strategy = "normal" 614 | num_nodes = adj.size()[1] 615 | self.mask, self.mask_bias = self.construct_edge_mask( 616 | num_nodes, init_strategy=init_strategy 617 | ) 618 | 619 | self.feat_mask = self.construct_feat_mask(x.size(-1), init_strategy="constant") 620 | params = [self.mask, self.feat_mask] 621 | if self.mask_bias is not None: 622 | params.append(self.mask_bias) 623 | # For masking diagonal entries 624 | self.diag_mask = torch.ones(num_nodes, num_nodes) - torch.eye(num_nodes) 625 | if args.gpu: 626 | self.diag_mask = self.diag_mask.cuda() 627 | 628 | self.scheduler, self.optimizer = train_utils.build_optimizer(args, params) 629 | 630 | self.coeffs = { 631 | "size": 0.005, 632 | "feat_size": 1.0, 633 | "ent": 1.0, 634 | "feat_ent": 0.1, 635 | "grad": 0, 636 | "lap": 1.0, 637 | } 638 | 639 | def construct_feat_mask(self, feat_dim, init_strategy="normal"): 640 | mask = nn.Parameter(torch.FloatTensor(feat_dim)) 641 | if init_strategy == "normal": 642 | std = 0.1 643 | with torch.no_grad(): 644 | mask.normal_(1.0, std) 645 | elif init_strategy == "constant": 646 | with torch.no_grad(): 647 | nn.init.constant_(mask, 0.0) 648 | # mask[0] = 2 649 | return mask 650 | 651 | def construct_edge_mask(self, num_nodes, init_strategy="normal", const_val=1.0): 652 | mask = nn.Parameter(torch.FloatTensor(num_nodes, num_nodes)) 653 | if init_strategy == "normal": 654 | std = nn.init.calculate_gain("relu") * math.sqrt( 655 | 2.0 / (num_nodes + num_nodes) 656 | ) 657 | with torch.no_grad(): 658 | mask.normal_(1.0, std) 659 | # mask.clamp_(0.0, 1.0) 660 | elif init_strategy == "const": 661 | nn.init.constant_(mask, const_val) 662 | 663 | if self.args.mask_bias: 664 | mask_bias = nn.Parameter(torch.FloatTensor(num_nodes, num_nodes)) 665 | nn.init.constant_(mask_bias, 0.0) 666 | else: 667 | mask_bias = None 668 | 669 | return mask, mask_bias 670 | 671 | def _masked_adj(self): 672 | sym_mask = self.mask 673 | if self.mask_act == "sigmoid": 674 | sym_mask = torch.sigmoid(self.mask) 675 | elif self.mask_act == "ReLU": 676 | sym_mask = nn.ReLU()(self.mask) 677 | sym_mask = (sym_mask + sym_mask.t()) / 2 678 | adj = self.adj.cuda() if self.args.gpu else self.adj 679 | masked_adj = adj * sym_mask 680 | if self.args.mask_bias: 681 | bias = (self.mask_bias + self.mask_bias.t()) / 2 682 | bias = nn.ReLU6()(bias * 6) / 6 683 | masked_adj += (bias + bias.t()) / 2 684 | return masked_adj * self.diag_mask 685 | 686 | def mask_density(self): 687 | mask_sum = torch.sum(self._masked_adj()).cpu() 688 | adj_sum = torch.sum(self.adj) 689 | return mask_sum / adj_sum 690 | 691 | def forward(self, node_idx, unconstrained=False, mask_features=True, marginalize=False): 692 | x = self.x.cuda() if self.args.gpu else self.x 693 | 694 | if unconstrained: 695 | sym_mask = torch.sigmoid(self.mask) if self.use_sigmoid else self.mask 696 | self.masked_adj = ( 697 | torch.unsqueeze((sym_mask + sym_mask.t()) / 2, 0) * self.diag_mask 698 | ) 699 | else: 700 | self.masked_adj = self._masked_adj() 701 | if mask_features: 702 | feat_mask = ( 703 | torch.sigmoid(self.feat_mask) 704 | if self.use_sigmoid 705 | else self.feat_mask 706 | ) 707 | if marginalize: 708 | std_tensor = torch.ones_like(x, dtype=torch.float) / 2 709 | mean_tensor = torch.zeros_like(x, dtype=torch.float) - x 710 | z = torch.normal(mean=mean_tensor, std=std_tensor) 711 | x = x + z * (1 - feat_mask) 712 | else: 713 | x = x * feat_mask 714 | 715 | ypred, adj_att = self.model(x, self.masked_adj) 716 | if self.graph_mode: 717 | res = nn.Softmax(dim=0)(ypred[0]) 718 | else: 719 | node_pred = ypred[self.graph_idx, node_idx, :] 720 | res = nn.Softmax(dim=0)(node_pred) 721 | return res, adj_att 722 | 723 | def adj_feat_grad(self, node_idx, pred_label_node): 724 | self.model.zero_grad() 725 | self.adj.requires_grad = True 726 | self.x.requires_grad = True 727 | if self.adj.grad is not None: 728 | self.adj.grad.zero_() 729 | self.x.grad.zero_() 730 | if self.args.gpu: 731 | adj = self.adj.cuda() 732 | x = self.x.cuda() 733 | label = self.label.cuda() 734 | else: 735 | x, adj = self.x, self.adj 736 | ypred, _ = self.model(x, adj) 737 | if self.graph_mode: 738 | logit = nn.Softmax(dim=0)(ypred[0]) 739 | else: 740 | logit = nn.Softmax(dim=0)(ypred[self.graph_idx, node_idx, :]) 741 | logit = logit[pred_label_node] 742 | loss = -torch.log(logit) 743 | loss.backward() 744 | return self.adj.grad, self.x.grad 745 | 746 | def loss(self, pred, pred_label, node_idx, epoch): 747 | """ 748 | Args: 749 | pred: prediction made by current model 750 | pred_label: the label predicted by the original model. 751 | """ 752 | mi_obj = False 753 | if mi_obj: 754 | pred_loss = -torch.sum(pred * torch.log(pred)) 755 | else: 756 | pred_label_node = pred_label if self.graph_mode else pred_label[node_idx] 757 | gt_label_node = self.label if self.graph_mode else self.label[0][node_idx] 758 | logit = pred[gt_label_node] 759 | pred_loss = -torch.log(logit) 760 | # size 761 | mask = self.mask 762 | if self.mask_act == "sigmoid": 763 | mask = torch.sigmoid(self.mask) 764 | elif self.mask_act == "ReLU": 765 | mask = nn.ReLU()(self.mask) 766 | size_loss = self.coeffs["size"] * torch.sum(mask) 767 | 768 | # pre_mask_sum = torch.sum(self.feat_mask) 769 | feat_mask = ( 770 | torch.sigmoid(self.feat_mask) if self.use_sigmoid else self.feat_mask 771 | ) 772 | feat_size_loss = self.coeffs["feat_size"] * torch.mean(feat_mask) 773 | 774 | # entropy 775 | mask_ent = -mask * torch.log(mask) - (1 - mask) * torch.log(1 - mask) 776 | mask_ent_loss = self.coeffs["ent"] * torch.mean(mask_ent) 777 | 778 | feat_mask_ent = - feat_mask \ 779 | * torch.log(feat_mask) \ 780 | - (1 - feat_mask) \ 781 | * torch.log(1 - feat_mask) 782 | 783 | feat_mask_ent_loss = self.coeffs["feat_ent"] * torch.mean(feat_mask_ent) 784 | 785 | # laplacian 786 | D = torch.diag(torch.sum(self.masked_adj[0], 0)) 787 | m_adj = self.masked_adj if self.graph_mode else self.masked_adj[self.graph_idx] 788 | L = D - m_adj 789 | pred_label_t = torch.tensor(pred_label, dtype=torch.float) 790 | if self.args.gpu: 791 | pred_label_t = pred_label_t.cuda() 792 | L = L.cuda() 793 | if self.graph_mode: 794 | lap_loss = 0 795 | else: 796 | lap_loss = (self.coeffs["lap"] 797 | * (pred_label_t @ L @ pred_label_t) 798 | / self.adj.numel() 799 | ) 800 | 801 | # grad 802 | # adj 803 | # adj_grad, x_grad = self.adj_feat_grad(node_idx, pred_label_node) 804 | # adj_grad = adj_grad[self.graph_idx] 805 | # x_grad = x_grad[self.graph_idx] 806 | # if self.args.gpu: 807 | # adj_grad = adj_grad.cuda() 808 | # grad_loss = self.coeffs['grad'] * -torch.mean(torch.abs(adj_grad) * mask) 809 | 810 | # feat 811 | # x_grad_sum = torch.sum(x_grad, 1) 812 | # grad_feat_loss = self.coeffs['featgrad'] * -torch.mean(x_grad_sum * mask) 813 | 814 | loss = pred_loss + size_loss + lap_loss + mask_ent_loss + feat_size_loss 815 | if self.writer is not None: 816 | self.writer.add_scalar("optimization/size_loss", size_loss, epoch) 817 | self.writer.add_scalar("optimization/feat_size_loss", feat_size_loss, epoch) 818 | self.writer.add_scalar("optimization/mask_ent_loss", mask_ent_loss, epoch) 819 | self.writer.add_scalar( 820 | "optimization/feat_mask_ent_loss", mask_ent_loss, epoch 821 | ) 822 | # self.writer.add_scalar('optimization/grad_loss', grad_loss, epoch) 823 | self.writer.add_scalar("optimization/pred_loss", pred_loss, epoch) 824 | self.writer.add_scalar("optimization/lap_loss", lap_loss, epoch) 825 | self.writer.add_scalar("optimization/overall_loss", loss, epoch) 826 | return loss 827 | 828 | def log_mask(self, epoch): 829 | plt.switch_backend("agg") 830 | fig = plt.figure(figsize=(4, 3), dpi=400) 831 | plt.imshow(self.mask.cpu().detach().numpy(), cmap=plt.get_cmap("BuPu")) 832 | cbar = plt.colorbar() 833 | cbar.solids.set_edgecolor("face") 834 | 835 | plt.tight_layout() 836 | fig.canvas.draw() 837 | self.writer.add_image( 838 | "mask/mask", tensorboardX.utils.figure_to_image(fig), epoch 839 | ) 840 | 841 | # fig = plt.figure(figsize=(4,3), dpi=400) 842 | # plt.imshow(self.feat_mask.cpu().detach().numpy()[:,np.newaxis], cmap=plt.get_cmap('BuPu')) 843 | # cbar = plt.colorbar() 844 | # cbar.solids.set_edgecolor("face") 845 | 846 | # plt.tight_layout() 847 | # fig.canvas.draw() 848 | # self.writer.add_image('mask/feat_mask', tensorboardX.utils.figure_to_image(fig), epoch) 849 | io_utils.log_matrix( 850 | self.writer, torch.sigmoid(self.feat_mask), "mask/feat_mask", epoch 851 | ) 852 | 853 | fig = plt.figure(figsize=(4, 3), dpi=400) 854 | # use [0] to remove the batch dim 855 | plt.imshow(self.masked_adj[0].cpu().detach().numpy(), cmap=plt.get_cmap("BuPu")) 856 | cbar = plt.colorbar() 857 | cbar.solids.set_edgecolor("face") 858 | 859 | plt.tight_layout() 860 | fig.canvas.draw() 861 | self.writer.add_image( 862 | "mask/adj", tensorboardX.utils.figure_to_image(fig), epoch 863 | ) 864 | 865 | if self.args.mask_bias: 866 | fig = plt.figure(figsize=(4, 3), dpi=400) 867 | # use [0] to remove the batch dim 868 | plt.imshow(self.mask_bias.cpu().detach().numpy(), cmap=plt.get_cmap("BuPu")) 869 | cbar = plt.colorbar() 870 | cbar.solids.set_edgecolor("face") 871 | 872 | plt.tight_layout() 873 | fig.canvas.draw() 874 | self.writer.add_image( 875 | "mask/bias", tensorboardX.utils.figure_to_image(fig), epoch 876 | ) 877 | 878 | def log_adj_grad(self, node_idx, pred_label, epoch, label=None): 879 | log_adj = False 880 | 881 | if self.graph_mode: 882 | predicted_label = pred_label 883 | # adj_grad, x_grad = torch.abs(self.adj_feat_grad(node_idx, predicted_label)[0])[0] 884 | adj_grad, x_grad = self.adj_feat_grad(node_idx, predicted_label) 885 | adj_grad = torch.abs(adj_grad)[0] 886 | x_grad = torch.sum(x_grad[0], 0, keepdim=True).t() 887 | else: 888 | predicted_label = pred_label[node_idx] 889 | # adj_grad = torch.abs(self.adj_feat_grad(node_idx, predicted_label)[0])[self.graph_idx] 890 | adj_grad, x_grad = self.adj_feat_grad(node_idx, predicted_label) 891 | adj_grad = torch.abs(adj_grad)[self.graph_idx] 892 | x_grad = x_grad[self.graph_idx][node_idx][:, np.newaxis] 893 | # x_grad = torch.sum(x_grad[self.graph_idx], 0, keepdim=True).t() 894 | adj_grad = (adj_grad + adj_grad.t()) / 2 895 | adj_grad = (adj_grad * self.adj).squeeze() 896 | if log_adj: 897 | io_utils.log_matrix(self.writer, adj_grad, "grad/adj_masked", epoch) 898 | self.adj.requires_grad = False 899 | io_utils.log_matrix(self.writer, self.adj.squeeze(), "grad/adj_orig", epoch) 900 | 901 | masked_adj = self.masked_adj[0].cpu().detach().numpy() 902 | 903 | # only for graph mode since many node neighborhoods for syn tasks are relatively large for 904 | # visualization 905 | if self.graph_mode: 906 | G = io_utils.denoise_graph( 907 | masked_adj, node_idx, feat=self.x[0], threshold=None, max_component=False 908 | ) 909 | io_utils.log_graph( 910 | self.writer, 911 | G, 912 | name="grad/graph_orig", 913 | epoch=epoch, 914 | identify_self=False, 915 | label_node_feat=True, 916 | nodecolor="feat", 917 | edge_vmax=None, 918 | args=self.args, 919 | ) 920 | io_utils.log_matrix(self.writer, x_grad, "grad/feat", epoch) 921 | 922 | adj_grad = adj_grad.detach().numpy() 923 | if self.graph_mode: 924 | print("GRAPH model") 925 | G = io_utils.denoise_graph( 926 | adj_grad, 927 | node_idx, 928 | feat=self.x[0], 929 | threshold=0.0003, # threshold_num=20, 930 | max_component=True, 931 | ) 932 | # io_utils.log_graph( 933 | # self.writer, 934 | # G, 935 | # name="grad/graph", 936 | # epoch=epoch, 937 | # identify_self=False, 938 | # label_node_feat=True, 939 | # nodecolor="feat", 940 | # edge_vmax=None, 941 | # args=self.args, 942 | # ) 943 | else: 944 | # G = io_utils.denoise_graph(adj_grad, node_idx, label=label, threshold=0.5) 945 | G = io_utils.denoise_graph(adj_grad, node_idx, threshold_num=12) 946 | io_utils.log_graph( 947 | self.writer, G, name="grad/graph", epoch=epoch, args=self.args 948 | ) 949 | 950 | # if graph attention, also visualize att 951 | 952 | def log_masked_adj(self, node_idx, epoch, name="mask/graph", label=None): 953 | # use [0] to remove the batch dim 954 | masked_adj = self.masked_adj[0].cpu().detach().numpy() 955 | if self.graph_mode: 956 | G = io_utils.denoise_graph( 957 | masked_adj, 958 | node_idx, 959 | feat=self.x[0], 960 | threshold=0.2, # threshold_num=20, 961 | max_component=True, 962 | ) 963 | io_utils.log_graph( 964 | self.writer, 965 | G, 966 | name=name, 967 | identify_self=False, 968 | nodecolor="feat", 969 | epoch=epoch, 970 | label_node_feat=True, 971 | edge_vmax=None, 972 | args=self.args, 973 | ) 974 | else: 975 | G = io_utils.denoise_graph( 976 | masked_adj, node_idx, threshold_num=12, max_component=True 977 | ) 978 | io_utils.log_graph( 979 | self.writer, 980 | G, 981 | name=name, 982 | identify_self=True, 983 | nodecolor="label", 984 | epoch=epoch, 985 | edge_vmax=None, 986 | args=self.args, 987 | ) 988 | 989 | --------------------------------------------------------------------------------