├── .gitignore ├── LICENSE ├── README.md ├── baselines ├── __init__.py ├── eval_utils.py ├── methods │ ├── OrphicX │ │ ├── __init__.py │ │ ├── causaleffect.py │ │ ├── gae │ │ │ ├── layers.py │ │ │ ├── model.py │ │ │ └── optimizer.py │ │ └── orphicx.py │ ├── __init__.py │ ├── graphsvx.py │ ├── pgm_explainer.py │ ├── subgraphx.py │ └── subgraphx_base.py ├── run_gnnexplainer.py ├── run_graphsvx.py ├── run_orphicx.py ├── run_pgexplainer.py ├── run_pgmexplainer.py ├── run_sa.py ├── run_subgraphx.py └── utils.py ├── cppextension └── cudagnnshap.cu ├── dataset ├── __init__.py ├── configs.py └── utils.py ├── examples ├── BaselinesEvaluation.ipynb ├── CustomModelData.ipynb └── Visualization.ipynb ├── gnnshap ├── __init__.py ├── eval_utils.py ├── explainer.py ├── explanation.py ├── samplers │ ├── __init__.py │ ├── _base.py │ ├── _exact.py │ ├── _gnnshap.py │ ├── _shap.py │ ├── _shap_unique.py │ └── _svx.py ├── solvers │ ├── __init__.py │ ├── _base.py │ ├── _wlr.py │ └── _wls.py └── utils.py ├── models ├── GATModel.py ├── GCNModel.py └── __init__.py ├── pretrained ├── CiteSeer_pretrained.pt ├── Coauthor-CS_pretrained.pt ├── Coauthor-Physics_pretrained.pt ├── Cora_GAT_pretrained.pt ├── Cora_pretrained.pt ├── Facebook_pretrained.pt ├── PubMed_pretrained.pt ├── Reddit_explain_data.pt ├── Reddit_pretrained.pt ├── ogbn-products_explain_data.pt ├── ogbn-products_pretrained.pt ├── split_Coauthor-CS.pt ├── split_Coauthor-Physics.pt └── split_Facebook.pt ├── requirements.txt ├── results └── .placeholder ├── run_baseline_experiments.sh ├── run_gnnshap.py ├── run_gnnshap_experiments.sh ├── train.py └── train_large.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | 162 | data/ 163 | .vscode/ 164 | results/*.pkl 165 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # GNNShap: Scalable and Accurate GNN Explanation using Shapley Values 2 | This repository contains the source code of 3 | `GNNShap: Scalable and Accurate GNN Explanation using Shapley Values` paper accepted in The 4 | Web Conference 2024. 5 | 6 | ### Setup 7 | Our implementation is based on PyTorch and PYG. Also, Our Shapley sampling strategy is implemented 8 | in Cuda. Therefore, GNNShap requires a GPU with Cuda support. 9 | 10 | First install PyTorch with GPU support from [here](https://pytorch.org/get-started/locally/) and 11 | make sure PyTorch is using GPU. 12 | 13 | The rest of the required packages and versions are provided in the `requirements.txt` file. 14 | 15 | You can install the requirements by running: 16 | ```bash 17 | pip install -r requirements.txt 18 | ``` 19 | 20 | ### Dataset Configs 21 | 22 | Dataset and dataset-specific model configurations are in the `dataset/configs.py` file. 23 | 24 | 25 | ### Model training 26 | 27 | We provided pretrained models in the `pretrained` folder. 28 | 29 | To train Cora, CiteSeer, PubMed, Coauthor-CS, Coauthor-Physics, and Facebook datasets: 30 | ```bash 31 | python train.py --dataset Cora 32 | ``` 33 | 34 | Reddit and ogbn-products require `NeighborLoader` for training. To train them: 35 | ```bash 36 | python train_large.py --dataset Reddit 37 | ``` 38 | 39 | ### Experiments 40 | 41 | We provided scripts for baselines and GNNShap experiments. Scripts will save explanation results to 42 | the results folder. Note that scripts repeat each experiment five times. This can be changed in the 43 | scripts. 44 | 45 | For baselines, you can use the following script. For individual baseline, you can refer to 46 | the script file content. 47 | 48 | ```bash 49 | ./run_baseline_experiments.sh 50 | ``` 51 | 52 | For GNNShap experiments, you can use the following script: 53 | ``` 54 | ./run_gnnshap_experiments.sh 55 | ``` 56 | 57 | - _We ran experiments on a GPU with 24GB of memory. You may need to adjust `batch_size` 58 | and `num_samples` parameters if you have less GPU memory._ 59 | - _The first run might take some time: it needs to compile the Cuda code._ 60 | 61 | For individual dataset experiments, an example is provided below: 62 | ```bash 63 | python run_gnnshap.py --dataset Cora --num_samples 25000 --sampler GNNShapSampler 64 | --batch_size 1024 --repeat 1 65 | ``` 66 | 67 | The results will be saved to the `results` folder. The default result folder can be changed 68 | in `dataset/configs.py` 69 | 70 | 71 | ### Evaluation 72 | We used the `BaselinesEvaluation.ipynb` notebook under the `examples` folder for explanation times 73 | and fidelity results. 74 | 75 | ### Visualization 76 | We provided explanation visualization examples in the `Visualization.ipynb` notebook 77 | under `examples.` 78 | 79 | ### Custom Model & Data Explanations 80 | We provided an example in the `CustomModelData.ipynb` notebook under `examples`. 81 | 82 | ### Citation 83 | Please cite our work if you find it useful. 84 | 85 | ``` 86 | Selahattin Akkas and Ariful Azad. 2024. GNNShap: Scalable and Accurate GNN Explanation using Shapley Values. 87 | In Proceedings of the ACM Web Conference 2024 (WWW ’24), May 13–17, 2024, Singapore, Singapore. 88 | ``` 89 | -------------------------------------------------------------------------------- /baselines/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HipGraph/GNNShap/f9672297394493ebe1ea9cf60bd14530e06d4916/baselines/__init__.py -------------------------------------------------------------------------------- /baselines/eval_utils.py: -------------------------------------------------------------------------------- 1 | 2 | import pickle 3 | 4 | import numpy as np 5 | import pandas as pd 6 | import torch 7 | from torch_geometric.data import Data 8 | from tqdm import tqdm 9 | 10 | from dataset.utils import get_model_data_config 11 | from gnnshap.eval_utils import fidelity 12 | 13 | 14 | def read_results(file_path: str) -> tuple[list, float]: 15 | """Reads a pickle results file and returns scores, total time and training time. 16 | Note that only some methods like PGExplainer requires separate training. Other methods will have 17 | 0.0 as the training time. 18 | 19 | Args: 20 | file_path (str): file path 21 | 22 | Returns: 23 | tuple[list, float]: results as list, and training time. 24 | """ 25 | try: 26 | res = pickle.load(open(file_path, 'rb')) 27 | if len(res) == 2: 28 | return res[0], res[1] 29 | else: 30 | return res, 0.0 31 | except Exception as e: 32 | return [], 0.0 33 | 34 | def compute_fidelity_score(results: list, data: Data, model: torch.nn.Module, 35 | sparsity: float = 0.3, fid_type: str = 'neg', topk: int = 0, 36 | target_class: int = None, testing_pred: str = 'mix', 37 | apply_abs: bool=True) -> tuple: 38 | """Computes fidelity+ and fidelity- scores. It supports both topk and sparsity. If sparsity set 39 | to 0.3, it drops 30% of the edges. Based on the neg or pos, it drops unimportant or important 40 | edges. It applies topk based keep if topk is set to a positive integer other than zero. 41 | 42 | `testing_pred` helps to further analyze fidelity scores for correct and wrong classified nodes. 43 | 44 | Note that it computes fidelity scores for the predicted class if target class is not provided. 45 | Args: 46 | results (list): List of dictionaries. Each dictionary should have node_id, num_players, 47 | scores keys. 48 | data (Data): pyG Data. 49 | model (torch.nn.Module): a PyTorch model. 50 | sparsity (float, optional): target sparsity value. Defaults to 0.3. 51 | fid_type (str, optional): Fidelity type: neg or pos. Defaults to 'neg'. 52 | topk (int, optional): Topk edges to keep. Defaults to 0. 53 | target_class (int, optional): Target class to compute fidelity score. If None, it computes 54 | fidelity score for the predicted class. Defaults to None. 55 | testing_pred (str, optional): Testing prediction filter. Options are 'mix', 'wrong', 56 | and 'correct' Defaults to mix. 57 | apply_abs (bool, optional): applies absolute to scores. Some methods can find negative and 58 | positive contributing nodes/edges. Fidelity-wise, we only care the change amount. We can 59 | use this to get rid of negative contributing edges to improve accuracy. Defaults to 60 | True. 61 | 62 | Returns: 63 | tuple: average score, list of individual scores[node_id, nplayers, fidelity prob score, 64 | current sparsity, correct_class, init_pred_class, sparse_pred_class, fidelity acc] 65 | """ 66 | assert testing_pred in [ 67 | 'mix', 'wrong', 'correct'], "Testing prediction option is not correct!" 68 | 69 | 70 | fid_scores = [] 71 | sum_prob = 0 72 | for res in results: 73 | if res['num_players'] < 2: 74 | continue 75 | try: 76 | (node_id, num_players, prob_score, current_sparsity, correct_class, init_pred_class, 77 | sparse_pred_class) = fidelity(res, data, model, sparsity, fid_type, topk, 78 | target_class, apply_abs) 79 | 80 | if testing_pred == 'wrong' and correct_class == init_pred_class: # skip correct preds 81 | continue 82 | if testing_pred == 'correct' and correct_class != init_pred_class: # skip wrong preds 83 | continue 84 | 85 | sum_prob += prob_score 86 | fid_scores.append([node_id, num_players, prob_score, current_sparsity, correct_class, 87 | init_pred_class, sparse_pred_class]) 88 | except: 89 | print(f"Error in: {res}") 90 | return -1.0, [] 91 | 92 | overall_prob = sum_prob / len(fid_scores) 93 | return overall_prob, fid_scores 94 | 95 | 96 | def run_times_table(path_gen_fn, dataset_names, num_repeats: int=5) -> pd.DataFrame: 97 | """Reads pickle files and extract times. When there is multiple run, it computes the average. 98 | 99 | Args: 100 | path_gen_fn: function that returns the paths of the results 101 | dataset_names (list): list of dataset names 102 | num_repeats (int, optional): Number of repeats. Defaults to 5. 103 | 104 | Returns: 105 | explanation_times: pd.DataFrame: table with the average times for each method and dataset 106 | exp_model_train_times: pd.DataFrame: table with the average training times for 107 | each method and dataset. Valid for PGExplainer like methods. 108 | 109 | """ 110 | method_names = np.array(path_gen_fn(dataset_names[0]))[:,0].tolist() 111 | 112 | res_table = np.zeros((len(method_names), len(dataset_names)), dtype=object) 113 | res_tr_table = np.zeros((len(method_names), len(dataset_names)), dtype=object) 114 | 115 | for i, dataset_name in enumerate(tqdm(dataset_names)): 116 | total_times = [] 117 | tr_total_times = [] 118 | 119 | for rep_num in range(1, num_repeats+1): 120 | result_file_paths = path_gen_fn(dataset_name, rep_num) 121 | total_times.append([]) 122 | tr_total_times.append([]) 123 | for res_file in result_file_paths: 124 | res, tr_time = read_results(res_file[1]) 125 | tmp_time = 0 126 | for r in res: 127 | tmp_time += r['time'] 128 | total_times[-1].append(tmp_time) 129 | tr_total_times[-1].append(tr_time) 130 | total_times = np.array(total_times) 131 | tr_total_times = np.array(tr_total_times) 132 | 133 | 134 | # compute the mean and std 135 | mean_res = np.mean(total_times, axis=0) 136 | std_res = np.std(total_times, axis=0) 137 | mean_tr_res = np.mean(tr_total_times, axis=0) 138 | std_tr_res = np.std(tr_total_times, axis=0) 139 | 140 | 141 | # fill the table with the results 142 | for k in range(len(method_names)): 143 | res_table[k,i] = f"{mean_res[k]:.2f}\u00B1{std_res[k]:.2f}" 144 | res_tr_table[k,i] = f"{mean_tr_res[k]:.2f}\u00B1{std_tr_res[k]:.2f}" 145 | 146 | expl_df = pd.DataFrame(res_table, columns=dataset_names, index=method_names) 147 | expl_tr_df = pd.DataFrame(res_tr_table, columns=dataset_names, index=method_names) 148 | 149 | return expl_df, expl_tr_df 150 | 151 | def fidelity_table(path_gen_fn, dataset_names, sparsity=0.1, score_type='neg', topk=0, 152 | num_repeats=1, device='cpu', testing_pred='mix', apply_abs=True) -> pd.DataFrame: 153 | """Create a table with the fidelity scores of the methods. It applies topk based edge keep if 154 | topk is set to a positive integer other than zero. 155 | 156 | Args: 157 | path_gen_fn: function that returns the paths of the results 158 | dataset_names (list): list of dataset names 159 | sparsity (float, optional): sparsity of the explanations. Defaults to 0.1. 160 | score_type (str, optional): score_type (str, optional): Fidelity type: 'neg' or 'pos'. 161 | Defaults to 'neg'. 162 | topk (int, optional): Topk edges to keep. Defaults to 0. 163 | num_repeats (int, optional): number of experiment repeats. Defaults to 1. 164 | device (str, optional): device. Defaults to 'cpu'. 165 | testing_pred (str, optional): Testing prediction filter. Options are 'mix', 'wrong', 166 | and 'correct' Defaults to mix. 167 | apply_abs (bool, optional): applies absolute to scores. Some methods can find negative and 168 | positive contributing nodes/edges. Fidelity-wise, we only care the change amount. We can 169 | use this to get rid of negative contributing edges to improve accuracy. Defaults to 170 | True. 171 | 172 | Returns: 173 | pd.DataFrame: Fidelity table 174 | """ 175 | 176 | method_names = np.array(path_gen_fn(dataset_names[0]))[:,0].tolist() 177 | # create the table 178 | res_table = np.zeros((len(method_names), len(dataset_names)), dtype=object) 179 | 180 | # iterate over the datasets 181 | for i, dataset_name in enumerate(tqdm(dataset_names)): 182 | model, data, _ = get_model_data_config(dataset_name, load_pretrained=True, device=device) 183 | res_runs = [] 184 | not_founds = {name: [] for name in method_names} 185 | # iterate over the repeats 186 | for rep_num in range(1, num_repeats+1): 187 | f_paths = path_gen_fn(dataset_name, rep_num) 188 | res_runs.append([]) 189 | for name, path in f_paths: 190 | res_data, _ = read_results(path) 191 | if len(res_data) == 0: 192 | not_founds[name].append(path) 193 | res_runs[-1].append(-1.0) 194 | else: 195 | res_runs[-1].append(compute_fidelity_score(res_data, data, model, sparsity, 196 | score_type, topk, 197 | testing_pred=testing_pred, 198 | apply_abs=apply_abs)[0]) 199 | 200 | res_runs = np.array(res_runs) 201 | 202 | # compute the mean and std 203 | mean_res = np.mean(res_runs, axis=0) 204 | std_res = np.std(res_runs, axis=0) 205 | # fill the table with the results 206 | for k, n in enumerate(method_names): 207 | if len(not_founds[n]) > 0: 208 | res_table[k,i] = "N/A" 209 | print(f"Results not found for {not_founds[n]}. Check the paths.") 210 | else: 211 | res_table[k,i] = f"{mean_res[k]:.3f}\u00B1{std_res[k]:.3f}" 212 | res_table = pd.DataFrame(res_table, columns=dataset_names, index=method_names) 213 | return res_table 214 | -------------------------------------------------------------------------------- /baselines/methods/OrphicX/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HipGraph/GNNShap/f9672297394493ebe1ea9cf60bd14530e06d4916/baselines/methods/OrphicX/__init__.py -------------------------------------------------------------------------------- /baselines/methods/OrphicX/causaleffect.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | 6 | """ 7 | joint_uncond: 8 | Sample-based estimate of "joint, unconditional" causal effect, -I(alpha; Yhat). 9 | Inputs: 10 | - params['Nalpha'] monte-carlo samples per causal factor 11 | - params['Nbeta'] monte-carlo samples per noncausal factor 12 | - params['K'] number of causal factors 13 | - params['L'] number of noncausal factors 14 | - params['M'] number of classes (dimensionality of classifier output) 15 | - decoder 16 | - classifier 17 | - device 18 | Outputs: 19 | - negCausalEffect (sample-based estimate of -I(alpha; Yhat)) 20 | - info['xhat'] 21 | - info['yhat'] 22 | """ 23 | def joint_uncond(params, decoder, classifier, adj, feat, node_idx=None, act=torch.sigmoid, mu=0, std=1, device=None): 24 | eps = 1e-8 25 | I = 0.0 26 | q = torch.zeros(params['M'], device=device) 27 | feat = feat.repeat(params['Nalpha'] * params['Nbeta'], 1, 1) 28 | adj = adj.repeat(params['Nalpha'] * params['Nbeta'], 1, 1) 29 | if torch.is_tensor(mu): 30 | alpha_mu = mu[:,:params['K']] 31 | beta_mu = mu[:,params['K']:] 32 | 33 | alpha_std = std[:,:params['K']] 34 | beta_std = std[:,params['K']:] 35 | else: 36 | alpha_mu = 0 37 | beta_mu = 0 38 | alpha_std = 1 39 | beta_std = 1 40 | 41 | alpha = torch.randn((params['Nalpha'], adj.shape[-1], params['K']), device=device).mul(alpha_std).add_(alpha_mu).repeat(1,params['Nbeta'],1).view(params['Nalpha'] * params['Nbeta'] , adj.shape[-1], params['K']) 42 | beta = torch.randn((params['Nalpha'] * params['Nbeta'], adj.shape[-1], params['L']), device=device).mul(beta_std).add_(beta_mu) 43 | zs = torch.cat([alpha, beta], dim=-1) 44 | xhat = act(decoder(zs)) * adj 45 | if node_idx is None: 46 | logits = classifier(feat, xhat)[0] 47 | else: 48 | logits = classifier(feat, xhat)[0][:,node_idx,:] 49 | yhat = F.softmax(logits, dim=1).view(params['Nalpha'], params['Nbeta'] ,params['M']) 50 | p = yhat.mean(1) 51 | I = torch.sum(torch.mul(p, torch.log(p+eps)), dim=1).mean() 52 | q = p.mean(0) 53 | I = I - torch.sum(torch.mul(q, torch.log(q+eps))) 54 | return -I, None 55 | 56 | 57 | def beta_info_flow(params, decoder, classifier, adj, feat, node_idx=None, act=torch.sigmoid, mu=0, std=1, device=None): 58 | eps = 1e-8 59 | I = 0.0 60 | q = torch.zeros(params['M'], device=device) 61 | feat = feat.repeat(params['Nalpha'] * params['Nbeta'], 1, 1) 62 | adj = adj.repeat(params['Nalpha'] * params['Nbeta'], 1, 1) 63 | if torch.is_tensor(mu): 64 | alpha_mu = mu[:,:params['K']] 65 | beta_mu = mu[:,params['K']:] 66 | 67 | alpha_std = std[:,:params['K']] 68 | beta_std = std[:,params['K']:] 69 | else: 70 | alpha_mu = 0 71 | beta_mu = 0 72 | alpha_std = 1 73 | beta_std = 1 74 | 75 | alpha = torch.randn((params['Nalpha'] * params['Nbeta'], adj.shape[-1], params['K']), device=device).mul(alpha_std).add_(alpha_mu) 76 | beta = torch.randn((params['Nalpha'], adj.shape[-1], params['L']), device=device).mul(beta_std).add_(beta_mu).repeat(1,params['Nbeta'],1).view(params['Nalpha'] * params['Nbeta'] , adj.shape[-1], params['L']) 77 | zs = torch.cat([alpha, beta], dim=-1) 78 | xhat = act(decoder(zs)) * adj 79 | if node_idx is None: 80 | logits = classifier(feat, xhat)[0] 81 | else: 82 | logits = classifier(feat, xhat)[0][:,node_idx,:] 83 | yhat = F.softmax(logits, dim=1).view(params['Nalpha'], params['Nbeta'] ,params['M']) 84 | p = yhat.mean(1) 85 | I = torch.sum(torch.mul(p, torch.log(p+eps)), dim=1).mean() 86 | q = p.mean(0) 87 | I = I - torch.sum(torch.mul(q, torch.log(q+eps))) 88 | return -I, None 89 | for i in range(0, params['Nalpha']): 90 | # alpha = torch.randn((100, params['K']), device=device) 91 | # zs = torch.zeros((params['Nbeta'], 100, params['z_dim']), device=device) 92 | # for j in range(0, params['Nbeta']): 93 | # beta = torch.randn((100, params['L']), device=device) 94 | # zs[j,:,:params['K']] = alpha 95 | # zs[j,:,params['K']:] = beta 96 | 97 | alpha = torch.randn((100, params['K']), device=device).mul(alpha_std).add_(alpha_mu).unsqueeze(0).repeat(params['Nbeta'],1,1) 98 | beta = torch.randn((params['Nbeta'], 100, params['L']), device=device).mul(beta_std).add_(beta_mu) 99 | zs = torch.cat([alpha, beta], dim=-1) 100 | # decode and classify batch of Nbeta samples with same alpha 101 | xhat = torch.sigmoid(decoder(zs)) * adj 102 | yhat = F.softmax(classifier(feat, xhat)[0], dim=1) 103 | p = 1./float(params['Nbeta']) * torch.sum(yhat,0) # estimate of p(y|alpha) 104 | I = I + 1./float(params['Nalpha']) * torch.sum(torch.mul(p, torch.log(p+eps))) 105 | q = q + 1./float(params['Nalpha']) * p # accumulate estimate of p(y) 106 | I = I - torch.sum(torch.mul(q, torch.log(q+eps))) 107 | negCausalEffect = -I 108 | info = {"xhat" : xhat, "yhat" : yhat} 109 | return negCausalEffect, info 110 | 111 | 112 | """ 113 | joint_uncond_singledim: 114 | Sample-based estimate of "joint, unconditional" causal effect 115 | for single latent factor, -I(z_i; Yhat). Note the interpretation 116 | of params['Nalpha'] and params['Nbeta'] here: Nalpha is the number 117 | of samples of z_i, and Nbeta is the number of samples of the other 118 | latent factors. 119 | Inputs: 120 | - params['Nalpha'] 121 | - params['Nbeta'] 122 | - params['K'] 123 | - params['L'] 124 | - params['M'] 125 | - decoder 126 | - classifier 127 | - device 128 | - dim (i : compute -I(z_i; Yhat) **note: i is zero-indexed!**) 129 | Outputs: 130 | - negCausalEffect (sample-based estimate of -I(z_i; Yhat)) 131 | - info['xhat'] 132 | - info['yhat'] 133 | """ 134 | def joint_uncond_singledim(params, decoder, classifier, adj, feat, dim, node_idx=None, act=torch.sigmoid, mu=0, std=1, device=None): 135 | eps = 1e-8 136 | I = 0.0 137 | q = torch.zeros(params['M'], device=device) 138 | feat = feat.repeat(params['Nalpha'] * params['Nbeta'], 1, 1) 139 | adj = adj.repeat(params['Nalpha'] * params['Nbeta'], 1, 1) 140 | if torch.is_tensor(mu): 141 | alpha_mu = mu 142 | beta_mu = mu[:,dim] 143 | 144 | alpha_std = std 145 | beta_std = std[:,dim] 146 | else: 147 | alpha_mu = 0 148 | beta_mu = 0 149 | alpha_std = 1 150 | beta_std = 1 151 | 152 | alpha = torch.randn((params['Nalpha'], adj.shape[-1]), device=device).mul(alpha_std).add_(alpha_mu).repeat(1,params['Nbeta']).view(params['Nalpha'] * params['Nbeta'] , adj.shape[-1]) 153 | zs = torch.randn((params['Nalpha'] * params['Nbeta'], adj.shape[-1], params['z_dim']), device=device).mul(beta_std).add_(beta_mu) 154 | zs[:,:,dim] = alpha 155 | xhat = act(decoder(zs)) * adj 156 | if node_idx is None: 157 | logits = classifier(feat, xhat)[0] 158 | else: 159 | logits = classifier(feat, xhat)[0][:,node_idx,:] 160 | yhat = F.softmax(logits, dim=1).view(params['Nalpha'], params['Nbeta'] ,params['M']) 161 | p = yhat.mean(1) 162 | I = torch.sum(torch.mul(p, torch.log(p+eps)), dim=1).mean() 163 | q = p.mean(0) 164 | I = I - torch.sum(torch.mul(q, torch.log(q+eps))) 165 | return -I, None 166 | # eps = 1e-8 167 | # I = 0.0 168 | # q = torch.zeros(params['M']).to(device) 169 | # zs = np.zeros((params['Nalpha']*params['Nbeta'], params['z_dim'])) 170 | # for i in range(0, params['Nalpha']): 171 | # z_fix = np.random.randn(1) 172 | # zs = np.zeros((params['Nbeta'],params['z_dim'])) 173 | # for j in range(0, params['Nbeta']): 174 | # zs[j,:] = np.random.randn(params['K']+params['L']) 175 | # zs[j,dim] = z_fix 176 | # # decode and classify batch of Nbeta samples with same alpha 177 | # xhat = decoder(torch.from_numpy(zs).float().to(device)) 178 | # yhat = classifier(xhat)[0] 179 | # p = 1./float(params['Nbeta']) * torch.sum(yhat,0) # estimate of p(y|alpha) 180 | # I = I + 1./float(params['Nalpha']) * torch.sum(torch.mul(p, torch.log(p+eps))) 181 | # q = q + 1./float(params['Nalpha']) * p # accumulate estimate of p(y) 182 | # I = I - torch.sum(torch.mul(q, torch.log(q+eps))) 183 | # negCausalEffect = -I 184 | # info = {"xhat" : xhat, "yhat" : yhat} 185 | # return negCausalEffect, info -------------------------------------------------------------------------------- /baselines/methods/OrphicX/gae/layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch.nn.modules.module import Module 4 | from torch.nn.parameter import Parameter 5 | 6 | 7 | class GraphConvolution(Module): 8 | """ 9 | Simple GCN layer, similar to https://arxiv.org/abs/1609.02907 10 | """ 11 | 12 | def __init__(self, in_features, out_features, dropout=0., act=F.relu): 13 | super(GraphConvolution, self).__init__() 14 | self.in_features = in_features 15 | self.out_features = out_features 16 | self.dropout = dropout 17 | self.act = act 18 | self.linear = torch.nn.Linear(in_features, out_features, bias=False) 19 | self.reset_parameters() 20 | 21 | def reset_parameters(self): 22 | torch.nn.init.xavier_uniform_(self.linear.weight) 23 | 24 | def forward(self, input, adj): 25 | input = F.dropout(input, self.dropout, self.training) 26 | support = self.linear(input) 27 | output = torch.bmm(adj, support) 28 | output = self.act(output) 29 | return output 30 | 31 | def __repr__(self): 32 | return self.__class__.__name__ + ' (' \ 33 | + str(self.in_features) + ' -> ' \ 34 | + str(self.out_features) + ')' 35 | -------------------------------------------------------------------------------- /baselines/methods/OrphicX/gae/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from baselines.methods.OrphicX.gae.layers import GraphConvolution 6 | 7 | 8 | class VGAE(nn.Module): 9 | def __init__(self, input_feat_dim, hidden_dim1, output_dim, dropout): 10 | super(VGAE, self).__init__() 11 | self.gc1 = GraphConvolution(input_feat_dim, hidden_dim1, dropout, act=F.relu) 12 | self.gc2 = GraphConvolution(hidden_dim1, output_dim, dropout, act=lambda x: x) 13 | self.gc3 = GraphConvolution(hidden_dim1, output_dim, dropout, act=lambda x: x) 14 | self.dc = InnerProductDecoder(dropout, act=lambda x: x) 15 | 16 | def encode(self, x, adj): 17 | hidden1 = self.gc1(x, adj) 18 | return self.gc2(hidden1, adj), self.gc3(hidden1, adj) 19 | 20 | def reparameterize(self, mu, logvar): 21 | if self.training: 22 | std = torch.exp(logvar) 23 | eps = torch.randn_like(std) 24 | return eps.mul(std).add_(mu) 25 | else: 26 | return mu 27 | 28 | def forward(self, x, adj): 29 | mu, logvar = self.encode(x, adj) 30 | z = self.reparameterize(mu, logvar) 31 | return self.dc(z), mu, logvar 32 | 33 | 34 | class VGAE3(VGAE): 35 | def __init__(self, input_feat_dim, hidden_dim1, hidden_dim2, output_dim, dropout): 36 | super(VGAE, self).__init__() 37 | self.gc1 = GraphConvolution(input_feat_dim, hidden_dim1, dropout, act=F.relu) 38 | self.gc1_1 = GraphConvolution(hidden_dim1, hidden_dim2, dropout, act=F.relu) 39 | self.gc2 = GraphConvolution(hidden_dim2, output_dim, dropout, act=lambda x: x) 40 | self.gc3 = GraphConvolution(hidden_dim2, output_dim, dropout, act=lambda x: x) 41 | self.dc = InnerProductDecoder(dropout, act=lambda x: x) 42 | 43 | def encode(self, x, adj): 44 | hidden1 = self.gc1(x, adj) 45 | hidden2 = self.gc1_1(hidden1, adj) 46 | return self.gc2(hidden2, adj), self.gc3(hidden2, adj) 47 | 48 | 49 | class InnerProductDecoder(nn.Module): 50 | """Decoder for using inner product for prediction.""" 51 | 52 | def __init__(self, dropout, act=torch.sigmoid): 53 | super(InnerProductDecoder, self).__init__() 54 | self.dropout = dropout 55 | self.act = act 56 | 57 | def forward(self, z): 58 | z = F.dropout(z, self.dropout, training=self.training) 59 | adj = self.act(torch.bmm(z, torch.transpose(z, 1, 2))) 60 | return adj 61 | 62 | 63 | class VGAE2MLP(VGAE): # added by sakkas. This one uses 2 GCN layers instead three. 64 | def __init__(self, input_feat_dim, hidden_dim1, output_dim, decoder_hidden_dim1, decoder_hidden_dim2, K, dropout): 65 | super(VGAE2MLP, self).__init__(input_feat_dim, hidden_dim1, output_dim, dropout) 66 | self.dc = InnerProductDecoderMLP(output_dim, decoder_hidden_dim1, decoder_hidden_dim2, dropout, act=lambda x: x) 67 | 68 | 69 | class VGAE3MLP(VGAE3): 70 | def __init__(self, input_feat_dim, hidden_dim1, hidden_dim2, output_dim, decoder_hidden_dim1, decoder_hidden_dim2, K, dropout): 71 | super(VGAE3MLP, self).__init__(input_feat_dim, hidden_dim1, hidden_dim2, output_dim, dropout) 72 | self.dc = InnerProductDecoderMLP(output_dim, decoder_hidden_dim1, decoder_hidden_dim2, dropout, act=lambda x: x) 73 | 74 | 75 | class InnerProductDecoderMLP(nn.Module): 76 | """Decoder for using inner product for prediction.""" 77 | def __init__(self, input_dim, hidden_dim1, hidden_dim2, dropout, act=torch.sigmoid): 78 | super(InnerProductDecoderMLP, self).__init__() 79 | self.fc = nn.Linear(input_dim, hidden_dim1) 80 | self.fc2 = nn.Linear(hidden_dim1, hidden_dim2) 81 | self.dropout = dropout 82 | self.act = act 83 | self.reset_parameters() 84 | 85 | def reset_parameters(self): 86 | torch.nn.init.xavier_uniform_(self.fc.weight) 87 | torch.nn.init.zeros_(self.fc.bias) 88 | torch.nn.init.xavier_uniform_(self.fc2.weight) 89 | torch.nn.init.zeros_(self.fc2.bias) 90 | 91 | def forward(self, z): 92 | z = F.relu(self.fc(z)) 93 | z = torch.sigmoid(self.fc2(z)) 94 | z = F.dropout(z, self.dropout, training=self.training) 95 | adj = self.act(torch.bmm(z, torch.transpose(z, 1, 2))) 96 | return adj -------------------------------------------------------------------------------- /baselines/methods/OrphicX/gae/optimizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.modules.loss 3 | import torch.nn.functional as F 4 | 5 | 6 | def loss_function(preds, labels, mu, logvar, n_nodes, norm, pos_weight): 7 | bce = F.binary_cross_entropy_with_logits(preds.flatten(1).T, labels.flatten(1).T,pos_weight=pos_weight,reduce=False).mean(0) 8 | cost = norm * bce 9 | 10 | # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014 11 | # https://arxiv.org/abs/1312.6114 12 | # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2) 13 | KLD = -0.5 / n_nodes * torch.mean(torch.sum( 14 | 1 + 2 * logvar - mu.pow(2) - logvar.exp().pow(2), -1)) 15 | return cost + KLD 16 | -------------------------------------------------------------------------------- /baselines/methods/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HipGraph/GNNShap/f9672297394493ebe1ea9cf60bd14530e06d4916/baselines/methods/__init__.py -------------------------------------------------------------------------------- /baselines/methods/pgm_explainer.py: -------------------------------------------------------------------------------- 1 | # Source: https://github.com/DS3Lab/GraphFramEx/blob/main/code/explainer/pgmexplainer.py 2 | 3 | import numpy as np 4 | import pandas as pd 5 | import torch 6 | from pgmpy.estimators.CITests import chi_square 7 | from scipy.special import softmax 8 | from torch_geometric.utils import k_hop_subgraph 9 | 10 | ###### Node Classification ###### 11 | 12 | 13 | class PGM_Node_Explainer: 14 | def __init__(self, model, edge_index, edge_weight, X, num_layers, device=None, mode=0, print_result=1): 15 | self.model = model 16 | self.model.eval() 17 | self.edge_index = edge_index 18 | self.edge_weight = edge_weight 19 | self.X = X 20 | self.num_layers = num_layers 21 | self.device = device 22 | self.mode = mode 23 | self.print_result = print_result 24 | 25 | def perturb_features_on_node(self, feature_matrix, node_idx, random=0, mode=0): 26 | # return a random perturbed feature matrix 27 | # random = 0 for nothing, 1 for random. 28 | # mode = 0 for random 0-1, 1 for scaling with original feature 29 | 30 | X_perturb = feature_matrix 31 | if mode == 0: 32 | if random == 0: 33 | perturb_array = X_perturb[node_idx] 34 | elif random == 1: 35 | perturb_array = np.random.randint(2, size=X_perturb[node_idx].shape[0]) 36 | X_perturb[node_idx] = perturb_array 37 | elif mode == 1: 38 | if random == 0: 39 | perturb_array = X_perturb[node_idx] 40 | elif random == 1: 41 | perturb_array = np.multiply( 42 | X_perturb[node_idx], np.random.uniform(low=0.0, high=2.0, size=X_perturb[node_idx].shape[0]) 43 | ) 44 | X_perturb[node_idx] = perturb_array 45 | return X_perturb 46 | 47 | def explain(self, node_idx, target, num_samples=100, top_node=None, p_threshold=0.05, pred_threshold=0.1): 48 | neighbors, _, _, _ = k_hop_subgraph(node_idx, self.num_layers, self.edge_index) 49 | neighbors = neighbors.cpu().detach().numpy() 50 | 51 | if node_idx not in neighbors: 52 | neighbors = np.append(neighbors, node_idx) 53 | 54 | pred_torch = self.model(self.X, self.edge_index, self.edge_weight).cpu() 55 | soft_pred = np.asarray([softmax(np.asarray(pred_torch[node_].data)) for node_ in range(self.X.shape[0])]) 56 | 57 | pred_node = np.asarray(pred_torch[node_idx].data) 58 | label_node = np.argmax(pred_node) 59 | soft_pred_node = softmax(pred_node) 60 | 61 | Samples = [] 62 | Pred_Samples = [] 63 | 64 | for iteration in range(num_samples): 65 | 66 | X_perturb = self.X.cpu().detach().numpy() 67 | sample = [] 68 | for node in neighbors: 69 | seed = np.random.randint(2) 70 | if seed == 1: 71 | latent = 1 72 | X_perturb = self.perturb_features_on_node(X_perturb, node, random=seed) 73 | else: 74 | latent = 0 75 | sample.append(latent) 76 | 77 | X_perturb_torch = torch.tensor(X_perturb, dtype=torch.float).to(self.device) 78 | pred_perturb_torch = self.model(X_perturb_torch, self.edge_index, self.edge_weight).cpu() 79 | soft_pred_perturb = np.asarray( 80 | [softmax(np.asarray(pred_perturb_torch[node_].data)) for node_ in range(self.X.shape[0])] 81 | ) 82 | 83 | sample_bool = [] 84 | for node in neighbors: 85 | if (soft_pred_perturb[node, target] + pred_threshold) < soft_pred[node, target]: 86 | sample_bool.append(1) 87 | else: 88 | sample_bool.append(0) 89 | 90 | Samples.append(sample) 91 | Pred_Samples.append(sample_bool) 92 | 93 | Samples = np.asarray(Samples) 94 | Pred_Samples = np.asarray(Pred_Samples) 95 | Combine_Samples = Samples - Samples 96 | for s in range(Samples.shape[0]): 97 | Combine_Samples[s] = np.asarray( 98 | [Samples[s, i] * 10 + Pred_Samples[s, i] + 1 for i in range(Samples.shape[1])] 99 | ) 100 | 101 | data_pgm = pd.DataFrame(Combine_Samples) 102 | data_pgm = data_pgm.rename(columns={0: "A", 1: "B"}) # Trick to use chi_square test on first two data columns 103 | ind_ori_to_sub = dict(zip(neighbors, list(data_pgm.columns))) 104 | 105 | p_values = [] 106 | for node in neighbors: 107 | if node == node_idx: 108 | p = 0 # p<0.05 => we are confident that we can reject the null hypothesis (i.e. the prediction is the same after perturbing the neighbouring node 109 | # => this neighbour has no influence on the prediction - should not be in the explanation) 110 | else: 111 | chi2, p, _ = chi_square( 112 | ind_ori_to_sub[node], ind_ori_to_sub[node_idx], [], data_pgm, boolean=False, significance_level=0.05 113 | ) 114 | p_values.append(p) 115 | 116 | pgm_stats = dict(zip(neighbors, p_values)) 117 | 118 | node_attr = np.zeros(self.X.size(0)) 119 | for node, p_value in pgm_stats.items(): 120 | node_attr[node] = 1 - p_value 121 | # edge_mask = node_attr_to_edge(data.edge_index, node_attr) 122 | 123 | return node_attr -------------------------------------------------------------------------------- /baselines/methods/subgraphx_base.py: -------------------------------------------------------------------------------- 1 | # Source: https://github.com/DS3Lab/GraphFramEx/blob/main/code/explainer/shapley.py 2 | import copy 3 | import torch 4 | import numpy as np 5 | from scipy.special import comb 6 | from itertools import combinations 7 | import torch.nn.functional as F 8 | from torch_geometric.utils import to_networkx 9 | from torch_geometric.data import Data, Batch, InMemoryDataset 10 | from torch_geometric.loader import DataLoader 11 | 12 | 13 | def GnnNetsGC2valueFunc(gnnNets, target_class): 14 | def value_func(data): 15 | with torch.no_grad(): 16 | logits = gnnNets(data.x, data.edge_index) 17 | probs = F.softmax(logits, dim=-1) 18 | score = probs[:, target_class] 19 | return score 20 | 21 | return value_func 22 | 23 | 24 | def GnnNetsNC2valueFunc(gnnNets_NC, node_idx, target_class): 25 | def value_func(data): 26 | with torch.no_grad(): 27 | probs = gnnNets_NC(data.x, data.edge_index) 28 | # select the corresponding node prob through the node idx on all the sampling graphs 29 | batch_size = data.batch.max() + 1 30 | probs = probs.reshape(batch_size, -1, probs.shape[-1]) 31 | score = probs[:, node_idx, target_class] 32 | return score 33 | 34 | return value_func 35 | 36 | 37 | def get_graph_build_func(build_method): 38 | if build_method.lower() == "zero_filling": 39 | return graph_build_zero_filling 40 | elif build_method.lower() == "split": 41 | return graph_build_split 42 | else: 43 | raise NotImplementedError 44 | 45 | 46 | class MarginalSubgraphDataset(InMemoryDataset): 47 | def __init__(self, data, exclude_mask, include_mask, subgraph_build_func): 48 | self.num_nodes = data.num_nodes 49 | self.X = data.x 50 | self.edge_index = data.edge_index 51 | self.device = self.X.device 52 | 53 | self.label = data.y 54 | self.exclude_mask = torch.tensor(exclude_mask).type(torch.float32).to(self.device) 55 | self.include_mask = torch.tensor(include_mask).type(torch.float32).to(self.device) 56 | self.subgraph_build_func = subgraph_build_func 57 | 58 | def __len__(self): 59 | return self.exclude_mask.shape[0] 60 | 61 | def __getitem__(self, idx): 62 | exclude_graph_X, exclude_graph_edge_index = self.subgraph_build_func( 63 | self.X, self.edge_index, self.exclude_mask[idx] 64 | ) 65 | include_graph_X, include_graph_edge_index = self.subgraph_build_func( 66 | self.X, self.edge_index, self.include_mask[idx] 67 | ) 68 | exclude_data = Data(x=exclude_graph_X, edge_index=exclude_graph_edge_index) 69 | include_data = Data(x=include_graph_X, edge_index=include_graph_edge_index) 70 | return exclude_data, include_data 71 | 72 | def marginal_contribution(data: Data, exclude_mask: np.array, include_mask: np.array, value_func, subgraph_build_func): 73 | """Calculate the marginal value for each pair. Here exclude_mask and include_mask are node mask.""" 74 | marginal_subgraph_dataset = MarginalSubgraphDataset(data, exclude_mask, include_mask, subgraph_build_func) 75 | #dataloader = DataLoader(marginal_subgraph_dataset, batch_size=256, shuffle=False, num_workers=0) 76 | 77 | dataloader = DataLoader(marginal_subgraph_dataset, batch_size=1, shuffle=False, num_workers=0) 78 | 79 | marginal_contribution_list = [] 80 | 81 | for exclude_data, include_data in dataloader: 82 | exclude_values = value_func(exclude_data) 83 | include_values = value_func(include_data) 84 | margin_values = include_values - exclude_values 85 | marginal_contribution_list.append(margin_values) 86 | 87 | marginal_contributions = torch.cat(marginal_contribution_list, dim=0) 88 | return marginal_contributions 89 | 90 | 91 | def graph_build_zero_filling(X, edge_index, node_mask: np.array): 92 | """subgraph building through masking the unselected nodes with zero features""" 93 | ret_X = X * node_mask.unsqueeze(1) 94 | return ret_X, edge_index 95 | 96 | 97 | def graph_build_split(X, edge_index, node_mask: np.array): 98 | """subgraph building through spliting the selected nodes from the original graph""" 99 | ret_X = X 100 | row, col = edge_index 101 | edge_mask = (node_mask[row] == 1) & (node_mask[col] == 1) 102 | ret_edge_index = edge_index[:, edge_mask] 103 | return ret_X, ret_edge_index 104 | 105 | 106 | def l_shapley(coalition: list, data: Data, local_radius: int, value_func: str, subgraph_building_method="zero_filling"): 107 | """shapley value where players are local neighbor nodes""" 108 | graph = to_networkx(data) 109 | num_nodes = graph.number_of_nodes() 110 | subgraph_build_func = get_graph_build_func(subgraph_building_method) 111 | 112 | local_region = copy.copy(coalition) 113 | for k in range(local_radius - 1): 114 | k_neiborhoood = [] 115 | for node in local_region: 116 | k_neiborhoood += list(graph.neighbors(node)) 117 | local_region += k_neiborhoood 118 | local_region = list(set(local_region)) 119 | 120 | set_exclude_masks = [] 121 | set_include_masks = [] 122 | nodes_around = [node for node in local_region if node not in coalition] 123 | num_nodes_around = len(nodes_around) 124 | 125 | for subset_len in range(0, num_nodes_around + 1): 126 | node_exclude_subsets = combinations(nodes_around, subset_len) 127 | for node_exclude_subset in node_exclude_subsets: 128 | set_exclude_mask = np.ones(num_nodes) 129 | set_exclude_mask[local_region] = 0.0 130 | if node_exclude_subset: 131 | set_exclude_mask[list(node_exclude_subset)] = 1.0 132 | set_include_mask = set_exclude_mask.copy() 133 | set_include_mask[coalition] = 1.0 134 | 135 | set_exclude_masks.append(set_exclude_mask) 136 | set_include_masks.append(set_include_mask) 137 | 138 | exclude_mask = np.stack(set_exclude_masks, axis=0) 139 | include_mask = np.stack(set_include_masks, axis=0) 140 | num_players = len(nodes_around) + 1 141 | num_player_in_set = num_players - 1 + len(coalition) - (1 - exclude_mask).sum(axis=1) 142 | p = num_players 143 | S = num_player_in_set 144 | coeffs = torch.tensor(1.0 / comb(p, S) / (p - S + 1e-6)) 145 | 146 | marginal_contributions = marginal_contribution(data, exclude_mask, include_mask, value_func, subgraph_build_func) 147 | 148 | l_shapley_value = (marginal_contributions.squeeze().cpu() * coeffs).sum().item() 149 | return l_shapley_value 150 | 151 | 152 | def mc_shapley( 153 | coalition: list, data: Data, value_func: str, subgraph_building_method="zero_filling", sample_num=1000 154 | ) -> float: 155 | """monte carlo sampling approximation of the shapley value""" 156 | subset_build_func = get_graph_build_func(subgraph_building_method) 157 | 158 | num_nodes = data.num_nodes 159 | node_indices = np.arange(num_nodes) 160 | coalition_placeholder = num_nodes 161 | set_exclude_masks = [] 162 | set_include_masks = [] 163 | 164 | for example_idx in range(sample_num): 165 | subset_nodes_from = [node for node in node_indices if node not in coalition] 166 | random_nodes_permutation = np.array(subset_nodes_from + [coalition_placeholder]) 167 | random_nodes_permutation = np.random.permutation(random_nodes_permutation) 168 | split_idx = np.where(random_nodes_permutation == coalition_placeholder)[0][0] 169 | selected_nodes = random_nodes_permutation[:split_idx] 170 | set_exclude_mask = np.zeros(num_nodes) 171 | set_exclude_mask[selected_nodes] = 1.0 172 | set_include_mask = set_exclude_mask.copy() 173 | set_include_mask[coalition] = 1.0 174 | 175 | set_exclude_masks.append(set_exclude_mask) 176 | set_include_masks.append(set_include_mask) 177 | 178 | exclude_mask = np.stack(set_exclude_masks, axis=0) 179 | include_mask = np.stack(set_include_masks, axis=0) 180 | marginal_contributions = marginal_contribution(data, exclude_mask, include_mask, value_func, subset_build_func) 181 | mc_shapley_value = marginal_contributions.mean().item() 182 | 183 | return mc_shapley_value 184 | 185 | 186 | def mc_l_shapley( 187 | coalition: list, 188 | data: Data, 189 | local_radius: int, 190 | value_func: str, 191 | subgraph_building_method="zero_filling", 192 | sample_num=1000, 193 | ) -> float: 194 | """monte carlo sampling approximation of the l_shapley value""" 195 | graph = to_networkx(data) 196 | num_nodes = graph.number_of_nodes() 197 | subgraph_build_func = get_graph_build_func(subgraph_building_method) 198 | 199 | local_region = copy.copy(coalition) 200 | for k in range(local_radius - 1): 201 | k_neiborhoood = [] 202 | for node in local_region: 203 | k_neiborhoood += list(graph.neighbors(node)) 204 | local_region += k_neiborhoood 205 | local_region = list(set(local_region)) 206 | 207 | coalition_placeholder = num_nodes 208 | set_exclude_masks = [] 209 | set_include_masks = [] 210 | for example_idx in range(sample_num): 211 | subset_nodes_from = [node for node in local_region if node not in coalition] 212 | random_nodes_permutation = np.array(subset_nodes_from + [coalition_placeholder]) 213 | random_nodes_permutation = np.random.permutation(random_nodes_permutation) 214 | split_idx = np.where(random_nodes_permutation == coalition_placeholder)[0][0] 215 | selected_nodes = random_nodes_permutation[:split_idx] 216 | set_exclude_mask = np.ones(num_nodes) 217 | set_exclude_mask[local_region] = 0.0 218 | set_exclude_mask[selected_nodes] = 1.0 219 | set_include_mask = set_exclude_mask.copy() 220 | set_include_mask[coalition] = 1.0 221 | 222 | set_exclude_masks.append(set_exclude_mask) 223 | set_include_masks.append(set_include_mask) 224 | 225 | exclude_mask = np.stack(set_exclude_masks, axis=0) 226 | include_mask = np.stack(set_include_masks, axis=0) 227 | marginal_contributions = marginal_contribution(data, exclude_mask, include_mask, value_func, subgraph_build_func) 228 | 229 | mc_l_shapley_value = (marginal_contributions).mean().item() 230 | return mc_l_shapley_value 231 | 232 | 233 | def gnn_score(coalition: list, data: Data, value_func: str, subgraph_building_method="zero_filling") -> torch.Tensor: 234 | """the value of subgraph with selected nodes""" 235 | num_nodes = data.num_nodes 236 | subgraph_build_func = get_graph_build_func(subgraph_building_method) 237 | mask = torch.zeros(num_nodes).type(torch.float32).to(data.x.device) 238 | mask[coalition] = 1.0 239 | ret_x, ret_edge_index = subgraph_build_func(data.x, data.edge_index, mask) 240 | mask_data = Data(x=ret_x, edge_index=ret_edge_index) 241 | mask_data = Batch.from_data_list([mask_data]) 242 | score = value_func(mask_data) 243 | # get the score of predicted class for graph or specific node idx 244 | return score.item() 245 | 246 | 247 | def NC_mc_l_shapley( 248 | coalition: list, 249 | data: Data, 250 | local_radius: int, 251 | value_func: str, 252 | node_idx: int = -1, 253 | subgraph_building_method="zero_filling", 254 | sample_num=1000, 255 | ) -> float: 256 | """monte carlo approximation of l_shapley where the target node is kept in both subgraph""" 257 | graph = to_networkx(data) 258 | num_nodes = graph.number_of_nodes() 259 | subgraph_build_func = get_graph_build_func(subgraph_building_method) 260 | 261 | local_region = copy.copy(coalition) 262 | for k in range(local_radius - 1): 263 | k_neiborhoood = [] 264 | for node in local_region: 265 | k_neiborhoood += list(graph.neighbors(node)) 266 | local_region += k_neiborhoood 267 | local_region = list(set(local_region)) 268 | 269 | coalition_placeholder = num_nodes 270 | set_exclude_masks = [] 271 | set_include_masks = [] 272 | for example_idx in range(sample_num): 273 | subset_nodes_from = [node for node in local_region if node not in coalition] 274 | random_nodes_permutation = np.array(subset_nodes_from + [coalition_placeholder]) 275 | random_nodes_permutation = np.random.permutation(random_nodes_permutation) 276 | split_idx = np.where(random_nodes_permutation == coalition_placeholder)[0][0] 277 | selected_nodes = random_nodes_permutation[:split_idx] 278 | set_exclude_mask = np.ones(num_nodes) 279 | set_exclude_mask[local_region] = 0.0 280 | set_exclude_mask[selected_nodes] = 1.0 281 | if node_idx != -1: 282 | set_exclude_mask[node_idx] = 1.0 283 | set_include_mask = set_exclude_mask.copy() 284 | set_include_mask[coalition] = 1.0 # include the node_idx 285 | 286 | set_exclude_masks.append(set_exclude_mask) 287 | set_include_masks.append(set_include_mask) 288 | 289 | exclude_mask = np.stack(set_exclude_masks, axis=0) 290 | include_mask = np.stack(set_include_masks, axis=0) 291 | marginal_contributions = marginal_contribution(data, exclude_mask, include_mask, value_func, subgraph_build_func) 292 | 293 | mc_l_shapley_value = (marginal_contributions).mean().item() 294 | return mc_l_shapley_value 295 | 296 | 297 | def sparsity(coalition: list, data: Data, subgraph_building_method="zero_filling"): 298 | if subgraph_building_method == "zero_filling": 299 | return 1.0 - len(coalition) / data.num_nodes 300 | 301 | elif subgraph_building_method == "split": 302 | row, col = data.edge_index 303 | node_mask = torch.zeros(data.x.shape[0]) 304 | node_mask[coalition] = 1.0 305 | edge_mask = (node_mask[row] == 1) & (node_mask[col] == 1) 306 | return 1.0 - edge_mask.sum() / edge_mask.shape[0] -------------------------------------------------------------------------------- /baselines/run_gnnexplainer.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import pickle 3 | import time 4 | 5 | import torch 6 | from torch_geometric.explain import Explainer, GNNExplainer 7 | from tqdm.auto import tqdm 8 | 9 | from baselines.utils import result2dict 10 | from dataset.utils import get_model_data_config 11 | from gnnshap.utils import pruned_comp_graph 12 | from torch_geometric.utils import k_hop_subgraph 13 | 14 | parser = argparse.ArgumentParser() 15 | 16 | parser.add_argument('--dataset', default='Cora', type=str) 17 | parser.add_argument('--repeat', default=1, type=int) 18 | 19 | args = parser.parse_args() 20 | 21 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 22 | 23 | 24 | model, data, config = get_model_data_config(args.dataset, load_pretrained=True, 25 | device=device) 26 | 27 | 28 | model.eval() 29 | 30 | # target = data.y 31 | target = torch.argmax(model(data.x, data.edge_index), dim=-1) 32 | 33 | test_nodes = config['test_nodes'] 34 | 35 | 36 | for r in range(args.repeat): 37 | explainer = Explainer( 38 | model=model, 39 | algorithm=GNNExplainer(epochs=200), 40 | explanation_type="phenomenon", 41 | node_mask_type= None, 42 | edge_mask_type='object', 43 | model_config=dict( 44 | mode='multiclass_classification', 45 | task_level='node', #node level prediction. 46 | return_type='raw', 47 | ), 48 | ) 49 | results = [] 50 | 51 | for ind in tqdm(test_nodes, desc=f"GNNExplainer explanations - run{r+1}"): 52 | try: 53 | start_time = time.time() 54 | 55 | # explain just using the k-hop subgraph: original paper uses this, 56 | # but pyg implementation does not. 57 | (subset, sub_edge_index, sub_mapping, 58 | sub_edge_mask) = k_hop_subgraph(ind, config['num_hops'], 59 | data.edge_index, relabel_nodes=True) 60 | target2 = target[subset] 61 | explanation = explainer(data.x[subset], sub_edge_index, index=sub_mapping, 62 | target=target2) 63 | 64 | 65 | # save in our format: pruned edges 66 | (_, _, mapping2, mask2) = pruned_comp_graph(sub_mapping, config['num_hops'], 67 | sub_edge_index, relabel_nodes=False) 68 | edge_importance = explanation.edge_mask[mask2].detach().cpu().numpy() 69 | results.append(result2dict(ind, edge_importance, time.time() - start_time)) 70 | except Exception as e: 71 | print(f"Node {ind} failed!") 72 | rfile = f'{config["results_path"]}/{args.dataset}_GNNExplainer_run{r+1}.pkl' 73 | with open(rfile, 'wb') as pkl_file: 74 | pickle.dump([results, 0], pkl_file) 75 | -------------------------------------------------------------------------------- /baselines/run_graphsvx.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | import torch 4 | 5 | warnings.filterwarnings("ignore") 6 | 7 | import pickle 8 | import time 9 | 10 | from tqdm.auto import tqdm 11 | 12 | from baselines.methods.graphsvx import GraphSVX, arg_parse 13 | from baselines.utils import result2dict 14 | from dataset.utils import get_model_data_config 15 | 16 | 17 | def main(): 18 | 19 | args = arg_parse() 20 | 21 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 22 | 23 | model, data, config = get_model_data_config(args.dataset, load_pretrained=True, device=device, 24 | log_softmax_return=True) 25 | args.gpu = True 26 | args.num_samples= config['graphSVX_args']['num_samples'] 27 | args.hops=config['num_hops'] 28 | args.hv='compute_pred' 29 | args.coal = "SmarterSeparate" 30 | args.S = config['graphSVX_args']['S'] 31 | args.regu=0 # only cosidering graph structure. Node features are not considered. 32 | args.feat = config['graphSVX_args']['feat'] 33 | data.name = args.dataset 34 | data.num_classes = (max(data.y)+1).item() 35 | 36 | num_samples = args.num_samples 37 | 38 | 39 | test_nodes = config['test_nodes'] 40 | 41 | 42 | for r in range(args.repeat): 43 | # Explain it with GraphSVX 44 | explainer = GraphSVX(data, model, args.gpu) 45 | results = [] 46 | 47 | 48 | for ind in tqdm(test_nodes, 49 | desc=f"GraphSVX explanations - run{r+1} - nsamp:{num_samples}"): 50 | try: 51 | start_time = time.time() 52 | explanations = explainer.explain([ind], args.hops, num_samples, args.info, 53 | args.multiclass, args.fullempty, args.S, 54 | args.hv, args.feat, args.coal, args.g, 55 | args.regu, False) 56 | results.append(result2dict(ind, explanations[0], time.time() - start_time)) 57 | except Exception as e: 58 | print(f"Node {ind} failed!") 59 | print(e) 60 | 61 | rfile = (f'{config["results_path"]}/{data.name}_GraphSVX_{args.coal}_{args.S }_' 62 | f'{num_samples}_run{r+1}.pkl') 63 | with open(rfile, 'wb') as pkl_file: 64 | pickle.dump([results, 0], pkl_file) 65 | 66 | if __name__ == "__main__": 67 | main() 68 | -------------------------------------------------------------------------------- /baselines/run_orphicx.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import pickle 3 | import time 4 | import numpy as np 5 | 6 | import torch 7 | from tqdm.auto import tqdm 8 | 9 | from baselines.methods.OrphicX.orphicx import OrphicXExplainer 10 | from baselines.utils import result2dict 11 | from dataset.utils import get_model_data_config 12 | from gnnshap.utils import pruned_comp_graph 13 | from torch_geometric.utils import dense_to_sparse 14 | from torch_geometric.data import Data 15 | from torch_geometric.loader import DataLoader 16 | from torch_geometric.utils import to_dense_adj, k_hop_subgraph 17 | 18 | parser = argparse.ArgumentParser() 19 | 20 | parser.add_argument('--dataset', default='Cora', type=str) 21 | parser.add_argument('--repeat', default=1, type=int) 22 | parser.add_argument('--epoch', default=50, type=int) 23 | 24 | # reduced the number of samples for alpha and beta due to gpu memory limitations 25 | parser.add_argument('--Nalpha', type=int, default=15, help='Number of samples of alpha.') 26 | parser.add_argument('--Nbeta', type=int, default=50, help='Number of samples of beta.') 27 | 28 | args = parser.parse_args() 29 | 30 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 31 | 32 | 33 | model, data, config = get_model_data_config(args.dataset, load_pretrained=True, device=device) 34 | 35 | target = torch.argmax(model(data.x, data.edge_index), dim=-1) 36 | num_classes = data.y.max().item() + 1 37 | num_features = data.x.shape[1] 38 | 39 | class WrapperModel(torch.nn.Module): 40 | """We need to wrap the model in a class to match input and output dimensions of OrphicXExplainer. 41 | """ 42 | def __init__(self, model): 43 | super().__init__() 44 | self.model = model 45 | 46 | def forward(self, x, adj): 47 | if x.size(0) == 1: 48 | edge_index, weight = dense_to_sparse(adj) 49 | return self.model(x[0], edge_index, weight).unsqueeze(0).unsqueeze(0) 50 | else: 51 | out = torch.zeros((x.shape[0], x.shape[1], num_classes)).to(device) 52 | # for i in range(x.shape[1]): 53 | # edge_index, weight = dense_to_sparse(adj[i]) 54 | # out[i] = self.model(x[i], edge_index, weight) 55 | # return out.unsqueeze(0) 56 | 57 | # batched inference. this speeds up the process by a lot. 58 | # without batched inference, the process takes a lot of time. 59 | batch_size = 128 60 | num_nodes = x.shape[1] 61 | for i in range(0, x.shape[0], batch_size): 62 | data_list = [] 63 | for j in range(batch_size): 64 | if i+j >= x.shape[0]: 65 | break 66 | edge_index, weight = dense_to_sparse(adj[i+j]) 67 | data_list.append(Data(x=x[i+j], edge_index=edge_index, edge_weight=weight)) 68 | 69 | loader = DataLoader(data_list, batch_size=len(data_list)) 70 | batched_data = next(iter(loader)) 71 | out[i:i+batch_size] = self.model(batched_data.x, batched_data.edge_index, 72 | batched_data.edge_weight 73 | ).reshape(-1, num_nodes, num_classes) 74 | 75 | return out.unsqueeze(0) 76 | 77 | model = WrapperModel(model).to(device) 78 | 79 | result_file = f'{config["results_path"]}/{args.dataset}_OrphicX.txt' 80 | 81 | 82 | model.eval() 83 | 84 | test_nodes = config['test_nodes'] 85 | 86 | 87 | for r in range(args.repeat): 88 | results = [] 89 | explainer = OrphicXExplainer(data, model, config['num_hops'], device=device) 90 | start = time.time() 91 | explainer.train(args.epoch) 92 | train_time = time.time() - start 93 | for ind in tqdm(test_nodes, desc=f"OrphicX Individual explanations - run{r+1}"): 94 | try: 95 | start_time = time.time() 96 | explanation = explainer.explain(ind) 97 | 98 | 99 | (subset, sub_edge_index, sub_mapping, 100 | sub_edge_mask) = pruned_comp_graph(ind, config['num_hops'], data.edge_index, 101 | relabel_nodes=True) 102 | 103 | exp_results = np.zeros(sub_edge_index.size(1)) 104 | for i in range(sub_edge_index.size(1)): 105 | exp_results[i] = explanation[sub_edge_index[0, i], sub_edge_index[1, i]].item() 106 | 107 | results.append(result2dict(ind, exp_results, time.time() - start_time)) 108 | 109 | except Exception as e: 110 | print(f"Node {ind} has failed. General error: {e}") 111 | 112 | 113 | rfile = f'{config["results_path"]}/{args.dataset}_OrphicX_{args.epoch}_run{r+1}.pkl' 114 | with open(rfile, 'wb') as pkl_file: 115 | pickle.dump([results, train_time], pkl_file) 116 | 117 | 118 | 119 | -------------------------------------------------------------------------------- /baselines/run_pgexplainer.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import pickle 3 | import time 4 | import torch 5 | 6 | from torch_geometric.explain import Explainer, PGExplainer 7 | from tqdm.auto import tqdm 8 | 9 | from dataset.utils import get_model_data_config 10 | from gnnshap.utils import pruned_comp_graph 11 | from baselines.utils import result2dict 12 | 13 | parser = argparse.ArgumentParser() 14 | 15 | parser.add_argument('--dataset', default='Cora', type=str) 16 | parser.add_argument('--repeat', default=1, type=int) 17 | 18 | args = parser.parse_args() 19 | 20 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 21 | 22 | 23 | model, data, config = get_model_data_config(args.dataset, load_pretrained=True, device=device) 24 | 25 | 26 | model.eval() 27 | 28 | train_nodes = data.train_mask.nonzero(as_tuple=False).cpu().numpy().flatten().tolist() 29 | test_nodes = config['test_nodes'] 30 | 31 | # use the predictions as the target 32 | target = torch.argmax(model(data.x, data.edge_index), dim=-1) 33 | 34 | for r in range(args.repeat): 35 | explainer = Explainer( 36 | model=model, 37 | algorithm=PGExplainer(epochs=20, lr=0.005, device=device), 38 | explanation_type='phenomenon', # it only supports this. no model option 39 | edge_mask_type='object', 40 | model_config=dict( 41 | mode='multiclass_classification', 42 | task_level='node', #node level prediction. 43 | return_type='raw',),) 44 | 45 | train_start = time.time() 46 | # Train the explainer. 47 | for epoch in tqdm(range(20), desc="PGExplainer Model Training"): 48 | if len(train_nodes) > 500: 49 | # Randomly sample 500 nodes to train against. 50 | tr_nodes = torch.randperm(data.num_nodes, device=device)[:500] 51 | else: 52 | tr_nodes = train_nodes 53 | for index in tr_nodes: # train on a subset of the training nodes 54 | loss = explainer.algorithm.train(epoch, model, data.x, data.edge_index, 55 | target=target, index=int(index)) 56 | train_time = time.time() - train_start 57 | 58 | results = [] 59 | for ind in tqdm(test_nodes, desc=f"PGExplainer explanations - run{r+1}"): 60 | start_time = time.time() 61 | explanation = explainer(data.x, data.edge_index, index=ind, 62 | edge_weight=data.edge_weight, target=target) 63 | (subset, sub_edge_index, sub_mapping, 64 | sub_edge_mask) = pruned_comp_graph(ind, config['num_hops'], 65 | data.edge_index, relabel_nodes=False) 66 | edge_importance = explanation.edge_mask[sub_edge_mask].detach().cpu().numpy() 67 | results.append(result2dict(ind, edge_importance, time.time() - start_time)) 68 | 69 | rfile = f'{config["results_path"]}/{args.dataset}_PGExplainer_run{r+1}.pkl' 70 | with open(rfile, 'wb') as pkl_file: 71 | pickle.dump([results, train_time], pkl_file) 72 | -------------------------------------------------------------------------------- /baselines/run_pgmexplainer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from baselines.methods.pgm_explainer import PGM_Node_Explainer 3 | import numpy as np 4 | import time 5 | from tqdm.auto import tqdm 6 | from baselines.utils import result2dict 7 | import pickle 8 | from gnnshap.utils import pruned_comp_graph 9 | from dataset.utils import get_model_data_config 10 | import argparse 11 | 12 | parser = argparse.ArgumentParser() 13 | 14 | parser.add_argument('--dataset', default='Cora', type=str) 15 | parser.add_argument('--repeat', default=1, type=int) 16 | 17 | args = parser.parse_args() 18 | 19 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 20 | 21 | 22 | model, data, config = get_model_data_config(args.dataset, load_pretrained=True, device=device, 23 | log_softmax_return=True) 24 | 25 | 26 | model.eval() 27 | 28 | # use the predictions as the target 29 | target = torch.argmax(model(data.x, data.edge_index), dim=-1) 30 | test_nodes = config['test_nodes'] 31 | 32 | for r in range(args.repeat): 33 | results = [] 34 | pgm_explainer = PGM_Node_Explainer(model, data.edge_index, None, data.x, 35 | num_layers=config['num_hops'], device=device, mode=0, 36 | print_result=1) 37 | 38 | for ind in tqdm(test_nodes, desc=f"PGMExp explanations - run{r+1}"): 39 | start_time = time.time() 40 | explanation = pgm_explainer.explain(ind, target=target[ind], num_samples=100, 41 | top_node=None) 42 | subset, e_index, _, _ = pruned_comp_graph(ind, config['num_hops'], 43 | data.edge_index) 44 | 45 | results.append(result2dict(ind, np.array(explanation[subset.cpu().numpy()]), 46 | time.time() - start_time)) 47 | rfile = f'{config["results_path"]}/{args.dataset}_PGMExplainer_run{r+1}.pkl' 48 | with open(rfile, 'wb') as pkl_file: 49 | pickle.dump([results, 0], pkl_file) 50 | -------------------------------------------------------------------------------- /baselines/run_sa.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import pickle 3 | import time 4 | 5 | import torch 6 | from captum.attr import Saliency 7 | from tqdm.auto import tqdm 8 | 9 | from baselines.utils import result2dict 10 | from dataset.utils import get_model_data_config 11 | from gnnshap.utils import pruned_comp_graph 12 | 13 | parser = argparse.ArgumentParser() 14 | 15 | parser.add_argument('--dataset', default='Cora', type=str) 16 | parser.add_argument('--repeat', default=5, type=int) 17 | 18 | args = parser.parse_args() 19 | 20 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 21 | 22 | 23 | model, data, config = get_model_data_config(args.dataset, load_pretrained=True, 24 | device=device) 25 | 26 | 27 | target = torch.argmax(model(data.x, data.edge_index), dim=-1) 28 | 29 | 30 | #model.eval() 31 | test_nodes = config['test_nodes'] 32 | 33 | def model_forward_node(x, model, edge_index, node_idx): 34 | out = model(x, edge_index).softmax(dim=-1) 35 | return out[[node_idx]] 36 | 37 | for r in range(args.repeat): 38 | results = [] 39 | for i, ind in tqdm(enumerate(test_nodes), desc=f"SA explanations - run{r+1}"): 40 | start_time = time.time() 41 | explainer = Saliency(model_forward_node) 42 | 43 | (subset, sub_edge_index, sub_mapping, 44 | sub_edge_mask) = pruned_comp_graph(ind, config['num_hops'], data.edge_index, 45 | relabel_nodes=True) 46 | x_mask = data.x[subset].clone().requires_grad_(True).to(device) 47 | saliency_mask = explainer.attribute( 48 | x_mask, target=target[i].item(), 49 | additional_forward_args=(model, sub_edge_index, sub_mapping.item()), abs=False) 50 | 51 | node_importance = saliency_mask.cpu().numpy().sum(axis=1) 52 | results.append(result2dict(ind, node_importance, time.time() - start_time)) 53 | 54 | 55 | rfile = f'{config["results_path"]}/{args.dataset}_SA_run{r+1}.pkl' 56 | with open(rfile, 'wb') as pkl_file: 57 | pickle.dump([results, 0], pkl_file) 58 | -------------------------------------------------------------------------------- /baselines/run_subgraphx.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import pickle 3 | import time 4 | 5 | import torch 6 | from tqdm.auto import tqdm 7 | 8 | from baselines.methods.subgraphx import SubgraphX 9 | from baselines.utils import result2dict 10 | from dataset.utils import get_model_data_config 11 | from gnnshap.utils import pruned_comp_graph 12 | 13 | parser = argparse.ArgumentParser() 14 | 15 | parser.add_argument('--dataset', default='Cora', type=str) 16 | parser.add_argument('--repeat', default=1, type=int) 17 | 18 | args = parser.parse_args() 19 | 20 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 21 | 22 | 23 | model, data, config = get_model_data_config(args.dataset, load_pretrained=True, device=device, 24 | log_softmax_return=True) 25 | 26 | 27 | result_file = f'{config["results_path"]}/{args.dataset}_SubgraphX.txt' 28 | 29 | 30 | model.eval() 31 | 32 | 33 | target = torch.argmax(model(data.x, data.edge_index), dim=-1) 34 | 35 | test_nodes = config['test_nodes'] 36 | 37 | max_nodes = config['subgraphx_args']['max_nodes'] 38 | print(f"max nodes: {max_nodes}") 39 | 40 | #TODO: It's is working but number of nodes in the explanation is a parameter. Can't control edge sparsity. 41 | # It also requires a new run for each node based sparsity. 42 | 43 | for r in range(args.repeat): 44 | results = [] 45 | explainer = SubgraphX(model, num_classes=config['num_classes'], 46 | num_hops=config['num_hops'], explain_graph=False, device=device, high2low=True, 47 | reward_method='nc_mc_l_shapley', rollout=20, min_atoms=4, expand_atoms=14, 48 | sample_num=50, local_radius=4, subgraph_building_method='zero_filling') 49 | 50 | 51 | 52 | for ind in tqdm(test_nodes, desc=f"SubgraphX Individual explanations - run{r+1}"): 53 | start_time = time.time() 54 | explanation = explainer.explain(data.x, data.edge_index, edge_weight=None, 55 | label=target[ind].item(), node_idx=ind, max_nodes=max_nodes) 56 | (subset, sub_edge_index, sub_mapping, 57 | sub_edge_mask) = pruned_comp_graph(ind, config['num_hops'], data.edge_index, 58 | relabel_nodes=False) 59 | edge_importance = explanation[sub_edge_mask.detach().cpu().numpy()] 60 | results.append(result2dict(ind, edge_importance, time.time() - start_time)) 61 | rfile = f'{config["results_path"]}/{args.dataset}_SubgraphX_run{r+1}.pkl' 62 | with open(rfile, 'wb') as pkl_file: 63 | pickle.dump([results, 0], pkl_file) 64 | -------------------------------------------------------------------------------- /baselines/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def result2dict(node_id: int, scores: np.array, comp_time: float) -> dict: 4 | """Converts an explanation result to a dictionary 5 | 6 | Args: 7 | node_id (int): node id 8 | scores (np.array): importance scores 9 | comp_time (float): computation time 10 | 11 | Returns: 12 | dict: result as dictionary 13 | """ 14 | return {'node_id': node_id, 'scores': scores, 'num_players': len(scores), 'time': comp_time} 15 | -------------------------------------------------------------------------------- /cppextension/cudagnnshap.cu: -------------------------------------------------------------------------------- 1 | // This is GPU sampler of GNNShap. 2 | // It is compiled as a shared library during the first run and called by the main GNNShap code. 3 | 4 | 5 | #include 6 | #include 7 | #include 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | 16 | 17 | #include 18 | #include 19 | 20 | 21 | 22 | namespace py = pybind11; 23 | 24 | 25 | // ######################DEVICE FUNCTIONS######################### 26 | 27 | 28 | /* combination generation algorithm on cuda from 29 | https://forums.developer.nvidia.com/t/enumerating-combinations/19980/4 30 | */ 31 | 32 | //calculate binomial coefficient 33 | __inline__ __host__ __device__ unsigned int BinCoef(int n, int r) { 34 | unsigned int b = 1; 35 | for(int i=0;i<=(r-1);i++) { 36 | b= (b * (n-i))/(i+1); 37 | } 38 | return(b); 39 | 40 | //the following is slower on CPU. I didn't test on GPU. 41 | //lround(std::exp( std::lgamma(n+1)-std::lgamma(n-k+1)-std::lgamma(k+1))); 42 | } 43 | 44 | //assigns the rth combination of n choose k to the array maskMat's row 45 | __device__ int rthComb(unsigned int r, bool* rowPtr, bool* symRowPtr, int n, int k) { 46 | int x = 1; 47 | unsigned int y; 48 | for(int i=1; i <= k; i++) { 49 | y = BinCoef(n-x,k-i); 50 | while (y <= r) { 51 | r = r - y; 52 | x = x+1; 53 | if (x > n) 54 | return 0; 55 | y= BinCoef(n-x,k-i); 56 | } 57 | rowPtr[x-1] = true; 58 | symRowPtr[x-1] = false; 59 | x = x + 1; 60 | } 61 | return 1; 62 | 63 | } 64 | 65 | __global__ void cudaSampleGenerator(int nPlayers, int nHalfSamp, int* sizeLookup, 66 | bool* maskMat, int rndStartInd, 67 | int* devStartInds, int* devShuffleArr) { 68 | 69 | int tid = blockIdx.x * blockDim.x + threadIdx.x; 70 | int nTotalThreads = gridDim.x * gridDim.y * blockDim.x * blockDim.y; 71 | int chunk = nHalfSamp/nTotalThreads + 1; 72 | 73 | if (tid*chunk >= nHalfSamp) 74 | return; 75 | 76 | int* localShuffleArr = devShuffleArr + tid * nPlayers; 77 | 78 | 79 | 80 | 81 | bool* mask = maskMat + chunk * tid * nPlayers; 82 | // symmetrics starts from the middle 83 | bool* maskSym = maskMat + nHalfSamp * nPlayers + chunk * tid * nPlayers; 84 | int i, k; 85 | 86 | // nchoosek based sampling 87 | int fullCoalTaskEndInd = min(chunk*(tid + 1), rndStartInd); 88 | int rndTaskEndInd = min(chunk*(tid + 1), nHalfSamp); 89 | 90 | for (i = tid * chunk; i < fullCoalTaskEndInd; i++) { 91 | k = sizeLookup[i]; 92 | rthComb( i - devStartInds[k-1], mask, maskSym, nPlayers, k); //generate combination 93 | mask += nPlayers; //move pointer to the next combination 94 | maskSym += nPlayers; 95 | } 96 | 97 | 98 | if (rndTaskEndInd <= fullCoalTaskEndInd) 99 | return; 100 | 101 | 102 | // random sampling 103 | // do random sampling here! 104 | 105 | curandState_t state; 106 | curand_init(1234, tid, 0, &state); 107 | int temp, y, z; 108 | for (z = 0; z < nPlayers; z++) 109 | localShuffleArr[z] = z; 110 | 111 | for (;i < rndTaskEndInd; i++) { 112 | //knuthShuffle Algorithm 113 | for (z = nPlayers - 1; z > 0; z--) { 114 | y = (int)(curand_uniform(&state)*(z + .999999)); 115 | temp = localShuffleArr[z]; 116 | localShuffleArr[z] = localShuffleArr[y]; 117 | localShuffleArr[y] = temp; 118 | } 119 | 120 | for (int j = 0; j < sizeLookup[i]; j++){ 121 | mask[localShuffleArr[j]] = true; 122 | maskSym[localShuffleArr[j]] = false; 123 | } 124 | 125 | mask += nPlayers; //move pointer to the next combination 126 | maskSym += nPlayers; 127 | } 128 | } 129 | 130 | // ######################HOST FUNCTIONS######################### 131 | 132 | __inline__ double arraySum(double* arr, int n) { 133 | double sum = 0; 134 | for (int i = 0; i < n; i++) 135 | sum += arr[i]; 136 | return sum; 137 | } 138 | 139 | __inline__ int arraySum(int* arr, int n) { 140 | int sum = 0; 141 | for (int i = 0; i < n; i++) 142 | sum += arr[i]; 143 | return sum; 144 | } 145 | 146 | __inline__ void normalizeArray(double* arr, int n){ 147 | double sum = arraySum(arr, n); 148 | for (int i = 0; i < n; i++) 149 | arr[i] /= sum; 150 | } 151 | 152 | __inline__ void divideArray(double* arr, int n, double divisor){ 153 | for (int i = 0; i < n; i++) 154 | arr[i] /= divisor; 155 | } 156 | 157 | void cudaSample(torch:: Tensor maskMatTensor, torch::Tensor kWTensor, int nPlayers, 158 | int nSamples, int nBlocks = 1, int nThreads = 6) { 159 | 160 | int nHalfSamp = nSamples/2 ; // we will get symmetric samples. no need to compute the other half 161 | int nSamplesLeft = nHalfSamp; 162 | 163 | int* sizeLookup = new int[nSamplesLeft]; 164 | double* kernelWeights = new double[nSamples]; 165 | 166 | int* tmpSLookPtr = sizeLookup; 167 | double* tmpKWPointer = kernelWeights; 168 | 169 | 170 | int nSubsetSizes = ceil((nPlayers - 1) / 2.0); // number of subset sizes 171 | // coalition size in the middle not a paired subset 172 | // if nPlayers=4, 1 and 3 are pairs, 2 doesn't have a pair 173 | int nPairedSubsetSizes = floor((nPlayers - 1) / 2.0); 174 | 175 | // number of samples for each subset size 176 | int* coalSizeNSamples = new int[nSubsetSizes]; 177 | 178 | int* startInds = new int[nSubsetSizes+1]; // coalition size sample start indices 179 | 180 | // weight vector to distribute samples 181 | double* weightVect = new double[nSubsetSizes]; 182 | 183 | 184 | // compute weight vector 185 | for (int i = 1; i <= nSubsetSizes; i++) { 186 | weightVect[i-1] = ((nPlayers - 1.0) / (i * (nPlayers - i))); 187 | } 188 | 189 | 190 | // we will get the symmetric except in the middle 191 | if (nSubsetSizes != nPairedSubsetSizes) 192 | weightVect[nPairedSubsetSizes] /= 2; 193 | 194 | // normalize weight vector to sum to 1 195 | normalizeArray(weightVect, nSubsetSizes); 196 | 197 | double * remWeightVect = new double[nSubsetSizes]; 198 | std::copy(weightVect, weightVect + nSubsetSizes, remWeightVect); 199 | 200 | // std::cout << "initial remWeightVect: "; 201 | // for (int b = 0; b < nSubsetSizes; b++){ 202 | // std::cout << remWeightVect[b] << " "; 203 | // } 204 | // std::cout << std::endl; 205 | 206 | double sumKW = 0; 207 | startInds[0] = 0; 208 | 209 | // check if we have enough samples to iterate all coalitions for each subset size. 210 | int nFullSubsets = 0; 211 | long nSubsets; 212 | for(int i = 1; i <= nSubsetSizes; i++){ 213 | nSubsets = BinCoef(nPlayers, i);//nChoosek(nPlayers, i); 214 | 215 | if (i > nPairedSubsetSizes){ 216 | if (nSubsets % 2 != 0) 217 | std::cout << "Error: nSubsets is not even. Be careful!!!!" << std::endl; 218 | nSubsets /= 2; 219 | // std::cout << "inside if middle full sample control case" << std::endl; 220 | } 221 | 222 | if (nSamplesLeft * remWeightVect[i-1] + 1e-8 >= nSubsets){ 223 | nFullSubsets++; 224 | coalSizeNSamples[i-1] = nSubsets; 225 | nSamplesLeft -= nSubsets; 226 | startInds[i] = startInds[i-1] + nSubsets; 227 | 228 | sumKW += (50*weightVect[i-1]); 229 | std::fill(tmpKWPointer, tmpKWPointer + nSubsets, (50*weightVect[i-1]) / nSubsets); 230 | std::fill(tmpSLookPtr, tmpSLookPtr + nSubsets, i); 231 | 232 | tmpKWPointer += nSubsets; 233 | tmpSLookPtr += nSubsets; 234 | 235 | if (remWeightVect[i-1] < 1.0){ 236 | divideArray(remWeightVect + i-1, nSubsetSizes -i+1, 1-remWeightVect[i-1]); 237 | } 238 | 239 | } 240 | else{ 241 | break; 242 | } 243 | } 244 | 245 | // use this if we want equal weights for each randomly sampled coalitions. 246 | double remKw = (50.0 - sumKW)/nSamplesLeft; 247 | std::fill(tmpKWPointer, tmpKWPointer + nSamplesLeft, remKw); 248 | tmpKWPointer += nSamplesLeft; 249 | 250 | int rndStartInd = nHalfSamp - nSamplesLeft; 251 | 252 | // if we have enough samples to iterate all coalitions for each subset size, then we are done. 253 | if (nFullSubsets != nSubsetSizes){ 254 | int remSamples = nSamplesLeft; 255 | bool roundUp = true; 256 | for (int i = nFullSubsets; i < nSubsetSizes - 1; i++){ 257 | 258 | // extra check to avoid negative number of samples for the middle coal. Might be redundant 259 | if (nSamplesLeft <= 0) { 260 | nSamplesLeft = 0; 261 | break; 262 | } 263 | 264 | if (roundUp) 265 | coalSizeNSamples[i] = min((int)ceil(remSamples * remWeightVect[i]), nSamplesLeft); 266 | else 267 | coalSizeNSamples[i] = min((int)floor(remSamples * remWeightVect[i]), nSamplesLeft); 268 | nSamplesLeft -= coalSizeNSamples[i]; 269 | 270 | // if we want different weights for each randomly sampled coalition sizes, we can use this. 271 | // However, experiments show that it doesn't make a difference. 272 | //std::fill(tmpKWPointer, tmpKWPointer + coalSizeNSamples[i], (50*weightVect[i]) / coalSizeNSamples[i]); 273 | //tmpKWPointer += coalSizeNSamples[i]; 274 | 275 | std::fill(tmpSLookPtr, tmpSLookPtr + coalSizeNSamples[i], i+1); 276 | tmpSLookPtr += coalSizeNSamples[i]; 277 | 278 | startInds[i+1] = startInds[i] + coalSizeNSamples[i]; 279 | 280 | roundUp = !roundUp; 281 | } 282 | //add the remaining samples to the middle coal. I removed the middle coal from the loop above 283 | // to avoid negative number of samples for the middle coal. 284 | coalSizeNSamples[nSubsetSizes-1] = nSamplesLeft; 285 | 286 | //startInds[nSubsetSizes-1] = startInds[nSubsetSizes-2] + nSamplesLeft; 287 | 288 | 289 | // uncomment this if we want different weights for each randomly sampled coalition sizes. 290 | // However, experiments show that it doesn't make a difference. 291 | // std::fill(tmpKWPointer, tmpKWPointer + nSamplesLeft, (50*remWeightVect[nSubsetSizes-1]) / nSamplesLeft); 292 | //tmpKWPointer += nSamplesLeft; 293 | 294 | std::fill(tmpSLookPtr, tmpSLookPtr + nSamplesLeft, nSubsetSizes); 295 | 296 | if (coalSizeNSamples[nSubsetSizes-1] < 0) 297 | std::cout << "Error: negative number of samples for the middle coalition" << std::endl; 298 | 299 | nSamplesLeft = 0; 300 | } 301 | 302 | // symmetric weights. No need to compute the other half, no need to flip 303 | memcpy(tmpKWPointer, kernelWeights, nHalfSamp * sizeof(double)); 304 | 305 | bool *devMaskMat = maskMatTensor.data_ptr(); 306 | // cudaMalloc(&devMaskMat, nSamples * nPlayers * sizeof(bool)); 307 | // cudaMemset(devMaskMat, false, nHalfSamp * nPlayers * sizeof(bool)); 308 | cudaMemset(devMaskMat + nPlayers * nHalfSamp, true, nHalfSamp * nPlayers * sizeof(bool)); 309 | 310 | int *deviceSizeLookup; 311 | cudaMalloc(&deviceSizeLookup, nHalfSamp * sizeof(int)); 312 | cudaMemcpy(deviceSizeLookup, sizeLookup, nHalfSamp * sizeof(int), cudaMemcpyHostToDevice); 313 | 314 | 315 | int* devShuffleArr; 316 | cudaMalloc(&devShuffleArr, nBlocks * nThreads * nPlayers * sizeof(int)); 317 | 318 | int* devStartInds; // device start indices 319 | cudaMalloc(&devStartInds, nSubsetSizes * sizeof(int)); 320 | cudaMemcpy(devStartInds, startInds, nSubsetSizes * sizeof(int), cudaMemcpyHostToDevice); 321 | 322 | cudaSampleGenerator<<>>(nPlayers, nHalfSamp, deviceSizeLookup, 323 | devMaskMat, rndStartInd, devStartInds, devShuffleArr); 324 | 325 | 326 | cudaMemcpy(kWTensor.data_ptr(), kernelWeights, 327 | nSamples * sizeof(double), cudaMemcpyHostToDevice); 328 | 329 | cudaDeviceSynchronize(); 330 | 331 | free(sizeLookup); 332 | free(kernelWeights); 333 | free(coalSizeNSamples); 334 | free(startInds); 335 | free(weightVect); 336 | free(remWeightVect); 337 | cudaFree(deviceSizeLookup); 338 | cudaFree(devStartInds); 339 | cudaFree(devShuffleArr); 340 | } 341 | 342 | 343 | 344 | // ######################PYTHON BINDINGS######################### 345 | 346 | 347 | 348 | 349 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 350 | m.def("sample", &cudaSample, "Cuda Sample"); 351 | } -------------------------------------------------------------------------------- /dataset/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HipGraph/GNNShap/f9672297394493ebe1ea9cf60bd14530e06d4916/dataset/__init__.py -------------------------------------------------------------------------------- /dataset/configs.py: -------------------------------------------------------------------------------- 1 | def get_config(conf_name): 2 | """ Required configs for each dataset 3 | 4 | Args: 5 | conf_name (str): Dataset name 6 | 7 | Returns: 8 | dict: configuration dictionary 9 | """ 10 | 11 | root_path = './' # root path. Can be changed to the path where the dataset is stored. 12 | results_path = f'./results' # path to store the explanation results 13 | 14 | dataset_configs = { 15 | 'Cora': { 16 | 'hidden_dim': 16, 17 | 'model': 'GCNModel', 18 | 'num_layers': 2, 19 | 'epoch': 200, 20 | 'lr': 0.01, 21 | 'weight_decay': 5e-4, 22 | 'dropout': 0.5, 23 | 'normalize': True, 24 | 'add_self_loops': True, 25 | 'graphSVX_args': {'num_samples': 1000, 'S': 3, 'coal': 'SmarterSeparate', 26 | 'feat':'Expectation'}, 27 | }, 28 | 'Cora_GAT': { 29 | 'hidden_dim': 16, 30 | 'num_layers': 2, 31 | 'epoch': 200, 32 | 'lr': 0.005, 33 | 'weight_decay': 5e-4, 34 | 'dropout': 0.5, 35 | 'normalize': True, 36 | 'add_self_loops': True, 37 | 'heads': 8, 38 | 'model': 'GATModel', 39 | }, 40 | 'CiteSeer': { 41 | 'hidden_dim': 16, 42 | 'model': 'GCNModel', 43 | 'num_layers': 2, 44 | 'epoch': 200, 45 | 'lr': 0.01, 46 | 'weight_decay': 5e-4, 47 | 'dropout': 0.5, 48 | 'normalize': True, 49 | 'add_self_loops': True, 50 | 'graphSVX_args': {'num_samples': 1000, 'S':3, 'coal': 'SmarterSeparate', 51 | 'feat':'Expectation'}, 52 | 53 | }, 54 | 'PubMed': { 55 | 'hidden_dim': 16, 56 | 'model': 'GCNModel', 57 | 'num_layers': 2, 58 | 'epoch': 200, 59 | 'lr': 0.01, 60 | 'weight_decay': 5e-4, 61 | 'dropout': 0.5, 62 | 'normalize': True, 63 | 'add_self_loops': True, 64 | 'graphSVX_args': {'num_samples': 1000, 'S':3, 'coal': 'SmarterSeparate', 65 | 'feat':'Expectation'}, 66 | }, 67 | 'Facebook': { 68 | 'hidden_dim': 16, 69 | 'model': 'GCNModel', 70 | 'num_layers': 2, 71 | 'epoch': 200, 72 | 'lr': 0.01, 73 | 'weight_decay': 5e-4, 74 | 'dropout': 0.5, 75 | 'normalize': True, 76 | 'add_self_loops': True, 77 | 'graphSVX_args': {'num_samples': 1000, 'S': 3, 'coal': 'SmarterSeparate', 78 | 'feat':'Expectation'}, 79 | }, 80 | 'Coauthor-CS': { 81 | 'hidden_dim': 64, 82 | 'model': 'GCNModel', 83 | 'num_layers': 2, 84 | 'epoch': 200, 85 | 'lr': 0.01, 86 | 'weight_decay': 5e-4, 87 | 'dropout': 0.5, 88 | 'normalize': True, 89 | 'add_self_loops': True, 90 | 'graphSVX_args': {'num_samples': 1000, 'S':3, 'coal': 'SmarterSeparate', 91 | 'feat':'Expectation'}, 92 | }, 93 | 'Coauthor-Physics': { 94 | 'hidden_dim': 64, 95 | 'model': 'GCNModel', 96 | 'num_layers': 2, 97 | 'epoch': 200, 98 | 'lr': 0.01, 99 | 'weight_decay': 5e-4, 100 | 'dropout': 0.5, 101 | 'normalize': True, 102 | 'add_self_loops': True, 103 | 'graphSVX_args': {'num_samples': 1000, 'S':3, 'coal': 'SmarterSeparate', 104 | 'feat':'Expectation'}, 105 | }, 106 | 'Reddit': { 107 | 'hidden_dim': 128, 108 | 'model': 'GCNModel', 109 | 'num_layers': 2, 110 | 'epoch': 11, 111 | 'lr': 0.01, 112 | 'weight_decay': 5e-4, 113 | 'dropout': 0.5, 114 | 'normalize': True, 115 | 'add_self_loops': True, 116 | 'graphSVX_args': {'num_samples': 1000, 'S':3, 'coal': 'SmarterSeparate', 117 | 'feat':'Expectation'}, 118 | 'nei_sampler_args': {'sizes': [25, 10], 'batch_size': 1024}, 119 | }, 120 | 'ogbn-products': { 121 | 'hidden_dim': 128, 122 | 'model': 'GCNModel', 123 | 'num_layers': 2, 124 | 'failed_test_nodes': [], 125 | 'epoch': 11, 126 | 'lr': 0.01, 127 | 'weight_decay': 5e-4, 128 | 'dropout': 0.5, 129 | 'normalize': True, 130 | 'add_self_loops': True, 131 | 'graphSVX_args': {'num_samples': 1000, 'S':3, 'coal': 'SmarterSeparate', 132 | 'feat':'Expectation'}, 133 | 'nei_sampler_args': {'sizes': [25, 10], 'batch_size': 2048}, 134 | } 135 | } 136 | if conf_name is None: 137 | return dataset_configs 138 | 139 | conf = dataset_configs[conf_name] 140 | conf['root_path'] = root_path 141 | conf['results_path'] = results_path 142 | 143 | conf['num_hops'] = conf['num_layers'] 144 | return conf 145 | -------------------------------------------------------------------------------- /dataset/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import torch 5 | from torch import Tensor 6 | from torch_geometric.datasets import (Coauthor, FacebookPagePage, Planetoid, 7 | Reddit) 8 | from torch_geometric.transforms import NormalizeFeatures 9 | 10 | from dataset.configs import get_config 11 | from models.GATModel import GATModel 12 | from models.GCNModel import GCNModel 13 | 14 | 15 | def get_model(model_name, config, num_features, num_classes, log_softmax_return=False): 16 | """Gets the model for a given model name. 17 | 18 | Args: 19 | model_name (str): model name like 'GCNModel', 'GATModel' 20 | config (dict): configuration dictionary 21 | num_features (int): number of input features 22 | num_classes (int): number of output classes 23 | log_softmax_return (bool, optional): whether to return raw output or log softmax. 24 | Defaults to False. 25 | 26 | Returns: 27 | model: PyTorch model 28 | """ 29 | 30 | hidden_channels = config['hidden_dim'] 31 | num_layers = config['num_layers'] 32 | dropout = config['dropout'] 33 | normalize = config.get('normalize', None) 34 | add_self_loops = config.get('add_self_loops', None) 35 | 36 | 37 | if model_name == 'GCNModel': 38 | model = GCNModel(num_layers=num_layers, 39 | hidden_channels=hidden_channels, 40 | num_features=num_features, 41 | num_classes=num_classes, 42 | dropout=dropout, 43 | normalize=normalize, 44 | add_self_loops=add_self_loops, 45 | log_softmax_return=log_softmax_return) 46 | elif model_name == 'GATModel': 47 | model = GATModel(hidden_channels=hidden_channels, num_features=num_features, 48 | num_classes=num_classes, 49 | num_layers=num_layers, 50 | add_self_loops=add_self_loops, dropout=dropout, normalize=normalize, 51 | log_softmax_return=log_softmax_return, 52 | heads=config.get('heads', 1)) 53 | else: 54 | raise ValueError(f"Model name {model_name} is not supported") 55 | return model 56 | 57 | def get_model_data_config(dataset_name, device='cpu', load_pretrained=True, 58 | log_softmax_return=False, full_data=False): 59 | """Gets model, data and configuration for a given dataset name. Check the dataset.configs.py 60 | for the configuration details. 61 | Args: 62 | dataset_name (str): name of the dataset 63 | device (str, optional): 'cpu' or 'cuda'. Defaults to 'cpu'. 64 | load_pretrained (bool, optional): whether to load pretrained model. Defaults to True. 65 | log_softmax_return (bool, optional): whether to return raw output or log softmax. 66 | Defaults to False. 67 | full_data (bool, optional): full data or explanation data. If True for large datasets, only 68 | loads the subset of the data used in explanation that created by neighbor sampling. 69 | It doesn't make any difference for others. Defaults to False. 70 | 71 | Returns: 72 | model, data, config 73 | """ 74 | 75 | config = get_config(dataset_name) 76 | root_path = config['root_path'] 77 | 78 | if 'Cora' in dataset_name: 79 | dataset = Planetoid(root=f'{root_path}/data/Planetoid', name='Cora', 80 | transform=NormalizeFeatures()) 81 | elif 'CiteSeer' in dataset_name: 82 | dataset = Planetoid(root=f'{root_path}/data/Planetoid', name='CiteSeer', 83 | transform=NormalizeFeatures()) 84 | elif 'PubMed' in dataset_name: 85 | dataset = Planetoid(root=f'{root_path}/data/Planetoid', name='PubMed', 86 | transform=NormalizeFeatures()) 87 | elif 'Coauthor-CS' in dataset_name: 88 | dataset = Coauthor(root=f'{root_path}/data/Coauthor', name='CS', 89 | transform=NormalizeFeatures()) 90 | elif 'Coauthor-Physics' in dataset_name: 91 | dataset = Coauthor(root=f'{root_path}/data/Coauthor', name='Physics', 92 | transform=NormalizeFeatures()) 93 | elif dataset_name == 'Facebook': 94 | dataset = FacebookPagePage(root=f'{root_path}/data/FacebookPagePage', 95 | transform=NormalizeFeatures()) 96 | elif dataset_name == 'Reddit': 97 | dataset = Reddit(f'{root_path}/data/Reddit') 98 | elif dataset_name == 'ogbn-products': 99 | from ogb.nodeproppred import PygNodePropPredDataset 100 | dataset = PygNodePropPredDataset(name='ogbn-products', 101 | root=f'{root_path}/data/ogbn-products') 102 | else: 103 | raise ValueError(f"Dataset name {dataset_name} is not supported") 104 | 105 | model_name = config['model'] 106 | 107 | model = get_model(model_name, config, dataset.num_features, 108 | dataset.num_classes, log_softmax_return=log_softmax_return).to(device) 109 | 110 | if load_pretrained: 111 | try: 112 | state_dict = torch.load(f"{root_path}/pretrained/{dataset_name}_pretrained.pt") 113 | model.load_state_dict(state_dict) 114 | model.eval() 115 | except FileNotFoundError: 116 | print(f"No pretrained model found for {dataset_name}. Don't forget to train the model.") 117 | 118 | data = dataset[0] 119 | config['num_classes'] = dataset.num_classes 120 | 121 | if dataset_name == 'ogbn-products': 122 | split_idx = dataset.get_idx_split() 123 | data.train_mask = torch.zeros(data.num_nodes, dtype=torch.bool) 124 | data.val_mask = torch.zeros(data.num_nodes, dtype=torch.bool) 125 | data.test_mask = torch.zeros(data.num_nodes, dtype=torch.bool) 126 | data.train_mask[split_idx['train']] = True 127 | data.val_mask[split_idx['valid']] = True 128 | data.test_mask[split_idx['test']] = True 129 | data.y = data.y.squeeze(1) 130 | 131 | elif not hasattr(data, 'train_mask'): 132 | split = get_split(f'{root_path}/pretrained/', dataset_name, data) 133 | data.train_mask, data.val_mask, data.test_mask = split[0], split[1], split[2] 134 | 135 | if dataset_name in ['Reddit', 'ogbn-products'] and not full_data: 136 | data = torch.load(f"{root_path}/pretrained/{dataset_name}_explain_data.pt") 137 | # these are not saved in the file. Load them from the original data 138 | data.x = dataset[0].x[data.n_id] 139 | data.y = dataset[0].y[data.n_id] 140 | del dataset # free up some memory 141 | if dataset_name == 'ogbn-products': 142 | data.y = data.y.squeeze(1) 143 | data = data.to(device) 144 | else: 145 | data = data.to(device) 146 | 147 | config['test_nodes'] = data.test_mask.nonzero().cpu().numpy()[:,0].tolist()[:100] 148 | 149 | return model, data, config 150 | 151 | 152 | def generate_balanced_split(data, num_train_per_class, num_val_per_class): 153 | """Generates a balanced train, validation, test split for a given dataset. This is used for 154 | datasets that do not have built-in splits. 155 | 156 | Args: 157 | data (torch.Tensor): PyTorch Geometric Data object 158 | num_train_per_class (int): number of training samples per class 159 | num_val_per_class (_type_): number of validation samples per class 160 | 161 | Returns: 162 | list: a list contains splits. 163 | """ 164 | 165 | labels = data.y.cpu() 166 | num_classes = labels.max().item() + 1 167 | 168 | train_mask = np.zeros(len(labels), dtype=np.bool) 169 | val_mask = np.zeros(len(labels), dtype=np.bool) 170 | test_mask = np.zeros(len(labels), dtype=np.bool) 171 | 172 | for c in range(num_classes): 173 | idx = np.where(labels == c)[0] 174 | tr_idx = idx[np.random.choice(len(idx), size=num_train_per_class, replace=False)] 175 | train_mask[tr_idx] = True 176 | 177 | val_idx = np.array([i for i in idx if i not in tr_idx]) 178 | val_idx = val_idx[np.random.choice(len(val_idx), size=num_val_per_class, replace=False)] 179 | val_mask[val_idx] = True 180 | # print(f"class: {c+1}, tr count: {len(tr_idx)}, val count: {len(val_idx)}") 181 | # print(f'train mask count: {np.count_nonzero(train_mask)}') 182 | # print(f'val mask count: {np.count_nonzero()}') 183 | remaining = np.where((train_mask == False) & (val_mask == False))[0] 184 | test_mask[remaining] = True 185 | return [torch.from_numpy(train_mask), torch.from_numpy(val_mask), torch.from_numpy(test_mask)] 186 | 187 | def get_split(pretrained_dir, dataset_name, data): 188 | """Gets the train, validation, test split for a given dataset. If the split is not found, 189 | creates a new split and saves it to the disk. This is used for datasets that do not have 190 | built-in splits. 191 | 192 | Args: 193 | pretrained_dir (str): directory to save the split 194 | dataset_name (str): name of the dataset 195 | data: PyTorch Geometric Data object 196 | 197 | Returns: 198 | split: a list contains splits 199 | """ 200 | f_name = f'{pretrained_dir}/split_{dataset_name}.pt' 201 | if os.path.exists(f_name): 202 | split = torch.load(f_name) 203 | else: 204 | split = generate_balanced_split(data, num_train_per_class=30, num_val_per_class=30) 205 | torch.save(split, f_name) 206 | print(f'No previous split found for {dataset_name}.', 207 | f'New split was created and saved to {f_name}.') 208 | return split 209 | -------------------------------------------------------------------------------- /gnnshap/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HipGraph/GNNShap/f9672297394493ebe1ea9cf60bd14530e06d4916/gnnshap/__init__.py -------------------------------------------------------------------------------- /gnnshap/eval_utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn.functional as F 6 | from torch_geometric.data import Data 7 | 8 | from gnnshap.utils import pruned_comp_graph 9 | 10 | 11 | def node2edge_score(edge_index: torch.Tensor, node_scores: np.array): 12 | """Converts node scores to edge scores: an edge score is equal to average of connected nodes. 13 | Needed for some baselines that only provide node scores. 14 | 15 | Args: 16 | edge_index (torch.Tensor): PyG edge index. 17 | node_scores (np.array[float]): node scores 18 | 19 | Returns: 20 | np.array: edge scores 21 | """ 22 | 23 | edge_scores = np.zeros(edge_index.size(1)) 24 | np_node_scores = np.array(node_scores) 25 | edge_scores += np_node_scores[edge_index[0].cpu().numpy()] 26 | edge_scores += np_node_scores[edge_index[1].cpu().numpy()] 27 | return edge_scores/2 28 | 29 | 30 | 31 | def fidelity(node_data: dict, data: Data, model: torch.nn.Module, sparsity: float = 0.3, 32 | fid_type: str = 'neg', topk: int = 0, target_class: int = None, 33 | apply_abs: bool=True) -> tuple: 34 | """Computes fidelity+ and fidelity- score of a node. It supports both topk and sparsity. 35 | If sparsity set to 0.3, it drops 30% of the edges. Based on the neg or pos, it drops 36 | unimportant or important edges. It applies topk based keep if topk is set to a positive 37 | integer other than zero. 38 | 39 | Note that it computes fidelity scores for the predicted class if target class is not provided. 40 | 41 | Args: 42 | node_data (dict): a node's explanation data with node_id, num_players, scores keys. 43 | data (Data): pyG Data. 44 | model (torch.nn.Module): a PyTorch model. 45 | sparsity (float, optional): target sparsity value. Defaults to 0.3. 46 | fid_type (str, optional): Fidelity type: neg or pos. Defaults to 'neg'. 47 | topk (int, optional): Topk edges to keep. Defaults to 0. 48 | target_class (int, optional): Target class to compute fidelity score. Defaults to None. 49 | apply_abs (bool, optional): applies absolute to scores. Some methods can find negative and 50 | positive contributing nodes/edges. Fidelity-wise, we only care the change amount. We can 51 | use this to get rid of negative contributing edges to improve accuracy. Defaults to 52 | True. 53 | 54 | Returns: 55 | tuple: node_id, nplayers, fidelity score, current sparsity, correct_class, init_pred_class, 56 | and sparse_pred_class. 57 | """ 58 | assert topk >= 0, "topk cannot be a negative number" 59 | assert 0 <= sparsity <= 1, "Sparsity should be between zero and one." 60 | 61 | node_id = int(node_data['node_id']) 62 | correct_class = data.y[node_id].item() 63 | 64 | model.eval() 65 | 66 | 67 | 68 | # find khop computational graph 69 | (subset, sub_edge_index, new_node_id, 70 | _) = pruned_comp_graph(node_id, model.num_layers, data.edge_index, relabel_nodes=True) 71 | # new node id due to relabeling 72 | new_node_id = int(new_node_id[0].cpu().numpy()) 73 | num_initial_edges = sub_edge_index.size(1) # number of players 74 | 75 | 76 | subset = subset.cpu().numpy() 77 | 78 | # initial prediction 79 | init_pred = F.softmax(model(data.x[subset], sub_edge_index), dim=1)[new_node_id] 80 | init_pred_class = init_pred.argmax(dim=-1).item() 81 | if target_class is None: 82 | target_class = init_pred_class 83 | init_prob = init_pred[target_class].item() 84 | 85 | 86 | if node_data['num_players'] == num_initial_edges: 87 | edge_scores = np.array(node_data['scores']) 88 | 89 | # convert node scores to edge scores if node score is provided 90 | elif node_data['num_players'] == subset.shape[0]: 91 | edge_scores = node2edge_score(sub_edge_index, node_data['scores']) 92 | 93 | else: 94 | raise ValueError("Number of players should be equal to either" 95 | " number of edges or number of nodes!") 96 | 97 | 98 | edge_scores = np.abs(edge_scores) if apply_abs else edge_scores 99 | 100 | 101 | # less important edge at first index 102 | edge_importance_sorted = edge_scores.argsort() 103 | 104 | if topk == 0: # sparsity based 105 | if fid_type == 'pos': # reverse the list: most important edge at first index 106 | edge_importance_sorted = edge_scores.argsort()[::-1].copy() 107 | # copy required for bug fixing. pytorch doesn't support negative index 108 | 109 | # how many edges to drop 110 | drop_len = num_initial_edges - math.ceil(num_initial_edges * (1 - sparsity)) 111 | keep_edges = edge_importance_sorted[drop_len:] 112 | 113 | else: # topk based 114 | if fid_type == 'neg': 115 | keep_edges = edge_importance_sorted[topk:] # drop least important topk edges 116 | else: # fid+ 117 | keep_edges = edge_importance_sorted[:-topk] # keep edges except topk 118 | 119 | drop_len = num_initial_edges - len(keep_edges) 120 | 121 | 122 | keep_edges.sort() 123 | 124 | sparse_pred = F.softmax(model(data.x[subset], sub_edge_index[:, keep_edges]), 125 | dim=-1)[new_node_id] 126 | 127 | sparse_pred_class = sparse_pred.argmax(dim=-1).item() 128 | sparse_prob = sparse_pred[target_class].item() 129 | 130 | prob_score = sparse_prob - init_prob 131 | prob_score = np.abs(prob_score) if apply_abs else prob_score 132 | 133 | 134 | current_sparsity = drop_len / num_initial_edges 135 | return (node_id, num_initial_edges, prob_score, current_sparsity, 136 | correct_class, init_pred_class, sparse_pred_class) -------------------------------------------------------------------------------- /gnnshap/explainer.py: -------------------------------------------------------------------------------- 1 | import math 2 | import time 3 | from typing import Callable, List, Tuple, Union 4 | 5 | import torch 6 | import torch_geometric 7 | from torch import BoolTensor, Tensor 8 | from torch.nn.functional import softmax 9 | from torch_geometric.utils import add_remaining_self_loops 10 | from tqdm import tqdm 11 | 12 | from gnnshap.samplers import get_sampler 13 | from gnnshap.solvers import get_solver 14 | from gnnshap.utils import * 15 | from gnnshap.explanation import GNNShapExplanation 16 | 17 | log = get_logger(__name__) 18 | 19 | 20 | def default_predict_fn(model: torch.nn.Module, 21 | node_features: Tensor, 22 | edge_index: Tensor, 23 | node_idx: Union[int, List[int]], 24 | edge_weight: Tensor = None) -> Tensor: 25 | r"""Model prediction function for prediction. A custom predict function can be provided for 26 | different tasks. 27 | 28 | Args: 29 | model (torch.nn.Module): a PyG model. 30 | node_features (Tensor): node feature tensor. 31 | edge_index (Tensor): edge_index tensor. 32 | node_idx (Union[int, List[int]]): node index. Can be an integer or list (list 33 | for batched data). 34 | edge_weight (Tensor, optional): edge weights. Defaults to None. 35 | 36 | Returns: 37 | Tensor: model prediction for the node being explained. 38 | """ 39 | 40 | model.eval() 41 | 42 | # [node_idx] will only work for non-batched. [node_idx, :] works for both 43 | pred = model.forward(node_features, edge_index, edge_weight=edge_weight) 44 | 45 | # this is for 3d predictions tensors 46 | # pred = pred[node_idx, :] if len(pred.size()) == 2 else pred[:, node_idx, :] 47 | 48 | pred = softmax(pred[node_idx, :], dim=-1) 49 | return pred 50 | 51 | 52 | class GNNShapExplainer: 53 | """GNNShap main Explainer class. 54 | 55 | Args: 56 | model (torch.nn.Module): a pyg model. 57 | data (torch_geometric.data.Data): a pyg data. 58 | nhops (int, optional): number of hops. It will be computed if not provided. 59 | Defaults to None. 60 | device (Tuple[str, torch.device], optional): torch device. Defaults to 'cpu'. 61 | forward_fn (Callable, optional): A forward function. It can be customized for custom 62 | needs. Defaults to default_predict_fn. 63 | progress_hide (bool, optional): Hides tqdm progress if set to True. Defaults to False. 64 | verbose (int, optional): Shows some information if set to a positive number. 65 | Defaults to 0. 66 | """ 67 | 68 | def __init__(self, model: torch.nn.Module, data: torch_geometric.data.Data, 69 | nhops: int = None, device: Tuple[str, torch.device] = 'cpu', 70 | forward_fn: Callable = default_predict_fn, 71 | progress_hide: bool = False, 72 | verbose: int = 0): 73 | 74 | self.model = model 75 | self.num_hops = nhops if nhops is not None else len(get_gnn_layers(self.model)) 76 | self.data = data 77 | self.forward_fn = forward_fn # prediction function 78 | self.progress_hide = progress_hide # tqdm progress bar show or hide 79 | self.verbose = verbose # to show or hide extra info 80 | self.device = device 81 | 82 | self.has_self_loops = data.has_self_loops() 83 | 84 | self.sampler = None # will be set in explain. 85 | self.preds = None # will be set in compute_model_predictions. 86 | 87 | 88 | def __compute_preds_no_batching(self, node_features: Tensor, 89 | edge_index: Tensor, mask_matrix: BoolTensor, 90 | node_idx: int, target_class: int) -> Tensor: 91 | """Computes predictions by iterating each coalitions one by one. 92 | 93 | Args: 94 | node_features (Tensor): node features. 95 | edge_index (Tensor): edge index. 96 | mask_matrix (BoolTensor): boolean 2d mask matrix. 97 | node_idx (int): node index (it should be the relabeled node index in the subgraph). 98 | target_class (int): Target class. 99 | 100 | Returns: 101 | Tensor: Returns predictions tensor. 102 | """ 103 | preds = torch.zeros((mask_matrix.size(0)), 104 | dtype=torch.double, device=mask_matrix.device, 105 | requires_grad=False) 106 | 107 | for i in tqdm(range(mask_matrix.shape[0]), desc="Coalition scores", 108 | disable=self.progress_hide): 109 | mask = mask_matrix[i] 110 | masked_edges = edge_index[:, mask] 111 | # masked_edges = self.filter_fn(masked_edges, node_idx, self.num_hops) 112 | y_hat = self.forward_fn(model=self.model, node_features=node_features, 113 | edge_index=masked_edges, node_idx=node_idx, edge_weight=None) 114 | preds[i] = y_hat[target_class] 115 | return preds 116 | 117 | def __compute_preds_batched(self, node_features: Tensor, edge_index: Tensor, 118 | mask_matrix: torch.BoolTensor, node_idx: int, 119 | batch_size: int, target_class: int) -> Tensor: 120 | """Computes model predictions by mini-batching. Creates a large graph by stacking edge 121 | indices. Note that individual subgraphs are disconnected. So, for 5 nodes subgraphs, 122 | first subgraph node numbers are 0, 1, 2, 3, and 4. Second subgraph node numbers are 5, 6, 123 | 7, and 8. Three key tensors are created like below: 124 | batch_edge_index = [edge_index1, edge_index2, ...] : (2, nplayers * batch_size) 125 | node_features = node_features.repeat(batch_size, 1) : (num_nodes * batch_size, F) 126 | batch_mask = mask_matrix[batch_start:batch_end].flatten() : (nplayers * batch_size,) 127 | 128 | 129 | Args: 130 | node_features (Tensor): node features. 131 | edge_index (Tensor): edge index. 132 | mask_matrix (torch.BoolTensor): boolean 2d mask matrix. 133 | node_idx (int): node index (it should be the relabeled node index in the subgraph). 134 | batch_size (int): batch size 135 | target_class (int): Target class. 136 | 137 | Returns: 138 | Tensor: Returns predictions tensor. 139 | """ 140 | preds = torch.zeros((mask_matrix.size(0)), 141 | dtype=torch.double, device=mask_matrix.device, 142 | requires_grad=False) 143 | 144 | num_batches = math.ceil(mask_matrix.shape[0] / batch_size) 145 | num_nodes = node_features.size(0) 146 | 147 | # pyg creates a large graph by combining our masked subgraphs. 148 | # We need to get predictions of the same node for each subgraph. 149 | # [node0, node1, node2, node3 ... node0, node1, node2] 150 | node_indices = torch.arange(node_idx, batch_size * num_nodes, num_nodes, device=self.device) 151 | current_ind = 0 152 | 153 | edge_size = edge_index.size(1) 154 | 155 | # Creating batched_data using PyG minibatching approach. It has a small overhead. 156 | # tmp_data = Data(x=node_features, edge_index=edge_index).to(node_features.device) 157 | # data_list = [tmp_data for _ in range(batch_size)] 158 | # loader = DataLoader(data_list, batch_size=len(data_list)) 159 | # batched_data = next(iter(loader)) 160 | # batch_edge_index = batched_data.x 161 | # batch_node_features = batched_data.edge_index 162 | 163 | 164 | # Alternative approach to PyG minibatch since we don't need many features of PyG minibatch 165 | batch_node_features = node_features.repeat(batch_size, 1) 166 | #create large batched edge indices 167 | batch_edge_index = torch.zeros((2, edge_index.size(1) * batch_size), device=self.device, 168 | dtype=torch.long) 169 | for k, n_ind in enumerate(range(0, batch_size * edge_size, edge_size)): 170 | batch_edge_index[:, n_ind:n_ind + edge_size] = edge_index + k * num_nodes 171 | 172 | 173 | # predictions for batches 174 | for i in tqdm(range(num_batches), desc="Batched coalition scores", 175 | disable=self.progress_hide, leave=False): 176 | 177 | batch_start = batch_size * i 178 | batch_end = min(batch_size * (i + 1), mask_matrix.shape[0]) 179 | 180 | tmp_batch_size = batch_end - batch_start 181 | if tmp_batch_size < batch_size: # for the last batch 182 | batch_edge_index=batch_edge_index[:, :edge_index.size(1) * tmp_batch_size] 183 | node_features=batch_node_features[: tmp_batch_size * num_nodes] 184 | 185 | tmp_node_indices = node_indices[:tmp_batch_size] # to make sure for the last batch 186 | y_hat = self.forward_fn(model=self.model, 187 | node_features=batch_node_features, 188 | edge_index=batch_edge_index[:,mask_matrix[ 189 | batch_start:batch_end].flatten()], 190 | edge_weight=None, 191 | node_idx=tmp_node_indices) 192 | preds[current_ind: current_ind + tmp_batch_size] = y_hat[:, target_class] 193 | current_ind += tmp_batch_size 194 | return preds 195 | 196 | def compute_model_predictions(self, node_features: Tensor, edge_index: Tensor, 197 | mask_matrix: torch.BoolTensor, node_idx: int, 198 | batch_size: int, target_class: int) -> Tuple[Tensor, int]: 199 | """Computes model predictions and writes results to self.preds variable. 200 | 201 | Args: 202 | node_features (Tensor): node features. 203 | edge_index (Tensor): edge index. 204 | mask_matrix (torch.BoolTensor): boolean 2d mask (coalition) matrix. 205 | node_idx (int): node index (it should be the relabeled node index in the subgraph). 206 | batch_size (int): batch size. No batching if set to zero. 207 | target_class (int): Target class. 208 | 209 | Returns: 210 | Tuple[Tensor, int]: Returns predictions tensor and number of computed samples. 211 | 212 | """ 213 | assert batch_size >= 0, "Batch size can not be a negative number" 214 | 215 | 216 | # empty coalition prediction, use only self loop edges 217 | self.fnull = self.forward_fn(self.model, node_features, 218 | edge_index[:, self.sampler.nplayers:], 219 | node_idx)[target_class].item() 220 | 221 | # full coalition prediction, use all edges in the subgraph 222 | self.fx = self.forward_fn(self.model, node_features, edge_index, 223 | node_idx)[target_class].item() 224 | 225 | # s_time = time.time() 226 | 227 | # there could be added self loops. Limit with nplayers 228 | one_hop_incoming_idx = (edge_index[1, :self.sampler.nplayers] == node_idx).nonzero()[:,0] 229 | 230 | # only compute indices when the target node is not isolated 231 | compute_indices = 0 != mask_matrix[:, one_hop_incoming_idx].count_nonzero(dim=-1) 232 | 233 | preds = torch.zeros((mask_matrix.size(0)), 234 | dtype=torch.double, device=mask_matrix.device, 235 | requires_grad=False).fill_(self.fnull) 236 | 237 | # no batch 238 | if batch_size == 0: 239 | tmp_preds = self.__compute_preds_no_batching( 240 | node_features, edge_index, mask_matrix[compute_indices], node_idx, target_class) 241 | # batch 242 | else: 243 | tmp_preds = self.__compute_preds_batched( 244 | node_features, edge_index, mask_matrix[compute_indices], node_idx, 245 | batch_size,target_class) 246 | preds[compute_indices] = tmp_preds 247 | 248 | return preds, compute_indices.count_nonzero().item() 249 | 250 | @torch.no_grad() 251 | def explain(self, node_idx: int, nsamples: int, 252 | batch_size: int = 512, sampler_name: str = 'GNNShapSampler', 253 | target_class: Union[int, None]=None, l1_reg: bool= False, 254 | solver_name: str = "WLSSolver", **kwargs): 255 | r"""Computes shapley scores. It has four steps: 256 | 257 | | 1. finds computational graph and players 258 | | 2. samples coalitions 259 | | 3. runs model and get predictions for sampled graphs. 260 | | 4. solves linear regression problem to compute shapley scores. 261 | 262 | Args: 263 | node_idx (int): Node index to explain 264 | nsamples (int, optional): number of samples. 265 | batch_size (int, optional): batch size. Defaults to 512. 266 | sampler_name (str, optional): Sampler class name for sampling. 267 | Defaults to 'SHAPSampler'. 268 | target_class (int, optional): Computes Shapley scores for the target class. 269 | Predicted class is used if target_class is not provided. Defaults to None. 270 | l1_reg (bool, optional): use l1 reg or not. l1 reg will not be used if 271 | fraction_evaluated > 0.2 272 | solver_name (str, optional): Solver name. Defaults to 'TorchSolver'. 273 | kwargs: Additional sampler & solver args if needed. 274 | 275 | Returns: 276 | GNNSHAPExplanation: Returns GNNSHAPExplanation objects that contain many information. 277 | """ 278 | 279 | device = self.device 280 | 281 | # use the predicted class if no target class is provided. 282 | if target_class is None: 283 | # target_class = self.data.y[node_idx].item() # for ground truth 284 | target_class = torch.argmax(self.forward_fn(self.model, self.data.x, 285 | self.data.edge_index, node_idx)).item() 286 | start_time = time.time() 287 | # we only need k-hop neighbors for explanation 288 | (subset, sub_edge_index, sub_mapping, 289 | sub_edge_mask) = pruned_comp_graph(node_idx, self.num_hops, 290 | self.data.edge_index, 291 | relabel_nodes=True) 292 | 293 | nplayers = sub_edge_index.size(1) # number of players 294 | 295 | compgraph_time = time.time() 296 | log.info(f"Computational graph finding time(s):\t{compgraph_time - start_time:.4f}") 297 | 298 | # get samples 299 | self.sampler = get_sampler( 300 | sampler_name=sampler_name, nplayers=nplayers, nsamples=nsamples, 301 | edge_index=sub_edge_index, nhops=self.num_hops, 302 | target_node=sub_mapping[0].item(), **kwargs) 303 | mask_matrix, kernel_weights = self.sampler.sample() 304 | 305 | mask_matrix = mask_matrix.to(device) 306 | kernel_weights = kernel_weights.to(device) 307 | 308 | nsamples = self.sampler.nsamples # samplers may update nsamples 309 | 310 | sampling_time = time.time() 311 | 312 | log.info(f"Sampling time(s):\t\t{sampling_time - compgraph_time:.4f}") 313 | 314 | 315 | # temporarily switch add_self_loops to False if it is enabled. 316 | # we will add self loops manually. 317 | use_add_self_loops = has_add_self_loops(self.model) 318 | add_self_loops_swithced = False 319 | if use_add_self_loops: 320 | switch_add_self_loops(self.model) 321 | add_self_loops_swithced = True 322 | 323 | 324 | self.model.eval() 325 | 326 | if self.verbose == 1: 327 | print(f"Number of samples: {self.sampler.nsamples}, " 328 | f"sampler:{self.sampler.__class__.__name__}, " 329 | "batch size: {batch_size}") 330 | 331 | # new node_idx after relabeling in k hop subgraph. 332 | new_node_idx = sub_mapping[0].item() 333 | 334 | node_features = self.data.x[subset].to(device) 335 | sub_edge_index = sub_edge_index.to(device) 336 | 337 | # add remaining self loops if GNN layers' add_self_loops param set to True 338 | self_loop_sub_edge_index = add_remaining_self_loops( 339 | edge_index=sub_edge_index)[0] if use_add_self_loops else sub_edge_index 340 | 341 | if use_add_self_loops: 342 | torch_self_loop_mask_matrix_bool = torch.ones((mask_matrix.size(0), 343 | self_loop_sub_edge_index.size(1)), 344 | dtype=torch.bool).to(device) 345 | # Self loop indices are always True, rest is based on coalition matrix 346 | torch_self_loop_mask_matrix_bool[:, :mask_matrix.size(1)] = mask_matrix 347 | else: 348 | torch_self_loop_mask_matrix_bool = mask_matrix 349 | 350 | preds, comp_samp = self.compute_model_predictions(node_features, self_loop_sub_edge_index, 351 | torch_self_loop_mask_matrix_bool, 352 | new_node_idx, batch_size, target_class) 353 | 354 | # revert back if add_self_loops are disabled 355 | if add_self_loops_swithced: 356 | switch_add_self_loops(self.model) 357 | del torch_self_loop_mask_matrix_bool 358 | 359 | 360 | del self_loop_sub_edge_index 361 | pred_time = time.time() 362 | log.info(f"Model predictions time(s):{pred_time - sampling_time:.4f}") 363 | 364 | fraction_evaluated = nsamples / self.sampler.max_samples 365 | 366 | # We'll most likely get OOM error if we use WLSSolver on GPU for large number of players. 367 | # This can be disabled if GPU has enough memory. Our GPU has 24GB memory. 368 | if nplayers > 5000 and solver_name == "WLSSolver" and device != 'cpu': 369 | solver_name = "WLRSolver" 370 | log.warning(f"Switching to WLRSolver. Reason: large number of players: {nplayers}") 371 | 372 | solver = get_solver(solver_name=solver_name, 373 | mask_matrix=mask_matrix, 374 | kernel_weights=kernel_weights, 375 | yhat=preds, fnull=self.fnull, 376 | ffull=self.fx, device=device, fraction_evaluated=fraction_evaluated, 377 | l1_reg=l1_reg) 378 | 379 | shap_vals = solver.solve() 380 | 381 | solve_time = time.time() 382 | 383 | log.info(f"Solve time(s):\t{solve_time - pred_time:.4f}") 384 | 385 | 386 | 387 | # non-relabeled computional edge index 388 | sub_edge_index = self.data.edge_index[:, sub_edge_mask] 389 | 390 | total_time = time.time() - start_time 391 | 392 | 393 | explanation = GNNShapExplanation(node_idx, nplayers, float(self.fnull), shap_vals, nsamples, 394 | self.fx, target_class, sub_edge_index, subset, 395 | self.data.y[subset], 396 | time_total_comp=total_time, 397 | time_comp_graph=compgraph_time - start_time, 398 | time_sampling=sampling_time - compgraph_time, 399 | time_predictions=pred_time - sampling_time, 400 | time_solver=solve_time - pred_time, 401 | computed_samples=comp_samp) 402 | 403 | return explanation 404 | -------------------------------------------------------------------------------- /gnnshap/explanation.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from typing import Union 3 | 4 | import matplotlib.pyplot as plt 5 | import networkx as nx 6 | import numpy as np 7 | import pandas as pd 8 | import torch 9 | from matplotlib.lines import Line2D 10 | from numba.core.errors import (NumbaDeprecationWarning, 11 | NumbaPendingDeprecationWarning) 12 | from shap._explanation import Explanation as SHAPExplanation 13 | from shap.plots._bar import bar 14 | from shap.plots._force import force as shap_force_plt 15 | from torch import Tensor 16 | 17 | from gnnshap.eval_utils import fidelity 18 | from gnnshap.utils import get_logger 19 | 20 | warnings.simplefilter('ignore', category=NumbaDeprecationWarning) 21 | warnings.simplefilter('ignore', category=NumbaPendingDeprecationWarning) 22 | 23 | 24 | log = get_logger(__name__) 25 | 26 | 27 | class GNNShapExplanation: 28 | """This class is used to return explanation results. Tensor values are converted and stored 29 | as numpy array for convention. Time parts are used for benchmarking and they are optional. 30 | The results can be visualized, or fidelity scores can be computed via methods. 31 | 32 | Args: 33 | node_idx (int): Explained node idx. 34 | nplayers (int): Number of players. 35 | base_value (int): Base value. 36 | shap_values (np.array): Shapley values. 37 | nsamples (int): Number of samples. 38 | fx (float): Model's prediction with the subgraph. 39 | target_class (int): Target class for the explanation. 40 | sub_edge_index (Tensor): computational graph's edge index. 41 | sub_nodes (Tensor): nodes in the computational graph. 42 | sub_node_labels (Tensor): node labels. 43 | time_total_comp (float, optional): Total SHAP computation time. Defaults to None. 44 | time_comp_graph (float, optional): computational graph extraction time. Defaults to None. 45 | time_sampling (float, optional): Sampling time. Defaults to None. 46 | time_predictions (float, optional): Total coalition predictions time. Defaults to None. 47 | time_solver (float, optional): Solver time. Defaults to None. 48 | computed_samples (int, optional): Number of computed samples. Some samples doesn't need 49 | computing when the target node is isolated. Defaults to None. 50 | **kwargs (dict): Other arguments. 51 | """ 52 | 53 | def __init__(self, node_idx: int, nplayers: int, base_value: int, 54 | shap_values: np.array, nsamples: int, fx: float, target_class: int, 55 | sub_edge_index: np.array, sub_nodes: np.array, sub_node_labels: np.array, 56 | time_total_comp: float = None, time_comp_graph: float = None, 57 | time_sampling: float = None, time_predictions: float = None, 58 | time_solver: float = None, **kwargs: dict): 59 | self.node_idx = node_idx 60 | self.nplayers = nplayers 61 | self.base_value = base_value 62 | self.shap_values = shap_values 63 | self.nsamples = nsamples 64 | self.fx = fx 65 | self.target_class = target_class 66 | self.sub_edge_index = sub_edge_index.detach().cpu().numpy() 67 | self.sub_nodes = sub_nodes.detach().cpu().numpy() 68 | self.sub_node_labels = sub_node_labels.detach().cpu().numpy() 69 | 70 | self.time_total_comp = time_total_comp 71 | self.time_comp_graph = time_comp_graph 72 | self.time_sampling = time_sampling 73 | self.time_predictions = time_predictions 74 | self.time_solver = time_solver 75 | self.computed_samples = kwargs.get('computed_samples', None) 76 | self.kwargs = kwargs 77 | 78 | def __get_edge_names(self) -> list: 79 | """Gets edge names as list of strings in `source -> target` format. 80 | 81 | Returns: 82 | list: edge names as list of strings. 83 | """ 84 | labels = [] 85 | for src, trgt in self.sub_edge_index.T: 86 | lbl = f'{src}\u2192{trgt}' 87 | labels.append(lbl) 88 | 89 | return labels 90 | 91 | def __get_edge_names_series(self) -> pd.Series: 92 | """Gets edge names and shapley values as pandas series. This can be used if shapley values 93 | wants to be seen together with edge names in the visualizations. 94 | 95 | Returns: 96 | pd.Series: Pandas series 97 | """ 98 | tmp_dict = {} 99 | for i, (src, trgt) in enumerate(self.sub_edge_index.T): 100 | lbl = f'{src}\u2192{trgt}' 101 | tmp_dict[lbl] = self.shap_values[i] 102 | return pd.Series(tmp_dict) 103 | 104 | def plot_force(self, contrib_threshold: float = 0.005, 105 | show_values: bool = False) -> None: 106 | """Plots force plot using SHAP package's force plot. 107 | 108 | Args: 109 | contrib_threshold (float, optional): A threshold value to discard some edges. 110 | Defaults to 0.005. 111 | show_values (bool, optional): Shows Shapley values along with edge names. 112 | Defaults to False. 113 | """ 114 | features = self.__get_edge_names_series() if show_values else None 115 | feature_names = self.__get_edge_names() 116 | shap_force_plt(self.base_value, self.shap_values, features, feature_names, matplotlib=True, 117 | contribution_threshold=contrib_threshold) 118 | 119 | def plot_bar(self, max_display: int = 10, show=True) -> None: 120 | """Plots force plot using SHAP package's force plot. 121 | 122 | Args: 123 | max_display (int, optional): A threshold to show top number of edges/players. 124 | Defaults to 10. 125 | show (bool, optional): Shows the plot. Defaults to True. Set to False if you want more 126 | costumization. 127 | """ 128 | feature_names = self.__get_edge_names() 129 | 130 | shap_explanation = SHAPExplanation(self.shap_values, np.array([self.base_value]), 131 | feature_names=feature_names) 132 | bar(shap_explanation, max_display=max_display, show=show) 133 | 134 | def plot_graph(self, topk: int = 25, save_path: str = None, pos=None, show_scores: bool = False, 135 | return_pos: bool = False, show: bool = True) -> None: 136 | """Plots computational computational graph with topk edges. Since it uses topk edges, some 137 | nodes will not be visible if connecting edges to target node not in the topk. 138 | 139 | Args: 140 | topk (int, optional): maximum number of topk edges in the plot. Defaults to 25. 141 | save_path (str, optional): Save path. The plot will be saved if provided. 142 | Defaults to None. 143 | pos (dict, optional): Position dictionary. If not provided, it will be computed. 144 | Defaults to None. 145 | show_scores (bool, optional): Shows Shapley values along with edges. 146 | Defaults to False. 147 | return_pos (bool, optional): Returns position dictionary. Defaults to False. 148 | show (bool, optional): Shows the plot. Defaults to True. Set to False if you want more 149 | costumization. 150 | 151 | Returns: 152 | dict: position dictionary. 153 | 154 | """ 155 | 156 | # maximum one at the first index 157 | top_edges = np.argsort(-np.abs(self.shap_values)) 158 | topk_edges = top_edges[:topk] 159 | max_shap_val = np.abs(self.shap_values).max() 160 | color_list = ['orange', 'blue', 'red', 'green', 161 | '#D3B98C', "lightblue", '#D3A38C', 'red', 'yellow', 162 | 'pink', 'grey', 'purple', 'gold'] 163 | 164 | # predefined colors are not enough for dataset. Remaining class backgrounds are white. 165 | if self.sub_node_labels.max() > len(color_list): 166 | color_list += ['white' for i in range( 167 | self.sub_node_labels.max() - len(color_list))] 168 | log.warning("Predefined colors are not enough for each class. Classes from %d to %d" 169 | "are colored as white", len(color_list), self.sub_node_labels.max()) 170 | 171 | fig, ax = plt.subplots() 172 | 173 | G = nx.DiGraph() 174 | for i in topk_edges: 175 | G.add_edge(int(self.sub_edge_index[0, i]), int(self.sub_edge_index[1, i]), 176 | label=self.shap_values[i]) 177 | 178 | edge_labels = nx.get_edge_attributes(G, 'label') 179 | formatted_edge_labels = { 180 | (elem[0], elem[1]): f"{edge_labels[elem]:.4f}" for elem in edge_labels} 181 | 182 | edge_colors = ['blue' if edge_labels[elem] < 0 else 'red' for elem in edge_labels] 183 | 184 | nodes = list(G.nodes) 185 | node_sizes = [600 if n == self.node_idx else 500 for n in nodes] 186 | 187 | node_colors = [color_list[self.sub_node_labels 188 | [np.where(self.sub_nodes == n)[0][0]]] for n in nodes] 189 | node_border_colors = ['black' if n == self.node_idx 190 | else node_colors[i] for i, n in enumerate(nodes)] 191 | 192 | 193 | #pos = nx.kamada_kawai_layout(G, scale=5) 194 | if pos is None: 195 | pos = nx.spring_layout(G, scale=5) 196 | edge_transparency = np.array( 197 | [np.abs(v/max_shap_val) for v in self.shap_values[topk_edges]]) 198 | tmp_min, tmp_max = edge_transparency.min(), edge_transparency.max() 199 | # scale transparencies between 0.2 and 1.0 200 | edge_transparency = ( 201 | 1-0.2) * ((edge_transparency - tmp_min)/(tmp_max - tmp_min)) + 0.2 202 | 203 | nx.draw_networkx_nodes(G, pos, node_color=node_colors, node_size=node_sizes, 204 | edgecolors=node_border_colors) 205 | nx.draw_networkx_labels( 206 | G, pos, labels={n: f'{n}'for n in nodes}, font_size=7) 207 | nx.draw_networkx_edges(G, pos, width=2, connectionstyle="arc3,rad=0.1", arrows=True, 208 | arrowsize=14, arrowstyle='-|>', node_size=500, 209 | edge_color=edge_colors, alpha=edge_transparency) 210 | 211 | legend_labels = ["lower", "higher"] 212 | handles = [Line2D([0], [0], color='b', lw=2, label='Line'), 213 | Line2D([0], [0], color='r', lw=2, label='Line2')] 214 | ax.legend(handles, legend_labels, loc='best', fontsize='small', 215 | fancybox=True, framealpha=0.7) 216 | 217 | if show_scores: 218 | nx.draw_networkx_edge_labels(G, pos, edge_labels=formatted_edge_labels, 219 | rotate=True, label_pos=0.75, font_size=6, ax=ax, 220 | bbox=dict(alpha=0)) 221 | 222 | if save_path is not None: 223 | plt.savefig(save_path) 224 | 225 | if show: 226 | plt.show() 227 | 228 | if return_pos: 229 | return pos 230 | 231 | def fidelity_prob(self, model: torch.nn.Module, data: Tensor, sparsity: float=0.1, 232 | score_type: str = 'neg', topk: int = 0, 233 | apply_abs: bool = True) -> Union[float, float]: 234 | """Computes fidelity probabilty score. Score type 'neg' computes fidelity- score, and 'pos' 235 | computes fidelity+ score. If topk is other than 0, then fidelity score is computed by 236 | droping k edges. 237 | 238 | Args: 239 | model (torch.nn.Module): A PyG model. 240 | data (Tensor): A PyG data. 241 | sparsity (float): Target sparsity value. It should be a value in range (0, 1). Defaults 242 | to 0.1. 243 | score_type (str, optional): 'neg' for fidelity- or 'pos' for fidelity+. 244 | Defaults to 'neg'. 245 | topk (int, optional): It is used to drop k edges if not 0. Sparsity value will not be 246 | used if topk is used (a dummy value can be provided for sparsity). Defaults to 0. 247 | apply_abs (bool, optional): Applies absolute to scores and fidelity. Defaults to True. 248 | 249 | Returns: 250 | Union[float, float]: fidelity score and sparsity value. 251 | """ 252 | assert 0 <= sparsity <= 1, "Sparsity should be between zero and one." 253 | assert topk >= 0, "topk cannot be a negative number" 254 | 255 | res = fidelity(self.result2dict(), data, model, sparsity, score_type, topk, 256 | self.target_class, apply_abs) 257 | 258 | return res[2], res[3] 259 | 260 | def result2dict(self) -> dict: 261 | """Converts an explanation result to a dictionary. 262 | 263 | Args: 264 | node_id (int): node id 265 | scores (np.array): importance scores 266 | 267 | Returns: 268 | dict: result as dictionary 269 | """ 270 | return {'node_id': self.node_idx, 'scores': self.shap_values, 271 | 'num_players': self.nplayers, 'num_samples': self.nsamples, 272 | 'base_val': self.base_value, 273 | 'time': self.time_total_comp} 274 | -------------------------------------------------------------------------------- /gnnshap/samplers/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from ._base import BaseSampler 4 | from ._exact import SHAPExactSampler 5 | from ._gnnshap import GNNShapSampler 6 | from ._svx import SVXSampler 7 | from ._shap import SHAPSampler 8 | from ._shap_unique import SHAPUniqueSampler 9 | 10 | 11 | def get_sampler(nplayers: int, nsamples: int = None, sampler_name: str = "shap_sampler", 12 | **kwargs: Optional) -> BaseSampler: 13 | """Returns the instanciated sampler based on the name. 14 | 15 | Args: 16 | nplayers (int): number of players 17 | nsamples (int, optional): number of samples. Defaults to None. 18 | sampler_name (str, optional): sampler name. Defaults to "shap_sampler". 19 | kwargs (optional): extra arguments if sampler needs it. 20 | 21 | Raises: 22 | KeyError: Raises error if sampler doesnot exist. 23 | 24 | Returns: 25 | BaseSampler: A sampler class instance. 26 | """ 27 | samplers = { 28 | 'SHAPSampler': SHAPSampler, 29 | 'SHAPExactSampler': SHAPExactSampler, 30 | 'SHAPUniqueSampler': SHAPUniqueSampler, 31 | 'GNNShapSampler': GNNShapSampler, 32 | 'SVXSampler': SVXSampler 33 | } 34 | 35 | try: 36 | return samplers[sampler_name](nplayers=nplayers, nsamples=nsamples, **kwargs) 37 | except KeyError: 38 | raise KeyError(f"Sampler '{sampler_name}' not found!") 39 | -------------------------------------------------------------------------------- /gnnshap/samplers/_base.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Tuple 3 | 4 | from torch import Tensor 5 | 6 | from gnnshap.utils import get_logger 7 | 8 | log = get_logger(__name__) 9 | 10 | 11 | class BaseSampler(ABC): 12 | """A base class for samplers 13 | """ 14 | 15 | def __init__(self, nplayers: int, nsamples: int) -> None: 16 | if not isinstance(nplayers, int): 17 | raise TypeError("nplayers must be an integer.") 18 | assert nplayers > 1, "Number of players should be greater than 1" 19 | if not isinstance(nsamples, int): 20 | raise TypeError("nsamples must be an integer.") 21 | if nsamples is not None: 22 | assert nsamples > 1, "Number of samples should be a positive number" 23 | self.nplayers = nplayers 24 | self.nsamples = nsamples 25 | 26 | self.max_samples = 2 ** 30 27 | if self.nplayers <= 30: 28 | self.max_samples = 2 ** self.nplayers - 2 29 | # don't use more samples than 2 ** 30 30 | self.nsamples = min(self.nsamples, self.max_samples) 31 | 32 | @abstractmethod 33 | def sample(self) -> Tuple[Tensor, Tensor]: 34 | """An abstract method that all samplers should override. 35 | 36 | Returns: 37 | Tuple[Tensor, Tensor]: 2d booelan mask_matrix and 1d coalition weights. 38 | """ -------------------------------------------------------------------------------- /gnnshap/samplers/_exact.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | from collections.abc import Iterable 3 | from typing import Tuple 4 | 5 | import numpy as np 6 | import torch 7 | from torch import Tensor 8 | import scipy 9 | 10 | from gnnshap.utils import get_logger 11 | 12 | from ._base import BaseSampler 13 | 14 | log = get_logger(__name__) 15 | 16 | class SHAPExactSampler(BaseSampler): 17 | """Brute Force Kernel SHAP sampler from: 18 | `SHAP package. `__ 19 | It samples all :math:`2^{N}` possible coalitions. Note that it does not include 20 | empty (no players) and full (all players) coalitions. In addition, the weights are normalized 21 | to match Shapley formula. It is not suggested to use this sampler since it is not practical. 22 | """ 23 | 24 | def __init__(self, nplayers: int, **kwargs) -> None: 25 | """Only requires number of players 26 | 27 | Args: 28 | nplayers (int): number of players 29 | 30 | Raises: 31 | AssertionError: Raises error if there are more than 30 players since it is 32 | not practical. 33 | """ 34 | if nplayers > 30: 35 | raise AssertionError("It is not possible to iterate all possible coalitions" 36 | f" when there are more than 30 players: 2^30={2 ** 30}") 37 | super().__init__(nplayers=nplayers, nsamples=2 ** nplayers) 38 | 39 | def shapley_kernel(self, s): 40 | """ 41 | Computes coalition weight 42 | :param M: total number of players 43 | :param s: number of players in the coalition 44 | :return: coalition weight 45 | """ 46 | M = self.nplayers 47 | # return a large number for empty and full coalition 48 | if s == 0 or s == M: 49 | return 10000 50 | if scipy.special.binom(M, s) == float('+inf'): 51 | return 0 52 | return (M - 1) / (scipy.special.binom(M, s) * s * (M - s)) 53 | 54 | def _powerset(self, iterable: Iterable) -> itertools.chain: 55 | """Generates and returns powerset. 56 | 57 | Args: 58 | iterable (Iterable): an iterable object. Example: range(10) 59 | 60 | Returns: 61 | itertools.chain: a chain object. 62 | """ 63 | coal_size = list(iterable) 64 | return itertools.chain.from_iterable( 65 | itertools.combinations(coal_size, r) for r in range(len(coal_size) + 1)) 66 | 67 | def sample(self) -> Tuple[Tensor, Tensor]: 68 | """Returns all possible coalitions and weights. Note that it doesn't include empty and full 69 | coalitions (Solver does not need them). 70 | 71 | Returns: 72 | Tuple[Tensor, Tensor]: mask_matrix and kernel weights. 73 | """ 74 | mask_matrix = np.zeros((2 ** self.nplayers, self.nplayers)) 75 | weights = np.zeros(2 ** self.nplayers) 76 | 77 | # exact kernel weights 78 | p_w = np.array([self.shapley_kernel(s) for s in range(0, self.nplayers + 1)]) 79 | 80 | for i, coal in enumerate(self._powerset(range(self.nplayers))): 81 | coal = list(coal) 82 | mask_matrix[i, coal] = 1 83 | weights[i] = p_w[len(coal)] # shapley_kernel(M, len(s)) 84 | 85 | return (torch.tensor(mask_matrix[1:-1], requires_grad=False).bool(), 86 | torch.tensor(weights[1:-1], requires_grad=False)) 87 | -------------------------------------------------------------------------------- /gnnshap/samplers/_gnnshap.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import os 3 | from typing import Tuple 4 | 5 | import numpy as np 6 | import torch 7 | from scipy.special import binom 8 | from torch import Tensor 9 | 10 | from gnnshap.utils import get_logger 11 | 12 | from ._base import BaseSampler 13 | 14 | log = get_logger(__name__) 15 | 16 | from torch.utils.cpp_extension import load 17 | 18 | cppsamp = load(name='cudaGNNShapSampler', sources=['cppextension/cudagnnshap.cu'], 19 | extra_cflags=['-O2'], verbose=False) 20 | 21 | 22 | class GNNShapSampler(BaseSampler): 23 | r"""This sampling algorithm is implemented in Cuda to speed up the sampling process. It 24 | creates samples in parallel. The number of blocks and threads can be adjusted. 25 | The total weights are scaled to 100 to increase numerical stability. 26 | """ 27 | 28 | def __init__(self, nplayers: int, nsamples: int, **kwargs) -> None: 29 | """number of players and number of samples are required. 30 | 31 | Args: 32 | nplayers (int): number of players 33 | nsamples (int): number of samples 34 | num_blocks (int, optional): number of blocks for cuda. Defaults to 16. 35 | num_threads (int, optional): number of threads for cuda. Defaults to 128. 36 | """ 37 | super().__init__(nplayers=nplayers, nsamples=nsamples) 38 | self.num_blocks = kwargs.get('num_blocks', 16) 39 | self.num_threads = kwargs.get('num_threads', 128) 40 | 41 | def sample(self) -> Tuple[Tensor, Tensor]: 42 | mask_matrix = torch.zeros((self.nsamples, self.nplayers), 43 | dtype=torch.bool, requires_grad=False).cuda() 44 | kernel_weights = torch.zeros((self.nsamples), dtype=torch.float64, 45 | requires_grad=False).cuda() 46 | 47 | cppsamp.sample(mask_matrix, kernel_weights, self.nplayers, self.nsamples, 48 | self.num_blocks, self.num_threads) 49 | return mask_matrix, kernel_weights 50 | -------------------------------------------------------------------------------- /gnnshap/samplers/_shap.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import itertools 3 | from typing import Tuple 4 | 5 | import numpy as np 6 | import torch 7 | from scipy.special import binom 8 | from torch import Tensor 9 | 10 | from gnnshap.utils import get_logger 11 | 12 | from ._base import BaseSampler 13 | 14 | log = get_logger(__name__) 15 | 16 | 17 | class SHAPSampler(BaseSampler): 18 | r"""`"This sampling algorithm is a modified version of KernelSHAP from:" 19 | `. This one skips 20 | uniqueness check. It is faster than the original one. 21 | """ 22 | 23 | def __init__(self, nplayers: int, nsamples: int, **kwargs) -> None: 24 | """number of players and number of samples are required. 25 | 26 | Args: 27 | nplayers (int): number of players 28 | nsamples (int): number of samples 29 | """ 30 | super().__init__(nplayers=nplayers, nsamples=nsamples) 31 | 32 | def sample(self) -> Tuple[Tensor, Tensor]: 33 | r"""Returns sampled coalitions and weights. Note that it doesn't include empty and full 34 | coalitions(Solver does not need them). It's identical to ShapSampler except sampled 35 | coalitions uniqueness check is skipped. Please refer to the ShapSampler for the details. 36 | 37 | Returns: 38 | Tuple[Tensor, Tensor]: mask_matrix (boolean) and weights (float) 39 | """ 40 | mask_matrix = np.zeros((self.nsamples, self.nplayers)) 41 | kernel_weights = np.zeros(self.nsamples) 42 | nsamples_added = 0 43 | 44 | def addsample(m, w, n_samples_added): 45 | mask_matrix[n_samples_added, :] = m 46 | kernel_weights[n_samples_added] = w 47 | 48 | # weight the different subset sizes 49 | num_subset_sizes = int(np.ceil((self.nplayers - 1) / 2.0)) 50 | 51 | # coalition size in the middle not a paired subset 52 | # if nplayers=4, 1 and 3 are pairs, 2 doesnt have a pair 53 | num_paired_subset_sizes = int(np.floor((self.nplayers - 1) / 2.0)) 54 | 55 | weight_vector = np.array([(self.nplayers - 1.0) / (i * ( 56 | self.nplayers - i)) for i in range(1, num_subset_sizes + 1)]) 57 | weight_vector[:num_paired_subset_sizes] *= 2 58 | weight_vector /= np.sum(weight_vector) 59 | 60 | # fill out all the subset sizes we can completely enumerate 61 | # given nsamples*remaining_weight_vector[subset_size] 62 | num_full_subsets = 0 63 | num_samples_left = self.nsamples 64 | # no grouping in edge based shap 65 | group_inds = np.arange(self.nplayers, dtype='int64') 66 | mask = np.zeros(self.nplayers) 67 | remaining_weight_vector = copy.copy(weight_vector) 68 | 69 | for subset_size in range(1, num_subset_sizes + 1): 70 | 71 | # determine how many subsets (and their complements) are of the current size 72 | nsubsets = binom(self.nplayers, subset_size) 73 | if subset_size <= num_paired_subset_sizes: 74 | nsubsets *= 2 75 | 76 | # see if we have enough samples to enumerate all subsets of this size 77 | if num_samples_left * remaining_weight_vector[subset_size - 1] / nsubsets >= 1.0 - 1e-8: 78 | num_full_subsets += 1 79 | num_samples_left -= nsubsets 80 | 81 | # rescale what's left of the remaining weight vector to sum to 1 82 | # it works as like not used samples distributed to other bins. 83 | if remaining_weight_vector[subset_size - 1] < 1.0: 84 | remaining_weight_vector /= (1 - 85 | remaining_weight_vector[subset_size - 1]) 86 | 87 | # add all the samples of the current subset size 88 | w = weight_vector[subset_size - 1] / \ 89 | binom(self.nplayers, subset_size) 90 | if subset_size <= num_paired_subset_sizes: 91 | w /= 2.0 92 | for inds in itertools.combinations(group_inds, subset_size): 93 | mask[:] = 0.0 94 | mask[np.array(inds, dtype='int64')] = 1.0 95 | addsample(mask, w, nsamples_added) 96 | nsamples_added += 1 97 | 98 | if subset_size <= num_paired_subset_sizes: 99 | mask[:] = np.abs(mask - 1) 100 | addsample(mask, w, nsamples_added) 101 | nsamples_added += 1 102 | else: 103 | break 104 | 105 | # add random samples from what is left of the subset space 106 | nfixed_samples = nsamples_added 107 | samples_left = self.nsamples - nsamples_added 108 | if num_full_subsets != num_subset_sizes: 109 | remaining_weight_vector = copy.copy(weight_vector) 110 | # because we draw two samples each below 111 | remaining_weight_vector[:num_paired_subset_sizes] /= 2 112 | remaining_weight_vector = remaining_weight_vector[num_full_subsets:] 113 | remaining_weight_vector /= np.sum(remaining_weight_vector) 114 | 115 | # four times generated since it does not sample same coalition twice. 116 | # we use random samples until we reach target number of samples. 117 | ind_set = np.random.choice(len(remaining_weight_vector), 4 * samples_left, 118 | p=remaining_weight_vector) 119 | ind_set_pos = 0 120 | # used_masks = {} 121 | while samples_left > 0 and ind_set_pos < len(ind_set): 122 | mask.fill(0.0) 123 | # we call np.random.choice once to save time and then just read it here 124 | ind = ind_set[ind_set_pos] 125 | ind_set_pos += 1 126 | subset_size = ind + num_full_subsets + 1 127 | mask[np.random.permutation(self.nplayers)[:subset_size]] = 1.0 128 | 129 | samples_left -= 1 130 | addsample(mask, 1.0, nsamples_added) 131 | nsamples_added += 1 132 | # add the symmetric sample 133 | if samples_left > 0 and subset_size <= num_paired_subset_sizes: 134 | mask[:] = np.abs(mask - 1) 135 | samples_left -= 1 136 | addsample(mask, 1.0, nsamples_added) 137 | nsamples_added += 1 138 | 139 | # normalize the kernel weights for the random samples to equal the weight left after 140 | # the fixed enumerated samples have been already counted 141 | weight_left = np.sum(weight_vector[num_full_subsets:]) 142 | kernel_weights[nfixed_samples:] *= weight_left / kernel_weights[nfixed_samples:].sum() 143 | 144 | return (torch.tensor(mask_matrix, requires_grad=False).bool(), 145 | torch.tensor(kernel_weights, requires_grad=False)) -------------------------------------------------------------------------------- /gnnshap/samplers/_shap_unique.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import itertools 3 | from typing import Tuple 4 | 5 | import numpy as np 6 | import torch 7 | from scipy.special import binom 8 | from torch import Tensor 9 | 10 | from gnnshap.utils import get_logger 11 | 12 | from ._base import BaseSampler 13 | 14 | log = get_logger(__name__) 15 | 16 | 17 | class SHAPUniqueSampler(BaseSampler): 18 | r"""`"This sampling algorithm is a modified version of KernelSHAP from:" 19 | `. It makes sure 20 | each coalition is sampled only once. 21 | """ 22 | 23 | def __init__(self, nplayers: int, nsamples: int, **kwargs) -> None: 24 | """number of players and number of samples are required. 25 | 26 | Args: 27 | nplayers (int): number of players 28 | nsamples (int): number of samples 29 | """ 30 | super().__init__(nplayers=nplayers, nsamples=nsamples) 31 | 32 | def sample(self) -> Tuple[Tensor, Tensor]: 33 | r"""Returns sampled coalitions and weights. Note that it doesn't include empty and full 34 | coalitions(Solver does not need them). 35 | 36 | It makes symmetric sampling: if [0, 0, 1] is sampled, [1, 1, 0] is also included. 37 | Therefore, the loops care only half of the samples. 38 | 39 | Returns: 40 | Tuple[Tensor, Tensor]: mask_matrix (boolean) and weights 41 | """ 42 | 43 | mask_matrix = np.zeros((self.nsamples, self.nplayers)) 44 | kernel_weights = np.zeros(self.nsamples) 45 | nsamples_added = 0 46 | 47 | def addsample(m, w, n_samples_added): 48 | mask_matrix[n_samples_added, :] = m 49 | kernel_weights[n_samples_added] = w 50 | 51 | # weight the different subset sizes 52 | num_subset_sizes = int(np.ceil((self.nplayers - 1) / 2.0)) 53 | 54 | # coalition size in the middle not a paired subset 55 | # if nplayers=4, 1 and 3 are pairs, 2 doesnt have a pair 56 | num_paired_subset_sizes = int(np.floor((self.nplayers - 1) / 2.0)) 57 | 58 | weight_vector = np.array([(self.nplayers - 1.0) / (i * ( 59 | self.nplayers - i)) for i in range(1, num_subset_sizes + 1)]) 60 | weight_vector[:num_paired_subset_sizes] *= 2 61 | weight_vector /= np.sum(weight_vector) 62 | 63 | # fill out all the subset sizes we can completely enumerate 64 | # given nsamples*remaining_weight_vector[subset_size] 65 | num_full_subsets = 0 66 | num_samples_left = self.nsamples 67 | # no grouping in edge based shap 68 | group_inds = np.arange(self.nplayers, dtype='int64') 69 | mask = np.zeros(self.nplayers) 70 | remaining_weight_vector = copy.copy(weight_vector) 71 | 72 | for subset_size in range(1, num_subset_sizes + 1): 73 | 74 | # determine how many subsets (and their complements) are of the current size 75 | nsubsets = binom(self.nplayers, subset_size) 76 | if subset_size <= num_paired_subset_sizes: 77 | nsubsets *= 2 78 | 79 | # see if we have enough samples to enumerate all subsets of this size 80 | if num_samples_left * remaining_weight_vector[subset_size - 1] / nsubsets >= 1.0 - 1e-8: 81 | num_full_subsets += 1 82 | num_samples_left -= nsubsets 83 | 84 | # rescale what's left of the remaining weight vector to sum to 1 85 | # it works as like not used samples distributed to other bins. 86 | if remaining_weight_vector[subset_size - 1] < 1.0: 87 | remaining_weight_vector /= (1 - 88 | remaining_weight_vector[subset_size - 1]) 89 | 90 | # add all the samples of the current subset size 91 | w = weight_vector[subset_size - 1] / \ 92 | binom(self.nplayers, subset_size) 93 | if subset_size <= num_paired_subset_sizes: 94 | w /= 2.0 95 | for inds in itertools.combinations(group_inds, subset_size): 96 | mask[:] = 0.0 97 | mask[np.array(inds, dtype='int64')] = 1.0 98 | addsample(mask, w, nsamples_added) 99 | nsamples_added += 1 100 | 101 | if subset_size <= num_paired_subset_sizes: 102 | mask[:] = np.abs(mask - 1) 103 | addsample(mask, w, nsamples_added) 104 | nsamples_added += 1 105 | else: 106 | break 107 | 108 | # add random samples from what is left of the subset space 109 | nfixed_samples = nsamples_added 110 | samples_left = self.nsamples - nsamples_added 111 | if num_full_subsets != num_subset_sizes: 112 | remaining_weight_vector = copy.copy(weight_vector) 113 | # because we draw two samples each below 114 | remaining_weight_vector[:num_paired_subset_sizes] /= 2 115 | remaining_weight_vector = remaining_weight_vector[num_full_subsets:] 116 | remaining_weight_vector /= np.sum(remaining_weight_vector) 117 | 118 | # four times generated since it does not sample same coalition twice. 119 | # we use random samples until we reach target number of samples. 120 | ind_set = np.random.choice(len(remaining_weight_vector), 4 * samples_left, 121 | p=remaining_weight_vector) 122 | ind_set_pos = 0 123 | used_masks = {} 124 | while samples_left > 0 and ind_set_pos < len(ind_set): 125 | mask.fill(0.0) 126 | # we call np.random.choice once to save time and then just read it here 127 | ind = ind_set[ind_set_pos] 128 | ind_set_pos += 1 129 | subset_size = ind + num_full_subsets + 1 130 | mask[np.random.permutation(self.nplayers)[:subset_size]] = 1.0 131 | 132 | # only add the sample if we have not seen it before, otherwise just 133 | # increment a previous sample's weight 134 | mask_tuple = tuple(mask) 135 | new_sample = False 136 | if mask_tuple not in used_masks: # temporarily disabled 137 | new_sample = True 138 | used_masks[mask_tuple] = nsamples_added 139 | samples_left -= 1 140 | addsample(mask, 1.0, nsamples_added) 141 | nsamples_added += 1 142 | else: 143 | kernel_weights[used_masks[mask_tuple]] += 1.0 144 | 145 | # add the symmetric sample 146 | if samples_left > 0 and subset_size <= num_paired_subset_sizes: 147 | mask[:] = np.abs(mask - 1) 148 | 149 | # only add the sample if we have not seen it before, otherwise just 150 | # increment a previous sample's weight 151 | if new_sample: 152 | samples_left -= 1 153 | addsample(mask, 1.0, nsamples_added) 154 | nsamples_added += 1 155 | else: 156 | # we know the compliment sample is the next one after the 157 | # original sample, so + 1 158 | kernel_weights[used_masks[mask_tuple] + 1] += 1.0 159 | 160 | # normalize the kernel weights for the random samples to equal the weight left after 161 | # the fixed enumerated samples have been already counted 162 | weight_left = np.sum(weight_vector[num_full_subsets:]) 163 | kernel_weights[nfixed_samples:] *= weight_left / \ 164 | kernel_weights[nfixed_samples:].sum() 165 | 166 | return (torch.tensor(mask_matrix, requires_grad=False).bool(), 167 | torch.tensor(kernel_weights, requires_grad=False)) -------------------------------------------------------------------------------- /gnnshap/samplers/_svx.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import random 3 | import time 4 | from copy import deepcopy 5 | from itertools import combinations 6 | from typing import Tuple 7 | 8 | import numpy as np 9 | import scipy.special 10 | import torch 11 | from torch import Tensor 12 | 13 | from gnnshap.utils import get_logger 14 | 15 | from ._base import BaseSampler 16 | 17 | log = get_logger(__name__) 18 | 19 | class SVXSampler(BaseSampler): 20 | """ SVXSampler is based on GraphSVX’s "SmarterSeparate" sampling method. 21 | Source: https://raw.githubusercontent.com/AlexDuvalinho/GraphSVX/master/src/explainers.py 22 | """ 23 | 24 | def __init__(self, nplayers: int, nsamples: int, **kwargs) -> None: 25 | """number of players and number of samples are required. 26 | 27 | Args: 28 | nplayers (int): number of players 29 | nsamples (int): number of samples 30 | size_lim (int): maximum size of coalitions to sample from. Defaults to 3. 31 | """ 32 | super().__init__(nplayers=nplayers, nsamples=nsamples) 33 | self.size_lim = kwargs.get('size_lim', 3) 34 | 35 | def shapley_kernel(self, s, M): 36 | """ Computes a weight for each newly created sample 37 | 38 | Args: 39 | s (tensor): contains dimension of z for all instances 40 | (number of features + neighbours included) 41 | M (tensor): total number of features/nodes in dataset 42 | 43 | Returns: 44 | [tensor]: shapley kernel value for each sample 45 | """ 46 | shapley_kernel = [] 47 | 48 | for i in range(s.shape[0]): 49 | a = s[i].item() 50 | if a == 0 or a == M: 51 | # Enforce high weight on full/empty coalitions 52 | shapley_kernel.append(1000) 53 | elif scipy.special.binom(M, a) == float('+inf'): 54 | # Treat specific case - impossible computation 55 | shapley_kernel.append(1/ (M**2)) 56 | else: 57 | shapley_kernel.append( 58 | (M-1)/(scipy.special.binom(M, a)*a*(M-a))) 59 | 60 | shapley_kernel = np.array(shapley_kernel) 61 | shapley_kernel = np.where(shapley_kernel<1.0e-40, 1.0e-40,shapley_kernel) 62 | return torch.tensor(shapley_kernel) 63 | 64 | def smarter_separate(self): 65 | num_samples = self.nsamples 66 | M = self.nplayers 67 | args_K = self.size_lim 68 | z_ = torch.ones(num_samples, M) 69 | z_[1::2] = torch.zeros(num_samples//2, M) 70 | i = 0 # modified by sakkas. We don't need empty and full coalitions 71 | k = 1 72 | # Loop until all samples are created 73 | while i < num_samples: 74 | # Look at each feat/nei individually if have enough sample 75 | # Coalitions of the form (All nodes/feat, All-1 feat/nodes) & (No nodes/feat, 1 feat/nodes) 76 | if i + 2 * M < num_samples and k == 1: 77 | z_[i:i+M, :] = torch.ones(M, M) 78 | z_[i:i+M, :].fill_diagonal_(0) 79 | z_[i+M:i+2*M, :] = torch.zeros(M, M) 80 | z_[i+M:i+2*M, :].fill_diagonal_(1) 81 | i += 2 * M 82 | k += 1 83 | 84 | else: 85 | # Split in two number of remaining samples 86 | # Half for specific coalitions with low k and rest random samples 87 | #samp = i + 9*(num_samples - i)//10 88 | samp = num_samples 89 | while i < samp and k <= min(args_K, M): 90 | # Sample coalitions of k1 neighbours or k1 features without repet and order. 91 | L = list(combinations(range(0, M), k)) 92 | random.shuffle(L) 93 | L = L[:samp+1] 94 | 95 | for j in range(len(L)): 96 | # Coalitions (All nei, All-k feat) or (All feat, All-k nei) 97 | z_[i, L[j]] = torch.zeros(k) 98 | i += 1 99 | # If limit reached, sample random coalitions 100 | if i == samp: 101 | z_[i:, :] = torch.empty(num_samples-i, M).random_(2) 102 | return z_ 103 | # Coalitions (No nei, k feat) or (No feat, k nei) 104 | z_[i, L[j]] = torch.ones(k) 105 | i += 1 106 | # If limit reached, sample random coalitions 107 | if i == samp: 108 | z_[i:, :] = torch.empty(num_samples-i, M).random_(2) 109 | return z_ 110 | k += 1 111 | 112 | # Sample random coalitions 113 | z_[i:, :] = torch.empty(num_samples-i, M).random_(2) 114 | return z_ 115 | return z_ 116 | 117 | def sample(self) -> Tuple[Tensor, Tensor]: 118 | """Returns all possible coalitions and weights. Note that it doesn't include empty and full 119 | coalitions (Solver does not need them). 120 | 121 | Returns: 122 | Tuple[Tensor, Tensor]: mask_matrix and kernel weights. 123 | """ 124 | z_bis = self.smarter_separate() 125 | # z_bis = z_bis[torch.randperm(z_bis.size()[0])] # no need to shuffle 126 | s = (z_bis != 0).sum(dim=1) 127 | weights = self.shapley_kernel(s, self.nplayers) 128 | 129 | return z_bis, weights -------------------------------------------------------------------------------- /gnnshap/solvers/__init__.py: -------------------------------------------------------------------------------- 1 | from torch import Tensor 2 | 3 | from ._base import BaseSolver 4 | from ._wlr import WLRSolver 5 | from ._wls import WLSSolver 6 | 7 | 8 | def get_solver(solver_name: str, mask_matrix: Tensor, kernel_weights: Tensor, yhat: Tensor, 9 | fnull: float, ffull: float, **kwargs: dict) -> BaseSolver: 10 | """Returns the instanciated solver based on the name. 11 | 12 | Args: 13 | solver_name (str): Solver name 14 | mask_matrix (Tensor): mask matrix 15 | kernel_weights (Tensor): kernel weights 16 | yhat (Tensor): model predictions 17 | fnull (float): null model prediction 18 | ffull (float): full model prediction 19 | 20 | Raises: 21 | KeyError: If solver name is not found 22 | 23 | Returns: 24 | BaseSolver: Instanciated solver 25 | """ 26 | solvers = { 27 | 'WLSSolver': WLSSolver, 28 | 'WLRSolver': WLRSolver 29 | } 30 | 31 | try: 32 | return solvers[solver_name](mask_matrix, kernel_weights, yhat, fnull, ffull, **kwargs) 33 | except KeyError as exc: 34 | raise KeyError(f"Solver '{solver_name}' not found!") from exc 35 | -------------------------------------------------------------------------------- /gnnshap/solvers/_base.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Tuple 3 | 4 | from torch import Tensor 5 | import numpy as np 6 | 7 | from gnnshap.utils import get_logger 8 | 9 | log = get_logger(__name__) 10 | 11 | 12 | class BaseSolver(ABC): 13 | """A base class for solvers 14 | """ 15 | 16 | def __init__(self, mask_matrix: Tensor, kernel_weights: Tensor, 17 | yhat: Tensor, fnull: Tensor, ffull: Tensor, **kwargs: dict) -> None: 18 | """Common initialization for all solvers. 19 | 20 | Args: 21 | mask_matrix (Tensor): mask matrix 22 | kernel_weights (Tensor): kernel weights 23 | yhat (Tensor): model predictions 24 | fnull (Tensor): null model prediction 25 | ffull (Tensor): full model prediction 26 | **kwargs (dict): additional arguments 27 | """ 28 | self.mask_matrix = mask_matrix 29 | self.kernel_weights = kernel_weights 30 | self.yhat = yhat 31 | self.fnull = fnull 32 | self.ffull = ffull 33 | self.kwargs = kwargs 34 | 35 | @abstractmethod 36 | def solve(self) -> Tuple[np.array, dict]: 37 | """An abstract method that all solvers should override. 38 | 39 | Returns: 40 | Tuple[np.array, dict]: shapley values and solver statistics dictionary. 41 | """ 42 | -------------------------------------------------------------------------------- /gnnshap/solvers/_wlr.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import numpy as np 4 | import torch 5 | from torch import Tensor 6 | 7 | from gnnshap.utils import get_logger 8 | 9 | from ._base import BaseSolver 10 | 11 | log = get_logger(__name__) 12 | 13 | 14 | class WLRSolver(BaseSolver): 15 | """Solver that uses pytorch to solve the linear regression.""" 16 | def __init__(self, mask_matrix: Tensor, kernel_weights: Tensor, yhat: Tensor, 17 | fnull: Tensor, ffull: Tensor, **kwargs: dict) -> None: 18 | """Initialization for WLRSolver. 19 | 20 | Args: 21 | mask_matrix (Tensor): mask matrix 22 | kernel_weights (Tensor): kernel weights 23 | yhat (Tensor): model predictions 24 | fnull (float): null model prediction 25 | ffull (float): full model prediction 26 | **kwargs (dict): additional arguments 27 | """ 28 | super().__init__(mask_matrix, kernel_weights, yhat, fnull, ffull, **kwargs) 29 | self.device = kwargs.get('device', 'cpu') 30 | 31 | # will convert to double later 32 | self.mask_matrix = self.mask_matrix.to(self.device).type(torch.int8) 33 | self.kernel_weights = self.kernel_weights.to(self.device).unsqueeze(1) 34 | self.yhat = self.yhat.to(self.device) 35 | self.nplayers = self.mask_matrix.size(1) 36 | 37 | def solve(self) -> Tuple[np.array, dict]: 38 | r"""Solves weighted linear regression problem by training a linear model via PyTorch. 39 | 40 | Args: 41 | mask_matrix (Tensor): coalition matrix 42 | kernel_weights (Tensor): coalition weight values 43 | ey (Tensor): coalition predictions 44 | 45 | Returns: 46 | np.array: shapley_values 47 | """ 48 | 49 | # no need to add base value as player thanks to this: (base + shap_values) = ffull 50 | eyAdj = self.yhat - self.fnull 51 | del self.yhat 52 | 53 | # eliminate one variable with the constraint that all features sum to the output 54 | eyAdj2 = (eyAdj - self.mask_matrix[:, -1] * (self.ffull - self.fnull)).unsqueeze(1) 55 | etmp = self.mask_matrix[:, :-1] - self.mask_matrix[:, -1].unsqueeze(1) 56 | del self.mask_matrix 57 | 58 | 59 | torch.set_grad_enabled(True) 60 | 61 | lin_model = torch.nn.Linear(etmp.size(1), 1, dtype=torch.double, bias=False).to(self.device) 62 | optimizer = torch.optim.Adam(lin_model.parameters(), lr=0.001, weight_decay=0.005) 63 | 64 | 65 | etmp = etmp.double() 66 | # solve a weighted least squares equation to estimate phi 67 | lin_model.train() 68 | for i in range(200): 69 | optimizer.zero_grad() 70 | pred = lin_model(etmp) 71 | loss = torch.sum(self.kernel_weights * ((eyAdj2 - pred)**2)) 72 | loss.backward() 73 | optimizer.step() 74 | 75 | 76 | phi = torch.zeros(self.nplayers) 77 | phi[:-1] = lin_model.weight.squeeze() 78 | phi[-1] = (self.ffull - self.fnull) - torch.sum(lin_model.weight) 79 | 80 | # clean up any rounding errors 81 | #phi[torch.abs(phi) < 1e-10] = 0 82 | 83 | 84 | return phi.detach().numpy() -------------------------------------------------------------------------------- /gnnshap/solvers/_wls.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import numpy as np 4 | import torch 5 | from torch import Tensor 6 | 7 | from gnnshap.utils import get_logger 8 | 9 | from ._base import BaseSolver 10 | 11 | log = get_logger(__name__) 12 | 13 | 14 | class WLSSolver(BaseSolver): 15 | """Solver that uses pytorch to solve the weighted least squares problem.""" 16 | def __init__(self, mask_matrix: Tensor, kernel_weights: Tensor, yhat: Tensor, 17 | fnull: Tensor, ffull: Tensor, **kwargs: dict) -> None: 18 | """Initialization for WLSSolver. 19 | 20 | Args: 21 | mask_matrix (Tensor): mask matrix 22 | kernel_weights (Tensor): kernel weights 23 | yhat (Tensor): model predictions 24 | fnull (float): null model prediction 25 | ffull (float): full model prediction 26 | **kwargs (dict): additional arguments 27 | """ 28 | super().__init__(mask_matrix, kernel_weights, yhat, fnull, ffull, **kwargs) 29 | self.device = kwargs.get('device', 'cpu') 30 | 31 | self.mask_matrix = self.mask_matrix.to(self.device).double() 32 | self.kernel_weights = self.kernel_weights.to(self.device) 33 | self.yhat = self.yhat.to(self.device) 34 | 35 | def solve(self) -> Tuple[np.array, dict]: 36 | r"""Solves weighted least squares problem and learns shapley values 37 | 38 | Args: 39 | mask_matrix (Tensor): coalition matrix 40 | kernel_weights (Tensor): coalition weight values 41 | ey (Tensor): coalition predictions 42 | 43 | Returns: 44 | np.array: shapley_values 45 | """ 46 | 47 | 48 | # no need to add base value as player thanks to this: (base + shap_values) = ffull 49 | eyAdj = self.yhat - self.fnull 50 | 51 | # eliminate one variable with the constraint that all features sum to the output 52 | eyAdj2 = eyAdj - self.mask_matrix[:, -1] * (self.ffull - self.fnull) 53 | etmp = self.mask_matrix[:, :-1] - self.mask_matrix[:, -1].unsqueeze(1) 54 | 55 | 56 | # solve a weighted least squares equation to estimate phi 57 | tmp_transpose = (etmp * self.kernel_weights.unsqueeze(1)).transpose(0, 1) 58 | 59 | etmp_dot = torch.mm(tmp_transpose, etmp) 60 | try: 61 | tmp2 = torch.linalg.inv(etmp_dot) 62 | except torch.linalg.LinAlgError: 63 | tmp2 = torch.linalg.pinv(etmp_dot) 64 | print("Equation is singular, using pseudo-inverse.", 65 | "Consider increasing the number of samples.") 66 | w = torch.mm(tmp2, torch.mm(tmp_transpose, eyAdj2.unsqueeze(1)))[:,0].cpu() 67 | 68 | phi = torch.zeros(self.mask_matrix.size(1)) 69 | phi[:-1] = w 70 | phi[-1] = (self.ffull - self.fnull) - torch.sum(w) 71 | 72 | # clean up any rounding errors 73 | #phi[torch.abs(phi) < 1e-10] = 0 74 | 75 | 76 | return phi.numpy() 77 | -------------------------------------------------------------------------------- /gnnshap/utils.py: -------------------------------------------------------------------------------- 1 | 2 | import logging 3 | import sys 4 | from typing import List, Optional, Tuple, Union 5 | 6 | import colorlog 7 | import numpy as np 8 | import torch 9 | from torch import Tensor 10 | from torch_geometric.nn.conv import MessagePassing 11 | from torch_geometric.utils.num_nodes import maybe_num_nodes 12 | 13 | 14 | def get_logger(name: str) -> logging.Logger: 15 | """Returns a logger 16 | 17 | Args: 18 | name (str): Logger name 19 | 20 | Returns: 21 | logging.Logger: logger 22 | """ 23 | logger = colorlog.getLogger(name) 24 | handler = colorlog.StreamHandler(stream=sys.stdout) 25 | 26 | formatter = colorlog.ColoredFormatter( 27 | #"%(name)s: %(asctime)s : %(levelname)s : %(filename)s:%(lineno)s : %(message)s" 28 | "%(log_color)s%(filename)s:%(lineno)s:%(levelname)s: %(message)s", 29 | log_colors={ 30 | 'DEBUG': 'cyan', 31 | 'INFO': 'green', 32 | 'WARNING': 'yellow', 33 | 'ERROR': 'red', 34 | 'CRITICAL': 'red,bg_white', 35 | }, 36 | ) 37 | 38 | handler.setFormatter(formatter) 39 | logger.addHandler(handler) 40 | 41 | logger.setLevel(logging.WARNING) 42 | 43 | return logger 44 | 45 | 46 | 47 | 48 | def get_coalition_counts(mask_matrix: Union[np.array, Tensor]) -> np.array: 49 | """Finds counts of each coalition size for a given mask matrix. 50 | 51 | Args: 52 | mask_matrix (Union[np.array, Tensor]): mask matrix obtained from a sampler 53 | 54 | Returns: 55 | np.array: coalition counts 56 | """ 57 | if torch.is_tensor(mask_matrix): 58 | mask_matrix = mask_matrix.cpu().numpy() 59 | coal_sizes = mask_matrix.sum(1).astype(int) 60 | unique, counts = np.unique(coal_sizes, return_counts=True) 61 | return counts 62 | 63 | 64 | def get_coalition_size_weights(mask_matrix: Union[np.array, Tensor], 65 | weights: Union[np.array, Tensor]) -> np.array: 66 | """Finds sum of total weights for each coalition size. 67 | 68 | Args: 69 | mask_matrix (Union[np.array, Tensor]): mask matrix obtained from a sampler 70 | weights (Union[np.array, Tensor]): weights vector obtained from a sampler 71 | 72 | Returns: 73 | np.array: coalition size weights 74 | """ 75 | if torch.is_tensor(mask_matrix): 76 | mask_matrix = mask_matrix.cpu().numpy() 77 | 78 | if torch.is_tensor(weights): 79 | weights = weights.cpu().numpy() 80 | 81 | counts = mask_matrix.sum(1) 82 | nplayers = mask_matrix.shape[1] 83 | 84 | weight_sums = np.zeros(nplayers -1) 85 | for i in range(1, nplayers): 86 | weight_sums[i-1] = weights[counts == i].sum() 87 | return weight_sums 88 | 89 | 90 | def get_gnn_layers(model: torch.nn.Module) -> List[torch.nn.Module]: 91 | """Finds and returns GNN layers. 92 | 93 | Args: 94 | model (torch.nn.Module): pyg model. 95 | 96 | Returns: 97 | List[torch.nn.Module]: GNN layers as a list 98 | """ 99 | gnn_layers = [] 100 | for module in model.modules(): 101 | if isinstance(module, MessagePassing): 102 | gnn_layers.append(module) 103 | return gnn_layers 104 | 105 | def switch_add_self_loops(model: torch.nn.Module): 106 | """Switches each layers add_self_loops value to True or False. 107 | 108 | Args: 109 | model (torch.nn.Module): pyg model. 110 | """ 111 | layers = get_gnn_layers(model) 112 | for layer in layers: 113 | layer.add_self_loops = not layer.add_self_loops 114 | 115 | def switch_normalize(model: torch.nn.Module): 116 | """Switches each layers normalize value to True or False. 117 | 118 | Args: 119 | model (torch.nn.Module): pyg model. 120 | """ 121 | layers = get_gnn_layers(model) 122 | for layer in layers: 123 | layer.normalize = not layer.normalize 124 | 125 | def has_normalization(model: torch.nn.Module) -> bool: 126 | """Checks if gnn layers have normalization. It controls whether all layers 127 | have same configuration. 128 | 129 | Args: 130 | model (torch.nn.Module): pyg model. 131 | 132 | Raises: 133 | AssertionError: Raises assertion error if different layers have different configurations. 134 | AssertionError: Raises assertion error if there is no gnn layers. 135 | 136 | Returns: 137 | bool: boolean value whether gnn layers have normalization 138 | """ 139 | layers = get_gnn_layers(model) 140 | if len(layers) > 0: 141 | try: # some GNN types has no normalize attribute 142 | normalize = layers[0].normalize 143 | except: 144 | return False 145 | if len(layers) > 1: 146 | for layer in layers[1:]: 147 | if layer.normalize != normalize: 148 | raise AssertionError(("Layers have different normalization settings." 149 | " This is not supported!")) 150 | return normalize 151 | raise AssertionError("No GNN layers found!") 152 | 153 | 154 | def has_add_self_loops(model: torch.nn.Module) -> bool: 155 | """Checks if model adds self loops. It controls whether all layers have same configuration. 156 | 157 | Args: 158 | model (torch.nn.Module): pyg model. 159 | 160 | Raises: 161 | AssertionError: Raises assertion error if different layers have different configurations. 162 | AssertionError: Raises assertion error if there is no gnn layers. 163 | 164 | Returns: 165 | bool: boolean value whether model adds self loops or not. 166 | """ 167 | 168 | layers = get_gnn_layers(model) 169 | if len(layers) > 0: 170 | try: 171 | self_loop = layers[0].add_self_loops 172 | except: 173 | return False 174 | 175 | if len(layers) > 1: 176 | for layer in layers[1:]: 177 | if layer.add_self_loops != self_loop: 178 | raise AssertionError(("Layers have different add_self_loops settings." 179 | " This is not supported!")) 180 | return self_loop 181 | raise AssertionError("No GNN layers found!") 182 | 183 | 184 | @torch.no_grad() 185 | def pruned_comp_graph(node_idx: Union[int, List[int], Tensor], 186 | num_hops: int, 187 | edge_index: Tensor, 188 | relabel_nodes: bool = False, 189 | num_nodes: Optional[int] = None, 190 | flow: str = 'source_to_target', 191 | directed: bool = False) -> Tuple[Tensor, Tensor, Tensor, Tensor]: 192 | """Finds the pruned computational graph for a given node index. Similar to k_hop_subgraph, but 193 | k_hop_subgraph returns all edges between k-hop nodes. We are only interested in edges that 194 | carries message to target node in k_hops.""" 195 | 196 | num_nodes = maybe_num_nodes(edge_index, num_nodes) 197 | 198 | assert flow in ['source_to_target', 'target_to_source'] 199 | if flow == 'target_to_source': 200 | row, col = edge_index 201 | else: 202 | col, row = edge_index 203 | 204 | my_edge_mask = row.new_empty(row.size(0), dtype=torch.bool) # added by sakkas 205 | my_edge_mask.fill_(False) # added by sakkas 206 | 207 | node_mask = row.new_empty(num_nodes, dtype=torch.bool) 208 | edge_mask = row.new_empty(row.size(0), dtype=torch.bool) 209 | 210 | if isinstance(node_idx, (int, list, tuple)): 211 | node_idx = torch.tensor([node_idx], device=row.device).flatten() 212 | else: 213 | node_idx = node_idx.to(row.device) 214 | 215 | subsets = [node_idx] 216 | 217 | for _ in range(num_hops): 218 | node_mask.fill_(False) 219 | node_mask[subsets[-1]] = True 220 | torch.index_select(node_mask, 0, row, out=edge_mask)# input, dimension, index 221 | my_edge_mask[edge_mask] = True 222 | subsets.append(col[edge_mask]) 223 | 224 | subset, inv = torch.cat(subsets).unique(return_inverse=True) 225 | inv = inv[:node_idx.numel()] 226 | 227 | edge_index = edge_index[:, my_edge_mask] 228 | 229 | if relabel_nodes: 230 | node_idx = row.new_full((num_nodes, ), -1) 231 | node_idx[subset] = torch.arange(subset.size(0), device=row.device) 232 | edge_index = node_idx[edge_index] 233 | 234 | return subset, edge_index, inv, my_edge_mask 235 | -------------------------------------------------------------------------------- /models/GATModel.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch_geometric.nn.conv import GATConv 5 | from tqdm.auto import tqdm 6 | 7 | 8 | class GATModel(torch.nn.Module): 9 | def __init__(self, hidden_channels, 10 | num_features, num_classes, num_layers=2, 11 | add_self_loops=True, 12 | dropout = 0.0, 13 | normalize = True, 14 | log_softmax_return=False, 15 | heads=1): 16 | super().__init__() 17 | self.num_layers = num_layers 18 | self.dropout = dropout 19 | self.normalize = normalize 20 | self.add_self_loops = add_self_loops 21 | assert num_layers >= 2, "Number of layers should be two or larger" 22 | self.heads = heads 23 | self.convs = nn.ModuleList( 24 | [GATConv(num_features, hidden_channels, normalize=normalize, 25 | add_self_loops=add_self_loops, dropout=0.6, heads=heads)] + 26 | [GATConv(hidden_channels * heads, hidden_channels, normalize=normalize, 27 | add_self_loops=add_self_loops, dropout=0.6, heads=heads 28 | ) for i in range(num_layers - 2)] + 29 | [GATConv(hidden_channels * heads, num_classes, normalize=normalize, 30 | add_self_loops=add_self_loops, dropout=0.6, heads=1)]) 31 | self.softmax_fn = nn.LogSoftmax(dim=-1) if log_softmax_return else nn.Identity() 32 | 33 | def forward(self, x, edge_index, edge_weight=None): 34 | for i in range(self.num_layers - 1): 35 | x = F.relu(self.convs[i](x, edge_index, edge_weight)) 36 | x = F.dropout(x, p=self.dropout, training=self.training) 37 | x = self.convs[-1](x, edge_index, edge_weight) 38 | x = self.softmax_fn(x) # applied based on parameter. Default: not applied 39 | return x 40 | -------------------------------------------------------------------------------- /models/GCNModel.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch_geometric.nn import GCNConv 5 | from torch_geometric.utils import add_remaining_self_loops, degree 6 | from torch_geometric.utils.num_nodes import maybe_num_nodes 7 | from torch_scatter import scatter_add 8 | from tqdm.auto import tqdm 9 | 10 | 11 | class GCNModel(torch.nn.Module): 12 | r"""GCN model 13 | Args: 14 | hidden_channels (int): hidden layer dimensions 15 | num_features (int): number of input features 16 | num_classes (int): number of output classes 17 | num_layers (int, optional): number of layers. Defaults to 2. 18 | add_self_loops (bool, optional): whether to add self loops. Defaults to True. 19 | dropout (float, optional): dropout rate. Defaults to 0.0. 20 | normalize (bool, optional): whether to normalize. Defaults to True. 21 | log_softmax_return (bool, optional): whether to return raw output or log softmax. 22 | Defaults to False. 23 | """ 24 | 25 | def __init__(self, hidden_channels, 26 | num_features, num_classes, num_layers=2, 27 | add_self_loops=True, 28 | dropout = 0.0, 29 | normalize = True, 30 | log_softmax_return=False): 31 | 32 | super().__init__() 33 | self.num_layers = num_layers 34 | self.dropout = dropout 35 | self.normalize = normalize 36 | self.add_self_loops = add_self_loops 37 | assert num_layers >= 2, "Number of layers should be two or larger" 38 | self.convs = nn.ModuleList( 39 | [GCNConv(num_features, hidden_channels, normalize=normalize, 40 | add_self_loops=add_self_loops)] + 41 | [GCNConv(hidden_channels, hidden_channels, normalize=normalize, 42 | add_self_loops=add_self_loops) for i in range(num_layers - 2)] + 43 | [GCNConv(hidden_channels, num_classes, normalize=normalize, 44 | add_self_loops=add_self_loops)]) 45 | self.softmax_fn = nn.LogSoftmax(dim=-1) if log_softmax_return else nn.Identity() 46 | 47 | def forward(self, x, edge_index, edge_weight=None): 48 | for i in range(self.num_layers - 1): 49 | x = F.relu(self.convs[i](x, edge_index, edge_weight)) 50 | x = F.dropout(x, p=self.dropout, training=self.training) 51 | x = self.convs[-1](x, edge_index, edge_weight) 52 | x = self.softmax_fn(x) # applied based on parameter. Default: not applied 53 | return x 54 | 55 | # faster inference for reddit dataset 56 | # assumes there are two gcn layers. It won't work if there are more than two layers 57 | @torch.no_grad() 58 | def inference(self, x_all, subgraph_loader, device): 59 | pbar = tqdm(total=len(subgraph_loader.dataset) * 2) 60 | pbar.set_description('Evaluating') 61 | 62 | # Compute representations of nodes layer by layer, using *all* 63 | # available edges. This leads to faster computation in contrast to 64 | # immediately computing the final representations of each batch: 65 | for i, conv in enumerate(self.convs): 66 | xs = [] 67 | for batch in subgraph_loader: 68 | x = x_all[batch.n_id.to(x_all.device)].to(device) 69 | x = conv(x, batch.edge_index.to(device)) 70 | if i == 0: # first layer 71 | x = x.relu_() 72 | xs.append(x[:batch.batch_size].cpu()) 73 | pbar.update(batch.batch_size) 74 | x_all = torch.cat(xs, dim=0) 75 | pbar.close() 76 | return x_all -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HipGraph/GNNShap/f9672297394493ebe1ea9cf60bd14530e06d4916/models/__init__.py -------------------------------------------------------------------------------- /pretrained/CiteSeer_pretrained.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HipGraph/GNNShap/f9672297394493ebe1ea9cf60bd14530e06d4916/pretrained/CiteSeer_pretrained.pt -------------------------------------------------------------------------------- /pretrained/Coauthor-CS_pretrained.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HipGraph/GNNShap/f9672297394493ebe1ea9cf60bd14530e06d4916/pretrained/Coauthor-CS_pretrained.pt -------------------------------------------------------------------------------- /pretrained/Coauthor-Physics_pretrained.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HipGraph/GNNShap/f9672297394493ebe1ea9cf60bd14530e06d4916/pretrained/Coauthor-Physics_pretrained.pt -------------------------------------------------------------------------------- /pretrained/Cora_GAT_pretrained.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HipGraph/GNNShap/f9672297394493ebe1ea9cf60bd14530e06d4916/pretrained/Cora_GAT_pretrained.pt -------------------------------------------------------------------------------- /pretrained/Cora_pretrained.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HipGraph/GNNShap/f9672297394493ebe1ea9cf60bd14530e06d4916/pretrained/Cora_pretrained.pt -------------------------------------------------------------------------------- /pretrained/Facebook_pretrained.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HipGraph/GNNShap/f9672297394493ebe1ea9cf60bd14530e06d4916/pretrained/Facebook_pretrained.pt -------------------------------------------------------------------------------- /pretrained/PubMed_pretrained.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HipGraph/GNNShap/f9672297394493ebe1ea9cf60bd14530e06d4916/pretrained/PubMed_pretrained.pt -------------------------------------------------------------------------------- /pretrained/Reddit_explain_data.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HipGraph/GNNShap/f9672297394493ebe1ea9cf60bd14530e06d4916/pretrained/Reddit_explain_data.pt -------------------------------------------------------------------------------- /pretrained/Reddit_pretrained.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HipGraph/GNNShap/f9672297394493ebe1ea9cf60bd14530e06d4916/pretrained/Reddit_pretrained.pt -------------------------------------------------------------------------------- /pretrained/ogbn-products_explain_data.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HipGraph/GNNShap/f9672297394493ebe1ea9cf60bd14530e06d4916/pretrained/ogbn-products_explain_data.pt -------------------------------------------------------------------------------- /pretrained/ogbn-products_pretrained.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HipGraph/GNNShap/f9672297394493ebe1ea9cf60bd14530e06d4916/pretrained/ogbn-products_pretrained.pt -------------------------------------------------------------------------------- /pretrained/split_Coauthor-CS.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HipGraph/GNNShap/f9672297394493ebe1ea9cf60bd14530e06d4916/pretrained/split_Coauthor-CS.pt -------------------------------------------------------------------------------- /pretrained/split_Coauthor-Physics.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HipGraph/GNNShap/f9672297394493ebe1ea9cf60bd14530e06d4916/pretrained/split_Coauthor-Physics.pt -------------------------------------------------------------------------------- /pretrained/split_Facebook.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HipGraph/GNNShap/f9672297394493ebe1ea9cf60bd14530e06d4916/pretrained/split_Facebook.pt -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | captum>=0.6.0 2 | colorlog>=6.7.0 3 | matplotlib>=3.5.2 4 | networkx>=2.8.4 5 | numba>=0.55.1 6 | numpy>=1.21.5 7 | ogb>=1.3.6 8 | pandas>=1.4.4 9 | pgmpy>=0.1.21 10 | scikit_learn>=1.4.0 11 | scipy>=1.9.1 12 | shap>=0.41.0 13 | torch>=2.0.1 14 | torch_geometric>=2.3.1 15 | torch_scatter>=2.1.1 16 | tqdm>=4.64.1 17 | ninja>=1.10.2 18 | -------------------------------------------------------------------------------- /results/.placeholder: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HipGraph/GNNShap/f9672297394493ebe1ea9cf60bd14530e06d4916/results/.placeholder -------------------------------------------------------------------------------- /run_baseline_experiments.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # since baseline scripts are not in the root directory, they can't find modules. 4 | # so, we need to add the root directory to the PYTHONPATH 5 | export PYTHONPATH="$PYTHONPATH:$(pwd)" 6 | 7 | datasets=("Cora" "CiteSeer" "PubMed" "Coauthor-CS" "Coauthor-Physics" "Facebook") 8 | 9 | nrepeat=5 # number of repeats for each experiment 10 | 11 | for d in "${datasets[@]}"; do 12 | python baselines/run_gnnexplainer.py --dataset $d --repeat $nrepeat 13 | echo "${d} - GNNExplainer done!!!!" 14 | python baselines/run_pgexplainer.py --dataset $d --repeat $nrepeat 15 | echo "${d} - PGExplainer done!!!!" 16 | python baselines/run_pgmexplainer.py --dataset $d --repeat $nrepeat 17 | echo "${d} - PGMExplainer done!!!!" 18 | python baselines/run_graphsvx.py --dataset $d --repeat $nrepeat 19 | echo "${d} - GraphSVX done!!!!" 20 | python baselines/run_sa.py --dataset $d --repeat $nrepeat 21 | echo "${d} - Saliency done!!!!" 22 | 23 | # OrphicX gives OOM error for some datasets 24 | # python baselines/run_orphicx.py --dataset $d --repeat $nrepeat --epoch 50 25 | # echo "${d} - OrphicX done!!!!" 26 | done 27 | -------------------------------------------------------------------------------- /run_gnnshap.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import pickle 3 | import time 4 | 5 | import torch 6 | from tqdm.auto import tqdm 7 | 8 | from dataset.utils import get_model_data_config 9 | from gnnshap.explainer import GNNShapExplainer 10 | 11 | if __name__ == '__main__': 12 | parser = argparse.ArgumentParser() 13 | 14 | parser.add_argument('--dataset', type=str, default='Cora') 15 | parser.add_argument('--result_path', type=str, default=None, 16 | help=('Path to save the results. It will be saved in the config results ' 17 | 'path if not provided.')) 18 | parser.add_argument('--num_samples', type=int, default=10000, 19 | help='Number of samples to use for GNNShap') 20 | parser.add_argument('--repeat', default=1, type=int) 21 | parser.add_argument('--batch_size', type=int, default=1024) 22 | parser.add_argument('--sampler', type=str, default='GNNShapSampler', 23 | help='Sampler to use for sampling coalitions', 24 | choices=['GNNShapSampler', 'SVXSampler', 'SHAPSampler', 25 | 'SHAPUniqueSampler']) 26 | parser.add_argument('--solver', type=str, default='WLSSolver', 27 | help='Solver to use for solving SVX', choices=['WLSSolver', 'WLRSolver']) 28 | 29 | # SVXSampler maximum size of coalitions to sample from 30 | parser.add_argument('--size_lim', type=int, default=3) 31 | 32 | args = parser.parse_args() 33 | 34 | dataset_name = args.dataset 35 | num_samples = args.num_samples 36 | batch_size = args.batch_size 37 | sampler_name = args.sampler 38 | solver_name = args.solver 39 | 40 | 41 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 42 | model, data, config = get_model_data_config(dataset_name, load_pretrained=True, device=device) 43 | 44 | test_nodes = config['test_nodes'] 45 | 46 | result_path = args.result_path if args.result_path is not None else config["results_path"] 47 | 48 | 49 | 50 | 51 | if sampler_name == "SVXSampler": 52 | extra_param_suffixes = f"_{args.size_lim}" 53 | else: 54 | extra_param_suffixes = "" 55 | 56 | #explain_node_idx = 0 57 | for r in range(args.repeat): 58 | results = [] 59 | 60 | shap = GNNShapExplainer(model, data, nhops=config['num_hops'], verbose=0, device=device, 61 | progress_hide=True) 62 | start_time = time.time() 63 | 64 | failed_indices = [] 65 | for ind in tqdm(test_nodes, desc=f"GNNShap explanations - run{r+1}"): 66 | try: 67 | explanation = shap.explain(ind, nsamples=num_samples, 68 | sampler_name=sampler_name, batch_size=batch_size, 69 | solver_name=solver_name, size_lim=args.size_lim) 70 | results.append(explanation.result2dict()) 71 | except RuntimeError as e: 72 | failed_indices.append(ind) 73 | if 'out of memory' in str(e): 74 | print(f"Node {ind} has failed: out of memory") 75 | else: 76 | print(f"Node {ind} has failed: {e}") 77 | except Exception as e: 78 | print(f"Node {ind} has failed. General error: {e}") 79 | failed_indices.append(ind) 80 | 81 | rfile = (f'{result_path}/{dataset_name}_GNNShap_{sampler_name}_{solver_name}_' 82 | f'{num_samples}_{batch_size}{extra_param_suffixes}_run{r+1}.pkl') 83 | with open(rfile, 'wb') as pkl_file: 84 | pickle.dump([results, 0], pkl_file) 85 | 86 | if len(failed_indices) > 0: 87 | print(f"Failed indices: {failed_indices}") 88 | -------------------------------------------------------------------------------- /run_gnnshap_experiments.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # num_samples=(10000 25000 50000) # number of samples for GNNShap 4 | num_samples=(10000) 5 | 6 | datasets=("Cora" "CiteSeer" "PubMed" "Coauthor-CS" "Coauthor-Physics" "Facebook") 7 | solvers=("WLSSolver") # WLRSolver is suggested for Reddit and ogbn-products 8 | 9 | # samplers=("GNNShapSampler" "SVXSampler") 10 | samplers=("GNNShapSampler") 11 | 12 | for n in "${num_samples[@]}"; do 13 | for d in "${datasets[@]}"; do 14 | 15 | # batch size is set to 1024 for all datasets except Coauthor-CS and Coauthor-Physics. 16 | if [[ "$d" == "Coauthor"* ]]; then 17 | if [[ "$d" == "Coauthor-CS" ]]; then 18 | batch_size=512 19 | else 20 | batch_size=128 # coauthor-physics 21 | fi 22 | else 23 | batch_size=1024 24 | fi 25 | 26 | for solv in "${solvers[@]}"; do 27 | for samp in "${samplers[@]}"; do 28 | python run_gnnshap.py --dataset $d --num_samples $n \ 29 | --batch_size $batch_size --repeat 5 --sampler $samp --solver $solv 30 | echo "${solv} ${samp} ${d} - ${n} done!!!!" 31 | done 32 | done 33 | done 34 | done 35 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # Train models and save the model. The model will be saved in the pretrained folder. 2 | # Can't train large models with this script. Use train_large.py instead. 3 | # Large models are trained with NeighborLoader, which is not supported in this script. 4 | 5 | # The trained model is used for benchmarking explanation methods. 6 | 7 | import argparse 8 | import os 9 | import sys 10 | 11 | import torch 12 | 13 | from dataset.configs import get_config 14 | from dataset.utils import get_model_data_config 15 | 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument('--dataset', default='Cora', type=str) 18 | args = parser.parse_args() 19 | 20 | 21 | config = get_config(args.dataset) 22 | pretrained_file = f"{config['root_path']}/pretrained/{args.dataset}_pretrained.pt" 23 | 24 | if os.path.exists(pretrained_file): 25 | user_input = input('A pretrained file exist. Do you want to retrain? (y/n):') 26 | if user_input.lower() != 'y': 27 | print("Skipping training!") 28 | sys.exit(0) 29 | 30 | 31 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 32 | print(f"Device: {device}") 33 | 34 | model, data, config = get_model_data_config(args.dataset, load_pretrained=False, device=device) 35 | 36 | 37 | criterion = torch.nn.CrossEntropyLoss() # Define loss criterion. 38 | optimizer = torch.optim.Adam(model.parameters(), lr=config['lr'], 39 | weight_decay=config['weight_decay']) # Define optimizer. 40 | 41 | if 'grad_clip' in config: 42 | grad_clip = config['grad_clip'] 43 | else: 44 | grad_clip = False 45 | 46 | def train(): 47 | model.train() 48 | optimizer.zero_grad() # Clear gradients. 49 | out = model(data.x, data.edge_index) # Perform a single forward pass. 50 | loss = criterion(out[data.train_mask],data.y[data.train_mask]) 51 | if grad_clip: 52 | torch.nn.utils.clip_grad_norm_(model.parameters(), 2.0) 53 | loss.backward() # Derive gradients. 54 | optimizer.step() # Update parameters based on gradients. 55 | return loss 56 | 57 | 58 | def acc_test(): 59 | model.eval() 60 | out = model(data.x, data.edge_index) 61 | pred = out.argmax(dim=1) 62 | test_correct = pred[data.test_mask] == data.y[data.test_mask] 63 | test_acc = int(test_correct.sum()) / int(data.test_mask.sum()) 64 | train_correct = pred[data.train_mask] == data.y[data.train_mask] 65 | train_acc = int(train_correct.sum()) / int(data.train_mask.sum()) 66 | 67 | val_correct = pred[data.val_mask] == data.y[data.val_mask] 68 | val_acc = int(val_correct.sum()) / int(data.val_mask.sum()) 69 | 70 | 71 | return train_acc, val_acc, test_acc 72 | 73 | 74 | if __name__ == '__main__': 75 | best_val_acc = 0 76 | best_test_acc = 0 77 | for epoch in range(0, config['epoch']): 78 | loss = train() 79 | train_acc, val_acc, test_acc = acc_test() 80 | if val_acc > best_val_acc: 81 | best_val_acc = val_acc 82 | best_test_acc = test_acc 83 | torch.save(model.state_dict(), pretrained_file) 84 | 85 | print(f'Epoch: {epoch+1:03d}, Loss: {loss:.4f}, Train Acc: {train_acc:.4f}, ' 86 | f'Val Acc: {val_acc:.4f}, Test Acc: {test_acc:.4f}') 87 | print(f"Best Val Accuracy: {best_val_acc:.4f} ,Best Test Accuracy: {best_test_acc:.4f}") -------------------------------------------------------------------------------- /train_large.py: -------------------------------------------------------------------------------- 1 | # Description: This file is used to train GNN models on large datasets. The trained model is saved 2 | # in the pretrained folder. The model is trained using NeighborLoader, which is not supported 3 | # in train.py. 4 | 5 | # The trained model is used for benchmarking explanation methods. 6 | 7 | 8 | import argparse 9 | import copy 10 | import os 11 | import sys 12 | 13 | import torch 14 | import torch.nn.functional as F 15 | from torch_geometric.loader import NeighborLoader 16 | from tqdm import tqdm 17 | 18 | from dataset.utils import get_model_data_config 19 | from gnnshap.utils import pruned_comp_graph 20 | 21 | parser = argparse.ArgumentParser() 22 | parser.add_argument('--dataset', default='Reddit', type=str) 23 | args = parser.parse_args() 24 | dataset_name = args.dataset 25 | 26 | def train(epoch): 27 | model.train() 28 | 29 | pbar = tqdm(total=int(len(train_loader.dataset))) 30 | pbar.set_description(f'Epoch {epoch:02d}') 31 | 32 | total_loss = total_correct = total_examples = 0 33 | for batch in train_loader: 34 | optimizer.zero_grad() 35 | y = batch.y[:batch.batch_size] 36 | y_hat = model(batch.x, batch.edge_index.to(device))[:batch.batch_size] 37 | loss = F.cross_entropy(y_hat, y) 38 | loss.backward() 39 | optimizer.step() 40 | 41 | total_loss += float(loss) * batch.batch_size 42 | total_correct += int((y_hat.argmax(dim=-1) == y).sum()) 43 | total_examples += batch.batch_size 44 | pbar.update(batch.batch_size) 45 | pbar.close() 46 | 47 | return total_loss / total_examples, total_correct / total_examples 48 | 49 | @torch.no_grad() 50 | def test(): 51 | model.eval() 52 | y_hat = model.inference(data.x, subgraph_loader, device=device).argmax(dim=-1) 53 | y = data.y.to(y_hat.device) 54 | 55 | accs = [] 56 | for mask in [data.train_mask, data.val_mask, data.test_mask]: 57 | accs.append(int((y_hat[mask] == y[mask]).sum()) / int(mask.sum())) 58 | return accs 59 | 60 | 61 | if __name__ == '__main__': 62 | 63 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 64 | 65 | # don't load data to GPU, as it will be loaded to GPU during sampling. 66 | model, data, config = get_model_data_config(dataset_name, load_pretrained=False, device='cpu', 67 | full_data=True) 68 | 69 | model = model.to(device) 70 | 71 | 72 | pretrained_file = f"{config['root_path']}/pretrained/{dataset_name}_pretrained.pt" 73 | 74 | if os.path.exists(pretrained_file): 75 | user_input = input('A pretrained file exist. Do you want to retrain? (y/n):') 76 | if user_input.lower() != 'y': 77 | print("Skipping training!") 78 | sys.exit(0) 79 | 80 | 81 | # Already send node features/labels to GPU for faster access during sampling: 82 | data = data.to(device, 'x', 'y') 83 | neig_args = config['nei_sampler_args'] 84 | kwargs = {'batch_size': neig_args['batch_size'], 'num_workers': 6, 'persistent_workers': True} 85 | train_loader = NeighborLoader(data, input_nodes=data.train_mask, 86 | num_neighbors=neig_args['sizes'], shuffle=True, **kwargs) 87 | 88 | subgraph_loader = NeighborLoader(copy.copy(data), input_nodes=None, 89 | num_neighbors=[-1], shuffle=False, **kwargs) 90 | 91 | # No need to maintain these features during evaluation: 92 | del subgraph_loader.data.x, subgraph_loader.data.y 93 | # Add global node index information. 94 | subgraph_loader.data.num_nodes = data.num_nodes 95 | subgraph_loader.data.n_id = torch.arange(data.num_nodes) 96 | 97 | optimizer = torch.optim.Adam(model.parameters(), lr=0.01) 98 | 99 | for epoch in range(1, 11): 100 | loss, acc = train(epoch) 101 | print(f'Epoch {epoch:02d}, Loss: {loss:.4f}, Approx. Train: {acc:.4f}') 102 | train_acc, val_acc, test_acc = test() 103 | print(f'Epoch: {epoch:02d}, Train: {train_acc:.4f}, Val: {val_acc:.4f}, ' 104 | f'Test: {test_acc:.4f}') 105 | if test_acc > best_test: 106 | best_test = test_acc 107 | torch.save(model.state_dict(), pretrained_file) 108 | torch.save(model.state_dict(), pretrained_file) 109 | print(f"Model saved to {pretrained_file}.") 110 | 111 | 112 | 113 | # Sample explain data and save. This is used for benchmarking explanation methods. 114 | # This makes sure that the explain data is the same for all methods. 115 | 116 | num_test_nodes = 100 117 | explain_loader = NeighborLoader(data, input_nodes=data.test_mask.nonzero()[:num_test_nodes,0], 118 | num_neighbors=[200, 50], batch_size=num_test_nodes, 119 | num_workers=8, persistent_workers=True) 120 | 121 | max_size = 0 122 | max_ind = 0 123 | avg_size = 0 124 | batch = next(iter(explain_loader)) 125 | for i in range(batch.batch_size): 126 | m = pruned_comp_graph(i, 2, batch.edge_index)[1].size(1) 127 | if m > max_size: 128 | max_size = m 129 | max_ind = i 130 | avg_size += m 131 | 132 | del batch.x, batch.y # reduce saved file size in disk. Can be reloaded from the original data. 133 | torch.save(batch, f"{config['root_path']}/pretrained/{dataset_name}_explain_data.pt") 134 | print(f"Explain data saved to {config['root_path']}/pretrained.") 135 | print("Maximum size: ", max_size, "max index: ", max_ind, "avg size: ", 136 | avg_size / num_test_nodes) --------------------------------------------------------------------------------