├── .gitignore ├── README.md ├── another_small_example.ipynb ├── baseline_execution.py ├── data └── syn2.npz ├── evaluate_zorro.py ├── evaluation.py ├── example.ipynb ├── execution.py ├── explainer.py ├── generate_gnnexplainer_dataset.py ├── gnn_explainer.py ├── grad_explainer.py ├── models.py ├── print_progress_overview.py ├── requirements.txt └── results ├── CiteSeer_selected_nodes.npy ├── Cora_selected_nodes.npy ├── PubMed_selected_nodes.npy ├── citeseer_APPNP2Net └── appnp_2_layers.pt ├── citeseer_GAT └── gat_2_layers.pt ├── citeseer_GCN └── gcn_2_layers.pt ├── citeseer_GINConv └── gin_2_layers.pt ├── cora_APPNP2Net └── appnp_2_layers.pt ├── cora_GAT └── gat_2_layers.pt ├── cora_GCN └── gcn_2_layers.pt ├── cora_GINConv └── gin_2_layers.pt ├── pubmed_APPNP2Net └── appnp_2_layers.pt ├── pubmed_GAT └── gat_2_layers.pt ├── pubmed_GCN └── gcn_2_layers.pt ├── pubmed_GINConv └── gin_2_layers.pt ├── syn2_1_GCN_syn2 └── gcn_3_layers.pt ├── syn2_2_GCN_syn2 └── gcn_3_layers.pt ├── syn2_3_GCN_syn2 └── gcn_3_layers.pt ├── syn2_4_GCN_syn2 ├── gcn_3_layers.pt ├── gcn_3_layers_epoch_0.pt ├── gcn_3_layers_epoch_1400.pt ├── gcn_3_layers_epoch_200.pt ├── gcn_3_layers_epoch_400.pt └── gcn_3_layers_epoch_600.pt └── syn2_GCN_syn2 └── gcn_3_layers.pt /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # PyCharm 132 | .idea/ 133 | 134 | # Temporary files 135 | tmp/ 136 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Valid, Sparse, and Stable Explanations in Graph Neural Networks 2 | 3 | ##### by Thorben Funke, Megha Khosla, Mandeep Rathee, and Avishek Anand 4 | 5 | 6 | ## 1. Requirements 7 | 8 | See `requirements.txt` for the main python packages used to run this repository with `python 3.7`. 9 | 10 | ## 2. Data 11 | 12 | The real-world datasets will be downloaded via `pytorch-geometric`. 13 | For the synthetic dataset, we included the file `generate_gnnexplainer_dataset.py` and in `data/syn2.npz` our resulting graph. 14 | 15 | ## 3. Execute Zorro 16 | 17 | You can simply run 18 | ``` 19 | python3 execution.py 20 | ``` 21 | to get explanations for the default setting: 10 nodes for Cora and GCN with tau=0.85. 22 | 23 | We included the save points of the model and the randomly selected nodes in the `results` directory. 24 | 25 | ## 4. Evaluation 26 | 27 | Running `evaluate_zorro.py` will create csv files with the evaluate explanations. 28 | -------------------------------------------------------------------------------- /another_small_example.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": { 7 | "id": "T8KuPgslQ-CV" 8 | }, 9 | "outputs": [], 10 | "source": [ 11 | "import torch" 12 | ] 13 | }, 14 | { 15 | "cell_type": "markdown", 16 | "metadata": {}, 17 | "source": [ 18 | "### Setup in google colab\n", 19 | "\n", 20 | "Uncomment the code in the following cells to use this notebook in google colab" 21 | ] 22 | }, 23 | { 24 | "cell_type": "code", 25 | "execution_count": null, 26 | "metadata": { 27 | "colab": { 28 | "base_uri": "https://localhost:8080/" 29 | }, 30 | "id": "nv6x1s15RB-D", 31 | "outputId": "6bff7b91-c7a1-46d2-daa1-532bc3c8a315" 32 | }, 33 | "outputs": [], 34 | "source": [ 35 | "# def format_pytorch_version(version):\n", 36 | "# return version.split('+')[0]\n", 37 | "#\n", 38 | "# TORCH_version = torch.__version__\n", 39 | "# TORCH = format_pytorch_version(TORCH_version)\n", 40 | "#\n", 41 | "# def format_cuda_version(version):\n", 42 | "# return 'cu' + version.replace('.', '')\n", 43 | "#\n", 44 | "# CUDA_version = torch.version.cuda\n", 45 | "# CUDA = \"cpu\"\n", 46 | "#\n", 47 | "# !pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-{TORCH}+{CUDA}.html\n", 48 | "# !pip install torch-sparse -f https://pytorch-geometric.com/whl/torch-{TORCH}+{CUDA}.html\n", 49 | "# !pip install torch-cluster -f https://pytorch-geometric.com/whl/torch-{TORCH}+{CUDA}.html\n", 50 | "# !pip install torch-spline-conv -f https://pytorch-geometric.com/whl/torch-{TORCH}+{CUDA}.html\n", 51 | "# !pip install torch-geometric\n", 52 | "#" 53 | ] 54 | }, 55 | { 56 | "cell_type": "code", 57 | "execution_count": null, 58 | "metadata": { 59 | "colab": { 60 | "base_uri": "https://localhost:8080/" 61 | }, 62 | "id": "sk-fKWz6RO9M", 63 | "outputId": "650ab4d3-45e2-4d88-91e9-7095a32d5abb" 64 | }, 65 | "outputs": [], 66 | "source": [ 67 | "# !git clone https://github.com/funket/zorro.git\n", 68 | "#" 69 | ] 70 | }, 71 | { 72 | "cell_type": "code", 73 | "execution_count": null, 74 | "metadata": { 75 | "colab": { 76 | "base_uri": "https://localhost:8080/" 77 | }, 78 | "id": "-oXUDeOyRgt4", 79 | "outputId": "bc2c1be3-89a0-4bf9-8e52-444d113d12a8" 80 | }, 81 | "outputs": [], 82 | "source": [ 83 | "# !pwd\n", 84 | "#" 85 | ] 86 | }, 87 | { 88 | "cell_type": "code", 89 | "execution_count": null, 90 | "metadata": { 91 | "colab": { 92 | "base_uri": "https://localhost:8080/" 93 | }, 94 | "id": "_Eo-J8SYRicX", 95 | "outputId": "6004253a-2dcc-4763-8634-e0d7ae0db0d8" 96 | }, 97 | "outputs": [], 98 | "source": [ 99 | "# %cd zorro/\n", 100 | "#" 101 | ] 102 | }, 103 | { 104 | "cell_type": "code", 105 | "execution_count": null, 106 | "metadata": { 107 | "colab": { 108 | "base_uri": "https://localhost:8080/" 109 | }, 110 | "id": "WT0kEgBuRnvt", 111 | "outputId": "214699ba-151e-4bc7-a100-56552b16fc31" 112 | }, 113 | "outputs": [], 114 | "source": [ 115 | "# !pwd" 116 | ] 117 | }, 118 | { 119 | "cell_type": "code", 120 | "execution_count": null, 121 | "metadata": { 122 | "id": "Y_qfpRevQ8wm" 123 | }, 124 | "outputs": [], 125 | "source": [ 126 | "from explainer import *\n", 127 | "from models import *\n", 128 | "import torch\n", 129 | "import matplotlib.pylab as plt" 130 | ] 131 | }, 132 | { 133 | "cell_type": "markdown", 134 | "metadata": { 135 | "id": "3Ik8h53wQ8wr" 136 | }, 137 | "source": [ 138 | "# Data loading and GNN training" 139 | ] 140 | }, 141 | { 142 | "cell_type": "code", 143 | "execution_count": null, 144 | "metadata": { 145 | "colab": { 146 | "base_uri": "https://localhost:8080/" 147 | }, 148 | "id": "7m1OlRIAQ8wt", 149 | "outputId": "7c0201db-4530-42c4-9f4a-1787480b2705", 150 | "pycharm": { 151 | "name": "#%%\n" 152 | } 153 | }, 154 | "outputs": [], 155 | "source": [ 156 | "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", 157 | "dataset, data, results_path = load_dataset(\"Cora\")\n", 158 | "model = GCNNet(dataset)\n", 159 | "model.to(device)\n", 160 | "data = data.to(device)" 161 | ] 162 | }, 163 | { 164 | "cell_type": "code", 165 | "execution_count": null, 166 | "metadata": { 167 | "colab": { 168 | "base_uri": "https://localhost:8080/" 169 | }, 170 | "id": "e5bNvZzmQ8wu", 171 | "outputId": "76c6e660-5dfc-45d3-f62f-14a62495823f", 172 | "pycharm": { 173 | "name": "#%%\n" 174 | } 175 | }, 176 | "outputs": [], 177 | "source": [ 178 | "train_model(model, data)" 179 | ] 180 | }, 181 | { 182 | "cell_type": "markdown", 183 | "metadata": { 184 | "id": "8Ifr8w0JQ8wv" 185 | }, 186 | "source": [ 187 | "# Gradient based explanation" 188 | ] 189 | }, 190 | { 191 | "cell_type": "code", 192 | "execution_count": null, 193 | "metadata": { 194 | "id": "4FaYbl62Q8ww", 195 | "pycharm": { 196 | "name": "#%%\n" 197 | } 198 | }, 199 | "outputs": [], 200 | "source": [ 201 | "from gnn_explainer import GNNExplainer\n", 202 | "\n", 203 | "# GNNExplainer class needed for retrieval of computational graph\n", 204 | "gnn_explainer = GNNExplainer(model, log=False)\n", 205 | "\n", 206 | "explain_node = 0" 207 | ] 208 | }, 209 | { 210 | "cell_type": "code", 211 | "execution_count": null, 212 | "metadata": { 213 | "id": "QVtGOuIIQ8wx", 214 | "pycharm": { 215 | "name": "#%%\n" 216 | } 217 | }, 218 | "outputs": [], 219 | "source": [ 220 | "def execute_model_with_gradient(model, node, x, edge_index):\n", 221 | " \"\"\"Helper function, which mainly does a forward pass of the GNN\"\"\"\n", 222 | " ypred = model(x, edge_index)\n", 223 | "\n", 224 | " predicted_labels = ypred.argmax(dim=-1)\n", 225 | " predicted_label = predicted_labels[node]\n", 226 | " logit = torch.nn.functional.softmax((ypred[node, :]).squeeze(), dim=0)\n", 227 | "\n", 228 | " logit = logit[predicted_label]\n", 229 | " loss = -torch.log(logit)\n", 230 | " loss.backward()" 231 | ] 232 | }, 233 | { 234 | "cell_type": "code", 235 | "execution_count": null, 236 | "metadata": { 237 | "id": "CppOdP6NQ8wx", 238 | "pycharm": { 239 | "name": "#%%\n" 240 | } 241 | }, 242 | "outputs": [], 243 | "source": [ 244 | "def get_grad_node_explanation(model, node, data):\n", 245 | " \"\"\"Calculates the gradient feature and node explanation\"\"\"\n", 246 | "\n", 247 | " # retrieve computational graph\n", 248 | " computation_graph_feature_matrix, computation_graph_edge_index, mapping, hard_edge_mask, kwargs = \\\n", 249 | " gnn_explainer.__subgraph__(node, data.x, data.edge_index)\n", 250 | " # from now only work on the computational graph\n", 251 | " x = computation_graph_feature_matrix\n", 252 | " edge_index = computation_graph_edge_index\n", 253 | "\n", 254 | " # create a mask of ones which will be differentiated\n", 255 | " num_nodes, num_features = x.size()\n", 256 | " node_grad = torch.nn.Parameter(torch.ones(num_nodes, device=x.device))\n", 257 | " feature_grad = torch.nn.Parameter(torch.ones(num_features, device=x.device))\n", 258 | " node_grad.requires_grad = True\n", 259 | " feature_grad.requires_grad = True\n", 260 | " mask = node_grad.unsqueeze(0).T.matmul(feature_grad.unsqueeze(0))\n", 261 | "\n", 262 | " model.zero_grad()\n", 263 | " execute_model_with_gradient(model, mapping, mask*x, edge_index)\n", 264 | "\n", 265 | " node_mask = torch.abs(node_grad.grad).cpu().detach().numpy()\n", 266 | " feature_mask = torch.abs(feature_grad.grad).cpu().detach().numpy()\n", 267 | "\n", 268 | " return feature_mask, node_mask" 269 | ] 270 | }, 271 | { 272 | "cell_type": "code", 273 | "execution_count": null, 274 | "metadata": { 275 | "colab": { 276 | "base_uri": "https://localhost:8080/" 277 | }, 278 | "id": "jmjUCbApRwpe", 279 | "outputId": "0742718e-0231-4e3d-8e8d-44676f5087b2" 280 | }, 281 | "outputs": [], 282 | "source": [ 283 | "grad_explanation = get_grad_node_explanation(model, explain_node, data)" 284 | ] 285 | }, 286 | { 287 | "cell_type": "code", 288 | "execution_count": null, 289 | "metadata": { 290 | "pycharm": { 291 | "name": "#%%\n" 292 | } 293 | }, 294 | "outputs": [], 295 | "source": [ 296 | "plt.title(\"Distribution of Feature mask\")\n", 297 | "plt.hist(grad_explanation[0])\n", 298 | "plt.yscale(\"log\")" 299 | ] 300 | }, 301 | { 302 | "cell_type": "markdown", 303 | "metadata": { 304 | "id": "LtBkV46GQ8wz" 305 | }, 306 | "source": [ 307 | "##### Possible task: implementation of GradInput" 308 | ] 309 | }, 310 | { 311 | "cell_type": "code", 312 | "execution_count": null, 313 | "metadata": { 314 | "id": "55Htke8YQ8wz", 315 | "pycharm": { 316 | "name": "#%%\n" 317 | } 318 | }, 319 | "outputs": [], 320 | "source": [] 321 | }, 322 | { 323 | "cell_type": "markdown", 324 | "metadata": { 325 | "id": "2qM-yixWQ8wz" 326 | }, 327 | "source": [ 328 | "# GNNExplainer" 329 | ] 330 | }, 331 | { 332 | "cell_type": "code", 333 | "execution_count": null, 334 | "metadata": { 335 | "id": "-y8VWKjhQ8w0", 336 | "pycharm": { 337 | "name": "#%%\n" 338 | } 339 | }, 340 | "outputs": [], 341 | "source": [ 342 | "def get_gnn_explainer(node, data):\n", 343 | " feature_mask, edge_mask = gnn_explainer.explain_node(node, data.x, data.edge_index)\n", 344 | " return feature_mask, edge_mask" 345 | ] 346 | }, 347 | { 348 | "cell_type": "code", 349 | "execution_count": null, 350 | "metadata": { 351 | "colab": { 352 | "base_uri": "https://localhost:8080/" 353 | }, 354 | "id": "Z5BRP2rGR1Wx", 355 | "outputId": "f5c38cec-b10f-453d-9f0e-11896871a683" 356 | }, 357 | "outputs": [], 358 | "source": [ 359 | "gnn_explanation = get_gnn_explainer(explain_node, data)" 360 | ] 361 | }, 362 | { 363 | "cell_type": "code", 364 | "execution_count": null, 365 | "metadata": { 366 | "pycharm": { 367 | "name": "#%%\n" 368 | } 369 | }, 370 | "outputs": [], 371 | "source": [ 372 | "plt.title(\"Distribution of Feature mask\")\n", 373 | "plt.hist(gnn_explanation[0])\n", 374 | "plt.yscale(\"log\")" 375 | ] 376 | }, 377 | { 378 | "cell_type": "markdown", 379 | "metadata": { 380 | "id": "Az3xEGHASQrz" 381 | }, 382 | "source": [ 383 | "# Zorro" 384 | ] 385 | }, 386 | { 387 | "cell_type": "code", 388 | "execution_count": null, 389 | "metadata": { 390 | "id": "mMbIchxVSSva" 391 | }, 392 | "outputs": [], 393 | "source": [ 394 | "from explainer import Zorro\n", 395 | "\n", 396 | "zorro = Zorro(model, device)\n", 397 | "def get_zorro(node):\n", 398 | " # Same as the 0.98 in the paper\n", 399 | " tau = .03\n", 400 | " # only retrieve 1 explanation\n", 401 | " recursion_depth = 1\n", 402 | "\n", 403 | " explanation = zorro.explain_node(node, data.x, data.edge_index, tau=tau, recursion_depth=recursion_depth,)\n", 404 | "\n", 405 | " selected_nodes, selected_features, executed_selections = explanation[0]\n", 406 | "\n", 407 | " return selected_features[0], selected_nodes[0]" 408 | ] 409 | }, 410 | { 411 | "cell_type": "code", 412 | "execution_count": null, 413 | "metadata": { 414 | "colab": { 415 | "base_uri": "https://localhost:8080/" 416 | }, 417 | "id": "RSXLOLDATCDo", 418 | "outputId": "2f592425-203b-4338-abd7-c9250b972670" 419 | }, 420 | "outputs": [], 421 | "source": [ 422 | "zorro_explanation = get_zorro(explain_node)" 423 | ] 424 | }, 425 | { 426 | "cell_type": "code", 427 | "execution_count": null, 428 | "metadata": { 429 | "pycharm": { 430 | "name": "#%%\n" 431 | } 432 | }, 433 | "outputs": [], 434 | "source": [ 435 | "plt.title(\"Distribution of Feature mask\")\n", 436 | "plt.hist(zorro_explanation[0])\n", 437 | "plt.yscale(\"log\")" 438 | ] 439 | }, 440 | { 441 | "cell_type": "markdown", 442 | "metadata": {}, 443 | "source": [ 444 | "# SoftZorro" 445 | ] 446 | }, 447 | { 448 | "cell_type": "code", 449 | "execution_count": null, 450 | "metadata": { 451 | "pycharm": { 452 | "name": "#%%\n" 453 | } 454 | }, 455 | "outputs": [], 456 | "source": [ 457 | "from explainer import SoftZorro\n", 458 | "\n", 459 | "soft_zorro = SoftZorro(model, device)\n", 460 | "\n", 461 | "def get_soft_zorro(node):\n", 462 | " node_mask, feature_mask = soft_zorro.explain_node(node, data.x, data.edge_index)\n", 463 | " return feature_mask[0], node_mask[0]" 464 | ] 465 | }, 466 | { 467 | "cell_type": "code", 468 | "execution_count": null, 469 | "metadata": { 470 | "pycharm": { 471 | "name": "#%%\n" 472 | } 473 | }, 474 | "outputs": [], 475 | "source": [ 476 | "soft_zorro_explanation = get_soft_zorro(explain_node)" 477 | ] 478 | }, 479 | { 480 | "cell_type": "code", 481 | "execution_count": null, 482 | "metadata": { 483 | "pycharm": { 484 | "name": "#%%\n" 485 | } 486 | }, 487 | "outputs": [], 488 | "source": [ 489 | "plt.title(\"Distribution of Feature mask\")\n", 490 | "plt.hist(soft_zorro_explanation[0])\n", 491 | "plt.yscale(\"log\")" 492 | ] 493 | } 494 | ], 495 | "metadata": { 496 | "colab": { 497 | "collapsed_sections": [], 498 | "name": "xaiss_hands_on_xai_gnn.ipynb", 499 | "provenance": [] 500 | }, 501 | "kernelspec": { 502 | "display_name": "Python 3 (ipykernel)", 503 | "language": "python", 504 | "name": "python3" 505 | }, 506 | "language_info": { 507 | "codemirror_mode": { 508 | "name": "ipython", 509 | "version": 3 510 | }, 511 | "file_extension": ".py", 512 | "mimetype": "text/x-python", 513 | "name": "python", 514 | "nbconvert_exporter": "python", 515 | "pygments_lexer": "ipython3", 516 | "version": "3.9.10" 517 | }, 518 | "toc": { 519 | "base_numbering": 1, 520 | "nav_menu": {}, 521 | "number_sections": true, 522 | "sideBar": true, 523 | "skip_h1_title": false, 524 | "title_cell": "Table of Contents", 525 | "title_sidebar": "Contents", 526 | "toc_cell": false, 527 | "toc_position": {}, 528 | "toc_section_display": true, 529 | "toc_window_display": false 530 | } 531 | }, 532 | "nbformat": 4, 533 | "nbformat_minor": 1 534 | } -------------------------------------------------------------------------------- /baseline_execution.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from pathlib import Path 4 | import numpy as np 5 | import pandas as pd 6 | import time 7 | 8 | from models import load_model, load_dataset, GCNNet, GATNet, APPNP2Net, GINConvNet, GCN_syn2 9 | from execution import MODEL_SAVE_NAMES 10 | 11 | from gnn_explainer import GNNExplainer 12 | from grad_explainer import grad_edge_explanation, grad_node_explanation, gradinput_node_explanation 13 | 14 | from evaluation import evaluate_explanations, evaluate_synthetic 15 | 16 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 17 | 18 | 19 | def save_gnn_explanation(path_prefix, node, feature_mask, edge_mask=None, node_mask=None): 20 | save_dict = {"node": np.array(node), 21 | "feature_mask": np.array(feature_mask), 22 | } 23 | if edge_mask is not None: 24 | save_dict["edge_mask"] = np.array(edge_mask) 25 | 26 | if node_mask is not None: 27 | save_dict["node_mask"] = np.array(node_mask) 28 | 29 | np.savez_compressed(str(path_prefix) + str(node) + ".npz", **save_dict) 30 | 31 | 32 | def load_gnn_explanation(path_prefix, node): 33 | save = np.load(str(path_prefix) + str(node) + ".npz") 34 | 35 | saved_node = save["node"] 36 | if saved_node != node: 37 | raise ValueError("Other node then specified", saved_node, node) 38 | 39 | feature_mask = torch.Tensor(save["feature_mask"]) 40 | if "edge_mask" in save: 41 | edge_mask = torch.Tensor(save["edge_mask"]) 42 | else: 43 | edge_mask = None 44 | 45 | if "node_mask" in save: 46 | node_mask = torch.Tensor(save["node_mask"]) 47 | else: 48 | node_mask = None 49 | 50 | return feature_mask, edge_mask, node_mask 51 | 52 | 53 | def five_random_closest_neighbors(explainer_instance, node, data_full): 54 | from torch_geometric.utils import k_hop_subgraph 55 | import networkx as nx 56 | import random 57 | 58 | # get neighbors with subset - and without relabeling 59 | subset, edge_index, _, _ = k_hop_subgraph( 60 | node, explainer_instance.__num_hops__(), data_full.edge_index, relabel_nodes=False, 61 | num_nodes=None, flow=explainer_instance.__flow__()) 62 | num_features = data_full.x.size(1) 63 | num_nodes = subset.size() 64 | 65 | f_mask = torch.ones(num_features) 66 | 67 | n_mask = torch.zeros(num_nodes) 68 | 69 | node_explanation = {node} 70 | subset = list(subset.numpy()) 71 | 72 | graph = nx.Graph() 73 | graph.add_nodes_from(subset) 74 | graph.add_edges_from(edge_index.numpy().transpose()) 75 | 76 | neighbors = set(graph.neighbors(node)) 77 | for _ in range(6): 78 | if len(neighbors) > 5 - len(node_explanation): 79 | break 80 | node_explanation = node_explanation.union(neighbors) 81 | 82 | new_neighbors = set() 83 | for neighbor in neighbors: 84 | new_neighbors = new_neighbors.union(graph.neighbors(neighbor)) 85 | neighbors = new_neighbors.difference(node_explanation) 86 | node_explanation = node_explanation.union(random.sample(neighbors, min(len(neighbors), 5 - len(node_explanation)))) 87 | 88 | if len(node_explanation) < 5: 89 | print("Found only " + str(len(node_explanation)) + " neighbors for node " + str(node)) 90 | 91 | for selected_neighbor in node_explanation: 92 | n_mask[subset.index(selected_neighbor)] = 1 93 | 94 | return f_mask, n_mask 95 | 96 | 97 | if __name__ == "__main__": 98 | 99 | # datasets = [ 100 | # # "Cora", 101 | # # "CiteSeer", 102 | # # "PubMed", 103 | # # "AmazonC" 104 | # # "syn2", 105 | # 106 | # ] 107 | # models = [ 108 | # # "GCN", 109 | # # "GAT", 110 | # # "GINConv", 111 | # # # "APPNP", 112 | # # "APPNP2Net" 113 | # ] 114 | # config_name, datasets, models, epochs, explainers = "syn2_faith", \ 115 | # ["syn2_4", ], \ 116 | # ["GCN_syn2", ], \ 117 | # [-1, 0, 200, 400, 600, 1400], \ 118 | # ["five_neighbors", "GNNExplainer", "Grad", "GradInput"] 119 | 120 | # config_name, datasets, models, epochs, explainers = "real", \ 121 | # ["Cora", "CiteSeer", "PubMed", ], \ 122 | # ["GCN", "GAT", "GINConv", "APPNP2Net"], \ 123 | # [-1, ], \ 124 | # ["five_neighbors", "GNNExplainer", "Grad", "GradInput"] 125 | 126 | config_name, datasets, models, epochs, explainers = "retrain", \ 127 | ["Cora",], \ 128 | ["GCN",], \ 129 | [-1, ], \ 130 | ["GNNExplainer", "Grad", "GradInput"] 131 | 132 | raw_time = { 133 | "explainer": [], 134 | "dataset": [], 135 | "model": [], 136 | "epoch": [], 137 | "node": [], 138 | "time": [], 139 | } 140 | 141 | working_directory = Path(".").resolve() 142 | global_results_directory = working_directory.joinpath("results_baselines") 143 | global_model_save_directory = working_directory.joinpath("results") 144 | global_results_directory.mkdir(parents=False, exist_ok=True) 145 | 146 | create_mode = True 147 | 148 | evaluate_retrieved_explanations = False 149 | add_trivial_explanations = False 150 | 151 | for dataset_name in datasets: 152 | for model_name in models: 153 | for epoch in epochs: 154 | for explainer in explainers: 155 | explanations = {} 156 | dataset, data, results_path = load_dataset(dataset_name, working_directory=working_directory) 157 | 158 | results_path += "_" + model_name 159 | results_directory = global_results_directory.joinpath(explainer).joinpath(results_path) 160 | results_directory.mkdir(parents=True, exist_ok=True) 161 | 162 | model_save_directory = global_model_save_directory.joinpath(results_path) 163 | 164 | model_classes = {"GCN": GCNNet, 165 | "GAT": GATNet, 166 | "GINConv": GINConvNet, 167 | "APPNP2Net": APPNP2Net, 168 | "GCN_syn2": GCN_syn2, 169 | } 170 | model_class = model_classes[model_name] 171 | 172 | if epoch == -1: 173 | path_to_saved_model = str(model_save_directory.joinpath(MODEL_SAVE_NAMES[model_name] + ".pt")) 174 | else: 175 | path_to_saved_model = str( 176 | model_save_directory.joinpath( 177 | MODEL_SAVE_NAMES[model_name] + "_epoch_" + str(epoch) + ".pt")) 178 | 179 | path_to_saved_explanation_prefix = MODEL_SAVE_NAMES[model_name] 180 | 181 | if epoch != -1: 182 | path_to_saved_explanation_prefix += "_epoch_" + str(epoch) 183 | 184 | path_to_saved_explanation_prefix += "_" + explainer.lower() + "_soft_masks_node_" 185 | 186 | path_to_saved_explanation_prefix = results_directory.joinpath(path_to_saved_explanation_prefix) 187 | 188 | model = model_class(dataset) 189 | model.to(device) 190 | data = data.to(device) 191 | 192 | load_model(path_to_saved_model, model) 193 | print("Loaded saved model") 194 | 195 | gnn_explainer = GNNExplainer(model, log=False) 196 | 197 | if dataset_name[:4] == "syn2": 198 | selected_nodes = list(range(300, 700)) + list(range(1000, 1400)) 199 | else: 200 | selected_nodes = np.load(global_model_save_directory.joinpath(dataset_name + "_selected_nodes.npy")) 201 | 202 | if config_name == "retrain": 203 | selected_nodes = np.array(range(data.num_nodes))[data.train_mask] 204 | 205 | # cast to same data format as before and limit to specified block 206 | selected_nodes = set(int(node) for node in selected_nodes) 207 | # selected_nodes = range(2708) # setting corra full 208 | print("Selected nodes: " + str(selected_nodes)) 209 | 210 | for i, node in enumerate(selected_nodes): 211 | start_time = time.time() 212 | 213 | node_mask = None 214 | edge_mask = None 215 | 216 | if explainer == "GNNExplainer": 217 | try: 218 | feature_mask, edge_mask, node_mask = load_gnn_explanation( 219 | path_to_saved_explanation_prefix, 220 | node) 221 | create_mode = False 222 | except FileNotFoundError: 223 | feature_mask, edge_mask = gnn_explainer.explain_node(node, data.x, data.edge_index) 224 | save_gnn_explanation(path_to_saved_explanation_prefix, node, feature_mask, edge_mask) 225 | elif explainer == "GradEdge": 226 | computation_graph_feature_matrix, computation_graph_edge_index, mapping, hard_edge_mask, kwargs = \ 227 | gnn_explainer.__subgraph__(node, data.x, data.edge_index) 228 | try: 229 | feature_mask, edge_mask, node_mask = load_gnn_explanation( 230 | path_to_saved_explanation_prefix, 231 | node) 232 | create_mode = False 233 | except FileNotFoundError: 234 | feature_mask, edge_mask = grad_edge_explanation(model, 235 | mapping, 236 | computation_graph_feature_matrix, 237 | computation_graph_edge_index) 238 | save_gnn_explanation(path_to_saved_explanation_prefix, node, feature_mask, edge_mask) 239 | 240 | edge_mask = torch.tensor(edge_mask) 241 | feature_mask = torch.tensor(feature_mask) 242 | 243 | edge_mask_long = torch.zeros(data.edge_index.shape[1]) 244 | edge_mask_long[hard_edge_mask] = edge_mask 245 | edge_mask = edge_mask_long 246 | elif explainer in ["Grad", "GradInput"]: 247 | try: 248 | feature_mask, edge_mask, node_mask = load_gnn_explanation( 249 | path_to_saved_explanation_prefix, 250 | node) 251 | create_mode = False 252 | except FileNotFoundError: 253 | computation_graph_feature_matrix, computation_graph_edge_index, mapping, hard_edge_mask, kwargs = \ 254 | gnn_explainer.__subgraph__(node, data.x, data.edge_index) 255 | if explainer == "Grad": 256 | feature_mask, node_mask = grad_node_explanation(model, 257 | mapping, 258 | computation_graph_feature_matrix, 259 | computation_graph_edge_index) 260 | elif explainer == "GradInput": 261 | feature_mask, node_mask = gradinput_node_explanation(model, 262 | mapping, 263 | computation_graph_feature_matrix, 264 | computation_graph_edge_index) 265 | else: 266 | raise NotImplementedError("") 267 | save_gnn_explanation(path_to_saved_explanation_prefix, node, feature_mask, 268 | node_mask=node_mask) 269 | feature_mask = torch.tensor(feature_mask) 270 | node_mask = torch.tensor(node_mask) 271 | 272 | elif explainer == "five_neighbors": 273 | try: 274 | feature_mask, edge_mask, node_mask = load_gnn_explanation( 275 | path_to_saved_explanation_prefix, 276 | node) 277 | create_mode = False 278 | except FileNotFoundError: 279 | feature_mask, node_mask = five_random_closest_neighbors(gnn_explainer, node, data) 280 | save_gnn_explanation(path_to_saved_explanation_prefix, node, feature_mask, 281 | node_mask=node_mask) 282 | else: 283 | raise NotImplementedError("Explainer not implemented") 284 | 285 | end_time = time.time() 286 | 287 | raw_time["explainer"].append(explainer) 288 | raw_time["dataset"].append(dataset_name) 289 | raw_time["model"].append(model_name) 290 | raw_time["epoch"].append(epoch) 291 | raw_time["node"].append(node) 292 | raw_time["time"].append(end_time - start_time) 293 | 294 | if i % 10 == 0: 295 | print("\n", explainer, dataset_name, model_name, epoch, i, end=" ") 296 | 297 | explanations[node] = feature_mask, edge_mask, node_mask 298 | print("") 299 | 300 | if evaluate_retrieved_explanations: 301 | print("Starting evaluation", time.ctime()) 302 | evaluate_explanations(explainer, model_name, dataset_name, model, data, explanations, 303 | global_results_directory, epoch) 304 | 305 | if dataset_name[:4] == "syn2": 306 | evaluate_synthetic(explainer, model_name, dataset_name, model, data, explanations, 307 | global_results_directory, epoch) 308 | print("Finished evaluation", time.ctime()) 309 | 310 | if add_trivial_explanations: 311 | print("Adding trivial explanations") 312 | empty_explanations = {} 313 | for node in explanations: 314 | feature_mask, _, node_mask = explanations[node] 315 | 316 | feature_mask = torch.zeros_like(feature_mask) 317 | edge_mask = None 318 | node_mask = torch.zeros_like(node_mask) 319 | 320 | empty_explanations[node] = feature_mask, edge_mask, node_mask 321 | 322 | print("Starting evaluation", time.ctime()) 323 | evaluate_explanations("edge_only", model_name, dataset_name, model, data, empty_explanations, 324 | global_results_directory, epoch) 325 | if dataset_name[:4] == "syn2": 326 | evaluate_synthetic("edge_only", model_name, dataset_name, model, data, empty_explanations, 327 | global_results_directory, epoch) 328 | print("Finished evaluation", time.ctime()) 329 | print("") 330 | 331 | if create_mode: 332 | df_time = pd.DataFrame(data=raw_time) 333 | df_time.to_csv(global_results_directory.joinpath(config_name + "_time.csv")) 334 | -------------------------------------------------------------------------------- /data/syn2.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/funket/zorro/3aeab245e710bc41fdbeffe31e35bffff11ec727/data/syn2.npz -------------------------------------------------------------------------------- /evaluate_zorro.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import numpy as np 3 | import time 4 | import torch 5 | import pandas as pd 6 | 7 | from print_progress_overview import get_combinations, DATASET_RESULT_PATHS, MODEL_SAVE_NAMES 8 | from explainer import load_minimal_nodes_and_features_sets 9 | from models import load_model, load_dataset, GCNNet, GATNet, APPNP2Net, GINConvNet, GCN_syn2 10 | from evaluation import evaluate_explanations, evaluate_synthetic 11 | 12 | if __name__ == "__main__": 13 | working_directory = Path(".") 14 | global_results_directory = working_directory.joinpath("results_evaluated") 15 | global_model_save_directory = working_directory.joinpath("results") 16 | combinations = get_combinations() 17 | device = "cpu" 18 | 19 | for dataset_name, model_name, search_paths in combinations: 20 | for i, path_pattern in enumerate(search_paths): 21 | explainer = "Zorro" 22 | if path_pattern.find("_t_3") != -1: 23 | explainer += "_t_3" 24 | 25 | epoch = -1 26 | if path_pattern.find("_epoch_") != -1: 27 | # retrieve epoch from path 28 | epoch_pos_start = path_pattern.find("_epoch_") + 7 29 | epoch_pos_end = path_pattern[epoch_pos_start:].find("_") 30 | epoch = int(path_pattern[epoch_pos_start:epoch_pos_start + epoch_pos_end]) 31 | 32 | print(time.ctime(), explainer, dataset_name, model_name, epoch) 33 | 34 | results_directory = working_directory.joinpath("results") 35 | results_path_prefix = DATASET_RESULT_PATHS[dataset_name] 36 | 37 | if dataset_name[:4] == "syn2": 38 | selected_nodes = list(range(300, 700)) + list(range(1000, 1400)) 39 | else: 40 | selected_nodes = np.load(results_directory.joinpath(dataset_name + "_selected_nodes.npy")) 41 | 42 | results_path = results_path_prefix + "_" + model_name 43 | results_directory = working_directory.joinpath("results") 44 | results_directory = results_directory.joinpath(results_path) 45 | 46 | file_prefix = results_directory.joinpath(MODEL_SAVE_NAMES[model_name] + "_explanation") 47 | 48 | minimal_sets = {} 49 | first_explanations = {} 50 | skip = False 51 | for node in selected_nodes: 52 | node = int(node) 53 | try: 54 | # remove _node_{:d}.npz from path pattern 55 | minimal_sets[node] = load_minimal_nodes_and_features_sets(str(file_prefix) + path_pattern[:-14], 56 | node) 57 | except FileNotFoundError: 58 | skip = True 59 | break 60 | 61 | selected_nodes, selected_features, executed_selections = minimal_sets[node][0] 62 | selected_nodes = torch.Tensor(selected_nodes.squeeze()) 63 | selected_features = torch.Tensor(selected_features.squeeze()) 64 | first_explanations[node] = selected_features, None, selected_nodes 65 | 66 | if skip: 67 | continue 68 | 69 | dataset, data, results_path = load_dataset(dataset_name, working_directory=working_directory) 70 | 71 | results_path += "_" + model_name 72 | results_directory = global_results_directory.joinpath(explainer).joinpath(results_path) 73 | results_directory.mkdir(parents=True, exist_ok=True) 74 | 75 | model_save_directory = global_model_save_directory.joinpath(results_path) 76 | 77 | model_classes = {"GCN": GCNNet, 78 | "GAT": GATNet, 79 | "GINConv": GINConvNet, 80 | "APPNP2Net": APPNP2Net, 81 | "GCN_syn2": GCN_syn2, 82 | } 83 | model_class = model_classes[model_name] 84 | 85 | if epoch == -1: 86 | path_to_saved_model = str(model_save_directory.joinpath(MODEL_SAVE_NAMES[model_name] + ".pt")) 87 | else: 88 | path_to_saved_model = str( 89 | model_save_directory.joinpath( 90 | MODEL_SAVE_NAMES[model_name] + "_epoch_" + str(epoch) + ".pt")) 91 | 92 | model = model_class(dataset) 93 | model.to(device) 94 | data = data.to(device) 95 | 96 | load_model(path_to_saved_model, model) 97 | 98 | print("Evaluate") 99 | 100 | if dataset_name[:4] != "syn2": 101 | evaluate_explanations(explainer, model_name, dataset_name, model, data, first_explanations, 102 | global_results_directory, epoch) 103 | else: 104 | evaluate_synthetic(explainer, model_name, dataset_name, model, data, first_explanations, 105 | global_results_directory, epoch) 106 | -------------------------------------------------------------------------------- /evaluation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from gnn_explainer import GNNExplainer 3 | from scipy.stats import entropy 4 | import numpy as np 5 | import pandas as pd 6 | 7 | 8 | def binarize_tensor(tensor, number_of_ones): 9 | binary_tensor = torch.zeros_like(tensor) 10 | _, top_indices = torch.topk(tensor, number_of_ones, sorted=False) 11 | binary_tensor[top_indices] = 1 12 | 13 | return binary_tensor 14 | 15 | 16 | SAVE_CACHE = ("path", {}) 17 | 18 | 19 | def get_gnn_distortion_with_cache(path_prefix, node, model, data, 20 | feature_mask, edge_mask=None, node_mask=None, 21 | feature_ones=None, edge_ones=None, node_ones=None, 22 | validity=False, single_save=False, skip_save=False, 23 | ): 24 | global SAVE_CACHE 25 | if single_save: 26 | save_path = str(path_prefix) + ".npz" 27 | key = str(node) + "_" 28 | else: 29 | save_path = str(path_prefix) + str(node) + ".npz" 30 | key = "" 31 | 32 | if feature_ones is None: 33 | key += "f_complete" 34 | else: 35 | key += "f_" + str(feature_ones) 36 | 37 | key += "_" 38 | 39 | if edge_mask is not None: 40 | if edge_ones is None: 41 | key += "e_complete" 42 | else: 43 | key += "e_" + str(edge_ones) 44 | 45 | if node_mask is not None: 46 | if node_ones is None: 47 | key += "n_complete" 48 | else: 49 | key += "n_" + str(node_ones) 50 | 51 | if validity: 52 | key += "_validity" 53 | 54 | if save_path != SAVE_CACHE[0]: 55 | SAVE_CACHE = (save_path, {}) 56 | save = SAVE_CACHE[1] 57 | try: 58 | save_file = np.load(save_path) 59 | for saved_key in save_file: 60 | save[saved_key] = save_file[saved_key] 61 | except FileNotFoundError: 62 | pass 63 | save = SAVE_CACHE[1] 64 | 65 | if key not in save: 66 | if feature_ones is not None: 67 | feature_mask = binarize_tensor(feature_mask, feature_ones) 68 | 69 | if edge_ones is not None: 70 | edge_mask = binarize_tensor(edge_mask, edge_ones) 71 | 72 | if node_ones is not None: 73 | node_mask = binarize_tensor(node_mask, node_ones) 74 | 75 | if node_mask is not None: 76 | if len(node_mask.shape) == 1: 77 | node_mask = node_mask.unsqueeze(0) 78 | 79 | if len(feature_mask.shape) == 1: 80 | feature_mask = feature_mask.unsqueeze(0) 81 | 82 | save[key] = model.distortion(node, data.x, data.edge_index, 83 | feature_mask=feature_mask, 84 | edge_mask=edge_mask, 85 | node_mask=node_mask, 86 | validity=validity, 87 | ) 88 | if not skip_save: 89 | np.savez_compressed(save_path, **save) 90 | 91 | return float(save[key]) 92 | 93 | 94 | def evaluate_explanations(explainer_name, model_name, dataset_name, model, data, explanations, global_results_directory, 95 | epoch=-1, save_appendix=""): 96 | raw_explanation_info_softmax = { 97 | "explainer": [], 98 | "dataset": [], 99 | "model": [], 100 | "epoch": [], 101 | "node": [], 102 | "Fidelity": [], 103 | "Validity": [], 104 | "Entropy Feature Mask": [], 105 | "Entropy Edge Mask": [], 106 | "Entropy Node Mask": [], 107 | "Possible Nodes": [], 108 | "Possible Edges": [], 109 | "Possible Features": [], 110 | "SUM(Feature Mask)": [], 111 | "SUM(Edge Mask)": [], 112 | "SUM(Node Mask)": [], 113 | } 114 | 115 | # needed for the subgraph 116 | gnn_explainer = GNNExplainer(model, log=False) 117 | 118 | save_path = str(global_results_directory.joinpath(explainer_name + "_" + model_name + "_" + dataset_name)) 119 | if epoch != -1: 120 | save_path += "_epoch_" + str(epoch) 121 | if save_appendix: 122 | save_path += save_appendix 123 | 124 | for counter, node in enumerate(explanations): 125 | feature_mask, edge_mask, node_mask = explanations[node] 126 | 127 | computation_graph_feature_matrix, computation_graph_edge_index, mapping, hard_edge_mask, kwargs = \ 128 | gnn_explainer.__subgraph__(node, data.x, data.edge_index) 129 | 130 | distortion = get_gnn_distortion_with_cache(save_path, 131 | node, 132 | gnn_explainer, 133 | feature_mask=feature_mask, 134 | edge_mask=edge_mask, 135 | node_mask=node_mask, 136 | data=data, 137 | single_save=True, 138 | skip_save=True, 139 | ) 140 | 141 | validity = get_gnn_distortion_with_cache(save_path, 142 | node, 143 | gnn_explainer, 144 | feature_mask=feature_mask, 145 | edge_mask=edge_mask, 146 | node_mask=node_mask, 147 | data=data, 148 | validity=True, 149 | single_save=True, 150 | skip_save=(counter != len(explanations) - 1) and (counter % 25 != 0), 151 | ) 152 | 153 | raw_explanation_info_softmax["explainer"].append(explainer_name) 154 | raw_explanation_info_softmax["dataset"].append(dataset_name) 155 | raw_explanation_info_softmax["model"].append(model_name) 156 | raw_explanation_info_softmax["epoch"].append(epoch) 157 | raw_explanation_info_softmax["node"].append(node) 158 | raw_explanation_info_softmax["Fidelity"].append(distortion) 159 | raw_explanation_info_softmax["Validity"].append(validity) 160 | raw_explanation_info_softmax["SUM(Feature Mask)"].append(np.sum(feature_mask.numpy())) 161 | if np.abs(raw_explanation_info_softmax["SUM(Feature Mask)"][-1]) > 0.0001: 162 | raw_explanation_info_softmax["Entropy Feature Mask"].append(entropy(feature_mask.numpy().flatten())) 163 | else: 164 | raw_explanation_info_softmax["Entropy Feature Mask"].append(np.nan) 165 | 166 | if edge_mask is not None: 167 | raw_explanation_info_softmax["SUM(Edge Mask)"].append(np.sum(edge_mask.numpy())) 168 | else: 169 | raw_explanation_info_softmax["SUM(Edge Mask)"].append(0.0) 170 | if np.abs(raw_explanation_info_softmax["SUM(Edge Mask)"][-1]) > 0.0001: 171 | raw_explanation_info_softmax["Entropy Edge Mask"].append(entropy(edge_mask.numpy().flatten())) 172 | else: 173 | raw_explanation_info_softmax["Entropy Edge Mask"].append(np.nan) 174 | 175 | if node_mask is not None: 176 | raw_explanation_info_softmax["SUM(Node Mask)"].append(np.sum(node_mask.numpy())) 177 | else: 178 | raw_explanation_info_softmax["SUM(Node Mask)"].append(0.0) 179 | if np.abs(raw_explanation_info_softmax["SUM(Node Mask)"][-1]) > 0.0001: 180 | raw_explanation_info_softmax["Entropy Node Mask"].append(entropy(node_mask.numpy().flatten())) 181 | else: 182 | raw_explanation_info_softmax["Entropy Node Mask"].append(np.nan) 183 | 184 | num_nodes, num_features = computation_graph_feature_matrix.size() 185 | num_edges = computation_graph_edge_index.size(1) 186 | raw_explanation_info_softmax["Possible Nodes"].append(int(num_nodes)) 187 | raw_explanation_info_softmax["Possible Features"].append(int(num_features)) 188 | raw_explanation_info_softmax["Possible Edges"].append(int(num_edges)) 189 | 190 | df_explanation_info = pd.DataFrame(data=raw_explanation_info_softmax) 191 | if not save_appendix: 192 | df_explanation_info.to_csv(save_path + "_info.csv") 193 | return df_explanation_info 194 | 195 | 196 | def get_ground_truth_syn(node): 197 | # taken from https://github.com/vunhatminh/PGMExplainer/ 198 | base = [0, 1, 2, 3, 4] 199 | ground_truth = [] 200 | offset = node % 5 201 | ground_truth = [node - offset + val for val in base] 202 | return ground_truth 203 | 204 | 205 | def evaluate_synthetic(explainer_name, model_name, dataset_name, model, data, explanations, global_results_directory, 206 | epoch=-1): 207 | from torch_geometric.utils import k_hop_subgraph 208 | 209 | components_with_explained_node = [] 210 | components_without_explained_node = [] 211 | 212 | number_of_nodes_selected = [] 213 | 214 | node_true_positive = [] 215 | node_false_positive = [] 216 | node_true_negative = [] 217 | node_false_negative = [] 218 | node_tpr = [] 219 | node_precision = [] 220 | node_accuracy = [] 221 | 222 | # needed for the subgraph 223 | gnn_explainer = GNNExplainer(model, log=False) 224 | 225 | select_top_k_nodes = 10 226 | 227 | reduced_explanations = {} 228 | 229 | for node in explanations: 230 | subset, edge_index, _, _ = k_hop_subgraph( 231 | node, gnn_explainer.__num_hops__(), data.edge_index, relabel_nodes=False, 232 | num_nodes=None, flow=gnn_explainer.__flow__()) 233 | 234 | subset = subset.numpy() 235 | 236 | node_ground_truth = set(get_ground_truth_syn(node)) 237 | feature_mask, edge_mask, node_mask = explanations[node] 238 | 239 | if explainer_name == "GNNExplainer": 240 | # select top nodes based on edge mask 241 | top_edges_index = np.argpartition(edge_mask, -select_top_k_nodes)[-select_top_k_nodes:] 242 | # sort them descending (reason for the -edge_mask) 243 | top_edges_index = top_edges_index[np.argsort(-edge_mask[top_edges_index])] 244 | 245 | selected_nodes = set() 246 | for u, v in data.edge_index[:, top_edges_index].numpy().T: 247 | if len(selected_nodes) > 4: 248 | break 249 | selected_nodes.add(u) 250 | selected_nodes.add(v) 251 | 252 | if len(selected_nodes) < 5: 253 | raise Exception("Not enough elements" + str(node)) 254 | 255 | nodes_selected = set(selected_nodes) 256 | nodes_not_selected = set(subset).difference(nodes_selected) 257 | 258 | elif explainer_name in ["five_neighbors", "Grad", "GradInput"]: 259 | top_node_index = np.argpartition(node_mask, -5)[-5:] 260 | nodes_selected = set(subset[top_node_index]) 261 | nodes_not_selected = set(subset).difference(nodes_selected) 262 | elif explainer_name == "edge_only": 263 | nodes_selected = set() 264 | nodes_not_selected = set(subset).difference(nodes_selected) 265 | elif explainer_name in ["Zorro", "Zorro_t_3"]: 266 | nodes_selected = set(subset[node_mask > 0]) 267 | nodes_not_selected = set(subset).difference(nodes_selected) 268 | else: 269 | raise NotImplementedError("Not catched") 270 | 271 | number_of_nodes_selected.append(len(nodes_selected)) 272 | 273 | # create top 5/6 node mask 274 | top_5_node_mask = torch.zeros(len(subset)) 275 | list_subset = list(subset) 276 | for selected_neighbor in nodes_selected: 277 | top_5_node_mask[list_subset.index(selected_neighbor)] = 1 278 | 279 | reduced_explanations[node] = feature_mask, None, top_5_node_mask 280 | 281 | node_true_positive.append(len(nodes_selected.intersection(node_ground_truth))) 282 | node_false_positive.append(len(nodes_selected.difference(node_ground_truth))) 283 | node_true_negative.append(len(nodes_not_selected.difference(node_ground_truth))) 284 | node_false_negative.append(len(node_ground_truth.difference(nodes_selected))) 285 | 286 | node_tpr.append(node_true_positive[-1] / (node_true_positive[-1] + node_false_negative[-1])) 287 | if (node_true_positive[-1] + node_false_positive[-1]) > 0: 288 | node_precision.append( 289 | node_true_positive[-1] / (node_true_positive[-1] + node_false_positive[-1])) 290 | else: 291 | node_precision.append(np.nan) 292 | 293 | node_accuracy.append((node_true_positive[-1] + node_true_negative[-1]) / ( 294 | node_true_positive[-1] + node_true_negative[-1] 295 | + node_false_positive[-1] + node_false_negative[-1] 296 | )) 297 | 298 | save_appendix = "top5" 299 | general_eval_of_top_5 = evaluate_explanations( 300 | explainer_name, model_name, dataset_name, model, data, reduced_explanations, global_results_directory, 301 | epoch=epoch, save_appendix=save_appendix) 302 | 303 | save_path = str(global_results_directory.joinpath(explainer_name + "_" + model_name + "_" + dataset_name)) 304 | if epoch != -1: 305 | save_path += "_epoch_" + str(epoch) 306 | if save_appendix: 307 | save_path += save_appendix 308 | 309 | syn_details = pd.DataFrame(data={ 310 | "#nodes": number_of_nodes_selected, 311 | "node_true_positive": node_true_positive, 312 | "node_false_positive": node_false_positive, 313 | "node_true_negative": node_true_negative, 314 | "node_false_negative": node_false_negative, 315 | "node_tpr": node_tpr, 316 | "node_precision": node_precision, 317 | "node_accuracy": node_accuracy} 318 | ) 319 | 320 | full_evaluation = pd.concat([general_eval_of_top_5, syn_details], axis=1) 321 | full_evaluation.to_csv(save_path + "_info.csv") 322 | 323 | return full_evaluation 324 | -------------------------------------------------------------------------------- /example.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": { 7 | "collapsed": true 8 | }, 9 | "outputs": [], 10 | "source": [ 11 | "from explainer import *\n", 12 | "from models import *" 13 | ] 14 | }, 15 | { 16 | "cell_type": "code", 17 | "execution_count": 2, 18 | "outputs": [], 19 | "source": [ 20 | "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", 21 | "dataset, data, results_path = load_dataset(\"Cora\")\n", 22 | "model = GCNNet(dataset)\n", 23 | "model.to(device)\n", 24 | "data = data.to(device)" 25 | ], 26 | "metadata": { 27 | "collapsed": false, 28 | "pycharm": { 29 | "name": "#%%\n" 30 | } 31 | } 32 | }, 33 | { 34 | "cell_type": "code", 35 | "execution_count": 3, 36 | "outputs": [ 37 | { 38 | "name": "stdout", 39 | "output_type": "stream", 40 | "text": [ 41 | "Accuracy: 0.3210\n", 42 | "Accuracy: 0.7250\n", 43 | "Accuracy: 0.7450\n", 44 | "Accuracy: 0.7490\n", 45 | "Accuracy: 0.7490\n", 46 | "Accuracy: 0.7640\n", 47 | "Accuracy: 0.7660\n", 48 | "Accuracy: 0.7590\n" 49 | ] 50 | }, 51 | { 52 | "data": { 53 | "text/plain": "[]" 54 | }, 55 | "execution_count": 3, 56 | "metadata": {}, 57 | "output_type": "execute_result" 58 | } 59 | ], 60 | "source": [ 61 | "train_model(model, data)" 62 | ], 63 | "metadata": { 64 | "collapsed": false, 65 | "pycharm": { 66 | "name": "#%%\n" 67 | } 68 | } 69 | }, 70 | { 71 | "cell_type": "code", 72 | "execution_count": 4, 73 | "outputs": [], 74 | "source": [ 75 | "explainer = Zorro(model, device)" 76 | ], 77 | "metadata": { 78 | "collapsed": false, 79 | "pycharm": { 80 | "name": "#%%\n" 81 | } 82 | } 83 | }, 84 | { 85 | "cell_type": "code", 86 | "execution_count": 5, 87 | "outputs": [], 88 | "source": [ 89 | "# Same as the Zorro \\tau=0.98 in the paper\n", 90 | "tau = .03\n", 91 | "# Explain node 10\n", 92 | "node = 10\n", 93 | "# only retrieve 1 explanation\n", 94 | "recursion_depth = 1" 95 | ], 96 | "metadata": { 97 | "collapsed": false, 98 | "pycharm": { 99 | "name": "#%%\n" 100 | } 101 | } 102 | }, 103 | { 104 | "cell_type": "code", 105 | "execution_count": 6, 106 | "outputs": [ 107 | { 108 | "name": "stderr", 109 | "output_type": "stream", 110 | "text": [ 111 | "\n", 112 | " 0%| | 0/10031 [00:00/results/_selected_nodes.npy as process those nodes") 83 | parser.add_argument("--offset", default=0, type=int, help="Specify which block of the predefined nodes") 84 | parser.add_argument("-wd", "--working_directory", type=str, help="Specify path to working directory") 85 | parser.add_argument("--tau", default=15, type=int, help="Specify tau (threshold)") 86 | parser.add_argument("--recursion_depth", default=-1, type=int, help="Specify maximum recursion depth") 87 | parser.add_argument("--full_search", action="store_true", default=False, 88 | help="Always check all nodes and all features (non greedy variant)") 89 | parser.add_argument("--save_initial_improve", action="store_true", default=True, 90 | help="Store distortion improve values of first round") 91 | parser.add_argument("--record_processing_time", action="store_true", default=True, 92 | help="Save in addition to selected nodes and features the processing time") 93 | parser.add_argument("--samples", default=100, type=int, help="Specify samples for fidelity") 94 | args = parser.parse_args() 95 | 96 | if args.gpu != -1: 97 | os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu) 98 | 99 | # imports not at start to limit the GPU 100 | from pathlib import Path 101 | from explainer import * 102 | from models import * 103 | 104 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 105 | if args.cpu_only: 106 | device = torch.device('cpu') 107 | 108 | # limit number of CPU cores 109 | torch.set_num_threads(16) 110 | 111 | if args.working_directory is None: 112 | working_directory = Path(".").resolve() 113 | else: 114 | working_directory = Path(args.working_directory).resolve() 115 | 116 | global_results_directory = working_directory.joinpath("results") 117 | global_results_directory.mkdir(parents=False, exist_ok=True) 118 | dataset, data, results_path = load_dataset(args.dataset, working_directory=working_directory) 119 | 120 | results_path += "_" + args.model 121 | results_directory = global_results_directory.joinpath(results_path) 122 | results_directory.mkdir(parents=False, exist_ok=True) 123 | 124 | model_classes = {"GCN": GCNNet, 125 | "GAT": GATNet, 126 | "GINConv": GINConvNet, 127 | "APPNP2Net": APPNP2Net, 128 | "GCN_syn2": GCN_syn2, 129 | } 130 | model_class = model_classes[args.model] 131 | 132 | if args.epoch == -1: 133 | path_to_saved_model = str(results_directory.joinpath(MODEL_SAVE_NAMES[args.model] + ".pt")) 134 | else: 135 | path_to_saved_model = str( 136 | results_directory.joinpath(MODEL_SAVE_NAMES[args.model] + "_epoch_" + str(args.epoch) + ".pt")) 137 | 138 | path_to_saved_explanation_prefix = results_directory.joinpath(MODEL_SAVE_NAMES[args.model] + "_explanation") 139 | 140 | past_execution_counter = 0 141 | while True: 142 | path_to_log_file = results_directory.joinpath(args.model + "_execution_" + str(past_execution_counter) + ".log") 143 | try: 144 | with open(path_to_log_file): 145 | pass 146 | except FileNotFoundError: 147 | break 148 | 149 | past_execution_counter += 1 150 | 151 | if args.recursion_depth == -1: 152 | recursion_depth = np.inf 153 | else: 154 | recursion_depth = args.recursion_depth 155 | 156 | tau = args.tau / 100 157 | 158 | logger = set_up_logger(path_to_log_file) 159 | logger.info("Working directory: " + str(working_directory)) 160 | logger.info("CPU only: " + str(args.cpu_only)) 161 | logger.info("GPU Number: " + str(args.model)) 162 | logger.info("Device: " + str(device)) 163 | logger.info("Model: " + args.model) 164 | if args.epoch != -1: 165 | logger.info("Epoch: " + str(args.epoch)) 166 | logger.info("Dataset: " + args.dataset) 167 | logger.info("Tau: " + str(tau)) 168 | logger.info("Recursion depth: " + str(recursion_depth)) 169 | logger.info("Samples: " + str(args.samples)) 170 | if args.predefined_nodes: 171 | logger.info("Load predefined nodes: True") 172 | else: 173 | logger.info("Load predefined nodes: False") 174 | logger.info("Number of Nodes: " + str(args.nnodes)) 175 | 176 | model = model_class(dataset) 177 | model.to(device) 178 | data = data.to(device) 179 | 180 | try: 181 | load_model(path_to_saved_model, model) 182 | logger.info("Loaded saved model") 183 | except FileNotFoundError: 184 | if args.epoch != -1: 185 | raise Exception("Not supported if epoch is not last") 186 | 187 | if args.dataset[:4] == "syn2" or args.dataset == "syn1": 188 | if args.dataset == "syn2_4": 189 | accuracies = train_model(model, data, epochs=2000, lr=0.001, weight_decay=0.005, clip=2.0, 190 | loss_function="cross_entropy", epoch_save_path=path_to_saved_model) 191 | np.savez_compressed(str(results_directory.joinpath("accuracies.npz")), 192 | **{"accuracies": np.array(accuracies)}) 193 | else: 194 | train_model(model, data, epochs=2000, lr=0.001, weight_decay=0.005, clip=2.0, 195 | loss_function="cross_entropy") 196 | else: 197 | train_model(model, data) 198 | logger.info("Finished training model") 199 | save_model(model, path_to_saved_model) 200 | logger.info("Saved model") 201 | 202 | logger.info(retrieve_accuracy(model, data)) 203 | 204 | explainer = Zorro(model, device, greedy=not args.full_search, 205 | record_process_time=args.record_processing_time, samples=args.samples) 206 | 207 | if args.dataset == "syn1": 208 | explainer.add_noise = True 209 | 210 | total_number_of_nodes, _ = data.x.size() 211 | if args.predefined_nodes: 212 | if args.dataset == "syn1": 213 | selected_nodes = list(range(300, 700)) 214 | elif args.dataset[:4] == "syn2": 215 | selected_nodes = list(range(300, 700)) + list(range(1000, 1400)) 216 | else: 217 | selected_nodes = np.load(global_results_directory.joinpath(args.dataset + "_selected_nodes.npy")) 218 | 219 | if args.dataset == "Cora": 220 | # select training nodes 221 | selected_nodes = np.array(range(data.num_nodes))[data.train_mask.cpu().numpy()] 222 | 223 | # cast to same data format as before and limit to specified block 224 | selected_nodes = set(int(node) for node in selected_nodes[args.offset:args.offset + args.nnodes]) 225 | logger.info("Selected nodes: " + str(selected_nodes)) 226 | elif args.nnodes < total_number_of_nodes: 227 | selected_nodes = set() 228 | possible_nodes = list(range(total_number_of_nodes)) 229 | while len(selected_nodes) < args.nnodes and len(possible_nodes) > 0: 230 | possible_node = random.choice(possible_nodes) 231 | # check if explanation does not exists 232 | try: 233 | with open( 234 | get_save_file_path(path_to_saved_explanation_prefix, possible_node, args.tau, recursion_depth, 235 | args.full_search, args.samples, args.epoch)): 236 | pass 237 | except FileNotFoundError: 238 | possible_nodes.remove(possible_node) 239 | selected_nodes.add(possible_node) 240 | 241 | logger.info("Selected nodes: " + str(selected_nodes)) 242 | else: 243 | selected_nodes = range(total_number_of_nodes) 244 | logger.info("Selected nodes: all") 245 | 246 | for node in selected_nodes: 247 | explanation_save_path = get_save_file_path(path_to_saved_explanation_prefix, node, args.tau, recursion_depth, 248 | args.full_search, args.samples, args.epoch) 249 | # skip existing explanations 250 | try: 251 | with open(explanation_save_path): 252 | continue 253 | except FileNotFoundError: 254 | pass 255 | 256 | explanation = explainer.explain_node(node, data.x, data.edge_index, 257 | tau=tau, 258 | recursion_depth=recursion_depth, 259 | save_initial_improve=args.save_initial_improve) 260 | if args.save_initial_improve: 261 | save_minimal_nodes_and_features_sets(explanation_save_path, node, explanation[0], explanation[1], 262 | explanation[2]) 263 | else: 264 | save_minimal_nodes_and_features_sets(explanation_save_path, node, explanation) 265 | -------------------------------------------------------------------------------- /explainer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from tqdm import tqdm 3 | from torch_geometric.nn import MessagePassing 4 | from torch_geometric.nn import APPNP 5 | from torch_geometric.utils import k_hop_subgraph 6 | import numpy as np 7 | import logging 8 | import time 9 | 10 | 11 | 12 | class AbstractGraphExplainer(torch.nn.Module): 13 | 14 | def __init__(self, model, device, log=True, record_process_time=False): 15 | super(AbstractGraphExplainer, self).__init__() 16 | self.model = model 17 | self.log = log 18 | self.logger = logging.getLogger("explainer") 19 | self.device = device 20 | 21 | self.record_process_time = record_process_time 22 | 23 | @staticmethod 24 | def num_hops(model): 25 | num_hops = 0 26 | for module in model.modules(): 27 | if isinstance(module, MessagePassing): 28 | if isinstance(module, APPNP): 29 | num_hops += module.K 30 | else: 31 | num_hops += 1 32 | return num_hops 33 | 34 | def __num_hops__(self): 35 | return self.num_hops(self.model) 36 | 37 | @staticmethod 38 | def flow(model): 39 | for module in model.modules(): 40 | if isinstance(module, MessagePassing): 41 | return module.flow 42 | return 'source_to_target' 43 | 44 | def __flow__(self): 45 | return self.flow(self.model) 46 | 47 | def __subgraph__(self, node_idx, x, edge_index, **kwargs): 48 | num_nodes, num_edges = x.size(0), edge_index.size(1) 49 | 50 | subset, edge_index, mapping, edge_mask = k_hop_subgraph( 51 | node_idx, self.__num_hops__(), edge_index, relabel_nodes=True, 52 | num_nodes=num_nodes, flow=self.__flow__()) 53 | 54 | x = x[subset] 55 | for key, item in kwargs: 56 | if torch.is_tensor(item) and item.size(0) == num_nodes: 57 | item = item[subset] 58 | elif torch.is_tensor(item) and item.size(0) == num_edges: 59 | item = item[edge_mask] 60 | kwargs[key] = item 61 | 62 | return subset, x, edge_index, mapping, edge_mask, kwargs 63 | 64 | def distortion(self, node_idx=None, full_feature_matrix=None, computation_graph_feature_matrix=None, 65 | edge_index=None, node_mask=None, feature_mask=None, predicted_label=None, samples=None, 66 | random_seed=12345): 67 | if node_idx is None: 68 | node_idx = self.node_idx 69 | 70 | if full_feature_matrix is None: 71 | full_feature_matrix = self.full_feature_matrix 72 | 73 | if computation_graph_feature_matrix is None: 74 | computation_graph_feature_matrix = self.computation_graph_feature_matrix 75 | 76 | if edge_index is None: 77 | edge_index = self.computation_graph_edge_index 78 | 79 | if node_mask is None: 80 | node_mask = self.selected_nodes 81 | 82 | if feature_mask is None: 83 | feature_mask = self.selected_features 84 | 85 | if predicted_label is None: 86 | predicted_label = self.predicted_label 87 | 88 | if samples is None: 89 | samples = self.distortion_samples 90 | 91 | return distortion(self.model, 92 | node_idx=node_idx, 93 | full_feature_matrix=full_feature_matrix, 94 | computation_graph_feature_matrix=computation_graph_feature_matrix, 95 | edge_index=edge_index, 96 | node_mask=node_mask, 97 | feature_mask=feature_mask, 98 | predicted_label=predicted_label, 99 | samples=samples, 100 | random_seed=random_seed, 101 | device=self.device, 102 | ) 103 | 104 | 105 | class Zorro(AbstractGraphExplainer): 106 | 107 | def __init__(self, model, device, log=True, greedy=True, record_process_time=False, add_noise=False, samples=100, 108 | path_to_precomputed_distortions=None): 109 | super(Zorro, self).__init__( 110 | model=model, 111 | device=device, 112 | log=log, 113 | record_process_time=record_process_time, 114 | ) 115 | self.distortion_samples = samples 116 | 117 | self.ensure_improvement = False 118 | 119 | self.add_noise = add_noise 120 | 121 | self.initial_node_improve = [np.nan] 122 | self.initial_feature_improve = [np.nan] 123 | 124 | self.greedy = greedy 125 | if self.greedy: 126 | self.greediness = 10 127 | self.sorted_possible_nodes = [] 128 | self.sorted_possible_features = [] 129 | 130 | self.path_to_precomputed_distortions = path_to_precomputed_distortions 131 | self.precomputed_distortion_info = {} 132 | 133 | def load_initial_distortion(self, node, neighbor_subset): 134 | saved_info = np.load(self.path_to_precomputed_distortions) 135 | 136 | nodes = list(saved_info["nodes"]) 137 | subset = list(saved_info["subset"]) 138 | mapping = saved_info["mapping"] 139 | initial_distortion = saved_info["initial_distortion"] 140 | feature_distortion = saved_info["feature_distortion"] 141 | node_distortion = saved_info["node_distortion"] 142 | 143 | if node not in nodes: 144 | raise ValueError("Node " + str(node) + "not found in precomputed distortions file " 145 | + str(self.path_to_precomputed_distortions)) 146 | 147 | position = nodes.index(node) 148 | 149 | best_feature = None 150 | best_feature_distortion_improve = -1000 151 | raw_unsorted_features = [] 152 | for i in range(feature_distortion.shape[0]): 153 | distortion_improve = feature_distortion[i, position] - initial_distortion[position] 154 | raw_unsorted_features.append((i, distortion_improve)) 155 | if distortion_improve > best_feature_distortion_improve: 156 | best_feature_distortion_improve = distortion_improve 157 | best_feature = i 158 | 159 | best_node = None 160 | best_node_distortion_improve = -1000 161 | raw_unsorted_nodes = [] 162 | for i, neighbor in enumerate(neighbor_subset): 163 | if subset.index(neighbor) == -1: 164 | raise ValueError("Neighbor " + str(neighbor) + "not found in precomputed neighbors " + str(subset)) 165 | distortion_improve = node_distortion[subset.index(neighbor), position] - initial_distortion[position] 166 | raw_unsorted_nodes.append((i, distortion_improve)) 167 | if distortion_improve > best_node_distortion_improve: 168 | best_node_distortion_improve = distortion_improve 169 | best_node = i 170 | 171 | # save infos in dict 172 | self.precomputed_distortion_info["best_node"] = best_node 173 | self.precomputed_distortion_info["best_node_distortion_improve"] = best_node_distortion_improve 174 | self.precomputed_distortion_info["raw_unsorted_nodes"] = raw_unsorted_nodes 175 | self.precomputed_distortion_info["best_feature"] = best_feature 176 | self.precomputed_distortion_info["best_feature_distortion_improve"] = best_feature_distortion_improve 177 | self.precomputed_distortion_info["raw_unsorted_features"] = raw_unsorted_features 178 | self.logger.debug("Successfully loaded precomputed information") 179 | 180 | def argmax_distortion_general(self, 181 | previous_distortion, 182 | possible_elements, 183 | selected_elements, 184 | initialization=False, 185 | save_initial_improve=False, 186 | **distortion_kwargs, 187 | ): 188 | if self.greedy: 189 | # determine if node or features 190 | if selected_elements is not self.selected_nodes and selected_elements is not self.selected_features: 191 | raise Exception("Neither features nor nodes selected") 192 | if initialization: 193 | if self.epoch == 1 and self.precomputed_distortion_info: 194 | if selected_elements is self.selected_nodes: 195 | best_element = self.precomputed_distortion_info["best_node"] 196 | best_distortion_improve = self.precomputed_distortion_info["best_node_distortion_improve"] 197 | raw_sorted_elements = self.precomputed_distortion_info["raw_unsorted_nodes"] 198 | self.logger.debug("Used precomputed node info") 199 | else: 200 | best_element = self.precomputed_distortion_info["best_feature"] 201 | best_distortion_improve = self.precomputed_distortion_info["best_feature_distortion_improve"] 202 | raw_sorted_elements = self.precomputed_distortion_info["raw_unsorted_features"] 203 | self.logger.debug("Used precomputed feature info") 204 | 205 | else: 206 | best_element, best_distortion_improve, raw_sorted_elements = self.argmax_distortion_general_full( 207 | previous_distortion, 208 | possible_elements, 209 | selected_elements, 210 | save_all_pairs=True, 211 | **distortion_kwargs, 212 | ) 213 | 214 | if selected_elements is self.selected_nodes: 215 | self.sorted_possible_nodes = sorted(raw_sorted_elements, key=lambda x: x[1], reverse=True) 216 | if save_initial_improve: 217 | self.initial_node_improve = raw_sorted_elements 218 | else: 219 | self.sorted_possible_features = sorted(raw_sorted_elements, key=lambda x: x[1], reverse=True) 220 | if save_initial_improve: 221 | self.initial_feature_improve = raw_sorted_elements 222 | 223 | return best_element, best_distortion_improve 224 | 225 | else: 226 | if selected_elements is self.selected_nodes: 227 | sorted_elements = self.sorted_possible_nodes 228 | else: 229 | sorted_elements = self.sorted_possible_features 230 | 231 | restricted_possible_elements = torch.zeros_like(possible_elements, device=self.device) 232 | 233 | counter = 0 234 | for index, initial_distortion_improve in sorted_elements: 235 | if possible_elements[0, index] == 1 and selected_elements[0, index] == 0: 236 | counter += 1 237 | restricted_possible_elements[0, index] = 1 238 | # possible alternative based on initial distortion improve 239 | if counter == self.greediness: 240 | break 241 | 242 | else: 243 | # think about removing those elements 244 | pass 245 | 246 | # add selected elements to possible elements to avoid -1 in the calculation of remaining elements 247 | restricted_possible_elements += selected_elements 248 | 249 | best_element, best_distortion_improve = self.argmax_distortion_general_full( 250 | previous_distortion, 251 | restricted_possible_elements, 252 | selected_elements, 253 | **distortion_kwargs, 254 | ) 255 | 256 | return best_element, best_distortion_improve 257 | 258 | elif save_initial_improve: 259 | best_element, best_distortion_improve, raw_sorted_elements = self.argmax_distortion_general_full( 260 | previous_distortion, 261 | possible_elements, 262 | selected_elements, 263 | save_all_pairs=True, 264 | **distortion_kwargs, 265 | ) 266 | 267 | if selected_elements is self.selected_nodes: 268 | self.initial_node_improve = raw_sorted_elements 269 | else: 270 | self.initial_feature_improve = raw_sorted_elements 271 | 272 | return best_element, best_distortion_improve 273 | else: 274 | return self.argmax_distortion_general_full( 275 | previous_distortion, 276 | possible_elements, 277 | selected_elements, 278 | **distortion_kwargs, 279 | ) 280 | 281 | def argmax_distortion_general_full(self, 282 | previous_distortion, 283 | possible_elements, 284 | selected_elements, 285 | save_all_pairs=False, 286 | **distortion_kwargs, 287 | ): 288 | best_element = None 289 | best_distortion_improve = -1000 290 | 291 | remaining_nodes_to_select = possible_elements - selected_elements 292 | num_remaining = remaining_nodes_to_select.sum() 293 | 294 | # if no node left break 295 | if num_remaining == 0: 296 | return best_element, best_distortion_improve 297 | 298 | if self.log: # pragma: no cover 299 | pbar = tqdm(total=int(num_remaining), position=0) 300 | pbar.set_description(f'Argmax {best_element}, {best_distortion_improve}') 301 | 302 | all_calculated_pairs = [] 303 | 304 | i = 0 305 | while num_remaining > 0: 306 | if selected_elements[0, i] == 0 and possible_elements[0, i] == 1: 307 | num_remaining -= 1 308 | 309 | selected_elements[0, i] = 1 310 | 311 | distortion_improve = self.distortion(**distortion_kwargs) \ 312 | - previous_distortion 313 | 314 | selected_elements[0, i] = 0 315 | 316 | if save_all_pairs: 317 | all_calculated_pairs.append((i, distortion_improve)) 318 | 319 | if distortion_improve > best_distortion_improve: 320 | best_element = i 321 | best_distortion_improve = distortion_improve 322 | if self.log: # pragma: no cover 323 | pbar.set_description(f'Argmax {best_element}, {best_distortion_improve}') 324 | 325 | if self.log: # pragma: no cover 326 | pbar.update(1) 327 | i += 1 328 | 329 | if self.log: # pragma: no cover 330 | pbar.close() 331 | if save_all_pairs: 332 | return best_element, best_distortion_improve, all_calculated_pairs 333 | else: 334 | return best_element, best_distortion_improve 335 | 336 | def _determine_minimal_set(self, initial_distortion, tau, possible_nodes, possible_features, 337 | save_initial_improve=False): 338 | current_distortion = initial_distortion 339 | if self.record_process_time: 340 | last_time = time.time() 341 | executed_selections = [[np.nan, np.nan, current_distortion, 0]] 342 | else: 343 | last_time = 0 344 | executed_selections = [[np.nan, np.nan, current_distortion]] 345 | 346 | num_selected_nodes = 0 347 | num_selected_features = 0 348 | 349 | while current_distortion <= 1 - tau: 350 | 351 | if num_selected_nodes == num_selected_features == 0: 352 | best_node, improve_in_distortion_by_node = self.argmax_distortion_general( 353 | current_distortion, 354 | possible_nodes, 355 | self.selected_nodes, 356 | initialization=True, 357 | feature_mask=possible_features, # assume all features are selected 358 | save_initial_improve=save_initial_improve, 359 | ) 360 | 361 | best_feature, improve_in_distortion_by_feature = self.argmax_distortion_general( 362 | current_distortion, 363 | possible_features, 364 | self.selected_features, 365 | initialization=True, 366 | node_mask=possible_nodes, # assume all nodes are selected 367 | save_initial_improve=save_initial_improve, 368 | ) 369 | 370 | elif num_selected_features == 0: 371 | best_node, improve_in_distortion_by_node = None, -100 372 | 373 | best_feature, improve_in_distortion_by_feature = self.argmax_distortion_general( 374 | current_distortion, 375 | possible_features, 376 | self.selected_features, 377 | ) 378 | 379 | elif num_selected_nodes == 0: 380 | best_node, improve_in_distortion_by_node = self.argmax_distortion_general( 381 | current_distortion, 382 | possible_nodes, 383 | self.selected_nodes, 384 | ) 385 | 386 | best_feature, improve_in_distortion_by_feature = None, -100 387 | 388 | else: 389 | best_node, improve_in_distortion_by_node = self.argmax_distortion_general( 390 | current_distortion, 391 | possible_nodes, 392 | self.selected_nodes, 393 | ) 394 | 395 | best_feature, improve_in_distortion_by_feature = self.argmax_distortion_general( 396 | current_distortion, 397 | possible_features, 398 | self.selected_features, 399 | ) 400 | 401 | if self.ensure_improvement and \ 402 | improve_in_distortion_by_node < .00000001 and improve_in_distortion_by_feature < .00000001: 403 | pass 404 | 405 | if best_node is None and best_feature is None: 406 | break 407 | 408 | if best_node is None: 409 | self.selected_features[0, best_feature] = 1 410 | num_selected_features += 1 411 | executed_selection = [np.nan, best_feature] 412 | elif best_feature is None: 413 | self.selected_nodes[0, best_node] = 1 414 | num_selected_nodes += 1 415 | executed_selection = [best_node, np.nan] 416 | elif improve_in_distortion_by_feature >= improve_in_distortion_by_node: 417 | # on equal improve prefer feature 418 | self.selected_features[0, best_feature] = 1 419 | num_selected_features += 1 420 | executed_selection = [np.nan, best_feature] 421 | else: 422 | self.selected_nodes[0, best_node] = 1 423 | num_selected_nodes += 1 424 | executed_selection = [best_node, np.nan] 425 | 426 | current_distortion = self.distortion() 427 | 428 | print(current_distortion) 429 | executed_selection.append(current_distortion) 430 | 431 | if self.record_process_time: 432 | executed_selection.append(time.time() - last_time) 433 | last_time = time.time() 434 | 435 | executed_selections.append(executed_selection) 436 | 437 | self.epoch += 1 438 | 439 | if self.log: # pragma: no cover 440 | self.overall_progress_bar.update(1) 441 | 442 | return executed_selections 443 | 444 | def recursively_get_minimal_sets(self, initial_distortion, tau, possible_nodes, possible_features, 445 | recursion_depth=np.inf, save_initial_improve=False): 446 | 447 | self.logger.debug(" Possible features " + str(int(possible_features.sum()))) 448 | self.logger.debug(" Possible nodes " + str(int(possible_nodes.sum()))) 449 | 450 | # check maximal possible distortion with current possible nodes and features 451 | reachable_distortion = self.distortion( 452 | node_mask=possible_nodes, 453 | feature_mask=possible_features, 454 | ) 455 | self.logger.debug("Maximal reachable distortion in this path " + str(reachable_distortion)) 456 | if reachable_distortion <= 1 - tau: 457 | return None 458 | 459 | if recursion_depth == 0: 460 | return [(np.nan, np.nan, np.nan)] 461 | 462 | executed_selections = self._determine_minimal_set(initial_distortion, tau, possible_nodes, possible_features, 463 | save_initial_improve=save_initial_improve) 464 | 465 | minimal_nodes_and_features_sets = [ 466 | (self.selected_nodes.cpu().numpy(), 467 | self.selected_features.cpu().numpy(), 468 | executed_selections) 469 | ] 470 | 471 | self.logger.debug(" Explanation found") 472 | self.logger.debug(" Selected features " + str(int(minimal_nodes_and_features_sets[0][1].sum()))) 473 | self.logger.debug(" Selected nodes " + str(int(minimal_nodes_and_features_sets[0][0].sum()))) 474 | 475 | self.selected_nodes = torch.zeros((1, self.num_computation_graph_nodes), device=self.device) 476 | self.selected_features = torch.zeros((1, self.num_features), device=self.device) 477 | 478 | reduced_nodes = possible_nodes - torch.as_tensor(minimal_nodes_and_features_sets[0][0], device=self.device) 479 | reduced_features = possible_features - torch.as_tensor(minimal_nodes_and_features_sets[0][1], 480 | device=self.device) 481 | 482 | reduced_node_results = self.recursively_get_minimal_sets( 483 | initial_distortion, 484 | tau, 485 | reduced_nodes, 486 | possible_features, 487 | recursion_depth=recursion_depth - 1, 488 | save_initial_improve=False, 489 | ) 490 | if reduced_node_results is not None: 491 | minimal_nodes_and_features_sets.extend(reduced_node_results) 492 | 493 | self.selected_nodes = torch.zeros((1, self.num_computation_graph_nodes), device=self.device) 494 | self.selected_features = torch.zeros((1, self.num_features), device=self.device) 495 | 496 | reduced_feature_results = self.recursively_get_minimal_sets( 497 | initial_distortion, 498 | tau, 499 | possible_nodes, 500 | reduced_features, 501 | recursion_depth=recursion_depth - 1, 502 | save_initial_improve=False, 503 | ) 504 | if reduced_feature_results is not None: 505 | minimal_nodes_and_features_sets.extend(reduced_feature_results) 506 | 507 | return minimal_nodes_and_features_sets 508 | 509 | def explain_node(self, node_idx, full_feature_matrix, edge_index, tau=0.15, recursion_depth=np.inf, 510 | save_initial_improve=False): 511 | r"""Learns and returns a node feature mask and an edge mask that play a 512 | crucial role to explain the prediction made by the GNN for node 513 | :attr:`node_idx`. 514 | 515 | Args: 516 | node_idx (int): The node to explain. 517 | x (Tensor): The node feature matrix. 518 | edge_index (LongTensor): The edge indices. 519 | 520 | :rtype: (:class:`Tensor`, :class:`Tensor`) 521 | """ 522 | 523 | if save_initial_improve: 524 | self.initial_node_improve = [np.nan] 525 | self.initial_feature_improve = [np.nan] 526 | 527 | self.model.eval() 528 | 529 | if recursion_depth <= 0: 530 | self.logger.warning("Recursion depth not positve " + str(recursion_depth)) 531 | raise ValueError("Recursion depth not positve " + str(recursion_depth)) 532 | 533 | self.logger.info("------ Start explaining node " + str(node_idx)) 534 | self.logger.debug("Distortion drop (tau): " + str(tau)) 535 | self.logger.debug("Distortion samples: " + str(self.distortion_samples)) 536 | self.logger.debug("Greedy variant: " + str(self.greedy)) 537 | if self.greedy: 538 | self.logger.debug("Greediness: " + str(self.greediness)) 539 | self.logger.debug("Ensure improvement: " + str(self.ensure_improvement)) 540 | 541 | num_edges = edge_index.size(1) 542 | 543 | (num_nodes, self.num_features) = full_feature_matrix.size() 544 | 545 | self.full_feature_matrix = full_feature_matrix 546 | 547 | # Only operate on a k-hop subgraph around `node_idx`. 548 | neighbor_subset, self.computation_graph_feature_matrix, self.computation_graph_edge_index, mapping, hard_edge_mask, kwargs = \ 549 | self.__subgraph__(node_idx, full_feature_matrix, edge_index) 550 | 551 | if self.add_noise: 552 | self.full_feature_matrix = torch.cat( 553 | [self.full_feature_matrix, torch.zeros_like(self.full_feature_matrix)], 554 | dim=0) 555 | 556 | self.node_idx = mapping 557 | 558 | self.num_computation_graph_nodes = self.computation_graph_feature_matrix.size(0) 559 | 560 | # Get the initial prediction. 561 | with torch.no_grad(): 562 | log_logits = self.model(x=self.computation_graph_feature_matrix, 563 | edge_index=self.computation_graph_edge_index) 564 | predicted_labels = log_logits.argmax(dim=-1) 565 | 566 | self.predicted_label = predicted_labels[mapping] 567 | 568 | # self.__set_masks__(computation_graph_feature_matrix, edge_index) 569 | self.to(self.computation_graph_feature_matrix.device) 570 | 571 | if self.log: # pragma: no cover 572 | self.overall_progress_bar = tqdm(total=int(self.num_computation_graph_nodes * self.num_features), 573 | position=1) 574 | self.overall_progress_bar.set_description(f'Explain node {node_idx}') 575 | 576 | possible_nodes = torch.ones((1, self.num_computation_graph_nodes), device=self.device) 577 | possible_features = torch.ones((1, self.num_features), device=self.device) 578 | 579 | self.selected_nodes = torch.zeros((1, self.num_computation_graph_nodes), device=self.device) 580 | self.selected_features = torch.zeros((1, self.num_features), device=self.device) 581 | 582 | initial_distortion = self.distortion() 583 | 584 | # safe the unmasked distortion 585 | self.logger.debug("Initial distortion without any mask: " + str(initial_distortion)) 586 | 587 | if initial_distortion >= 1 - tau: 588 | # no mask needed, global distribution enough, see node 1861 in cora_GINConv 589 | self.logger.info("------ Finished explaining node " + str(node_idx)) 590 | self.logger.debug("# Explanations: Select any nodes and features") 591 | if save_initial_improve: 592 | return [ 593 | (self.selected_nodes.cpu().numpy(), 594 | self.selected_features.cpu().numpy(), 595 | [[np.nan, np.nan, initial_distortion], ] 596 | ) 597 | ], None, None 598 | else: 599 | return [ 600 | (self.selected_nodes.cpu().numpy(), 601 | self.selected_features.cpu().numpy(), 602 | [[np.nan, np.nan, initial_distortion], ] 603 | ) 604 | ] 605 | else: 606 | 607 | # if available load precomputed distortions 608 | if self.path_to_precomputed_distortions is not None: 609 | self.load_initial_distortion(node_idx, neighbor_subset) 610 | 611 | self.epoch = 1 612 | minimal_nodes_and_features_sets = self.recursively_get_minimal_sets( 613 | initial_distortion, 614 | tau, 615 | possible_nodes, 616 | possible_features, 617 | recursion_depth=recursion_depth, 618 | save_initial_improve=save_initial_improve, 619 | ) 620 | 621 | if self.log: # pragma: no cover 622 | self.overall_progress_bar.close() 623 | 624 | self.logger.info("------ Finished explaining node " + str(node_idx)) 625 | self.logger.debug("# Explanations: " + str(len(minimal_nodes_and_features_sets))) 626 | 627 | if save_initial_improve: 628 | return minimal_nodes_and_features_sets, self.initial_node_improve, self.initial_feature_improve 629 | else: 630 | return minimal_nodes_and_features_sets 631 | 632 | 633 | class SoftZorro(AbstractGraphExplainer): 634 | coeffs = { 635 | 'fidelity': 1, 636 | 'node_size': 0.01, 637 | 'node_ent': 0.1, 638 | 'feature_size': 0.01, 639 | 'feature_ent': 0.1, 640 | } 641 | 642 | def __init__(self, model, device, log=True, record_process_time=False, samples=100, learning_rate=0.01): 643 | super(SoftZorro, self).__init__( 644 | model=model, 645 | device=device, 646 | log=log, 647 | record_process_time=record_process_time, 648 | ) 649 | self.distortion_samples = samples 650 | self.learning_rate = learning_rate 651 | 652 | def loss(self, node, node_mask, feature_mask, full_feature_matrix, computation_graph_feature_matrix, 653 | computation_graph_edge_index, predicted_label, return_soft_distortion=False): 654 | loss = -distortion(self.model, node, 655 | node_mask=node_mask, 656 | feature_mask=feature_mask, 657 | full_feature_matrix=full_feature_matrix, 658 | computation_graph_feature_matrix=computation_graph_feature_matrix, 659 | edge_index=computation_graph_edge_index, 660 | samples=100, 661 | predicted_label=predicted_label, 662 | random_seed=None, 663 | soft_distortion=True, 664 | device=self.device) 665 | 666 | EPS = 1e-15 667 | 668 | if return_soft_distortion: 669 | soft_distortion = loss.clone().detach().cpu().numpy() 670 | # weight of soft fidelity 671 | loss = self.coeffs["fidelity"] * loss 672 | 673 | m = node_mask 674 | loss = loss + self.coeffs["node_size"] * m.sum() 675 | ent = -m * torch.log(m + EPS) - (1 - m) * torch.log(1 - m + EPS) 676 | loss = loss + self.coeffs["node_ent"] * ent.mean() 677 | 678 | m = feature_mask 679 | loss = loss + self.coeffs["feature_size"] * m.sum() 680 | ent = -m * torch.log(m + EPS) - (1 - m) * torch.log(1 - m + EPS) 681 | loss = loss + self.coeffs["feature_ent"] * ent.mean() 682 | 683 | if return_soft_distortion: 684 | return loss, soft_distortion 685 | else: 686 | return loss 687 | 688 | def explain_node(self, node_idx, full_feature_matrix, edge_index): 689 | (num_nodes, num_features) = full_feature_matrix.size() 690 | 691 | # Only operate on a k-hop subgraph around `node_idx`. 692 | neighbor_subset, computation_graph_feature_matrix, computation_graph_edge_index, mapping, hard_edge_mask, kwargs = self.__subgraph__( 693 | node_idx, full_feature_matrix, edge_index) 694 | 695 | self.model.eval() 696 | log_logits = self.model(x=computation_graph_feature_matrix, 697 | edge_index=computation_graph_edge_index) 698 | predicted_labels = log_logits.argmax(dim=-1) 699 | 700 | predicted_label = predicted_labels[mapping] 701 | 702 | num_computation_graph_nodes = computation_graph_feature_matrix.size(0) 703 | node_mask = torch.rand((1, num_computation_graph_nodes), device=self.device, requires_grad=True) 704 | feature_mask = torch.rand((1, num_features), device=self.device, requires_grad=True) 705 | 706 | optimizer = torch.optim.Adam([node_mask, feature_mask], lr=self.learning_rate) 707 | 708 | self.logger.info("------ Start explaining node " + str(node_idx)) 709 | loss, soft_distortion = self.loss(mapping, node_mask, feature_mask, full_feature_matrix, 710 | computation_graph_feature_matrix, 711 | computation_graph_edge_index, predicted_label, 712 | return_soft_distortion=True) 713 | self.logger.debug("Initial distortion: " + str(-soft_distortion[0])) 714 | self.logger.debug("Initial Loss: " + str(loss.detach().cpu().numpy()[0])) 715 | 716 | execution_time = time.time() 717 | 718 | epochs = 200 719 | for i in range(epochs): 720 | if epochs > i > 0 and i % 25 == 0: 721 | loss, soft_distortion = self.loss(mapping, node_mask, feature_mask, full_feature_matrix, 722 | computation_graph_feature_matrix, 723 | computation_graph_edge_index, predicted_label, 724 | return_soft_distortion=True) 725 | else: 726 | loss = self.loss(mapping, node_mask, feature_mask, full_feature_matrix, 727 | computation_graph_feature_matrix, 728 | computation_graph_edge_index, predicted_label) 729 | 730 | loss.backward() 731 | optimizer.step() 732 | with torch.no_grad(): 733 | node_mask.clamp_(min=0, max=1) 734 | feature_mask.clamp_(min=0, max=1) 735 | 736 | if epochs > i > 0 and i % 25 == 0: 737 | self.logger.debug("Epoch: " + str(i)) 738 | self.logger.debug("Distortion: " + str(-soft_distortion[0])) 739 | self.logger.debug("Loss: " + str(loss.detach().cpu().numpy()[0])) 740 | 741 | execution_time = time.time() - execution_time 742 | 743 | self.logger.info("------ Finished explaining node " + str(node_idx)) 744 | loss, soft_distortion = self.loss(mapping, node_mask, feature_mask, full_feature_matrix, 745 | computation_graph_feature_matrix, 746 | computation_graph_edge_index, predicted_label, 747 | return_soft_distortion=True) 748 | self.logger.debug("Final distortion: " + str(-soft_distortion[0])) 749 | self.logger.debug("Final Loss: " + str(loss.detach().cpu().numpy()[0])) 750 | 751 | numpy_node_mask = node_mask.clone().detach().cpu().numpy() 752 | numpy_feature_mask = feature_mask.clone().detach().cpu().numpy() 753 | self.logger.debug("Possible nodes: " + str((numpy_node_mask >= 0).sum())) 754 | self.logger.debug("Non zero nodes: " + str((numpy_node_mask > 0).sum())) 755 | self.logger.debug("Non zero features: " + str((numpy_feature_mask > 0).sum())) 756 | 757 | if self.record_process_time: 758 | return numpy_node_mask, numpy_feature_mask, execution_time 759 | else: 760 | return numpy_node_mask, numpy_feature_mask 761 | 762 | 763 | def save_soft_mask(save_path, node, node_mask, feature_mask, execution_time=np.inf): 764 | path = save_path 765 | 766 | numpy_dict = { 767 | "node": np.array(node), 768 | "node_mask": node_mask, 769 | "feature_mask": feature_mask, 770 | "execution_time": np.array(execution_time) 771 | } 772 | np.savez_compressed(path, **numpy_dict) 773 | 774 | 775 | def load_soft_mask(path_prefix, node): 776 | path = path_prefix + "_node_" + str(node) + ".npz" 777 | 778 | save = np.load(path) 779 | node_mask = save["node_mask"] 780 | feature_mask = save["feature_mask"] 781 | execution_time = save["execution_time"] 782 | if execution_time is np.inf: 783 | return node_mask, feature_mask 784 | else: 785 | return node_mask, feature_mask, float(execution_time) 786 | 787 | 788 | def save_minimal_nodes_and_features_sets(save_path, node, minimal_nodes_and_features_sets, 789 | initial_node_improve=None, initial_feature_improve=None): 790 | path = save_path 791 | 792 | if minimal_nodes_and_features_sets is None: 793 | numpy_dict = { 794 | "node": np.array(node), 795 | "number_of_sets": np.array(0), 796 | } 797 | 798 | else: 799 | 800 | numpy_dict = { 801 | "node": np.array(node), 802 | "number_of_sets": np.array(len(minimal_nodes_and_features_sets)), 803 | } 804 | 805 | features_label = "features_" 806 | nodes_label = "nodes_" 807 | selection_label = "selection_" 808 | 809 | for i, (selected_nodes, selected_features, executed_selections) in enumerate(minimal_nodes_and_features_sets): 810 | numpy_dict[nodes_label + str(i)] = selected_nodes 811 | numpy_dict[features_label + str(i)] = selected_features 812 | numpy_dict[selection_label + str(i)] = np.array(executed_selections) 813 | 814 | if initial_node_improve is not None: 815 | numpy_dict["initial_node_improve"] = np.array(initial_node_improve) 816 | 817 | if initial_feature_improve is not None: 818 | numpy_dict["initial_feature_improve"] = np.array(initial_feature_improve) 819 | 820 | np.savez_compressed(path, **numpy_dict) 821 | 822 | 823 | def load_minimal_nodes_and_features_sets(path_prefix, node, check_for_initial_improves=False): 824 | path = path_prefix + "_node_" + str(node) + ".npz" 825 | 826 | save = np.load(path, allow_pickle=False) 827 | 828 | saved_node = save["node"] 829 | if saved_node != node: 830 | raise ValueError("Other node then specified", saved_node, node) 831 | number_of_sets = save["number_of_sets"] 832 | 833 | minimal_nodes_and_features_sets = [] 834 | 835 | if number_of_sets > 0: 836 | 837 | features_label = "features_" 838 | nodes_label = "nodes_" 839 | selection_label = "selection_" 840 | 841 | for i in range(number_of_sets): 842 | selected_nodes = save[nodes_label + str(i)] 843 | selected_features = save[features_label + str(i)] 844 | executed_selections = save[selection_label + str(i)] 845 | 846 | minimal_nodes_and_features_sets.append((selected_nodes, selected_features, executed_selections)) 847 | 848 | if check_for_initial_improves: 849 | try: 850 | initial_node_improve = save["initial_node_improve"] 851 | except KeyError: 852 | initial_node_improve = None 853 | 854 | try: 855 | initial_feature_improve = save["initial_feature_improve"] 856 | except KeyError: 857 | initial_feature_improve = None 858 | 859 | return minimal_nodes_and_features_sets, initial_node_improve, initial_feature_improve 860 | else: 861 | return minimal_nodes_and_features_sets 862 | 863 | 864 | def distortion(model, node_idx=None, full_feature_matrix=None, computation_graph_feature_matrix=None, 865 | edge_index=None, node_mask=None, feature_mask=None, predicted_label=None, samples=None, 866 | random_seed=12345, device="cpu", validity=False, 867 | soft_distortion=False, detailed_mask=None, 868 | ): 869 | # conditional_samples=True only works for int feature matrix! 870 | 871 | (num_nodes, num_features) = full_feature_matrix.size() 872 | 873 | num_nodes_computation_graph = computation_graph_feature_matrix.size(0) 874 | 875 | # retrieve complete mask as matrix 876 | if detailed_mask is not None: 877 | mask = detailed_mask 878 | else: 879 | mask = node_mask.T.matmul(feature_mask) 880 | 881 | if validity: 882 | samples = 1 883 | full_feature_matrix = torch.zeros_like(full_feature_matrix) 884 | 885 | correct = 0.0 886 | 887 | rng = torch.Generator(device=device) 888 | if random_seed is not None: 889 | rng.manual_seed(random_seed) 890 | random_indices = torch.randint(num_nodes, (samples, num_nodes_computation_graph, num_features), 891 | generator=rng, 892 | device=device, 893 | ) 894 | random_indices = random_indices.type(torch.int64) 895 | 896 | for i in range(samples): 897 | random_features = torch.gather(full_feature_matrix, 898 | dim=0, 899 | index=random_indices[i, :, :]) 900 | 901 | randomized_features = mask * computation_graph_feature_matrix + (1 - mask) * random_features 902 | 903 | log_logits = model(x=randomized_features, edge_index=edge_index) 904 | if soft_distortion: 905 | correct += log_logits[node_idx].softmax(dim=-1).squeeze()[predicted_label] 906 | else: 907 | distorted_labels = log_logits.argmax(dim=-1) 908 | if distorted_labels[node_idx] == predicted_label: 909 | correct += 1 910 | return correct / samples 911 | 912 | 913 | def multi_node_distortion(model, 914 | nodes, 915 | full_feature_matrix, 916 | computation_graph_feature_matrix, 917 | computation_graph_edge_index, 918 | node_mask, 919 | feature_mask, 920 | predicted_labels, 921 | samples=100, 922 | random_seed=12345, 923 | device="cpu", 924 | ): 925 | (num_nodes, num_features) = full_feature_matrix.size() 926 | 927 | num_nodes_computation_graph = computation_graph_feature_matrix.size(0) 928 | 929 | # retrieve complete mask as matrix 930 | mask = node_mask.T.matmul(feature_mask) 931 | 932 | correct = torch.zeros_like(predicted_labels) 933 | 934 | rng = torch.Generator(device=device) 935 | rng.manual_seed(random_seed) 936 | random_indices = torch.randint(num_nodes, (samples, num_nodes_computation_graph, num_features), 937 | generator=rng, 938 | device=device, 939 | ) 940 | random_indices = random_indices.type(torch.int64) 941 | 942 | for i in range(samples): 943 | random_features = torch.gather(full_feature_matrix, 944 | dim=0, 945 | index=random_indices[i, :, :]) 946 | 947 | randomized_features = mask * computation_graph_feature_matrix + (1 - mask) * random_features 948 | 949 | log_logits = model(x=randomized_features, edge_index=computation_graph_edge_index) 950 | distorted_labels = log_logits.argmax(dim=-1) 951 | 952 | correct[predicted_labels.eq(distorted_labels[nodes])] += 1 953 | 954 | return correct * (1 / float(samples)) 955 | 956 | 957 | def multi_node_precompute_full_distortion(model, 958 | nodes, 959 | full_feature_matrix, 960 | full_edge_index, 961 | save_path, 962 | samples=100, 963 | random_seed=12345, 964 | device="cpu", 965 | ): 966 | # get basic attributes: num_hops, flow 967 | num_hops = Zorro.num_hops(model) 968 | flow = Zorro.flow(model) 969 | 970 | (num_nodes, num_features) = full_feature_matrix.size() 971 | 972 | subset, computation_graph_edge_index, mapping, edge_mask = k_hop_subgraph(torch.tensor(nodes, device=device), 973 | num_hops, 974 | full_edge_index, 975 | relabel_nodes=True, 976 | num_nodes=num_nodes, flow=flow) 977 | 978 | computation_graph_feature_matrix = full_feature_matrix[subset] 979 | 980 | num_nodes_computation_graph = computation_graph_feature_matrix.size(0) 981 | 982 | # calculate predicted labels 983 | log_logits = model(x=computation_graph_feature_matrix, 984 | edge_index=computation_graph_edge_index) 985 | predicted_labels = log_logits.argmax(dim=-1) 986 | predicted_labels = predicted_labels[mapping] 987 | 988 | # calculate initial distortion 989 | node_mask = torch.zeros((1, num_nodes_computation_graph), device=device) 990 | feature_mask = torch.zeros((1, num_features), device=device) 991 | initial_distortion = multi_node_distortion(model, 992 | mapping, 993 | full_feature_matrix, 994 | computation_graph_feature_matrix, 995 | computation_graph_edge_index, 996 | node_mask, 997 | feature_mask, 998 | predicted_labels, 999 | samples=samples, 1000 | random_seed=random_seed, 1001 | device=device, 1002 | ) 1003 | 1004 | # calculate the improvement of features 1005 | feature_distortion = torch.zeros((num_features, len(nodes)), device=device) 1006 | node_mask = torch.ones_like(node_mask, device=device) 1007 | for i in tqdm(range(num_features)): 1008 | feature_mask[0, i] += 1 1009 | 1010 | feature_distortion[i] = multi_node_distortion(model, 1011 | mapping, 1012 | full_feature_matrix, 1013 | computation_graph_feature_matrix, 1014 | computation_graph_edge_index, 1015 | node_mask, 1016 | feature_mask, 1017 | predicted_labels, 1018 | samples=samples, 1019 | random_seed=random_seed, 1020 | device=device, 1021 | ) 1022 | feature_mask[0, i] -= 1 1023 | 1024 | # calculate the improvement of nodes 1025 | node_distortion = torch.zeros((num_nodes_computation_graph, len(nodes)), device=device) 1026 | 1027 | feature_mask = torch.ones_like(feature_mask, device=device) 1028 | node_mask = torch.zeros_like(node_mask, device=device) 1029 | for i in tqdm(range(num_nodes_computation_graph)): 1030 | node_mask[0, i] += 1 1031 | 1032 | node_distortion[i] = multi_node_distortion(model, 1033 | mapping, 1034 | full_feature_matrix, 1035 | computation_graph_feature_matrix, 1036 | computation_graph_edge_index, 1037 | node_mask, 1038 | feature_mask, 1039 | predicted_labels, 1040 | samples=samples, 1041 | random_seed=random_seed, 1042 | device=device, 1043 | ) 1044 | node_mask[0, i] -= 1 1045 | 1046 | np.savez_compressed(save_path, 1047 | **{ 1048 | "nodes": nodes, 1049 | "subset": subset.cpu().numpy(), 1050 | "mapping": mapping.cpu().numpy(), 1051 | "initial_distortion": initial_distortion.cpu().numpy(), 1052 | "feature_distortion": feature_distortion.cpu().numpy(), 1053 | "node_distortion": node_distortion.cpu().numpy(), 1054 | } 1055 | ) 1056 | 1057 | return subset, mapping, initial_distortion, feature_distortion, node_distortion 1058 | -------------------------------------------------------------------------------- /generate_gnnexplainer_dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import abc 3 | import math 4 | 5 | import networkx as nx 6 | 7 | """ based on https://github.com/RexYing/gnn-model-explainer/ """ 8 | 9 | 10 | class FeatureGen(metaclass=abc.ABCMeta): 11 | """Feature Generator base class.""" 12 | 13 | @abc.abstractmethod 14 | def gen_node_features(self, G): 15 | pass 16 | 17 | 18 | class ConstFeatureGen(FeatureGen): 19 | """Constant Feature class.""" 20 | 21 | def __init__(self, val): 22 | self.val = val 23 | 24 | def gen_node_features(self, G): 25 | feat_dict = {i: {'x': np.array([self.val], dtype=np.float32)} for i in G.nodes()} 26 | nx.set_node_attributes(G, feat_dict) 27 | 28 | 29 | class GaussianFeatureGen(FeatureGen): 30 | """Gaussian Feature class.""" 31 | 32 | def __init__(self, mu, sigma): 33 | self.mu = mu 34 | if sigma.ndim < 2: 35 | self.sigma = np.diag(sigma) 36 | else: 37 | self.sigma = sigma 38 | 39 | def gen_node_features(self, G): 40 | feat = np.random.multivariate_normal(self.mu, self.sigma, G.number_of_nodes()) 41 | # Normalize feature 42 | feat = (feat + np.max(np.abs(feat))) / np.max(np.abs(feat)) / 2 43 | feat_dict = { 44 | i: {"x": feat[i]} for i in range(feat.shape[0]) # changed feat to x 45 | } 46 | nx.set_node_attributes(G, feat_dict) 47 | 48 | 49 | def ba(start, width, role_start=0, m=5): 50 | """Builds a BA preferential attachment graph, with index of nodes starting at start 51 | and role_ids at role_start 52 | INPUT: 53 | ------------- 54 | start : starting index for the shape 55 | width : int size of the graph 56 | role_start : starting index for the roles 57 | OUTPUT: 58 | ------------- 59 | graph : a house shape graph, with ids beginning at start 60 | roles : list of the roles of the nodes (indexed starting at 61 | role_start) 62 | """ 63 | graph = nx.barabasi_albert_graph(width, m) 64 | graph.add_nodes_from(range(start, start + width)) 65 | nids = sorted(graph) 66 | mapping = {nid: start + i for i, nid in enumerate(nids)} 67 | graph = nx.relabel_nodes(graph, mapping) 68 | roles = [role_start for i in range(width)] 69 | return graph, roles 70 | 71 | 72 | def house(start, role_start=0): 73 | """Builds a house-like graph, with index of nodes starting at start 74 | and role_ids at role_start 75 | INPUT: 76 | ------------- 77 | start : starting index for the shape 78 | role_start : starting index for the roles 79 | OUTPUT: 80 | ------------- 81 | graph : a house shape graph, with ids beginning at start 82 | roles : list of the roles of the nodes (indexed starting at 83 | role_start) 84 | """ 85 | graph = nx.Graph() 86 | graph.add_nodes_from(range(start, start + 5)) 87 | graph.add_edges_from( 88 | [ 89 | (start, start + 1), 90 | (start + 1, start + 2), 91 | (start + 2, start + 3), 92 | (start + 3, start), 93 | ] 94 | ) 95 | # graph.add_edges_from([(start, start + 2), (start + 1, start + 3)]) 96 | graph.add_edges_from([(start + 4, start), (start + 4, start + 1)]) 97 | roles = [role_start, role_start, role_start + 1, role_start + 1, role_start + 2] 98 | return graph, roles 99 | 100 | 101 | def build_graph( 102 | width_basis, 103 | basis_type, 104 | list_shapes, 105 | start=0, 106 | rdm_basis_plugins=False, 107 | add_random_edges=0, 108 | m=5, 109 | ): 110 | """This function creates a basis (scale-free, path, or cycle) 111 | and attaches elements of the type in the list randomly along the basis. 112 | Possibility to add random edges afterwards. 113 | INPUT: 114 | -------------------------------------------------------------------------------------- 115 | width_basis : width (in terms of number of nodes) of the basis 116 | basis_type : (torus, string, or cycle) 117 | shapes : list of shape list (1st arg: type of shape, 118 | next args:args for building the shape, 119 | except for the start) 120 | start : initial nb for the first node 121 | rdm_basis_plugins: boolean. Should the shapes be randomly placed 122 | along the basis (True) or regularly (False)? 123 | add_random_edges : nb of edges to randomly add on the structure 124 | m : number of edges to attach to existing node (for BA graph) 125 | OUTPUT: 126 | -------------------------------------------------------------------------------------- 127 | basis : a nx graph with the particular shape 128 | role_ids : labels for each role 129 | plugins : node ids with the attached shapes 130 | """ 131 | if basis_type == "ba": 132 | basis, role_id = ba(start, width_basis, m=m) 133 | # else: 134 | # basis, role_id = eval(basis_type)(start, width_basis) 135 | 136 | n_basis, n_shapes = nx.number_of_nodes(basis), len(list_shapes) 137 | start += n_basis # indicator of the id of the next node 138 | 139 | # Sample (with replacement) where to attach the new motifs 140 | if rdm_basis_plugins is True: 141 | plugins = np.random.choice(n_basis, n_shapes, replace=False) 142 | else: 143 | spacing = math.floor(n_basis / n_shapes) 144 | plugins = [int(k * spacing) for k in range(n_shapes)] 145 | seen_shapes = {"basis": [0, n_basis]} 146 | 147 | for shape_id, shape in enumerate(list_shapes): 148 | shape_type = shape[0] 149 | args = [start] 150 | if len(shape) > 1: 151 | args += shape[1:] 152 | args += [0] 153 | if shape_type == "house": 154 | graph_s, roles_graph_s = house(*args) 155 | else: 156 | raise Exception() 157 | # graph_s, roles_graph_s = eval(shape_type)(*args) 158 | n_s = nx.number_of_nodes(graph_s) 159 | try: 160 | col_start = seen_shapes[shape_type][0] 161 | except: 162 | col_start = np.max(role_id) + 1 163 | seen_shapes[shape_type] = [col_start, n_s] 164 | # Attach the shape to the basis 165 | basis.add_nodes_from(graph_s.nodes()) 166 | basis.add_edges_from(graph_s.edges()) 167 | basis.add_edges_from([(start, plugins[shape_id])]) 168 | if shape_type == "cycle": 169 | if np.random.random() > 0.5: 170 | a = np.random.randint(1, 4) 171 | b = np.random.randint(1, 4) 172 | basis.add_edges_from([(a + start, b + plugins[shape_id])]) 173 | temp_labels = [r + col_start for r in roles_graph_s] 174 | # temp_labels[0] += 100 * seen_shapes[shape_type][0] 175 | role_id += temp_labels 176 | start += n_s 177 | 178 | if add_random_edges > 0: 179 | # add random edges between nodes: 180 | for p in range(add_random_edges): 181 | src, dest = np.random.choice(nx.number_of_nodes(basis), 2, replace=False) 182 | print(src, dest) 183 | basis.add_edges_from([(src, dest)]) 184 | 185 | return basis, role_id, plugins 186 | 187 | 188 | def perturb(graph_list, p): 189 | """ Perturb the list of (sparse) graphs by adding/removing edges. 190 | Args: 191 | p: proportion of added edges based on current number of edges. 192 | Returns: 193 | A list of graphs that are perturbed from the original graphs. 194 | """ 195 | perturbed_graph_list = [] 196 | for G_original in graph_list: 197 | G = G_original.copy() 198 | edge_count = int(G.number_of_edges() * p) 199 | # randomly add the edges between a pair of nodes without an edge. 200 | for _ in range(edge_count): 201 | while True: 202 | u = np.random.randint(0, G.number_of_nodes()) 203 | v = np.random.randint(0, G.number_of_nodes()) 204 | if (not G.has_edge(u, v)) and (u != v): 205 | break 206 | G.add_edge(u, v) 207 | perturbed_graph_list.append(G) 208 | return perturbed_graph_list 209 | 210 | 211 | def gen_syn1(nb_shapes=80, width_basis=300, feature_generator=None, m=5): 212 | """ Synthetic Graph #1: 213 | Start with Barabasi-Albert graph and attach house-shaped subgraphs. 214 | Args: 215 | nb_shapes : The number of shapes (here 'houses') that should be added to the base graph. 216 | width_basis : The width of the basis graph (here 'Barabasi-Albert' random graph). 217 | feature_generator : A `FeatureGenerator` for node features. If `None`, add constant features to nodes. 218 | m : number of edges to attach to existing node (for BA graph) 219 | Returns: 220 | G : A networkx graph 221 | role_id : A list with length equal to number of nodes in the entire graph (basis 222 | : + shapes). role_id[i] is the ID of the role of node i. It is the label. 223 | name : A graph identifier 224 | """ 225 | basis_type = "ba" 226 | list_shapes = [["house"]] * nb_shapes 227 | 228 | # plt.figure(figsize=(8, 6), dpi=300) 229 | 230 | G, role_id, _ = build_graph( 231 | width_basis, basis_type, list_shapes, start=0, m=5 232 | ) 233 | G = perturb([G], 0.01)[0] 234 | 235 | if feature_generator is None: 236 | feature_generator = ConstFeatureGen(1) 237 | feature_generator.gen_node_features(G) 238 | 239 | name = basis_type + "_" + str(width_basis) + "_" + str(nb_shapes) 240 | return G, role_id, name 241 | 242 | 243 | def gen_syn2(nb_shapes=100, width_basis=350): 244 | """ Synthetic Graph #2: 245 | Start with Barabasi-Albert graph and add node features indicative of a community label. 246 | Args: 247 | nb_shapes : The number of shapes (here 'houses') that should be added to the base graph. 248 | width_basis : The width of the basis graph (here 'Barabasi-Albert' random graph). 249 | Returns: 250 | G : A networkx graph 251 | label : Label of the nodes (determined by role_id and community) 252 | name : A graph identifier 253 | """ 254 | basis_type = "ba" 255 | 256 | random_mu = [0.0] * 8 257 | random_sigma = [1.0] * 8 258 | 259 | # Create two grids 260 | mu_1, sigma_1 = np.array([-1.0] * 2 + random_mu), np.array([0.5] * 2 + random_sigma) 261 | mu_2, sigma_2 = np.array([1.0] * 2 + random_mu), np.array([0.5] * 2 + random_sigma) 262 | feat_gen_G1 = GaussianFeatureGen(mu=mu_1, sigma=sigma_1) 263 | feat_gen_G2 = GaussianFeatureGen(mu=mu_2, sigma=sigma_2) 264 | G1, role_id1, name = gen_syn1(feature_generator=feat_gen_G1, m=4) 265 | G2, role_id2, name = gen_syn1(feature_generator=feat_gen_G2, m=4) 266 | 267 | # return G1, role_id1, G2, role_id2 268 | G1_size = G1.number_of_nodes() 269 | num_roles = max(role_id1) + 1 270 | role_id2 = [r + num_roles for r in role_id2] 271 | label = role_id1 + role_id2 272 | 273 | # Edit node ids to avoid collisions on join 274 | g1_map = {n: i for i, n in enumerate(G1.nodes())} 275 | G1 = nx.relabel_nodes(G1, g1_map) 276 | g2_map = {n: i + G1_size for i, n in enumerate(G2.nodes())} 277 | G2 = nx.relabel_nodes(G2, g2_map) 278 | 279 | # Join 280 | n_pert_edges = width_basis 281 | G = join_graph(G1, G2, n_pert_edges) 282 | 283 | name = basis_type + "_" + str(width_basis) + "_" + str(nb_shapes) + "_2comm" 284 | 285 | return G, label, name 286 | 287 | 288 | def join_graph(G1, G2, n_pert_edges): 289 | """ Join two graphs along matching nodes, then perturb the resulting graph. 290 | Args: 291 | G1, G2: Networkx graphs to be joined. 292 | n_pert_edges: number of perturbed edges. 293 | Returns: 294 | A new graph, result of merging and perturbing G1 and G2. 295 | """ 296 | assert n_pert_edges > 0 297 | F = nx.compose(G1, G2) 298 | edge_cnt = 0 299 | while edge_cnt < n_pert_edges: 300 | node_1 = np.random.choice(G1.nodes()) 301 | node_2 = np.random.choice(G2.nodes()) 302 | F.add_edge(node_1, node_2) 303 | edge_cnt += 1 304 | return F 305 | 306 | 307 | def preprocess_input_graph(G, labels, normalize_adj=False): 308 | """ Load an existing graph to be converted for the experiments. 309 | Args: 310 | G: Networkx graph to be loaded. 311 | labels: Associated node labels. 312 | normalize_adj: Should the method return a normalized adjacency matrix. 313 | Returns: 314 | A dictionary containing adjacency, node features and labels 315 | """ 316 | adj = np.array(nx.to_numpy_matrix(G)) 317 | if normalize_adj: 318 | sqrt_deg = np.diag(1.0 / np.sqrt(np.sum(adj, axis=0, dtype=float).squeeze())) 319 | adj = np.matmul(np.matmul(sqrt_deg, adj), sqrt_deg) 320 | 321 | existing_node = list(G.nodes)[-1] 322 | feat_dim = G.nodes[existing_node]["feat"].shape[0] 323 | f = np.zeros((G.number_of_nodes(), feat_dim), dtype=float) 324 | for i, u in enumerate(G.nodes()): 325 | f[i, :] = G.nodes[u]["feat"] 326 | 327 | # add batch dim 328 | adj = np.expand_dims(adj, axis=0) 329 | f = np.expand_dims(f, axis=0) 330 | labels = np.expand_dims(labels, axis=0) 331 | return {"adj": adj, "feat": f, "labels": labels} 332 | 333 | 334 | if __name__ == "__main__": 335 | g1, label_1, g2, label_2 = gen_syn2() 336 | -------------------------------------------------------------------------------- /gnn_explainer.py: -------------------------------------------------------------------------------- 1 | from math import sqrt 2 | 3 | import torch 4 | from tqdm import tqdm 5 | import matplotlib.pyplot as plt 6 | import networkx as nx 7 | from torch_geometric.nn import MessagePassing 8 | from torch_geometric.data import Data 9 | from torch_geometric.utils import k_hop_subgraph, to_networkx 10 | from torch_geometric.nn import APPNP 11 | 12 | from explainer import distortion 13 | 14 | EPS = 1e-15 15 | 16 | 17 | class GNNExplainer(torch.nn.Module): 18 | r"""The GNN-Explainer model from the `"GNNExplainer: Generating 19 | Explanations for Graph Neural Networks" 20 | `_ paper for identifying compact subgraph 21 | structures and small subsets node features that play a crucial role in a 22 | GNN’s node-predictions. 23 | 24 | .. note:: 25 | 26 | For an example of using GNN-Explainer, see `examples/gnn_explainer.py 27 | `_. 29 | 30 | Args: 31 | model (torch.nn.Module): The GNN module to explain. 32 | epochs (int, optional): The number of epochs to train. 33 | (default: :obj:`100`) 34 | lr (float, optional): The learning rate to apply. 35 | (default: :obj:`0.01`) 36 | log (bool, optional): If set to :obj:`False`, will not log any learning 37 | progress. (default: :obj:`True`) 38 | """ 39 | 40 | coeffs = { 41 | 'edge_size': 0.005, 42 | 'node_feat_size': 1.0, 43 | 'edge_ent': 1.0, 44 | 'node_feat_ent': 0.1, 45 | } 46 | 47 | def __init__(self, model, epochs=100, lr=0.01, log=True): 48 | super(GNNExplainer, self).__init__() 49 | self.model = model 50 | self.epochs = epochs 51 | self.lr = lr 52 | self.log = log 53 | 54 | def __set_masks__(self, x, edge_index, init="normal"): 55 | (N, F), E = x.size(), edge_index.size(1) 56 | 57 | std = 0.1 58 | self.node_feat_mask = torch.nn.Parameter(torch.randn(F) * 0.1) 59 | 60 | std = torch.nn.init.calculate_gain('relu') * sqrt(2.0 / (2 * N)) 61 | self.edge_mask = torch.nn.Parameter(torch.randn(E) * std) 62 | 63 | for module in self.model.modules(): 64 | if isinstance(module, MessagePassing): 65 | module.__explain__ = True 66 | module.__edge_mask__ = self.edge_mask 67 | 68 | def __clear_masks__(self): 69 | for module in self.model.modules(): 70 | if isinstance(module, MessagePassing): 71 | module.__explain__ = False 72 | module.__edge_mask__ = None 73 | self.node_feat_masks = None 74 | self.edge_mask = None 75 | 76 | def __num_hops__(self): 77 | num_hops = 0 78 | for module in self.model.modules(): 79 | if isinstance(module, MessagePassing): 80 | if isinstance(module, APPNP): 81 | num_hops += module.K 82 | else: 83 | num_hops += 1 84 | return num_hops 85 | 86 | def __flow__(self): 87 | for module in self.model.modules(): 88 | if isinstance(module, MessagePassing): 89 | return module.flow 90 | return 'source_to_target' 91 | 92 | def __subgraph__(self, node_idx, x, edge_index, **kwargs): 93 | num_nodes, num_edges = x.size(0), edge_index.size(1) 94 | 95 | subset, edge_index, mapping, edge_mask = k_hop_subgraph( 96 | node_idx, self.__num_hops__(), edge_index, relabel_nodes=True, 97 | num_nodes=num_nodes, flow=self.__flow__()) 98 | 99 | x = x[subset] 100 | for key, item in kwargs: 101 | if torch.is_tensor(item) and item.size(0) == num_nodes: 102 | item = item[subset] 103 | elif torch.is_tensor(item) and item.size(0) == num_edges: 104 | item = item[edge_mask] 105 | kwargs[key] = item 106 | 107 | return x, edge_index, mapping, edge_mask, kwargs 108 | 109 | def __loss__(self, node_idx, log_logits, pred_label): 110 | loss = -log_logits[node_idx, pred_label[node_idx]] 111 | 112 | m = self.edge_mask.sigmoid() 113 | loss = loss + self.coeffs['edge_size'] * m.sum() 114 | ent = -m * torch.log(m + EPS) - (1 - m) * torch.log(1 - m + EPS) 115 | loss = loss + self.coeffs['edge_ent'] * ent.mean() 116 | 117 | m = self.node_feat_mask.sigmoid() 118 | loss = loss + self.coeffs['node_feat_size'] * m.sum() 119 | ent = -m * torch.log(m + EPS) - (1 - m) * torch.log(1 - m + EPS) 120 | loss = loss + self.coeffs['node_feat_ent'] * ent.mean() 121 | 122 | return loss 123 | 124 | def explain_node(self, node_idx, x, edge_index, **kwargs): 125 | r"""Learns and returns a node feature mask and an edge mask that play a 126 | crucial role to explain the prediction made by the GNN for node 127 | :attr:`node_idx`. 128 | 129 | Args: 130 | node_idx (int): The node to explain. 131 | x (Tensor): The node feature matrix. 132 | edge_index (LongTensor): The edge indices. 133 | **kwargs (optional): Additional arguments passed to the GNN module. 134 | 135 | :rtype: (:class:`Tensor`, :class:`Tensor`) 136 | """ 137 | 138 | self.model.eval() 139 | self.__clear_masks__() 140 | 141 | num_edges = edge_index.size(1) 142 | 143 | # Only operate on a k-hop subgraph around `node_idx`. 144 | x, edge_index, mapping, hard_edge_mask, kwargs = self.__subgraph__( 145 | node_idx, x, edge_index, **kwargs) 146 | 147 | # Get the initial prediction. 148 | with torch.no_grad(): 149 | log_logits = self.model(x=x, edge_index=edge_index, **kwargs) 150 | pred_label = log_logits.argmax(dim=-1) 151 | 152 | self.__set_masks__(x, edge_index) 153 | self.to(x.device) 154 | 155 | optimizer = torch.optim.Adam([self.node_feat_mask, self.edge_mask], 156 | lr=self.lr) 157 | 158 | if self.log: # pragma: no cover 159 | pbar = tqdm(total=self.epochs) 160 | pbar.set_description(f'Explain node {node_idx}') 161 | 162 | for epoch in range(1, self.epochs + 1): 163 | optimizer.zero_grad() 164 | h = x * self.node_feat_mask.view(1, -1).sigmoid() 165 | log_logits = self.model(x=h, edge_index=edge_index, **kwargs) 166 | loss = self.__loss__(mapping, log_logits, pred_label) 167 | loss.backward() 168 | optimizer.step() 169 | 170 | if self.log: # pragma: no cover 171 | pbar.update(1) 172 | 173 | if self.log: # pragma: no cover 174 | pbar.close() 175 | 176 | node_feat_mask = self.node_feat_mask.detach().sigmoid() 177 | edge_mask = self.edge_mask.new_zeros(num_edges) 178 | edge_mask[hard_edge_mask] = self.edge_mask.detach().sigmoid() 179 | 180 | self.__clear_masks__() 181 | 182 | return node_feat_mask, edge_mask 183 | 184 | def distortion(self, node_idx, full_feature_matrix, edge_index, feature_mask, edge_mask=None, node_mask=None, 185 | validity=False): 186 | 187 | computation_graph_feature_matrix, computation_graph_edge_index, mapping, hard_edge_mask, kwargs = \ 188 | self.__subgraph__(node_idx, full_feature_matrix, edge_index) 189 | 190 | self.__clear_masks__() 191 | 192 | log_logits = self.model(x=computation_graph_feature_matrix, 193 | edge_index=computation_graph_edge_index) 194 | predicted_labels = log_logits.argmax(dim=-1) 195 | 196 | predicted_label = predicted_labels[mapping] 197 | 198 | # set edge mask 199 | if edge_mask is not None: 200 | param_edge_mask = torch.nn.Parameter(edge_mask[hard_edge_mask]) 201 | for module in self.model.modules(): 202 | if isinstance(module, MessagePassing): 203 | module.__explain__ = True 204 | module.__edge_mask__ = param_edge_mask 205 | 206 | with torch.no_grad(): 207 | num_computation_graph_nodes = computation_graph_feature_matrix.size(0) 208 | if node_mask is None: 209 | # all nodes selected 210 | node_mask = torch.ones((1, num_computation_graph_nodes)) 211 | else: 212 | node_mask = node_mask 213 | 214 | value = distortion(self.model, 215 | node_idx=mapping, 216 | full_feature_matrix=full_feature_matrix, 217 | computation_graph_feature_matrix=computation_graph_feature_matrix, 218 | edge_index=computation_graph_edge_index, 219 | node_mask=node_mask, 220 | feature_mask=feature_mask, 221 | predicted_label=predicted_label, 222 | samples=100, 223 | random_seed=12345, 224 | validity=validity, 225 | # device="cpu", 226 | ) 227 | 228 | self.__clear_masks__() 229 | 230 | return value 231 | -------------------------------------------------------------------------------- /grad_explainer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch_geometric.nn import MessagePassing 4 | 5 | 6 | def execute_model_with_gradient(model, node, x, edge_index): 7 | ypred = model(x, edge_index) 8 | 9 | predicted_labels = ypred.argmax(dim=-1) 10 | predicted_label = predicted_labels[node] 11 | logit = torch.nn.functional.softmax((ypred[node, :]).squeeze(), dim=0) 12 | 13 | logit = logit[predicted_label] 14 | loss = -torch.log(logit) 15 | loss.backward() 16 | 17 | 18 | def grad_edge_explanation(model, node, x, edge_index): 19 | model.zero_grad() 20 | 21 | E = edge_index.size(1) 22 | edge_mask = torch.nn.Parameter(torch.ones(E)) 23 | 24 | for module in model.modules(): 25 | if isinstance(module, MessagePassing): 26 | module.__explain__ = True 27 | module.__edge_mask__ = edge_mask 28 | 29 | edge_mask.requires_grad = True 30 | x.requires_grad = True 31 | 32 | if edge_mask.grad is not None: 33 | edge_mask.grad.zero_() 34 | if x.grad is not None: 35 | x.grad.zero_() 36 | 37 | execute_model_with_gradient(model, node, x, edge_index) 38 | 39 | adj_grad = edge_mask.grad 40 | adj_grad = torch.abs(adj_grad) 41 | masked_adj = adj_grad + adj_grad.t() 42 | masked_adj = torch.sigmoid(masked_adj) 43 | masked_adj = masked_adj.cpu().detach().numpy() 44 | 45 | feature_mask = torch.abs(x.grad).cpu().detach().numpy() 46 | 47 | return np.max(feature_mask, axis=0), masked_adj 48 | 49 | 50 | def grad_node_explanation(model, node, x, edge_index): 51 | model.zero_grad() 52 | 53 | num_nodes, num_features = x.size() 54 | 55 | node_grad = torch.nn.Parameter(torch.ones(num_nodes)) 56 | feature_grad = torch.nn.Parameter(torch.ones(num_features)) 57 | 58 | node_grad.requires_grad = True 59 | feature_grad.requires_grad = True 60 | 61 | mask = node_grad.unsqueeze(0).T.matmul(feature_grad.unsqueeze(0)) 62 | 63 | execute_model_with_gradient(model, node, mask*x, edge_index) 64 | 65 | node_mask = torch.abs(node_grad.grad).cpu().detach().numpy() 66 | feature_mask = torch.abs(feature_grad.grad).cpu().detach().numpy() 67 | 68 | return feature_mask, node_mask 69 | 70 | 71 | def gradinput_node_explanation(model, node, x, edge_index): 72 | model.zero_grad() 73 | 74 | x.requires_grad = True 75 | if x.grad is not None: 76 | x.grad.zero_() 77 | 78 | execute_model_with_gradient(model, node, x, edge_index) 79 | 80 | feature_mask = torch.abs(x.grad * x).cpu().detach().numpy() 81 | 82 | return np.mean(feature_mask, axis=0), np.mean(feature_mask, axis=1) 83 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch.nn import Sequential, Linear, ReLU 4 | import torch_geometric 5 | from torch_geometric.nn import GCNConv 6 | from torch_geometric.nn import GATConv 7 | from torch_geometric.nn import APPNP 8 | from torch_geometric.nn import GINConv 9 | from torch_geometric.datasets import Planetoid 10 | from torch_geometric.utils import from_networkx 11 | from pathlib import Path 12 | import numpy as np 13 | 14 | SYN2_PATH = "data/syn2.npz" 15 | 16 | 17 | def load_dataset(data_set, working_directory=None): 18 | if working_directory is None: 19 | working_directory = Path(".").resolve() 20 | if data_set == "Cora": 21 | dataset = Planetoid(root=working_directory.joinpath('tmp/Cora'), name='Cora') 22 | data = dataset[0] 23 | results_path = "cora" 24 | elif data_set == "CiteSeer": 25 | dataset = Planetoid(root=working_directory.joinpath('tmp/CiteSeer'), name='CiteSeer') 26 | data = dataset[0] 27 | results_path = "citeseer" 28 | elif data_set == "PubMed": 29 | dataset = Planetoid(root=working_directory.joinpath('tmp/PubMed'), name='PubMed') 30 | data = dataset[0] 31 | results_path = "pubmed" 32 | elif data_set[:4] == "syn2": 33 | try: 34 | save_data = np.load(working_directory.joinpath(SYN2_PATH)) 35 | except FileNotFoundError: 36 | save_data = create_syn(data_set) 37 | 38 | transformed_data = {} 39 | for name in save_data: 40 | transformed_data[name] = torch.tensor(save_data[name]) 41 | data = torch_geometric.data.Data.from_dict(transformed_data) 42 | 43 | results_path = data_set 44 | from collections import namedtuple 45 | Dataset = namedtuple("Dataset", "num_node_features num_classes") 46 | dataset = Dataset(10, max(data.y.numpy()) + 1) 47 | 48 | else: 49 | raise ValueError("Dataset " + data_set + "not implemented") 50 | 51 | return dataset, data, results_path 52 | 53 | 54 | def create_syn(dataset_name="syn2"): 55 | import generate_gnnexplainer_dataset as gn 56 | if dataset_name == "syn2": 57 | g, labels, name = gn.gen_syn2() 58 | elif dataset_name == "syn1": 59 | g, labels, name = gn.gen_syn1() 60 | else: 61 | raise NotImplementedError("Dataset not known") 62 | 63 | data = from_networkx(g) 64 | 65 | edge_index = data.edge_index.numpy() 66 | x = data.x.numpy().astype(np.float32) 67 | y = np.array(labels) 68 | 69 | train_ratio = 0.8 70 | 71 | num_nodes = x.shape[0] 72 | num_train = int(num_nodes * train_ratio) 73 | idx = [i for i in range(num_nodes)] 74 | 75 | np.random.shuffle(idx) 76 | train_mask = np.full_like(y, False, dtype=bool) 77 | train_mask[idx[:num_train]] = True 78 | test_mask = np.full_like(y, False, dtype=bool) 79 | test_mask[idx[num_train:]] = True 80 | 81 | save_data = {"edge_index": edge_index, 82 | "x": x, 83 | "y": y, 84 | "train_mask": train_mask, 85 | "test_mask": test_mask, 86 | "num_nodes": g.number_of_nodes() 87 | } 88 | 89 | if dataset_name == "syn2": 90 | np.savez_compressed(SYN2_PATH, **save_data) 91 | elif dataset_name == "syn1": 92 | np.savez_compressed(SYN1_PATH, **save_data) 93 | return save_data 94 | 95 | 96 | # a slight adoption of the method of Planetoid 97 | def create_train_val_test_mask(data, num_train_per_class=20, num_classes=None, num_val=500, num_test=1000, ): 98 | import numpy as np 99 | # fix seed for selecting train_mask 100 | rng = np.random.RandomState(seed=42 * 20200909) 101 | 102 | if num_classes is None: 103 | num_classes = torch.max(data.y) 104 | 105 | train_mask = torch.full_like(data.y, False, dtype=torch.bool) 106 | for c in range(num_classes): 107 | idx = (data.y == c).nonzero().view(-1) 108 | idx = idx[rng.permutation(idx.size(0))[:num_train_per_class]] 109 | train_mask[idx] = True 110 | 111 | remaining = (~train_mask).nonzero().view(-1) 112 | remaining = remaining[rng.permutation(remaining.size(0))] 113 | 114 | val_mask = torch.full_like(data.y, False, dtype=torch.bool) 115 | val_mask[remaining[:num_val]] = True 116 | 117 | test_mask = torch.full_like(data.y, False, dtype=torch.bool) 118 | test_mask[remaining[num_val:num_val + num_test]] = True 119 | 120 | return train_mask, val_mask, test_mask 121 | 122 | 123 | class GCNNet(torch.nn.Module): 124 | def __init__(self, dataset): 125 | super(GCNNet, self).__init__() 126 | self.conv1 = GCNConv(dataset.num_node_features, 16) 127 | self.conv2 = GCNConv(16, dataset.num_classes) 128 | 129 | def forward(self, x, edge_index): 130 | x = self.conv1(x, edge_index) 131 | x = F.relu(x) 132 | x = F.dropout(x, training=self.training) 133 | x = self.conv2(x, edge_index) 134 | 135 | return F.log_softmax(x, dim=1) 136 | 137 | 138 | class GCN_syn2(torch.nn.Module): 139 | # only for syn2 140 | def __init__(self, dataset): 141 | super(GCN_syn2, self).__init__() 142 | hidden_dim = 20 143 | self.conv1 = GCNConv(dataset.num_node_features, hidden_dim, add_self_loops=False) 144 | self.conv2 = GCNConv(hidden_dim, hidden_dim, add_self_loops=False) 145 | self.conv3 = GCNConv(hidden_dim, hidden_dim, add_self_loops=False) 146 | self.lin_pred = Linear(3 * hidden_dim, dataset.num_classes) 147 | 148 | def forward(self, x, edge_index): 149 | x = self.conv1(x, edge_index) 150 | x = F.relu(x) 151 | x_all = [x] 152 | x = self.conv2(x, edge_index) 153 | x = F.relu(x) 154 | x_all.append(x) 155 | x = self.conv3(x, edge_index) 156 | x = F.relu(x) 157 | x_all.append(x) 158 | 159 | x = torch.cat(x_all, dim=1) 160 | x = self.lin_pred(x) 161 | 162 | return F.log_softmax(x, dim=1) 163 | 164 | 165 | 166 | class GATNet(torch.nn.Module): 167 | # based on https://github.com/rusty1s/pytorch_geometric/blob/master/examples/gat.py 168 | def __init__(self, dataset): 169 | super(GATNet, self).__init__() 170 | self.conv1 = GATConv(dataset.num_features, 8, heads=8, dropout=0.6) 171 | # On the Pubmed dataset, use heads=8 in conv2. 172 | self.conv2 = GATConv(8 * 8, dataset.num_classes, heads=1, concat=False, 173 | dropout=0.6) 174 | 175 | def forward(self, x, edge_index): 176 | x = F.dropout(x, p=0.6, training=self.training) 177 | x = F.elu(self.conv1(x, edge_index)) 178 | x = F.dropout(x, p=0.6, training=self.training) 179 | x = self.conv2(x, edge_index) 180 | return F.log_softmax(x, dim=1) 181 | 182 | 183 | 184 | class APPNP2Net(torch.nn.Module): 185 | def __init__(self, dataset): 186 | super(APPNP2Net, self).__init__() 187 | # default values from https://github.com/rusty1s/pytorch_geometric/blob/master/benchmark/citation/appnp.py 188 | self.dropout = 0.5 189 | self.hidden = 64 190 | self.K = 2 # adjusted to two layers 191 | self.alpha = 0.1 192 | self.lin1 = Linear(dataset.num_features, self.hidden) 193 | self.lin2 = Linear(self.hidden, dataset.num_classes) 194 | self.prop1 = APPNP(self.K, self.alpha) 195 | 196 | def reset_parameters(self): 197 | self.lin1.reset_parameters() 198 | self.lin2.reset_parameters() 199 | 200 | def forward(self, x, edge_index): 201 | x = F.dropout(x, p=self.dropout, training=self.training) 202 | x = F.relu(self.lin1(x)) 203 | x = F.dropout(x, p=self.dropout, training=self.training) 204 | x = self.lin2(x) 205 | x = self.prop1(x, edge_index) 206 | return F.log_softmax(x, dim=1) 207 | 208 | 209 | class GINConvNet(torch.nn.Module): 210 | def __init__(self, dataset): 211 | super(GINConvNet, self).__init__() 212 | 213 | num_features = dataset.num_features 214 | dim = 32 215 | 216 | nn1 = Sequential(Linear(num_features, dim), ReLU(), Linear(dim, dim)) 217 | self.conv1 = GINConv(nn1) 218 | self.bn1 = torch.nn.BatchNorm1d(dim) 219 | 220 | nn2 = Sequential(Linear(dim, dim), ReLU(), Linear(dim, dim)) 221 | self.conv2 = GINConv(nn2) 222 | self.bn2 = torch.nn.BatchNorm1d(dim) 223 | 224 | self.fc1 = Linear(dim, dim) 225 | self.fc2 = Linear(dim, dataset.num_classes) 226 | 227 | def forward(self, x, edge_index): 228 | x = F.relu(self.conv1(x, edge_index)) 229 | x = self.bn1(x) 230 | x = F.relu(self.conv2(x, edge_index)) 231 | x = self.bn2(x) 232 | x = F.relu(self.fc1(x)) 233 | x = F.dropout(x, p=0.5, training=self.training) 234 | x = self.fc2(x) 235 | return F.log_softmax(x, dim=-1) 236 | 237 | 238 | def load_model(path, model): 239 | if not torch.cuda.is_available(): 240 | model.load_state_dict(torch.load(path, map_location="cpu")) 241 | else: 242 | model.load_state_dict(torch.load(path)) 243 | model.eval() 244 | 245 | 246 | def train_model(model, data, epochs=200, lr=0.01, weight_decay=5e-4, clip=None, loss_function="nll_loss", 247 | epoch_save_path=None, no_output=False): 248 | optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay) 249 | 250 | accuracies = [] 251 | 252 | model.train() 253 | for epoch in range(epochs): 254 | optimizer.zero_grad() 255 | out = model(data.x, data.edge_index) 256 | if loss_function == "nll_loss": 257 | loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask]) 258 | elif loss_function == "cross_entropy": 259 | loss = F.cross_entropy(out[data.train_mask], data.y[data.train_mask], size_average=True) 260 | else: 261 | raise Exception() 262 | if clip is not None: 263 | torch.nn.utils.clip_grad_norm(model.parameters(), clip) 264 | loss.backward() 265 | optimizer.step() 266 | 267 | if epoch_save_path is not None: 268 | # circumvent .pt ending 269 | save_model(model, epoch_save_path[:-3] + "_epoch_" + str(epoch) + epoch_save_path[-3:]) 270 | accuracies.append(retrieve_accuracy(model, data, value=True)) 271 | print('Accuracy: {:.4f}'.format(accuracies[-1]), "Epoch", epoch) 272 | else: 273 | if epoch % 25 == 0 and not no_output: 274 | print(retrieve_accuracy(model, data)) 275 | 276 | model.eval() 277 | 278 | return accuracies 279 | 280 | 281 | def save_model(model, path): 282 | torch.save(model.state_dict(), path) 283 | 284 | 285 | def retrieve_accuracy(model, data, test_mask=None, value=False): 286 | _, pred = model(data.x, data.edge_index).max(dim=1) 287 | if test_mask is None: 288 | test_mask = data.test_mask 289 | correct = float(pred[test_mask].eq(data.y[test_mask]).sum().item()) 290 | acc = correct / test_mask.sum().item() 291 | if value: 292 | return acc 293 | else: 294 | return 'Accuracy: {:.4f}'.format(acc) 295 | -------------------------------------------------------------------------------- /print_progress_overview.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from execution import MODEL_SAVE_NAMES 3 | import numpy as np 4 | import time 5 | from itertools import product 6 | 7 | DATASET_RESULT_PATHS = {"Cora": "cora", 8 | "CiteSeer": "citeseer", 9 | "PubMed": "pubmed", 10 | "syn2": "syn2", 11 | "syn2_1": "syn2_1", 12 | "syn2_2": "syn2_2", 13 | "syn2_3": "syn2_3", 14 | "syn2_4": "syn2_4", 15 | } 16 | 17 | 18 | def get_combinations(): 19 | datasets = ["Cora", "CiteSeer", "PubMed", ] 20 | models = ["GCN", "GAT", "GINConv", "APPNP2Net", ] 21 | 22 | # check which files are already present 23 | combinations = [] 24 | raw_combinations = list(product(datasets, models)) 25 | for dataset_name, model in raw_combinations: 26 | paths = ["_node_{:d}.npz", ] 27 | path_pattern = "_t_3_r_1_node_{:d}.npz" 28 | if model == "GCN_syn2": 29 | path_pattern = "_r_1_node_{:d}.npz" 30 | paths.append(path_pattern) 31 | combinations.append((dataset_name, model, paths)) 32 | 33 | model = "GCN_syn2" 34 | paths = ["_node_{:d}.npz", "_t_3_node_{:d}.npz"] 35 | combinations.append(("syn2", model, paths)) 36 | combinations.append(("syn2_1", model, paths)) 37 | combinations.append(("syn2_2", model, paths)) 38 | combinations.append(("syn2_3", model, paths)) 39 | paths = ["_node_{:d}.npz", 40 | "_t_3_node_{:d}.npz", 41 | "_t_3_epoch_0_node_{:d}.npz", 42 | "_t_3_epoch_200_node_{:d}.npz", 43 | "_t_3_epoch_400_node_{:d}.npz", 44 | "_t_3_epoch_600_node_{:d}.npz", 45 | "_t_3_epoch_1400_node_{:d}.npz", 46 | ] 47 | combinations.append(("syn2_4", model, paths)) 48 | 49 | return combinations 50 | 51 | 52 | if __name__ == "__main__": 53 | total_counter = 0 54 | 55 | combinations = get_combinations() 56 | for dataset_name, model, search_paths in combinations: 57 | working_directory = Path(".") 58 | 59 | results_directory = working_directory.joinpath("results") 60 | results_path_prefix = DATASET_RESULT_PATHS[dataset_name] 61 | 62 | if dataset_name[:4] == "syn2": 63 | selected_nodes = list(range(300, 700)) + list(range(1000, 1400)) 64 | else: 65 | selected_nodes = np.load(results_directory.joinpath(dataset_name + "_selected_nodes.npy")) 66 | results_path = results_path_prefix + "_" + model 67 | results_directory = working_directory.joinpath("results") 68 | results_directory = results_directory.joinpath(results_path) 69 | 70 | file_prefix = results_directory.joinpath(MODEL_SAVE_NAMES[model] + "_explanation") 71 | 72 | counter = [0] * len(search_paths) 73 | 74 | for node in selected_nodes: 75 | for i, path_pattern in enumerate(search_paths): 76 | try: 77 | with open(Path((str(file_prefix) + path_pattern).format(int(node)))): 78 | pass 79 | counter[i] += 1 80 | except FileNotFoundError: 81 | pass 82 | 83 | print(dataset_name, model, *counter, sep="\t") 84 | total_counter += sum(counter) 85 | 86 | print(time.ctime(), "Total number of explanations:", total_counter) 87 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | matplotlib==3.2.2 2 | networkx==2.4 3 | numba==0.50.0 4 | numpy==1.19.0 5 | pandas==1.0.5 6 | scikit-learn==0.23.1 7 | scipy==1.5.0 8 | seaborn==0.11.0 9 | torch==1.6.0 10 | torch-cluster==1.5.7 11 | torch-geometric==1.6.1 12 | torch-scatter==2.0.5 13 | torch-sparse==0.6.7 14 | torch-spline-conv==1.2.0 15 | torchvision==0.7.0 -------------------------------------------------------------------------------- /results/CiteSeer_selected_nodes.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/funket/zorro/3aeab245e710bc41fdbeffe31e35bffff11ec727/results/CiteSeer_selected_nodes.npy -------------------------------------------------------------------------------- /results/Cora_selected_nodes.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/funket/zorro/3aeab245e710bc41fdbeffe31e35bffff11ec727/results/Cora_selected_nodes.npy -------------------------------------------------------------------------------- /results/PubMed_selected_nodes.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/funket/zorro/3aeab245e710bc41fdbeffe31e35bffff11ec727/results/PubMed_selected_nodes.npy -------------------------------------------------------------------------------- /results/citeseer_APPNP2Net/appnp_2_layers.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/funket/zorro/3aeab245e710bc41fdbeffe31e35bffff11ec727/results/citeseer_APPNP2Net/appnp_2_layers.pt -------------------------------------------------------------------------------- /results/citeseer_GAT/gat_2_layers.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/funket/zorro/3aeab245e710bc41fdbeffe31e35bffff11ec727/results/citeseer_GAT/gat_2_layers.pt -------------------------------------------------------------------------------- /results/citeseer_GCN/gcn_2_layers.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/funket/zorro/3aeab245e710bc41fdbeffe31e35bffff11ec727/results/citeseer_GCN/gcn_2_layers.pt -------------------------------------------------------------------------------- /results/citeseer_GINConv/gin_2_layers.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/funket/zorro/3aeab245e710bc41fdbeffe31e35bffff11ec727/results/citeseer_GINConv/gin_2_layers.pt -------------------------------------------------------------------------------- /results/cora_APPNP2Net/appnp_2_layers.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/funket/zorro/3aeab245e710bc41fdbeffe31e35bffff11ec727/results/cora_APPNP2Net/appnp_2_layers.pt -------------------------------------------------------------------------------- /results/cora_GAT/gat_2_layers.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/funket/zorro/3aeab245e710bc41fdbeffe31e35bffff11ec727/results/cora_GAT/gat_2_layers.pt -------------------------------------------------------------------------------- /results/cora_GCN/gcn_2_layers.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/funket/zorro/3aeab245e710bc41fdbeffe31e35bffff11ec727/results/cora_GCN/gcn_2_layers.pt -------------------------------------------------------------------------------- /results/cora_GINConv/gin_2_layers.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/funket/zorro/3aeab245e710bc41fdbeffe31e35bffff11ec727/results/cora_GINConv/gin_2_layers.pt -------------------------------------------------------------------------------- /results/pubmed_APPNP2Net/appnp_2_layers.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/funket/zorro/3aeab245e710bc41fdbeffe31e35bffff11ec727/results/pubmed_APPNP2Net/appnp_2_layers.pt -------------------------------------------------------------------------------- /results/pubmed_GAT/gat_2_layers.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/funket/zorro/3aeab245e710bc41fdbeffe31e35bffff11ec727/results/pubmed_GAT/gat_2_layers.pt -------------------------------------------------------------------------------- /results/pubmed_GCN/gcn_2_layers.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/funket/zorro/3aeab245e710bc41fdbeffe31e35bffff11ec727/results/pubmed_GCN/gcn_2_layers.pt -------------------------------------------------------------------------------- /results/pubmed_GINConv/gin_2_layers.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/funket/zorro/3aeab245e710bc41fdbeffe31e35bffff11ec727/results/pubmed_GINConv/gin_2_layers.pt -------------------------------------------------------------------------------- /results/syn2_1_GCN_syn2/gcn_3_layers.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/funket/zorro/3aeab245e710bc41fdbeffe31e35bffff11ec727/results/syn2_1_GCN_syn2/gcn_3_layers.pt -------------------------------------------------------------------------------- /results/syn2_2_GCN_syn2/gcn_3_layers.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/funket/zorro/3aeab245e710bc41fdbeffe31e35bffff11ec727/results/syn2_2_GCN_syn2/gcn_3_layers.pt -------------------------------------------------------------------------------- /results/syn2_3_GCN_syn2/gcn_3_layers.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/funket/zorro/3aeab245e710bc41fdbeffe31e35bffff11ec727/results/syn2_3_GCN_syn2/gcn_3_layers.pt -------------------------------------------------------------------------------- /results/syn2_4_GCN_syn2/gcn_3_layers.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/funket/zorro/3aeab245e710bc41fdbeffe31e35bffff11ec727/results/syn2_4_GCN_syn2/gcn_3_layers.pt -------------------------------------------------------------------------------- /results/syn2_4_GCN_syn2/gcn_3_layers_epoch_0.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/funket/zorro/3aeab245e710bc41fdbeffe31e35bffff11ec727/results/syn2_4_GCN_syn2/gcn_3_layers_epoch_0.pt -------------------------------------------------------------------------------- /results/syn2_4_GCN_syn2/gcn_3_layers_epoch_1400.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/funket/zorro/3aeab245e710bc41fdbeffe31e35bffff11ec727/results/syn2_4_GCN_syn2/gcn_3_layers_epoch_1400.pt -------------------------------------------------------------------------------- /results/syn2_4_GCN_syn2/gcn_3_layers_epoch_200.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/funket/zorro/3aeab245e710bc41fdbeffe31e35bffff11ec727/results/syn2_4_GCN_syn2/gcn_3_layers_epoch_200.pt -------------------------------------------------------------------------------- /results/syn2_4_GCN_syn2/gcn_3_layers_epoch_400.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/funket/zorro/3aeab245e710bc41fdbeffe31e35bffff11ec727/results/syn2_4_GCN_syn2/gcn_3_layers_epoch_400.pt -------------------------------------------------------------------------------- /results/syn2_4_GCN_syn2/gcn_3_layers_epoch_600.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/funket/zorro/3aeab245e710bc41fdbeffe31e35bffff11ec727/results/syn2_4_GCN_syn2/gcn_3_layers_epoch_600.pt -------------------------------------------------------------------------------- /results/syn2_GCN_syn2/gcn_3_layers.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/funket/zorro/3aeab245e710bc41fdbeffe31e35bffff11ec727/results/syn2_GCN_syn2/gcn_3_layers.pt --------------------------------------------------------------------------------