├── .github └── ISSUE_TEMPLATE │ └── bug_report.md ├── .gitignore ├── CHANGELOG.md ├── INSTALLATION.md ├── LICENSE ├── README.md ├── configs.py ├── example.sh ├── explain_pyg.py ├── explainer ├── __init__.py └── explain.py ├── explainer_main.py ├── gengraph.py ├── main.py ├── models.py ├── models_pyg.py ├── notebook ├── GNN-Explainer-Viz-Interactive.ipynb └── GNN-Explainer-Viz.ipynb ├── requirements.txt ├── train.py └── utils ├── __init__.py ├── featgen.py ├── graph_utils.py ├── io_utils.py ├── math_utils.py ├── parser_utils.py ├── synthetic_structsim.py └── train_utils.py /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us fix issues with the codebase. 4 | title: '' 5 | labels: bug 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Describe the bug** 11 | A clear and concise description of what the bug is. 12 | 13 | **To Reproduce** 14 | Steps to reproduce the behavior: 15 | 1. Go to '...' 16 | 2. Click on '....' 17 | 3. Scroll down to '....' 18 | 4. See error 19 | 20 | **Platform (please complete the following information):** 21 | - OS: [e.g. OSX] 22 | - Python [e.g. 3.7] 23 | - Version [e.g. latest] 24 | 25 | **Additional context** 26 | Add any other context about the problem here. 27 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | # edit 107 | *.sw? 108 | data/ 109 | log/ 110 | syn/ 111 | ckpt/ 112 | out/ 113 | results/ 114 | path/ 115 | 116 | .idea/ 117 | .vscode/ 118 | 119 | 120 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Changelog 2 | 3 | ## v`0.0.1` 4 | --- 5 | 6 | > 23 May 2019 7 | 8 | - Paper submitted 9 | - Experiments complete 10 | 11 | > 04 September 2019 12 | 13 | - Paper accepted -------------------------------------------------------------------------------- /INSTALLATION.md: -------------------------------------------------------------------------------- 1 | # Installation 2 | 3 | ## From Source 4 | 5 | Start by grabbing this source code: 6 | 7 | ``` 8 | git clone https://github.com/RexYing/gnn-model-explainer 9 | ``` 10 | 11 | It is recommended to run this code inside a `virtualenv` with `python3.7`. 12 | 13 | ``` 14 | virtualenv venv -p /usr/local/bin/python3 15 | source venv/bin/activate 16 | ``` 17 | 18 | ### Requirements 19 | 20 | To install all the requirements, run the following command: 21 | 22 | ``` 23 | python -m pip install -r requirements.txt 24 | ``` 25 | 26 | If you want to install the packages manually, here's what you'll need: 27 | 28 | 29 | - PyTorch (tested with `1.3`) 30 | 31 | ``` 32 | python -m pip install torch torchvision 33 | ``` 34 | 35 | For alternative methods, check the following [link](https://pytorch.org/) 36 | 37 | - OpenCV 38 | 39 | ``` 40 | python -m pip install opencv-python 41 | ``` 42 | 43 | > TODO: It might be worth finding a way to remove the dependency to OpenCV. 44 | 45 | - Datascience in Python classics: 46 | 47 | ``` 48 | python -m pip install matplotlib networkx pandas sklearn seaborn 49 | ``` 50 | 51 | - `TensorboardX` 52 | 53 | ``` 54 | python -m pip install tensorboardX 55 | ``` 56 | 57 | To install [RDKit](https://www.rdkit.org/), follow the instructions provided [on the website](https://www.rdkit.org/docs/Install.html). -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # gnn-explainer 2 | 3 | This repository contains the source code for the paper `GNNExplainer: Generating Explanations for Graph Neural Networks` by [Rex Ying](https://cs.stanford.edu/people/rexy/), [Dylan Bourgeois](https://dtsbourg.me/), [Jiaxuan You](https://cs.stanford.edu/~jiaxuan/), [Marinka Zitnik](http://helikoid.si/cms/) & [Jure Leskovec](https://cs.stanford.edu/people/jure/), presented at [NeurIPS 2019](nips.cc). 4 | 5 | [[Arxiv]](https://arxiv.org/abs/1903.03894) [[BibTex]](https://dblp.uni-trier.de/rec/bibtex/journals/corr/abs-1903-03894) [[Google Scholar]](https://scholar.google.com/scholar?q=GNNExplainer%3A%20Generating%20Explanations%20for%20Graph%20Neural%20Networks%20Rex%20arXiv%202019) 6 | 7 | ``` 8 | @misc{ying2019gnnexplainer, 9 | title={GNNExplainer: Generating Explanations for Graph Neural Networks}, 10 | author={Rex Ying and Dylan Bourgeois and Jiaxuan You and Marinka Zitnik and Jure Leskovec}, 11 | year={2019}, 12 | eprint={1903.03894}, 13 | archivePrefix={arXiv}, 14 | primaryClass={cs.LG} 15 | } 16 | ``` 17 | 18 | ## Using the explainer 19 | 20 | ### Installation 21 | 22 | See [INSTALLATION.md](#) 23 | 24 | ### Replicating the paper's results 25 | 26 | #### Training a GCN model 27 | 28 | This is the model that will be explained. We do provide [pre-trained models](#TODO) for all of the experiments 29 | that are shown in the paper. To re-train these models, run the following: 30 | 31 | ``` 32 | python train.py --dataset=EXPERIMENT_NAME 33 | ``` 34 | 35 | where `EXPERIMENT_NAME` is the experiment you want to replicate. 36 | 37 | For a complete list of options in training the GCN models: 38 | 39 | ``` 40 | python train.py --help 41 | ``` 42 | 43 | > TODO: Explain outputs 44 | 45 | #### Explaining a GCN model 46 | 47 | To run the explainer, run the following: 48 | 49 | ``` 50 | python explainer_main.py --dataset=EXPERIMENT_NAME 51 | ``` 52 | 53 | where `EXPERIMENT_NAME` is the experiment you want to replicate. 54 | 55 | 56 | For a complete list of options provided by the explainer: 57 | 58 | ``` 59 | python train.py --help 60 | ``` 61 | 62 | #### Visualizing the explanations 63 | 64 | ##### Tensorboard 65 | 66 | The result of the optimization can be visualized through Tensorboard. 67 | 68 | ``` 69 | tensorboard --logdir log 70 | ``` 71 | 72 | You should then have access to visualizations served from `localhost`. 73 | 74 | #### Jupyter Notebook 75 | 76 | We provide an example visualization through Jupyter Notebooks in the `notebook` folder. To try it: 77 | 78 | ``` 79 | jupyter notebook 80 | ``` 81 | 82 | The default visualizations are provided in `notebook/GNN-Explainer-Viz.ipynb`. 83 | 84 | > Note: For an interactive version, you must enable ipywidgets 85 | > 86 | > ``` 87 | > jupyter nbextension enable --py widgetsnbextension 88 | > ``` 89 | 90 | You can now play around with the mask threshold in the `GNN-Explainer-Viz-interactive.ipynb`. 91 | > TODO: Explain outputs + visualizations + baselines 92 | 93 | #### D3,js 94 | 95 | We provide export functionality so the generated masks can be visualized in other data visualization 96 | frameworks, for example [d3.js](http://observablehq.com). We provide [an example visualization in Observable](https://observablehq.com/d/00c5dc74f359e7a1). 97 | 98 | #### Included experiments 99 | 100 | | Name | `EXPERIMENT_NAME` | Description | 101 | |----------|:-------------------:|--------------| 102 | | Synthetic #1 | `syn1` | Random BA graph with House attachments. | 103 | | Synthetic #2 | `syn2` | Random BA graph with community features. | 104 | | Synthetic #3 | `syn3` | Random BA graph with grid attachments. | 105 | | Synthetic #4 | `syn4` | Random Tree with cycle attachments. | 106 | | Synthetic #5 | `syn5` | Random Tree with grid attachments. | 107 | | Enron | `enron` | Enron email dataset [source](https://www.cs.cmu.edu/~enron/). | 108 | | PPI | `ppi_essential` | Protein-Protein interaction dataset. | 109 | | | | | 110 | | Reddit* | `REDDIT-BINARY` | Reddit-Binary Graphs ([source](https://ls11-www.cs.tu-dortmund.de/staff/morris/graphkerneldatasets)). | 111 | | Mutagenicity* | `Mutagenicity` | Predicting the mutagenicity of molecules ([source](https://ls11-www.cs.tu-dortmund.de/staff/morris/graphkerneldatasets)). | 112 | | Tox 21* | `Tox21_AHR` | Predicting a compound's toxicity ([source](https://ls11-www.cs.tu-dortmund.de/staff/morris/graphkerneldatasets)). | 113 | 114 | > Datasets with a * are passed with the `--bmname` parameter rather than `--dataset` as they require being downloaded manually. 115 | 116 | > TODO: Provide all data for experiments packaged so we don't have to split the two. 117 | 118 | 119 | ### Using the explainer on other models 120 | A graph attention model is provided. This repo is still being actively developed to support other 121 | GNN models in the future. 122 | 123 | ## Changelog 124 | 125 | See [CHANGELOG.md](#) 126 | -------------------------------------------------------------------------------- /configs.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import utils.parser_utils as parser_utils 3 | 4 | def arg_parse(): 5 | parser = argparse.ArgumentParser(description='GraphPool arguments.') 6 | io_parser = parser.add_mutually_exclusive_group(required=False) 7 | io_parser.add_argument('--dataset', dest='dataset', 8 | help='Input dataset.') 9 | benchmark_parser = io_parser.add_argument_group() 10 | benchmark_parser.add_argument('--bmname', dest='bmname', 11 | help='Name of the benchmark dataset') 12 | io_parser.add_argument('--pkl', dest='pkl_fname', 13 | help='Name of the pkl data file') 14 | 15 | softpool_parser = parser.add_argument_group() 16 | softpool_parser.add_argument('--assign-ratio', dest='assign_ratio', type=float, 17 | help='ratio of number of nodes in consecutive layers') 18 | softpool_parser.add_argument('--num-pool', dest='num_pool', type=int, 19 | help='number of pooling layers') 20 | parser.add_argument('--linkpred', dest='linkpred', action='store_const', 21 | const=True, default=False, 22 | help='Whether link prediction side objective is used') 23 | 24 | parser_utils.parse_optimizer(parser) 25 | 26 | parser.add_argument('--datadir', dest='datadir', 27 | help='Directory where benchmark is located') 28 | parser.add_argument('--logdir', dest='logdir', 29 | help='Tensorboard log directory') 30 | parser.add_argument('--ckptdir', dest='ckptdir', 31 | help='Model checkpoint directory') 32 | parser.add_argument('--cuda', dest='cuda', 33 | help='CUDA.') 34 | parser.add_argument('--gpu', dest='gpu', action='store_const', 35 | const=True, default=False, 36 | help='whether to use GPU.') 37 | parser.add_argument('--max_nodes', dest='max_nodes', type=int, 38 | help='Maximum number of nodes (ignore graghs with nodes exceeding the number.') 39 | parser.add_argument('--batch_size', dest='batch_size', type=int, 40 | help='Batch size.') 41 | parser.add_argument('--epochs', dest='num_epochs', type=int, 42 | help='Number of epochs to train.') 43 | parser.add_argument('--train_ratio', dest='train_ratio', type=float, 44 | help='Ratio of number of graphs training set to all graphs.') 45 | parser.add_argument('--num_workers', dest='num_workers', type=int, 46 | help='Number of workers to load data.') 47 | parser.add_argument('--feature', dest='feature_type', 48 | help='Feature used for encoder. Can be: id, deg') 49 | parser.add_argument('--input_dim', dest='input_dim', type=int, 50 | help='Input feature dimension') 51 | parser.add_argument('--hidden_dim', dest='hidden_dim', type=int, 52 | help='Hidden dimension') 53 | parser.add_argument('--output_dim', dest='output_dim', type=int, 54 | help='Output dimension') 55 | parser.add_argument('--num_classes', dest='num_classes', type=int, 56 | help='Number of label classes') 57 | parser.add_argument('--num_gc_layers', dest='num_gc_layers', type=int, 58 | help='Number of graph convolution layers before each pooling') 59 | parser.add_argument('--bn', dest='bn', action='store_const', 60 | const=True, default=False, 61 | help='Whether batch normalization is used') 62 | parser.add_argument('--dropout', dest='dropout', type=float, 63 | help='Dropout rate.') 64 | parser.add_argument('--nobias', dest='bias', action='store_const', 65 | const=False, default=True, 66 | help='Whether to add bias. Default to True.') 67 | parser.add_argument('--weight_decay', dest='weight_decay', type=float, 68 | help='Weight decay regularization constant.') 69 | 70 | parser.add_argument('--method', dest='method', 71 | help='Method. Possible values: base, ') 72 | parser.add_argument('--name-suffix', dest='name_suffix', 73 | help='suffix added to the output filename') 74 | 75 | parser.set_defaults(datadir='data', # io_parser 76 | logdir='log', 77 | ckptdir='ckpt', 78 | dataset='syn1', 79 | opt='adam', # opt_parser 80 | opt_scheduler='none', 81 | max_nodes=100, 82 | cuda='1', 83 | feature_type='default', 84 | lr=0.001, 85 | clip=2.0, 86 | batch_size=20, 87 | num_epochs=1000, 88 | train_ratio=0.8, 89 | test_ratio=0.1, 90 | num_workers=1, 91 | input_dim=10, 92 | hidden_dim=20, 93 | output_dim=20, 94 | num_classes=2, 95 | num_gc_layers=3, 96 | dropout=0.0, 97 | weight_decay=0.005, 98 | method='base', 99 | name_suffix='', 100 | assign_ratio=0.1, 101 | ) 102 | return parser.parse_args() 103 | 104 | -------------------------------------------------------------------------------- /example.sh: -------------------------------------------------------------------------------- 1 | # train the prediction model 2 | python -m train --dataset=syn1 3 | python -m train --dataset=syn2 4 | python -m train --dataset=syn4 5 | python -m train --dataset=syn5 6 | 7 | 8 | # train the explainer 9 | python -m explainer_main --explain-node=300 --dataset=syn1 # change explain-node to 301, 302, etc. (basis graph is of size 300) 10 | python -m explainer_main --explain-node=350 --dataset=syn2 # (basis graph is of size 350, first 2 feature dimensions were useful features, the rest are uninformative) 11 | python -m explainer_main --explain-node=512 --dataset=syn4 # change explain-node to 512, etc.(basis tree is of size 512) 12 | python -m explainer_main --explain-node=512 --dataset=syn5 # (basis tree is of size 512) 13 | -------------------------------------------------------------------------------- /explain_pyg.py: -------------------------------------------------------------------------------- 1 | import gengraph 2 | import random 3 | import torch_geometric 4 | from utils import featgen 5 | import numpy as np 6 | import utils.io_utils as io_utils 7 | from configs import arg_parse 8 | import torch 9 | import torch.nn as nn 10 | from torch.autograd import Variable 11 | from models_pyg import GCNNet 12 | import os 13 | from torch_geometric.utils import from_networkx 14 | from tensorboardX import SummaryWriter 15 | 16 | def test(loader, model, args, labels, test_mask): 17 | model.eval() 18 | 19 | train_ratio = args.train_ratio 20 | correct = 0 21 | for data in loader: 22 | with torch.no_grad(): 23 | pred = model(data) 24 | # print ('pred:', pred) 25 | pred = pred.argmax(dim=1) 26 | # print ('pred:', pred) 27 | 28 | # node classification: only evaluate on nodes in test set 29 | pred = pred[test_mask] 30 | # print ('pred:', pred) 31 | label = labels[test_mask] 32 | # print ('label:', label) 33 | 34 | correct += pred.eq(label).sum().item() 35 | 36 | total = len(test_mask) 37 | # print ('correct:', correct) 38 | return correct / total 39 | 40 | def syn_task1(args, writer=None): 41 | # data 42 | print ('Generating graph.') 43 | G, labels, name = gengraph.gen_syn1( 44 | feature_generator=featgen.ConstFeatureGen(np.ones(args.input_dim, dtype=float))) 45 | # print ('G.node[0]:', G.node[0]['feat'].dtype) 46 | # print ('Original labels:', labels) 47 | pyg_G = from_networkx(G) 48 | num_classes = max(labels)+1 49 | labels = torch.LongTensor(labels) 50 | print ('Done generating graph.') 51 | 52 | # if args.method == 'att': 53 | # print('Method: att') 54 | # model = models.GcnEncoderNode(args.input_dim, args.hidden_dim, args.output_dim, num_classes, 55 | # args.num_gc_layers, bn=args.bn, args=args) 56 | 57 | # else: 58 | # print('Method:', args.method) 59 | # model = models.GcnEncoderNode(args.input_dim, args.hidden_dim, args.output_dim, num_classes, 60 | # args.num_gc_layers, bn=args.bn, args=args) 61 | 62 | model = GCNNet(args.input_dim, args.hidden_dim, num_classes, args.num_gc_layers, args=args) 63 | 64 | if args.gpu: 65 | model = model.cuda() 66 | 67 | train_ratio = args.train_ratio 68 | num_train = int(train_ratio * G.number_of_nodes()) 69 | num_test = G.number_of_nodes() - num_train 70 | shuffle_indices = list(range(G.number_of_nodes())) 71 | shuffle_indices = np.random.permutation(shuffle_indices) 72 | 73 | train_mask = num_train * [True] + num_test * [False] 74 | train_mask = torch.BoolTensor([train_mask[i] for i in shuffle_indices]) 75 | test_mask = num_train * [False] + num_test * [True] 76 | test_mask = torch.BoolTensor([test_mask[i] for i in shuffle_indices]) 77 | 78 | loader = torch_geometric.data.DataLoader([pyg_G], batch_size=1) 79 | opt = torch.optim.Adam(model.parameters(), lr=args.lr) 80 | for epoch in range(args.num_epochs): 81 | total_loss = 0 82 | model.train() 83 | for batch in loader: 84 | # print ('batch:', batch.feat) 85 | opt.zero_grad() 86 | pred = model(batch) 87 | 88 | pred = pred[train_mask] 89 | # print ('pred:', pred) 90 | label = labels[train_mask] 91 | # print ('label:', label) 92 | loss = model.loss(pred, label) 93 | print ('loss:', loss) 94 | loss.backward() 95 | opt.step() 96 | total_loss += loss.item() * 1 97 | total_loss /= num_train 98 | writer.add_scalar("loss", total_loss, epoch) 99 | 100 | if epoch % 10 == 0: 101 | test_acc = test(loader, model, args, labels, test_mask) 102 | print("Epoch {}. Loss: {:.4f}. Test accuracy: {:.4f}".format( 103 | epoch, total_loss, test_acc)) 104 | writer.add_scalar("test accuracy", test_acc, epoch) 105 | 106 | prog_args = arg_parse() 107 | path = os.path.join(prog_args.logdir, io_utils.gen_prefix(prog_args)) 108 | syn_task1(prog_args, writer=SummaryWriter(path)) 109 | -------------------------------------------------------------------------------- /explainer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RexYing/gnn-model-explainer/bc984829f4f4829e93c760e9bbdc8e73f96e2cc1/explainer/__init__.py -------------------------------------------------------------------------------- /explainer/explain.py: -------------------------------------------------------------------------------- 1 | """ explain.py 2 | 3 | Implementation of the explainer. 4 | """ 5 | 6 | import math 7 | import time 8 | import os 9 | 10 | import matplotlib 11 | import matplotlib.colors as colors 12 | import matplotlib.pyplot as plt 13 | from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas 14 | from matplotlib.figure import Figure 15 | 16 | import networkx as nx 17 | import numpy as np 18 | import pandas as pd 19 | import seaborn as sns 20 | import tensorboardX.utils 21 | 22 | import torch 23 | import torch.nn as nn 24 | from torch.autograd import Variable 25 | 26 | import sklearn.metrics as metrics 27 | from sklearn.metrics import roc_auc_score, recall_score, precision_score, roc_auc_score, precision_recall_curve 28 | from sklearn.cluster import DBSCAN 29 | 30 | import pdb 31 | 32 | import utils.io_utils as io_utils 33 | import utils.train_utils as train_utils 34 | import utils.graph_utils as graph_utils 35 | 36 | 37 | use_cuda = torch.cuda.is_available() 38 | FloatTensor = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor 39 | LongTensor = torch.cuda.LongTensor if use_cuda else torch.LongTensor 40 | Tensor = FloatTensor 41 | 42 | class Explainer: 43 | def __init__( 44 | self, 45 | model, 46 | adj, 47 | feat, 48 | label, 49 | pred, 50 | train_idx, 51 | args, 52 | writer=None, 53 | print_training=True, 54 | graph_mode=False, 55 | graph_idx=False, 56 | ): 57 | self.model = model 58 | self.model.eval() 59 | self.adj = adj 60 | self.feat = feat 61 | self.label = label 62 | self.pred = pred 63 | self.train_idx = train_idx 64 | self.n_hops = args.num_gc_layers 65 | self.graph_mode = graph_mode 66 | self.graph_idx = graph_idx 67 | self.neighborhoods = None if self.graph_mode else graph_utils.neighborhoods(adj=self.adj, n_hops=self.n_hops, use_cuda=use_cuda) 68 | self.args = args 69 | self.writer = writer 70 | self.print_training = print_training 71 | 72 | 73 | # Main method 74 | def explain( 75 | self, node_idx, graph_idx=0, graph_mode=False, unconstrained=False, model="exp" 76 | ): 77 | """Explain a single node prediction 78 | """ 79 | # index of the query node in the new adj 80 | if graph_mode: 81 | node_idx_new = node_idx 82 | sub_adj = self.adj[graph_idx] 83 | sub_feat = self.feat[graph_idx, :] 84 | sub_label = self.label[graph_idx] 85 | neighbors = np.asarray(range(self.adj.shape[0])) 86 | else: 87 | print("node label: ", self.label[graph_idx][node_idx]) 88 | node_idx_new, sub_adj, sub_feat, sub_label, neighbors = self.extract_neighborhood( 89 | node_idx, graph_idx 90 | ) 91 | print("neigh graph idx: ", node_idx, node_idx_new) 92 | sub_label = np.expand_dims(sub_label, axis=0) 93 | 94 | sub_adj = np.expand_dims(sub_adj, axis=0) 95 | sub_feat = np.expand_dims(sub_feat, axis=0) 96 | 97 | adj = torch.tensor(sub_adj, dtype=torch.float) 98 | x = torch.tensor(sub_feat, requires_grad=True, dtype=torch.float) 99 | label = torch.tensor(sub_label, dtype=torch.long) 100 | 101 | if self.graph_mode: 102 | pred_label = np.argmax(self.pred[0][graph_idx], axis=0) 103 | print("Graph predicted label: ", pred_label) 104 | else: 105 | pred_label = np.argmax(self.pred[graph_idx][neighbors], axis=1) 106 | print("Node predicted label: ", pred_label[node_idx_new]) 107 | 108 | explainer = ExplainModule( 109 | adj=adj, 110 | x=x, 111 | model=self.model, 112 | label=label, 113 | args=self.args, 114 | writer=self.writer, 115 | graph_idx=self.graph_idx, 116 | graph_mode=self.graph_mode, 117 | ) 118 | if self.args.gpu: 119 | explainer = explainer.cuda() 120 | 121 | self.model.eval() 122 | 123 | 124 | # gradient baseline 125 | if model == "grad": 126 | explainer.zero_grad() 127 | # pdb.set_trace() 128 | adj_grad = torch.abs( 129 | explainer.adj_feat_grad(node_idx_new, pred_label[node_idx_new])[0] 130 | )[graph_idx] 131 | masked_adj = adj_grad + adj_grad.t() 132 | masked_adj = nn.functional.sigmoid(masked_adj) 133 | masked_adj = masked_adj.cpu().detach().numpy() * sub_adj.squeeze() 134 | else: 135 | explainer.train() 136 | begin_time = time.time() 137 | for epoch in range(self.args.num_epochs): 138 | explainer.zero_grad() 139 | explainer.optimizer.zero_grad() 140 | ypred, adj_atts = explainer(node_idx_new, unconstrained=unconstrained) 141 | loss = explainer.loss(ypred, pred_label, node_idx_new, epoch) 142 | loss.backward() 143 | 144 | explainer.optimizer.step() 145 | if explainer.scheduler is not None: 146 | explainer.scheduler.step() 147 | 148 | mask_density = explainer.mask_density() 149 | if self.print_training: 150 | print( 151 | "epoch: ", 152 | epoch, 153 | "; loss: ", 154 | loss.item(), 155 | "; mask density: ", 156 | mask_density.item(), 157 | "; pred: ", 158 | ypred, 159 | ) 160 | single_subgraph_label = sub_label.squeeze() 161 | 162 | if self.writer is not None: 163 | self.writer.add_scalar("mask/density", mask_density, epoch) 164 | self.writer.add_scalar( 165 | "optimization/lr", 166 | explainer.optimizer.param_groups[0]["lr"], 167 | epoch, 168 | ) 169 | if epoch % 25 == 0: 170 | explainer.log_mask(epoch) 171 | explainer.log_masked_adj( 172 | node_idx_new, epoch, label=single_subgraph_label 173 | ) 174 | explainer.log_adj_grad( 175 | node_idx_new, pred_label, epoch, label=single_subgraph_label 176 | ) 177 | 178 | if epoch == 0: 179 | if self.model.att: 180 | # explain node 181 | print("adj att size: ", adj_atts.size()) 182 | adj_att = torch.sum(adj_atts[0], dim=2) 183 | # adj_att = adj_att[neighbors][:, neighbors] 184 | node_adj_att = adj_att * adj.float().cuda() 185 | io_utils.log_matrix( 186 | self.writer, node_adj_att[0], "att/matrix", epoch 187 | ) 188 | node_adj_att = node_adj_att[0].cpu().detach().numpy() 189 | G = io_utils.denoise_graph( 190 | node_adj_att, 191 | node_idx_new, 192 | threshold=3.8, # threshold_num=20, 193 | max_component=True, 194 | ) 195 | io_utils.log_graph( 196 | self.writer, 197 | G, 198 | name="att/graph", 199 | identify_self=not self.graph_mode, 200 | nodecolor="label", 201 | edge_vmax=None, 202 | args=self.args, 203 | ) 204 | if model != "exp": 205 | break 206 | 207 | print("finished training in ", time.time() - begin_time) 208 | if model == "exp": 209 | masked_adj = ( 210 | explainer.masked_adj[0].cpu().detach().numpy() * sub_adj.squeeze() 211 | ) 212 | else: 213 | adj_atts = nn.functional.sigmoid(adj_atts).squeeze() 214 | masked_adj = adj_atts.cpu().detach().numpy() * sub_adj.squeeze() 215 | 216 | fname = 'masked_adj_' + io_utils.gen_explainer_prefix(self.args) + ( 217 | 'node_idx_'+str(node_idx)+'graph_idx_'+str(self.graph_idx)+'.npy') 218 | with open(os.path.join(self.args.logdir, fname), 'wb') as outfile: 219 | np.save(outfile, np.asarray(masked_adj.copy())) 220 | print("Saved adjacency matrix to ", fname) 221 | return masked_adj 222 | 223 | 224 | # NODE EXPLAINER 225 | def explain_nodes(self, node_indices, args, graph_idx=0): 226 | """ 227 | Explain nodes 228 | 229 | Args: 230 | - node_indices : Indices of the nodes to be explained 231 | - args : Program arguments (mainly for logging paths) 232 | - graph_idx : Index of the graph to explain the nodes from (if multiple). 233 | """ 234 | masked_adjs = [ 235 | self.explain(node_idx, graph_idx=graph_idx) for node_idx in node_indices 236 | ] 237 | ref_idx = node_indices[0] 238 | ref_adj = masked_adjs[0] 239 | curr_idx = node_indices[1] 240 | curr_adj = masked_adjs[1] 241 | new_ref_idx, _, ref_feat, _, _ = self.extract_neighborhood(ref_idx) 242 | new_curr_idx, _, curr_feat, _, _ = self.extract_neighborhood(curr_idx) 243 | 244 | G_ref = io_utils.denoise_graph(ref_adj, new_ref_idx, ref_feat, threshold=0.1) 245 | denoised_ref_feat = np.array( 246 | [G_ref.nodes[node]["feat"] for node in G_ref.nodes()] 247 | ) 248 | denoised_ref_adj = nx.to_numpy_matrix(G_ref) 249 | # ref center node 250 | ref_node_idx = list(G_ref.nodes()).index(new_ref_idx) 251 | 252 | G_curr = io_utils.denoise_graph( 253 | curr_adj, new_curr_idx, curr_feat, threshold=0.1 254 | ) 255 | denoised_curr_feat = np.array( 256 | [G_curr.nodes[node]["feat"] for node in G_curr.nodes()] 257 | ) 258 | denoised_curr_adj = nx.to_numpy_matrix(G_curr) 259 | # curr center node 260 | curr_node_idx = list(G_curr.nodes()).index(new_curr_idx) 261 | 262 | P, aligned_adj, aligned_feat = self.align( 263 | denoised_ref_feat, 264 | denoised_ref_adj, 265 | ref_node_idx, 266 | denoised_curr_feat, 267 | denoised_curr_adj, 268 | curr_node_idx, 269 | args=args, 270 | ) 271 | io_utils.log_matrix(self.writer, P, "align/P", 0) 272 | 273 | G_ref = nx.convert_node_labels_to_integers(G_ref) 274 | io_utils.log_graph(self.writer, G_ref, "align/ref") 275 | G_curr = nx.convert_node_labels_to_integers(G_curr) 276 | io_utils.log_graph(self.writer, G_curr, "align/before") 277 | 278 | P = P.cpu().detach().numpy() 279 | aligned_adj = aligned_adj.cpu().detach().numpy() 280 | aligned_feat = aligned_feat.cpu().detach().numpy() 281 | 282 | aligned_idx = np.argmax(P[:, curr_node_idx]) 283 | print("aligned self: ", aligned_idx) 284 | G_aligned = io_utils.denoise_graph( 285 | aligned_adj, aligned_idx, aligned_feat, threshold=0.5 286 | ) 287 | io_utils.log_graph(self.writer, G_aligned, "mask/aligned") 288 | 289 | # io_utils.log_graph(self.writer, aligned_adj.cpu().detach().numpy(), new_curr_idx, 290 | # 'align/aligned', epoch=1) 291 | 292 | return masked_adjs 293 | 294 | 295 | def explain_nodes_gnn_stats(self, node_indices, args, graph_idx=0, model="exp"): 296 | masked_adjs = [ 297 | self.explain(node_idx, graph_idx=graph_idx, model=model) 298 | for node_idx in node_indices 299 | ] 300 | # pdb.set_trace() 301 | graphs = [] 302 | feats = [] 303 | adjs = [] 304 | pred_all = [] 305 | real_all = [] 306 | for i, idx in enumerate(node_indices): 307 | new_idx, _, feat, _, _ = self.extract_neighborhood(idx) 308 | G = io_utils.denoise_graph(masked_adjs[i], new_idx, feat, threshold_num=20) 309 | pred, real = self.make_pred_real(masked_adjs[i], new_idx) 310 | pred_all.append(pred) 311 | real_all.append(real) 312 | denoised_feat = np.array([G.nodes[node]["feat"] for node in G.nodes()]) 313 | denoised_adj = nx.to_numpy_matrix(G) 314 | graphs.append(G) 315 | feats.append(denoised_feat) 316 | adjs.append(denoised_adj) 317 | io_utils.log_graph( 318 | self.writer, 319 | G, 320 | "graph/{}_{}_{}".format(self.args.dataset, model, i), 321 | identify_self=True, 322 | args=self.args 323 | ) 324 | 325 | pred_all = np.concatenate((pred_all), axis=0) 326 | real_all = np.concatenate((real_all), axis=0) 327 | 328 | auc_all = roc_auc_score(real_all, pred_all) 329 | precision, recall, thresholds = precision_recall_curve(real_all, pred_all) 330 | 331 | plt.switch_backend("agg") 332 | plt.plot(recall, precision) 333 | plt.savefig("log/pr/pr_" + self.args.dataset + "_" + model + ".png") 334 | 335 | plt.close() 336 | 337 | auc_all = roc_auc_score(real_all, pred_all) 338 | precision, recall, thresholds = precision_recall_curve(real_all, pred_all) 339 | 340 | plt.switch_backend("agg") 341 | plt.plot(recall, precision) 342 | plt.savefig("log/pr/pr_" + self.args.dataset + "_" + model + ".png") 343 | 344 | plt.close() 345 | 346 | with open("log/pr/auc_" + self.args.dataset + "_" + model + ".txt", "w") as f: 347 | f.write( 348 | "dataset: {}, model: {}, auc: {}\n".format( 349 | self.args.dataset, "exp", str(auc_all) 350 | ) 351 | ) 352 | 353 | return masked_adjs 354 | 355 | # GRAPH EXPLAINER 356 | def explain_graphs(self, graph_indices): 357 | """ 358 | Explain graphs. 359 | """ 360 | masked_adjs = [] 361 | 362 | for graph_idx in graph_indices: 363 | masked_adj = self.explain(node_idx=0, graph_idx=graph_idx, graph_mode=True) 364 | G_denoised = io_utils.denoise_graph( 365 | masked_adj, 366 | 0, 367 | threshold_num=20, 368 | feat=self.feat[graph_idx], 369 | max_component=False, 370 | ) 371 | label = self.label[graph_idx] 372 | io_utils.log_graph( 373 | self.writer, 374 | G_denoised, 375 | "graph/graphidx_{}_label={}".format(graph_idx, label), 376 | identify_self=False, 377 | nodecolor="feat", 378 | args=self.args 379 | ) 380 | masked_adjs.append(masked_adj) 381 | 382 | G_orig = io_utils.denoise_graph( 383 | self.adj[graph_idx], 384 | 0, 385 | feat=self.feat[graph_idx], 386 | threshold=None, 387 | max_component=False, 388 | ) 389 | 390 | io_utils.log_graph( 391 | self.writer, 392 | G_orig, 393 | "graph/graphidx_{}".format(graph_idx), 394 | identify_self=False, 395 | nodecolor="feat", 396 | args=self.args 397 | ) 398 | 399 | # plot cmap for graphs' node features 400 | io_utils.plot_cmap_tb(self.writer, "tab20", 20, "tab20_cmap") 401 | 402 | return masked_adjs 403 | 404 | def log_representer(self, rep_val, sim_val, alpha, graph_idx=0): 405 | """ visualize output of representer instances. """ 406 | rep_val = rep_val.cpu().detach().numpy() 407 | sim_val = sim_val.cpu().detach().numpy() 408 | alpha = alpha.cpu().detach().numpy() 409 | sorted_rep = sorted(range(len(rep_val)), key=lambda k: rep_val[k]) 410 | print(sorted_rep) 411 | topk = 5 412 | most_neg_idx = [sorted_rep[i] for i in range(topk)] 413 | most_pos_idx = [sorted_rep[-i - 1] for i in range(topk)] 414 | rep_idx = [most_pos_idx, most_neg_idx] 415 | 416 | if self.graph_mode: 417 | pred = np.argmax(self.pred[0][graph_idx], axis=0) 418 | else: 419 | pred = np.argmax(self.pred[graph_idx][self.train_idx], axis=1) 420 | print(metrics.confusion_matrix(self.label[graph_idx][self.train_idx], pred)) 421 | plt.switch_backend("agg") 422 | fig = plt.figure(figsize=(5, 3), dpi=600) 423 | for i in range(2): 424 | for j in range(topk): 425 | idx = self.train_idx[rep_idx[i][j]] 426 | print( 427 | "node idx: ", 428 | idx, 429 | "; node label: ", 430 | self.label[graph_idx][idx], 431 | "; pred: ", 432 | pred, 433 | ) 434 | 435 | idx_new, sub_adj, sub_feat, sub_label, neighbors = self.extract_neighborhood( 436 | idx, graph_idx 437 | ) 438 | G = nx.from_numpy_matrix(sub_adj) 439 | node_colors = [1 for i in range(G.number_of_nodes())] 440 | node_colors[idx_new] = 0 441 | # node_color='#336699', 442 | 443 | ax = plt.subplot(2, topk, i * topk + j + 1) 444 | nx.draw( 445 | G, 446 | pos=nx.spring_layout(G), 447 | with_labels=True, 448 | font_size=4, 449 | node_color=node_colors, 450 | cmap=plt.get_cmap("Set1"), 451 | vmin=0, 452 | vmax=8, 453 | edge_vmin=0.0, 454 | edge_vmax=1.0, 455 | width=0.5, 456 | node_size=25, 457 | alpha=0.7, 458 | ) 459 | ax.xaxis.set_visible(False) 460 | fig.canvas.draw() 461 | self.writer.add_image( 462 | "local/representer_neigh", tensorboardX.utils.figure_to_image(fig), 0 463 | ) 464 | 465 | def representer(self): 466 | """ 467 | experiment using representer theorem for finding supporting instances. 468 | https://papers.nips.cc/paper/8141-representer-point-selection-for-explaining-deep-neural-networks.pdf 469 | """ 470 | self.model.train() 471 | self.model.zero_grad() 472 | adj = torch.tensor(self.adj, dtype=torch.float) 473 | x = torch.tensor(self.feat, requires_grad=True, dtype=torch.float) 474 | label = torch.tensor(self.label, dtype=torch.long) 475 | if self.args.gpu: 476 | adj, x, label = adj.cuda(), x.cuda(), label.cuda() 477 | 478 | preds, _ = self.model(x, adj) 479 | preds.retain_grad() 480 | self.embedding = self.model.embedding_tensor 481 | loss = self.model.loss(preds, label) 482 | loss.backward() 483 | self.preds_grad = preds.grad 484 | pred_idx = np.expand_dims(np.argmax(self.pred, axis=2), axis=2) 485 | pred_idx = torch.LongTensor(pred_idx) 486 | if self.args.gpu: 487 | pred_idx = pred_idx.cuda() 488 | self.alpha = self.preds_grad 489 | 490 | 491 | # Utilities 492 | def extract_neighborhood(self, node_idx, graph_idx=0): 493 | """Returns the neighborhood of a given ndoe.""" 494 | neighbors_adj_row = self.neighborhoods[graph_idx][node_idx, :] 495 | # index of the query node in the new adj 496 | node_idx_new = sum(neighbors_adj_row[:node_idx]) 497 | neighbors = np.nonzero(neighbors_adj_row)[0] 498 | sub_adj = self.adj[graph_idx][neighbors][:, neighbors] 499 | sub_feat = self.feat[graph_idx, neighbors] 500 | sub_label = self.label[graph_idx][neighbors] 501 | return node_idx_new, sub_adj, sub_feat, sub_label, neighbors 502 | 503 | def align( 504 | self, ref_feat, ref_adj, ref_node_idx, curr_feat, curr_adj, curr_node_idx, args 505 | ): 506 | """ Tries to find an alignment between two graphs. 507 | """ 508 | ref_adj = torch.FloatTensor(ref_adj) 509 | curr_adj = torch.FloatTensor(curr_adj) 510 | 511 | ref_feat = torch.FloatTensor(ref_feat) 512 | curr_feat = torch.FloatTensor(curr_feat) 513 | 514 | P = nn.Parameter(torch.FloatTensor(ref_adj.shape[0], curr_adj.shape[0])) 515 | with torch.no_grad(): 516 | nn.init.constant_(P, 1.0 / ref_adj.shape[0]) 517 | P[ref_node_idx, :] = 0.0 518 | P[:, curr_node_idx] = 0.0 519 | P[ref_node_idx, curr_node_idx] = 1.0 520 | opt = torch.optim.Adam([P], lr=0.01, betas=(0.5, 0.999)) 521 | for i in range(args.align_steps): 522 | opt.zero_grad() 523 | feat_loss = torch.norm(P @ curr_feat - ref_feat) 524 | 525 | aligned_adj = P @ curr_adj @ torch.transpose(P, 0, 1) 526 | align_loss = torch.norm(aligned_adj - ref_adj) 527 | loss = feat_loss + align_loss 528 | loss.backward() # Calculate gradients 529 | self.writer.add_scalar("optimization/align_loss", loss, i) 530 | print("iter: ", i, "; loss: ", loss) 531 | opt.step() 532 | 533 | return P, aligned_adj, P @ curr_feat 534 | 535 | def make_pred_real(self, adj, start): 536 | # house graph 537 | if self.args.dataset == "syn1" or self.args.dataset == "syn2": 538 | # num_pred = max(G.number_of_edges(), 6) 539 | pred = adj[np.triu(adj) > 0] 540 | real = adj.copy() 541 | 542 | if real[start][start + 1] > 0: 543 | real[start][start + 1] = 10 544 | if real[start + 1][start + 2] > 0: 545 | real[start + 1][start + 2] = 10 546 | if real[start + 2][start + 3] > 0: 547 | real[start + 2][start + 3] = 10 548 | if real[start][start + 3] > 0: 549 | real[start][start + 3] = 10 550 | if real[start][start + 4] > 0: 551 | real[start][start + 4] = 10 552 | if real[start + 1][start + 4]: 553 | real[start + 1][start + 4] = 10 554 | real = real[np.triu(real) > 0] 555 | real[real != 10] = 0 556 | real[real == 10] = 1 557 | 558 | # cycle graph 559 | elif self.args.dataset == "syn4": 560 | pred = adj[np.triu(adj) > 0] 561 | real = adj.copy() 562 | # pdb.set_trace() 563 | if real[start][start + 1] > 0: 564 | real[start][start + 1] = 10 565 | if real[start + 1][start + 2] > 0: 566 | real[start + 1][start + 2] = 10 567 | if real[start + 2][start + 3] > 0: 568 | real[start + 2][start + 3] = 10 569 | if real[start + 3][start + 4] > 0: 570 | real[start + 3][start + 4] = 10 571 | if real[start + 4][start + 5] > 0: 572 | real[start + 4][start + 5] = 10 573 | if real[start][start + 5]: 574 | real[start][start + 5] = 10 575 | real = real[np.triu(real) > 0] 576 | real[real != 10] = 0 577 | real[real == 10] = 1 578 | 579 | return pred, real 580 | 581 | 582 | class ExplainModule(nn.Module): 583 | def __init__( 584 | self, 585 | adj, 586 | x, 587 | model, 588 | label, 589 | args, 590 | graph_idx=0, 591 | writer=None, 592 | use_sigmoid=True, 593 | graph_mode=False, 594 | ): 595 | super(ExplainModule, self).__init__() 596 | self.adj = adj 597 | self.x = x 598 | self.model = model 599 | self.label = label 600 | self.graph_idx = graph_idx 601 | self.args = args 602 | self.writer = writer 603 | self.mask_act = args.mask_act 604 | self.use_sigmoid = use_sigmoid 605 | self.graph_mode = graph_mode 606 | 607 | init_strategy = "normal" 608 | num_nodes = adj.size()[1] 609 | self.mask, self.mask_bias = self.construct_edge_mask( 610 | num_nodes, init_strategy=init_strategy 611 | ) 612 | 613 | self.feat_mask = self.construct_feat_mask(x.size(-1), init_strategy="constant") 614 | params = [self.mask, self.feat_mask] 615 | if self.mask_bias is not None: 616 | params.append(self.mask_bias) 617 | # For masking diagonal entries 618 | self.diag_mask = torch.ones(num_nodes, num_nodes) - torch.eye(num_nodes) 619 | if args.gpu: 620 | self.diag_mask = self.diag_mask.cuda() 621 | 622 | self.scheduler, self.optimizer = train_utils.build_optimizer(args, params) 623 | 624 | self.coeffs = { 625 | "size": 0.005, 626 | "feat_size": 1.0, 627 | "ent": 1.0, 628 | "feat_ent": 0.1, 629 | "grad": 0, 630 | "lap": 1.0, 631 | } 632 | 633 | def construct_feat_mask(self, feat_dim, init_strategy="normal"): 634 | mask = nn.Parameter(torch.FloatTensor(feat_dim)) 635 | if init_strategy == "normal": 636 | std = 0.1 637 | with torch.no_grad(): 638 | mask.normal_(1.0, std) 639 | elif init_strategy == "constant": 640 | with torch.no_grad(): 641 | nn.init.constant_(mask, 0.0) 642 | # mask[0] = 2 643 | return mask 644 | 645 | def construct_edge_mask(self, num_nodes, init_strategy="normal", const_val=1.0): 646 | mask = nn.Parameter(torch.FloatTensor(num_nodes, num_nodes)) 647 | if init_strategy == "normal": 648 | std = nn.init.calculate_gain("relu") * math.sqrt( 649 | 2.0 / (num_nodes + num_nodes) 650 | ) 651 | with torch.no_grad(): 652 | mask.normal_(1.0, std) 653 | # mask.clamp_(0.0, 1.0) 654 | elif init_strategy == "const": 655 | nn.init.constant_(mask, const_val) 656 | 657 | if self.args.mask_bias: 658 | mask_bias = nn.Parameter(torch.FloatTensor(num_nodes, num_nodes)) 659 | nn.init.constant_(mask_bias, 0.0) 660 | else: 661 | mask_bias = None 662 | 663 | return mask, mask_bias 664 | 665 | def _masked_adj(self): 666 | sym_mask = self.mask 667 | if self.mask_act == "sigmoid": 668 | sym_mask = torch.sigmoid(self.mask) 669 | elif self.mask_act == "ReLU": 670 | sym_mask = nn.ReLU()(self.mask) 671 | sym_mask = (sym_mask + sym_mask.t()) / 2 672 | adj = self.adj.cuda() if self.args.gpu else self.adj 673 | masked_adj = adj * sym_mask 674 | if self.args.mask_bias: 675 | bias = (self.mask_bias + self.mask_bias.t()) / 2 676 | bias = nn.ReLU6()(bias * 6) / 6 677 | masked_adj += (bias + bias.t()) / 2 678 | return masked_adj * self.diag_mask 679 | 680 | def mask_density(self): 681 | mask_sum = torch.sum(self._masked_adj()).cpu() 682 | adj_sum = torch.sum(self.adj) 683 | return mask_sum / adj_sum 684 | 685 | def forward(self, node_idx, unconstrained=False, mask_features=True, marginalize=False): 686 | x = self.x.cuda() if self.args.gpu else self.x 687 | 688 | if unconstrained: 689 | sym_mask = torch.sigmoid(self.mask) if self.use_sigmoid else self.mask 690 | self.masked_adj = ( 691 | torch.unsqueeze((sym_mask + sym_mask.t()) / 2, 0) * self.diag_mask 692 | ) 693 | else: 694 | self.masked_adj = self._masked_adj() 695 | if mask_features: 696 | feat_mask = ( 697 | torch.sigmoid(self.feat_mask) 698 | if self.use_sigmoid 699 | else self.feat_mask 700 | ) 701 | if marginalize: 702 | std_tensor = torch.ones_like(x, dtype=torch.float) / 2 703 | mean_tensor = torch.zeros_like(x, dtype=torch.float) - x 704 | z = torch.normal(mean=mean_tensor, std=std_tensor) 705 | x = x + z * (1 - feat_mask) 706 | else: 707 | x = x * feat_mask 708 | 709 | ypred, adj_att = self.model(x, self.masked_adj) 710 | if self.graph_mode: 711 | res = nn.Softmax(dim=0)(ypred[0]) 712 | else: 713 | node_pred = ypred[self.graph_idx, node_idx, :] 714 | res = nn.Softmax(dim=0)(node_pred) 715 | return res, adj_att 716 | 717 | def adj_feat_grad(self, node_idx, pred_label_node): 718 | self.model.zero_grad() 719 | self.adj.requires_grad = True 720 | self.x.requires_grad = True 721 | if self.adj.grad is not None: 722 | self.adj.grad.zero_() 723 | self.x.grad.zero_() 724 | if self.args.gpu: 725 | adj = self.adj.cuda() 726 | x = self.x.cuda() 727 | label = self.label.cuda() 728 | else: 729 | x, adj = self.x, self.adj 730 | ypred, _ = self.model(x, adj) 731 | if self.graph_mode: 732 | logit = nn.Softmax(dim=0)(ypred[0]) 733 | else: 734 | logit = nn.Softmax(dim=0)(ypred[self.graph_idx, node_idx, :]) 735 | logit = logit[pred_label_node] 736 | loss = -torch.log(logit) 737 | loss.backward() 738 | return self.adj.grad, self.x.grad 739 | 740 | def loss(self, pred, pred_label, node_idx, epoch): 741 | """ 742 | Args: 743 | pred: prediction made by current model 744 | pred_label: the label predicted by the original model. 745 | """ 746 | mi_obj = False 747 | if mi_obj: 748 | pred_loss = -torch.sum(pred * torch.log(pred)) 749 | else: 750 | pred_label_node = pred_label if self.graph_mode else pred_label[node_idx] 751 | gt_label_node = self.label if self.graph_mode else self.label[0][node_idx] 752 | logit = pred[gt_label_node] 753 | pred_loss = -torch.log(logit) 754 | # size 755 | mask = self.mask 756 | if self.mask_act == "sigmoid": 757 | mask = torch.sigmoid(self.mask) 758 | elif self.mask_act == "ReLU": 759 | mask = nn.ReLU()(self.mask) 760 | size_loss = self.coeffs["size"] * torch.sum(mask) 761 | 762 | # pre_mask_sum = torch.sum(self.feat_mask) 763 | feat_mask = ( 764 | torch.sigmoid(self.feat_mask) if self.use_sigmoid else self.feat_mask 765 | ) 766 | feat_size_loss = self.coeffs["feat_size"] * torch.mean(feat_mask) 767 | 768 | # entropy 769 | mask_ent = -mask * torch.log(mask) - (1 - mask) * torch.log(1 - mask) 770 | mask_ent_loss = self.coeffs["ent"] * torch.mean(mask_ent) 771 | 772 | feat_mask_ent = - feat_mask \ 773 | * torch.log(feat_mask) \ 774 | - (1 - feat_mask) \ 775 | * torch.log(1 - feat_mask) 776 | 777 | feat_mask_ent_loss = self.coeffs["feat_ent"] * torch.mean(feat_mask_ent) 778 | 779 | # laplacian 780 | D = torch.diag(torch.sum(self.masked_adj[0], 0)) 781 | m_adj = self.masked_adj if self.graph_mode else self.masked_adj[self.graph_idx] 782 | L = D - m_adj 783 | pred_label_t = torch.tensor(pred_label, dtype=torch.float) 784 | if self.args.gpu: 785 | pred_label_t = pred_label_t.cuda() 786 | L = L.cuda() 787 | if self.graph_mode: 788 | lap_loss = 0 789 | else: 790 | lap_loss = (self.coeffs["lap"] 791 | * (pred_label_t @ L @ pred_label_t) 792 | / self.adj.numel() 793 | ) 794 | 795 | # grad 796 | # adj 797 | # adj_grad, x_grad = self.adj_feat_grad(node_idx, pred_label_node) 798 | # adj_grad = adj_grad[self.graph_idx] 799 | # x_grad = x_grad[self.graph_idx] 800 | # if self.args.gpu: 801 | # adj_grad = adj_grad.cuda() 802 | # grad_loss = self.coeffs['grad'] * -torch.mean(torch.abs(adj_grad) * mask) 803 | 804 | # feat 805 | # x_grad_sum = torch.sum(x_grad, 1) 806 | # grad_feat_loss = self.coeffs['featgrad'] * -torch.mean(x_grad_sum * mask) 807 | 808 | loss = pred_loss + size_loss + lap_loss + mask_ent_loss + feat_size_loss 809 | if self.writer is not None: 810 | self.writer.add_scalar("optimization/size_loss", size_loss, epoch) 811 | self.writer.add_scalar("optimization/feat_size_loss", feat_size_loss, epoch) 812 | self.writer.add_scalar("optimization/mask_ent_loss", mask_ent_loss, epoch) 813 | self.writer.add_scalar( 814 | "optimization/feat_mask_ent_loss", mask_ent_loss, epoch 815 | ) 816 | # self.writer.add_scalar('optimization/grad_loss', grad_loss, epoch) 817 | self.writer.add_scalar("optimization/pred_loss", pred_loss, epoch) 818 | self.writer.add_scalar("optimization/lap_loss", lap_loss, epoch) 819 | self.writer.add_scalar("optimization/overall_loss", loss, epoch) 820 | return loss 821 | 822 | def log_mask(self, epoch): 823 | plt.switch_backend("agg") 824 | fig = plt.figure(figsize=(4, 3), dpi=400) 825 | plt.imshow(self.mask.cpu().detach().numpy(), cmap=plt.get_cmap("BuPu")) 826 | cbar = plt.colorbar() 827 | cbar.solids.set_edgecolor("face") 828 | 829 | plt.tight_layout() 830 | fig.canvas.draw() 831 | self.writer.add_image( 832 | "mask/mask", tensorboardX.utils.figure_to_image(fig), epoch 833 | ) 834 | 835 | # fig = plt.figure(figsize=(4,3), dpi=400) 836 | # plt.imshow(self.feat_mask.cpu().detach().numpy()[:,np.newaxis], cmap=plt.get_cmap('BuPu')) 837 | # cbar = plt.colorbar() 838 | # cbar.solids.set_edgecolor("face") 839 | 840 | # plt.tight_layout() 841 | # fig.canvas.draw() 842 | # self.writer.add_image('mask/feat_mask', tensorboardX.utils.figure_to_image(fig), epoch) 843 | io_utils.log_matrix( 844 | self.writer, torch.sigmoid(self.feat_mask), "mask/feat_mask", epoch 845 | ) 846 | 847 | fig = plt.figure(figsize=(4, 3), dpi=400) 848 | # use [0] to remove the batch dim 849 | plt.imshow(self.masked_adj[0].cpu().detach().numpy(), cmap=plt.get_cmap("BuPu")) 850 | cbar = plt.colorbar() 851 | cbar.solids.set_edgecolor("face") 852 | 853 | plt.tight_layout() 854 | fig.canvas.draw() 855 | self.writer.add_image( 856 | "mask/adj", tensorboardX.utils.figure_to_image(fig), epoch 857 | ) 858 | 859 | if self.args.mask_bias: 860 | fig = plt.figure(figsize=(4, 3), dpi=400) 861 | # use [0] to remove the batch dim 862 | plt.imshow(self.mask_bias.cpu().detach().numpy(), cmap=plt.get_cmap("BuPu")) 863 | cbar = plt.colorbar() 864 | cbar.solids.set_edgecolor("face") 865 | 866 | plt.tight_layout() 867 | fig.canvas.draw() 868 | self.writer.add_image( 869 | "mask/bias", tensorboardX.utils.figure_to_image(fig), epoch 870 | ) 871 | 872 | def log_adj_grad(self, node_idx, pred_label, epoch, label=None): 873 | log_adj = False 874 | 875 | if self.graph_mode: 876 | predicted_label = pred_label 877 | # adj_grad, x_grad = torch.abs(self.adj_feat_grad(node_idx, predicted_label)[0])[0] 878 | adj_grad, x_grad = self.adj_feat_grad(node_idx, predicted_label) 879 | adj_grad = torch.abs(adj_grad)[0] 880 | x_grad = torch.sum(x_grad[0], 0, keepdim=True).t() 881 | else: 882 | predicted_label = pred_label[node_idx] 883 | # adj_grad = torch.abs(self.adj_feat_grad(node_idx, predicted_label)[0])[self.graph_idx] 884 | adj_grad, x_grad = self.adj_feat_grad(node_idx, predicted_label) 885 | adj_grad = torch.abs(adj_grad)[self.graph_idx] 886 | x_grad = x_grad[self.graph_idx][node_idx][:, np.newaxis] 887 | # x_grad = torch.sum(x_grad[self.graph_idx], 0, keepdim=True).t() 888 | adj_grad = (adj_grad + adj_grad.t()) / 2 889 | adj_grad = (adj_grad * self.adj).squeeze() 890 | if log_adj: 891 | io_utils.log_matrix(self.writer, adj_grad, "grad/adj_masked", epoch) 892 | self.adj.requires_grad = False 893 | io_utils.log_matrix(self.writer, self.adj.squeeze(), "grad/adj_orig", epoch) 894 | 895 | masked_adj = self.masked_adj[0].cpu().detach().numpy() 896 | 897 | # only for graph mode since many node neighborhoods for syn tasks are relatively large for 898 | # visualization 899 | if self.graph_mode: 900 | G = io_utils.denoise_graph( 901 | masked_adj, node_idx, feat=self.x[0], threshold=None, max_component=False 902 | ) 903 | io_utils.log_graph( 904 | self.writer, 905 | G, 906 | name="grad/graph_orig", 907 | epoch=epoch, 908 | identify_self=False, 909 | label_node_feat=True, 910 | nodecolor="feat", 911 | edge_vmax=None, 912 | args=self.args, 913 | ) 914 | io_utils.log_matrix(self.writer, x_grad, "grad/feat", epoch) 915 | 916 | adj_grad = adj_grad.detach().numpy() 917 | if self.graph_mode: 918 | print("GRAPH model") 919 | G = io_utils.denoise_graph( 920 | adj_grad, 921 | node_idx, 922 | feat=self.x[0], 923 | threshold=0.0003, # threshold_num=20, 924 | max_component=True, 925 | ) 926 | io_utils.log_graph( 927 | self.writer, 928 | G, 929 | name="grad/graph", 930 | epoch=epoch, 931 | identify_self=False, 932 | label_node_feat=True, 933 | nodecolor="feat", 934 | edge_vmax=None, 935 | args=self.args, 936 | ) 937 | else: 938 | # G = io_utils.denoise_graph(adj_grad, node_idx, label=label, threshold=0.5) 939 | G = io_utils.denoise_graph(adj_grad, node_idx, threshold_num=12) 940 | io_utils.log_graph( 941 | self.writer, G, name="grad/graph", epoch=epoch, args=self.args 942 | ) 943 | 944 | # if graph attention, also visualize att 945 | 946 | def log_masked_adj(self, node_idx, epoch, name="mask/graph", label=None): 947 | # use [0] to remove the batch dim 948 | masked_adj = self.masked_adj[0].cpu().detach().numpy() 949 | if self.graph_mode: 950 | G = io_utils.denoise_graph( 951 | masked_adj, 952 | node_idx, 953 | feat=self.x[0], 954 | threshold=0.2, # threshold_num=20, 955 | max_component=True, 956 | ) 957 | io_utils.log_graph( 958 | self.writer, 959 | G, 960 | name=name, 961 | identify_self=False, 962 | nodecolor="feat", 963 | epoch=epoch, 964 | label_node_feat=True, 965 | edge_vmax=None, 966 | args=self.args, 967 | ) 968 | else: 969 | G = io_utils.denoise_graph( 970 | masked_adj, node_idx, threshold_num=12, max_component=True 971 | ) 972 | io_utils.log_graph( 973 | self.writer, 974 | G, 975 | name=name, 976 | identify_self=True, 977 | nodecolor="label", 978 | epoch=epoch, 979 | edge_vmax=None, 980 | args=self.args, 981 | ) 982 | 983 | -------------------------------------------------------------------------------- /explainer_main.py: -------------------------------------------------------------------------------- 1 | """ explainer_main.py 2 | 3 | Main user interface for the explainer module. 4 | """ 5 | import argparse 6 | import os 7 | 8 | import sklearn.metrics as metrics 9 | 10 | from tensorboardX import SummaryWriter 11 | 12 | import pickle 13 | import shutil 14 | import torch 15 | 16 | import models 17 | import utils.io_utils as io_utils 18 | import utils.parser_utils as parser_utils 19 | from explainer import explain 20 | 21 | 22 | 23 | def arg_parse(): 24 | parser = argparse.ArgumentParser(description="GNN Explainer arguments.") 25 | io_parser = parser.add_mutually_exclusive_group(required=False) 26 | io_parser.add_argument("--dataset", dest="dataset", help="Input dataset.") 27 | benchmark_parser = io_parser.add_argument_group() 28 | benchmark_parser.add_argument( 29 | "--bmname", dest="bmname", help="Name of the benchmark dataset" 30 | ) 31 | io_parser.add_argument("--pkl", dest="pkl_fname", help="Name of the pkl data file") 32 | 33 | parser_utils.parse_optimizer(parser) 34 | 35 | parser.add_argument("--clean-log", action="store_true", help="If true, cleans the specified log directory before running.") 36 | parser.add_argument("--logdir", dest="logdir", help="Tensorboard log directory") 37 | parser.add_argument("--ckptdir", dest="ckptdir", help="Model checkpoint directory") 38 | parser.add_argument("--cuda", dest="cuda", help="CUDA.") 39 | parser.add_argument( 40 | "--gpu", 41 | dest="gpu", 42 | action="store_const", 43 | const=True, 44 | default=False, 45 | help="whether to use GPU.", 46 | ) 47 | parser.add_argument( 48 | "--epochs", dest="num_epochs", type=int, help="Number of epochs to train." 49 | ) 50 | parser.add_argument( 51 | "--hidden-dim", dest="hidden_dim", type=int, help="Hidden dimension" 52 | ) 53 | parser.add_argument( 54 | "--output-dim", dest="output_dim", type=int, help="Output dimension" 55 | ) 56 | parser.add_argument( 57 | "--num-gc-layers", 58 | dest="num_gc_layers", 59 | type=int, 60 | help="Number of graph convolution layers before each pooling", 61 | ) 62 | parser.add_argument( 63 | "--bn", 64 | dest="bn", 65 | action="store_const", 66 | const=True, 67 | default=False, 68 | help="Whether batch normalization is used", 69 | ) 70 | parser.add_argument("--dropout", dest="dropout", type=float, help="Dropout rate.") 71 | parser.add_argument( 72 | "--nobias", 73 | dest="bias", 74 | action="store_const", 75 | const=False, 76 | default=True, 77 | help="Whether to add bias. Default to True.", 78 | ) 79 | parser.add_argument( 80 | "--no-writer", 81 | dest="writer", 82 | action="store_const", 83 | const=False, 84 | default=True, 85 | help="Whether to add bias. Default to True.", 86 | ) 87 | # Explainer 88 | parser.add_argument("--mask-act", dest="mask_act", type=str, help="sigmoid, ReLU.") 89 | parser.add_argument( 90 | "--mask-bias", 91 | dest="mask_bias", 92 | action="store_const", 93 | const=True, 94 | default=False, 95 | help="Whether to add bias. Default to True.", 96 | ) 97 | parser.add_argument( 98 | "--explain-node", dest="explain_node", type=int, help="Node to explain." 99 | ) 100 | parser.add_argument( 101 | "--graph-idx", dest="graph_idx", type=int, help="Graph to explain." 102 | ) 103 | parser.add_argument( 104 | "--graph-mode", 105 | dest="graph_mode", 106 | action="store_const", 107 | const=True, 108 | default=False, 109 | help="whether to run Explainer on Graph Classification task.", 110 | ) 111 | parser.add_argument( 112 | "--multigraph-class", 113 | dest="multigraph_class", 114 | type=int, 115 | help="whether to run Explainer on multiple Graphs from the Classification task for examples in the same class.", 116 | ) 117 | parser.add_argument( 118 | "--multinode-class", 119 | dest="multinode_class", 120 | type=int, 121 | help="whether to run Explainer on multiple nodes from the Classification task for examples in the same class.", 122 | ) 123 | parser.add_argument( 124 | "--align-steps", 125 | dest="align_steps", 126 | type=int, 127 | help="Number of iterations to find P, the alignment matrix.", 128 | ) 129 | 130 | parser.add_argument( 131 | "--method", dest="method", type=str, help="Method. Possible values: base, att." 132 | ) 133 | parser.add_argument( 134 | "--name-suffix", dest="name_suffix", help="suffix added to the output filename" 135 | ) 136 | parser.add_argument( 137 | "--explainer-suffix", 138 | dest="explainer_suffix", 139 | help="suffix added to the explainer log", 140 | ) 141 | 142 | # TODO: Check argument usage 143 | parser.set_defaults( 144 | logdir="log", 145 | ckptdir="ckpt", 146 | dataset="syn1", 147 | opt="adam", 148 | opt_scheduler="none", 149 | cuda="0", 150 | lr=0.1, 151 | clip=2.0, 152 | batch_size=20, 153 | num_epochs=100, 154 | hidden_dim=20, 155 | output_dim=20, 156 | num_gc_layers=3, 157 | dropout=0.0, 158 | method="base", 159 | name_suffix="", 160 | explainer_suffix="", 161 | align_steps=1000, 162 | explain_node=None, 163 | graph_idx=-1, 164 | mask_act="sigmoid", 165 | multigraph_class=-1, 166 | multinode_class=-1, 167 | ) 168 | return parser.parse_args() 169 | 170 | 171 | def main(): 172 | # Load a configuration 173 | prog_args = arg_parse() 174 | 175 | if prog_args.gpu: 176 | os.environ["CUDA_VISIBLE_DEVICES"] = prog_args.cuda 177 | print("CUDA", prog_args.cuda) 178 | else: 179 | print("Using CPU") 180 | 181 | # Configure the logging directory 182 | if prog_args.writer: 183 | path = os.path.join(prog_args.logdir, io_utils.gen_explainer_prefix(prog_args)) 184 | if os.path.isdir(path) and prog_args.clean_log: 185 | print('Removing existing log dir: ', path) 186 | if not input("Are you sure you want to remove this directory? (y/n): ").lower().strip()[:1] == "y": sys.exit(1) 187 | shutil.rmtree(path) 188 | writer = SummaryWriter(path) 189 | else: 190 | writer = None 191 | 192 | # Load a model checkpoint 193 | ckpt = io_utils.load_ckpt(prog_args) 194 | cg_dict = ckpt["cg"] # get computation graph 195 | input_dim = cg_dict["feat"].shape[2] 196 | num_classes = cg_dict["pred"].shape[2] 197 | print("Loaded model from {}".format(prog_args.ckptdir)) 198 | print("input dim: ", input_dim, "; num classes: ", num_classes) 199 | 200 | # Determine explainer mode 201 | graph_mode = ( 202 | prog_args.graph_mode 203 | or prog_args.multigraph_class >= 0 204 | or prog_args.graph_idx >= 0 205 | ) 206 | 207 | # build model 208 | print("Method: ", prog_args.method) 209 | if graph_mode: 210 | # Explain Graph prediction 211 | model = models.GcnEncoderGraph( 212 | input_dim=input_dim, 213 | hidden_dim=prog_args.hidden_dim, 214 | embedding_dim=prog_args.output_dim, 215 | label_dim=num_classes, 216 | num_layers=prog_args.num_gc_layers, 217 | bn=prog_args.bn, 218 | args=prog_args, 219 | ) 220 | else: 221 | if prog_args.dataset == "ppi_essential": 222 | # class weight in CE loss for handling imbalanced label classes 223 | prog_args.loss_weight = torch.tensor([1.0, 5.0], dtype=torch.float).cuda() 224 | # Explain Node prediction 225 | model = models.GcnEncoderNode( 226 | input_dim=input_dim, 227 | hidden_dim=prog_args.hidden_dim, 228 | embedding_dim=prog_args.output_dim, 229 | label_dim=num_classes, 230 | num_layers=prog_args.num_gc_layers, 231 | bn=prog_args.bn, 232 | args=prog_args, 233 | ) 234 | if prog_args.gpu: 235 | model = model.cuda() 236 | # load state_dict (obtained by model.state_dict() when saving checkpoint) 237 | model.load_state_dict(ckpt["model_state"]) 238 | 239 | # Create explainer 240 | explainer = explain.Explainer( 241 | model=model, 242 | adj=cg_dict["adj"], 243 | feat=cg_dict["feat"], 244 | label=cg_dict["label"], 245 | pred=cg_dict["pred"], 246 | train_idx=cg_dict["train_idx"], 247 | args=prog_args, 248 | writer=writer, 249 | print_training=True, 250 | graph_mode=graph_mode, 251 | graph_idx=prog_args.graph_idx, 252 | ) 253 | 254 | # TODO: API should definitely be cleaner 255 | # Let's define exactly which modes we support 256 | # We could even move each mode to a different method (even file) 257 | if prog_args.explain_node is not None: 258 | explainer.explain(prog_args.explain_node, unconstrained=False) 259 | elif graph_mode: 260 | if prog_args.multigraph_class >= 0: 261 | print(cg_dict["label"]) 262 | # only run for graphs with label specified by multigraph_class 263 | labels = cg_dict["label"].numpy() 264 | graph_indices = [] 265 | for i, l in enumerate(labels): 266 | if l == prog_args.multigraph_class: 267 | graph_indices.append(i) 268 | if len(graph_indices) > 30: 269 | break 270 | print( 271 | "Graph indices for label ", 272 | prog_args.multigraph_class, 273 | " : ", 274 | graph_indices, 275 | ) 276 | explainer.explain_graphs(graph_indices=graph_indices) 277 | 278 | elif prog_args.graph_idx == -1: 279 | # just run for a customized set of indices 280 | explainer.explain_graphs(graph_indices=[1, 2, 3, 4]) 281 | else: 282 | explainer.explain( 283 | node_idx=0, 284 | graph_idx=prog_args.graph_idx, 285 | graph_mode=True, 286 | unconstrained=False, 287 | ) 288 | io_utils.plot_cmap_tb(writer, "tab20", 20, "tab20_cmap") 289 | else: 290 | if prog_args.multinode_class >= 0: 291 | print(cg_dict["label"]) 292 | # only run for nodes with label specified by multinode_class 293 | labels = cg_dict["label"][0] # already numpy matrix 294 | 295 | node_indices = [] 296 | for i, l in enumerate(labels): 297 | if len(node_indices) > 4: 298 | break 299 | if l == prog_args.multinode_class: 300 | node_indices.append(i) 301 | print( 302 | "Node indices for label ", 303 | prog_args.multinode_class, 304 | " : ", 305 | node_indices, 306 | ) 307 | explainer.explain_nodes(node_indices, prog_args) 308 | 309 | else: 310 | # explain a set of nodes 311 | masked_adj = explainer.explain_nodes_gnn_stats( 312 | range(400, 700, 5), prog_args 313 | ) 314 | 315 | if __name__ == "__main__": 316 | main() 317 | 318 | -------------------------------------------------------------------------------- /gengraph.py: -------------------------------------------------------------------------------- 1 | """gengraph.py 2 | 3 | Generating and manipulaton the synthetic graphs needed for the paper's experiments. 4 | """ 5 | 6 | import os 7 | 8 | from matplotlib import pyplot as plt 9 | from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas 10 | from matplotlib.figure import Figure 11 | import matplotlib.colors as colors 12 | 13 | # Set matplotlib backend to file writing 14 | plt.switch_backend("agg") 15 | 16 | import networkx as nx 17 | 18 | import numpy as np 19 | 20 | from tensorboardX import SummaryWriter 21 | 22 | from utils import synthetic_structsim 23 | from utils import featgen 24 | import utils.io_utils as io_utils 25 | 26 | 27 | #################################### 28 | # 29 | # Experiment utilities 30 | # 31 | #################################### 32 | def perturb(graph_list, p): 33 | """ Perturb the list of (sparse) graphs by adding/removing edges. 34 | Args: 35 | p: proportion of added edges based on current number of edges. 36 | Returns: 37 | A list of graphs that are perturbed from the original graphs. 38 | """ 39 | perturbed_graph_list = [] 40 | for G_original in graph_list: 41 | G = G_original.copy() 42 | edge_count = int(G.number_of_edges() * p) 43 | # randomly add the edges between a pair of nodes without an edge. 44 | for _ in range(edge_count): 45 | while True: 46 | u = np.random.randint(0, G.number_of_nodes()) 47 | v = np.random.randint(0, G.number_of_nodes()) 48 | if (not G.has_edge(u, v)) and (u != v): 49 | break 50 | G.add_edge(u, v) 51 | perturbed_graph_list.append(G) 52 | return perturbed_graph_list 53 | 54 | 55 | def join_graph(G1, G2, n_pert_edges): 56 | """ Join two graphs along matching nodes, then perturb the resulting graph. 57 | Args: 58 | G1, G2: Networkx graphs to be joined. 59 | n_pert_edges: number of perturbed edges. 60 | Returns: 61 | A new graph, result of merging and perturbing G1 and G2. 62 | """ 63 | assert n_pert_edges > 0 64 | F = nx.compose(G1, G2) 65 | edge_cnt = 0 66 | while edge_cnt < n_pert_edges: 67 | node_1 = np.random.choice(G1.nodes()) 68 | node_2 = np.random.choice(G2.nodes()) 69 | F.add_edge(node_1, node_2) 70 | edge_cnt += 1 71 | return F 72 | 73 | 74 | def preprocess_input_graph(G, labels, normalize_adj=False): 75 | """ Load an existing graph to be converted for the experiments. 76 | Args: 77 | G: Networkx graph to be loaded. 78 | labels: Associated node labels. 79 | normalize_adj: Should the method return a normalized adjacency matrix. 80 | Returns: 81 | A dictionary containing adjacency, node features and labels 82 | """ 83 | adj = np.array(nx.to_numpy_matrix(G)) 84 | if normalize_adj: 85 | sqrt_deg = np.diag(1.0 / np.sqrt(np.sum(adj, axis=0, dtype=float).squeeze())) 86 | adj = np.matmul(np.matmul(sqrt_deg, adj), sqrt_deg) 87 | 88 | existing_node = list(G.nodes)[-1] 89 | feat_dim = G.nodes[existing_node]["feat"].shape[0] 90 | f = np.zeros((G.number_of_nodes(), feat_dim), dtype=float) 91 | for i, u in enumerate(G.nodes()): 92 | f[i, :] = G.nodes[u]["feat"] 93 | 94 | # add batch dim 95 | adj = np.expand_dims(adj, axis=0) 96 | f = np.expand_dims(f, axis=0) 97 | labels = np.expand_dims(labels, axis=0) 98 | return {"adj": adj, "feat": f, "labels": labels} 99 | 100 | 101 | #################################### 102 | # 103 | # Generating synthetic graphs 104 | # 105 | ################################### 106 | def gen_syn1(nb_shapes=80, width_basis=300, feature_generator=None, m=5): 107 | """ Synthetic Graph #1: 108 | 109 | Start with Barabasi-Albert graph and attach house-shaped subgraphs. 110 | 111 | Args: 112 | nb_shapes : The number of shapes (here 'houses') that should be added to the base graph. 113 | width_basis : The width of the basis graph (here 'Barabasi-Albert' random graph). 114 | feature_generator : A `FeatureGenerator` for node features. If `None`, add constant features to nodes. 115 | m : number of edges to attach to existing node (for BA graph) 116 | 117 | Returns: 118 | G : A networkx graph 119 | role_id : A list with length equal to number of nodes in the entire graph (basis 120 | : + shapes). role_id[i] is the ID of the role of node i. It is the label. 121 | name : A graph identifier 122 | """ 123 | basis_type = "ba" 124 | list_shapes = [["house"]] * nb_shapes 125 | 126 | plt.figure(figsize=(8, 6), dpi=300) 127 | 128 | G, role_id, _ = synthetic_structsim.build_graph( 129 | width_basis, basis_type, list_shapes, start=0, m=5 130 | ) 131 | G = perturb([G], 0.01)[0] 132 | 133 | if feature_generator is None: 134 | feature_generator = featgen.ConstFeatureGen(1) 135 | feature_generator.gen_node_features(G) 136 | 137 | name = basis_type + "_" + str(width_basis) + "_" + str(nb_shapes) 138 | return G, role_id, name 139 | 140 | 141 | def gen_syn2(nb_shapes=100, width_basis=350): 142 | """ Synthetic Graph #2: 143 | 144 | Start with Barabasi-Albert graph and add node features indicative of a community label. 145 | 146 | Args: 147 | nb_shapes : The number of shapes (here 'houses') that should be added to the base graph. 148 | width_basis : The width of the basis graph (here 'Barabasi-Albert' random graph). 149 | 150 | Returns: 151 | G : A networkx graph 152 | label : Label of the nodes (determined by role_id and community) 153 | name : A graph identifier 154 | """ 155 | basis_type = "ba" 156 | 157 | random_mu = [0.0] * 8 158 | random_sigma = [1.0] * 8 159 | 160 | # Create two grids 161 | mu_1, sigma_1 = np.array([-1.0] * 2 + random_mu), np.array([0.5] * 2 + random_sigma) 162 | mu_2, sigma_2 = np.array([1.0] * 2 + random_mu), np.array([0.5] * 2 + random_sigma) 163 | feat_gen_G1 = featgen.GaussianFeatureGen(mu=mu_1, sigma=sigma_1) 164 | feat_gen_G2 = featgen.GaussianFeatureGen(mu=mu_2, sigma=sigma_2) 165 | G1, role_id1, name = gen_syn1(feature_generator=feat_gen_G1, m=4) 166 | G2, role_id2, name = gen_syn1(feature_generator=feat_gen_G2, m=4) 167 | G1_size = G1.number_of_nodes() 168 | num_roles = max(role_id1) + 1 169 | role_id2 = [r + num_roles for r in role_id2] 170 | label = role_id1 + role_id2 171 | 172 | # Edit node ids to avoid collisions on join 173 | g1_map = {n: i for i, n in enumerate(G1.nodes())} 174 | G1 = nx.relabel_nodes(G1, g1_map) 175 | g2_map = {n: i + G1_size for i, n in enumerate(G2.nodes())} 176 | G2 = nx.relabel_nodes(G2, g2_map) 177 | 178 | # Join 179 | n_pert_edges = width_basis 180 | G = join_graph(G1, G2, n_pert_edges) 181 | 182 | name = basis_type + "_" + str(width_basis) + "_" + str(nb_shapes) + "_2comm" 183 | 184 | return G, label, name 185 | 186 | 187 | def gen_syn3(nb_shapes=80, width_basis=300, feature_generator=None, m=5): 188 | """ Synthetic Graph #3: 189 | 190 | Start with Barabasi-Albert graph and attach grid-shaped subgraphs. 191 | 192 | Args: 193 | nb_shapes : The number of shapes (here 'grid') that should be added to the base graph. 194 | width_basis : The width of the basis graph (here 'Barabasi-Albert' random graph). 195 | feature_generator : A `FeatureGenerator` for node features. If `None`, add constant features to nodes. 196 | m : number of edges to attach to existing node (for BA graph) 197 | 198 | Returns: 199 | G : A networkx graph 200 | role_id : Role ID for each node in synthetic graph. 201 | name : A graph identifier 202 | """ 203 | basis_type = "ba" 204 | list_shapes = [["grid", 3]] * nb_shapes 205 | 206 | plt.figure(figsize=(8, 6), dpi=300) 207 | 208 | G, role_id, _ = synthetic_structsim.build_graph( 209 | width_basis, basis_type, list_shapes, start=0, m=5 210 | ) 211 | G = perturb([G], 0.01)[0] 212 | 213 | if feature_generator is None: 214 | feature_generator = featgen.ConstFeatureGen(1) 215 | feature_generator.gen_node_features(G) 216 | 217 | name = basis_type + "_" + str(width_basis) + "_" + str(nb_shapes) 218 | return G, role_id, name 219 | 220 | 221 | def gen_syn4(nb_shapes=60, width_basis=8, feature_generator=None, m=4): 222 | """ Synthetic Graph #4: 223 | 224 | Start with a tree and attach cycle-shaped subgraphs. 225 | 226 | Args: 227 | nb_shapes : The number of shapes (here 'houses') that should be added to the base graph. 228 | width_basis : The width of the basis graph (here a random 'Tree'). 229 | feature_generator : A `FeatureGenerator` for node features. If `None`, add constant features to nodes. 230 | m : The tree depth. 231 | 232 | Returns: 233 | G : A networkx graph 234 | role_id : Role ID for each node in synthetic graph 235 | name : A graph identifier 236 | """ 237 | basis_type = "tree" 238 | list_shapes = [["cycle", 6]] * nb_shapes 239 | 240 | fig = plt.figure(figsize=(8, 6), dpi=300) 241 | 242 | G, role_id, plugins = synthetic_structsim.build_graph( 243 | width_basis, basis_type, list_shapes, start=0 244 | ) 245 | G = perturb([G], 0.01)[0] 246 | 247 | if feature_generator is None: 248 | feature_generator = featgen.ConstFeatureGen(1) 249 | feature_generator.gen_node_features(G) 250 | 251 | name = basis_type + "_" + str(width_basis) + "_" + str(nb_shapes) 252 | 253 | path = os.path.join("log/syn4_base_h20_o20") 254 | writer = SummaryWriter(path) 255 | io_utils.log_graph(writer, G, "graph/full") 256 | 257 | return G, role_id, name 258 | 259 | 260 | def gen_syn5(nb_shapes=80, width_basis=8, feature_generator=None, m=3): 261 | """ Synthetic Graph #5: 262 | 263 | Start with a tree and attach grid-shaped subgraphs. 264 | 265 | Args: 266 | nb_shapes : The number of shapes (here 'houses') that should be added to the base graph. 267 | width_basis : The width of the basis graph (here a random 'grid'). 268 | feature_generator : A `FeatureGenerator` for node features. If `None`, add constant features to nodes. 269 | m : The tree depth. 270 | 271 | Returns: 272 | G : A networkx graph 273 | role_id : Role ID for each node in synthetic graph 274 | name : A graph identifier 275 | """ 276 | basis_type = "tree" 277 | list_shapes = [["grid", m]] * nb_shapes 278 | 279 | plt.figure(figsize=(8, 6), dpi=300) 280 | 281 | G, role_id, _ = synthetic_structsim.build_graph( 282 | width_basis, basis_type, list_shapes, start=0 283 | ) 284 | G = perturb([G], 0.1)[0] 285 | 286 | if feature_generator is None: 287 | feature_generator = featgen.ConstFeatureGen(1) 288 | feature_generator.gen_node_features(G) 289 | 290 | name = basis_type + "_" + str(width_basis) + "_" + str(nb_shapes) 291 | 292 | path = os.path.join("log/syn5_base_h20_o20") 293 | writer = SummaryWriter(path) 294 | 295 | return G, role_id, name 296 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | 3 | import torch 4 | import cv2 5 | import numpy as np 6 | 7 | from explainer import explain 8 | 9 | from utils import math_utils 10 | from utils import io_utils 11 | 12 | use_cuda = torch.cuda.is_available() 13 | 14 | 15 | # TODO: Is this still used? 16 | 17 | ##### 18 | # 19 | # 1) Load trained GNN model 20 | # 2) Load a query computation graph 21 | # 22 | ##### 23 | MODEL_PATH = "gcn-vanilla.pt" 24 | CG_PATH = "1.pt" 25 | model = io_utils.load_model(MODEL_PATH) 26 | original_cg = io_utils.load_cg(CG_PATH) 27 | 28 | 29 | ##### 30 | # 31 | # Set parameters of explainer 32 | # 33 | ##### 34 | tv_beta = 3 35 | learning_rate = 0.1 36 | max_iterations = 500 37 | l1_coeff = 0.01 38 | tv_coeff = 0.2 39 | 40 | 41 | # Initialize cg mask 42 | blurred_cg1 = cv2.GaussianBlur(original_cg, (11, 11), 5) 43 | blurred_cg2 = np.float32(cv2.medianBlur(original_cg, 11)) / 255 44 | mask_init = np.ones((28, 28), dtype=np.float32) 45 | 46 | # Convert to torch variables 47 | cg = io_utils.preprocess_cg(original_cg) 48 | blurred_cg = io_utils.preprocess_cg(blurred_cg2) 49 | mask = io_utils.numpy_to_torch(mask_init) 50 | 51 | if use_cuda: 52 | upsample = torch.nn.UpsamplingBilinear2d(size=(224, 224)).cuda() 53 | else: 54 | upsample = torch.nn.UpsamplingBilinear2d(size=(224, 224)) 55 | optimizer = torch.optim.Adam([mask], lr=learning_rate) 56 | 57 | target = torch.nn.Softmax()(model(cg)) 58 | category = np.argmax(target.cpu().data.numpy()) 59 | print("Category with highest probability", category) 60 | print("Optimizing.. ") 61 | 62 | for i in range(max_iterations): 63 | upsampled_mask = upsample(mask) 64 | 65 | # Use the mask to perturb the input computation graph 66 | perturbed_input = cg.mul(upsampled_mask) + blurred_cg.mul(1 - upsampled_mask) 67 | 68 | noise = np.zeros((224, 224, 3), dtype=np.float32) 69 | cv2.randn(noise, 0, 0.2) 70 | noise = io_utils.numpy_to_torch(noise) 71 | perturbed_input = perturbed_input + noise 72 | 73 | outputs = torch.nn.Softmax()(model(perturbed_input)) 74 | loss = ( 75 | l1_coeff * torch.mean(torch.abs(1 - mask)) 76 | + tv_coeff * math_utils.tv_norm(mask, tv_beta) 77 | + outputs[0, category] 78 | ) 79 | 80 | optimizer.zero_grad() 81 | loss.backward() 82 | optimizer.step() 83 | 84 | # Optional: clamping seems to give better results 85 | mask.data.clamp_(0, 1) 86 | 87 | upsampled_mask = upsample(mask) 88 | io_utils.save(upsampled_mask) 89 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import init 4 | import torch.nn.functional as F 5 | 6 | import numpy as np 7 | 8 | # GCN basic operation 9 | class GraphConv(nn.Module): 10 | def __init__( 11 | self, 12 | input_dim, 13 | output_dim, 14 | add_self=False, 15 | normalize_embedding=False, 16 | dropout=0.0, 17 | bias=True, 18 | gpu=True, 19 | att=False, 20 | ): 21 | super(GraphConv, self).__init__() 22 | self.att = att 23 | self.add_self = add_self 24 | self.dropout = dropout 25 | if dropout > 0.001: 26 | self.dropout_layer = nn.Dropout(p=dropout) 27 | self.normalize_embedding = normalize_embedding 28 | self.input_dim = input_dim 29 | self.output_dim = output_dim 30 | if not gpu: 31 | self.weight = nn.Parameter(torch.FloatTensor(input_dim, output_dim)) 32 | if add_self: 33 | self.self_weight = nn.Parameter( 34 | torch.FloatTensor(input_dim, output_dim) 35 | ) 36 | if att: 37 | self.att_weight = nn.Parameter(torch.FloatTensor(input_dim, input_dim)) 38 | else: 39 | self.weight = nn.Parameter(torch.FloatTensor(input_dim, output_dim).cuda()) 40 | if add_self: 41 | self.self_weight = nn.Parameter( 42 | torch.FloatTensor(input_dim, output_dim).cuda() 43 | ) 44 | if att: 45 | self.att_weight = nn.Parameter( 46 | torch.FloatTensor(input_dim, input_dim).cuda() 47 | ) 48 | if bias: 49 | if not gpu: 50 | self.bias = nn.Parameter(torch.FloatTensor(output_dim)) 51 | else: 52 | self.bias = nn.Parameter(torch.FloatTensor(output_dim).cuda()) 53 | else: 54 | self.bias = None 55 | 56 | # self.softmax = nn.Softmax(dim=-1) 57 | 58 | def forward(self, x, adj): 59 | if self.dropout > 0.001: 60 | x = self.dropout_layer(x) 61 | # deg = torch.sum(adj, -1, keepdim=True) 62 | if self.att: 63 | x_att = torch.matmul(x, self.att_weight) 64 | # import pdb 65 | # pdb.set_trace() 66 | att = x_att @ x_att.permute(0, 2, 1) 67 | # att = self.softmax(att) 68 | adj = adj * att 69 | 70 | y = torch.matmul(adj, x) 71 | y = torch.matmul(y, self.weight) 72 | if self.add_self: 73 | self_emb = torch.matmul(x, self.self_weight) 74 | y += self_emb 75 | if self.bias is not None: 76 | y = y + self.bias 77 | if self.normalize_embedding: 78 | y = F.normalize(y, p=2, dim=2) 79 | # print(y[0][0]) 80 | return y, adj 81 | 82 | 83 | class GcnEncoderGraph(nn.Module): 84 | def __init__( 85 | self, 86 | input_dim, 87 | hidden_dim, 88 | embedding_dim, 89 | label_dim, 90 | num_layers, 91 | pred_hidden_dims=[], 92 | concat=True, 93 | bn=True, 94 | dropout=0.0, 95 | add_self=False, 96 | args=None, 97 | ): 98 | super(GcnEncoderGraph, self).__init__() 99 | self.concat = concat 100 | add_self = add_self 101 | self.bn = bn 102 | self.num_layers = num_layers 103 | self.num_aggs = 1 104 | 105 | self.bias = True 106 | self.gpu = args.gpu 107 | if args.method == "att": 108 | self.att = True 109 | else: 110 | self.att = False 111 | if args is not None: 112 | self.bias = args.bias 113 | 114 | self.conv_first, self.conv_block, self.conv_last = self.build_conv_layers( 115 | input_dim, 116 | hidden_dim, 117 | embedding_dim, 118 | num_layers, 119 | add_self, 120 | normalize=True, 121 | dropout=dropout, 122 | ) 123 | self.act = nn.ReLU() 124 | self.label_dim = label_dim 125 | 126 | if concat: 127 | self.pred_input_dim = hidden_dim * (num_layers - 1) + embedding_dim 128 | else: 129 | self.pred_input_dim = embedding_dim 130 | self.pred_model = self.build_pred_layers( 131 | self.pred_input_dim, pred_hidden_dims, label_dim, num_aggs=self.num_aggs 132 | ) 133 | 134 | for m in self.modules(): 135 | if isinstance(m, GraphConv): 136 | init.xavier_uniform_(m.weight.data, gain=nn.init.calculate_gain("relu")) 137 | if m.att: 138 | init.xavier_uniform_( 139 | m.att_weight.data, gain=nn.init.calculate_gain("relu") 140 | ) 141 | if m.add_self: 142 | init.xavier_uniform_( 143 | m.self_weight.data, gain=nn.init.calculate_gain("relu") 144 | ) 145 | if m.bias is not None: 146 | init.constant_(m.bias.data, 0.0) 147 | 148 | def build_conv_layers( 149 | self, 150 | input_dim, 151 | hidden_dim, 152 | embedding_dim, 153 | num_layers, 154 | add_self, 155 | normalize=False, 156 | dropout=0.0, 157 | ): 158 | conv_first = GraphConv( 159 | input_dim=input_dim, 160 | output_dim=hidden_dim, 161 | add_self=add_self, 162 | normalize_embedding=normalize, 163 | bias=self.bias, 164 | gpu=self.gpu, 165 | att=self.att, 166 | ) 167 | conv_block = nn.ModuleList( 168 | [ 169 | GraphConv( 170 | input_dim=hidden_dim, 171 | output_dim=hidden_dim, 172 | add_self=add_self, 173 | normalize_embedding=normalize, 174 | dropout=dropout, 175 | bias=self.bias, 176 | gpu=self.gpu, 177 | att=self.att, 178 | ) 179 | for i in range(num_layers - 2) 180 | ] 181 | ) 182 | conv_last = GraphConv( 183 | input_dim=hidden_dim, 184 | output_dim=embedding_dim, 185 | add_self=add_self, 186 | normalize_embedding=normalize, 187 | bias=self.bias, 188 | gpu=self.gpu, 189 | att=self.att, 190 | ) 191 | return conv_first, conv_block, conv_last 192 | 193 | def build_pred_layers( 194 | self, pred_input_dim, pred_hidden_dims, label_dim, num_aggs=1 195 | ): 196 | pred_input_dim = pred_input_dim * num_aggs 197 | if len(pred_hidden_dims) == 0: 198 | pred_model = nn.Linear(pred_input_dim, label_dim) 199 | else: 200 | pred_layers = [] 201 | for pred_dim in pred_hidden_dims: 202 | pred_layers.append(nn.Linear(pred_input_dim, pred_dim)) 203 | pred_layers.append(self.act) 204 | pred_input_dim = pred_dim 205 | pred_layers.append(nn.Linear(pred_dim, label_dim)) 206 | pred_model = nn.Sequential(*pred_layers) 207 | return pred_model 208 | 209 | def construct_mask(self, max_nodes, batch_num_nodes): 210 | """ For each num_nodes in batch_num_nodes, the first num_nodes entries of the 211 | corresponding column are 1's, and the rest are 0's (to be masked out). 212 | Dimension of mask: [batch_size x max_nodes x 1] 213 | """ 214 | # masks 215 | packed_masks = [torch.ones(int(num)) for num in batch_num_nodes] 216 | batch_size = len(batch_num_nodes) 217 | out_tensor = torch.zeros(batch_size, max_nodes) 218 | for i, mask in enumerate(packed_masks): 219 | out_tensor[i, : batch_num_nodes[i]] = mask 220 | return out_tensor.unsqueeze(2).cuda() 221 | 222 | def apply_bn(self, x): 223 | """ Batch normalization of 3D tensor x 224 | """ 225 | bn_module = nn.BatchNorm1d(x.size()[1]) 226 | if self.gpu: 227 | bn_module = bn_module.cuda() 228 | return bn_module(x) 229 | 230 | def gcn_forward( 231 | self, x, adj, conv_first, conv_block, conv_last, embedding_mask=None 232 | ): 233 | 234 | """ Perform forward prop with graph convolution. 235 | Returns: 236 | Embedding matrix with dimension [batch_size x num_nodes x embedding] 237 | The embedding dim is self.pred_input_dim 238 | """ 239 | 240 | x, adj_att = conv_first(x, adj) 241 | x = self.act(x) 242 | if self.bn: 243 | x = self.apply_bn(x) 244 | x_all = [x] 245 | adj_att_all = [adj_att] 246 | # out_all = [] 247 | # out, _ = torch.max(x, dim=1) 248 | # out_all.append(out) 249 | for i in range(len(conv_block)): 250 | x, _ = conv_block[i](x, adj) 251 | x = self.act(x) 252 | if self.bn: 253 | x = self.apply_bn(x) 254 | x_all.append(x) 255 | adj_att_all.append(adj_att) 256 | x, adj_att = conv_last(x, adj) 257 | x_all.append(x) 258 | adj_att_all.append(adj_att) 259 | # x_tensor: [batch_size x num_nodes x embedding] 260 | x_tensor = torch.cat(x_all, dim=2) 261 | if embedding_mask is not None: 262 | x_tensor = x_tensor * embedding_mask 263 | self.embedding_tensor = x_tensor 264 | 265 | # adj_att_tensor: [batch_size x num_nodes x num_nodes x num_gc_layers] 266 | adj_att_tensor = torch.stack(adj_att_all, dim=3) 267 | return x_tensor, adj_att_tensor 268 | 269 | def forward(self, x, adj, batch_num_nodes=None, **kwargs): 270 | # mask 271 | max_num_nodes = adj.size()[1] 272 | if batch_num_nodes is not None: 273 | self.embedding_mask = self.construct_mask(max_num_nodes, batch_num_nodes) 274 | else: 275 | self.embedding_mask = None 276 | 277 | # conv 278 | x, adj_att = self.conv_first(x, adj) 279 | x = self.act(x) 280 | if self.bn: 281 | x = self.apply_bn(x) 282 | out_all = [] 283 | out, _ = torch.max(x, dim=1) 284 | out_all.append(out) 285 | adj_att_all = [adj_att] 286 | for i in range(self.num_layers - 2): 287 | x, adj_att = self.conv_block[i](x, adj) 288 | x = self.act(x) 289 | if self.bn: 290 | x = self.apply_bn(x) 291 | out, _ = torch.max(x, dim=1) 292 | out_all.append(out) 293 | if self.num_aggs == 2: 294 | out = torch.sum(x, dim=1) 295 | out_all.append(out) 296 | adj_att_all.append(adj_att) 297 | x, adj_att = self.conv_last(x, adj) 298 | adj_att_all.append(adj_att) 299 | # x = self.act(x) 300 | out, _ = torch.max(x, dim=1) 301 | out_all.append(out) 302 | if self.num_aggs == 2: 303 | out = torch.sum(x, dim=1) 304 | out_all.append(out) 305 | if self.concat: 306 | output = torch.cat(out_all, dim=1) 307 | else: 308 | output = out 309 | 310 | # adj_att_tensor: [batch_size x num_nodes x num_nodes x num_gc_layers] 311 | adj_att_tensor = torch.stack(adj_att_all, dim=3) 312 | 313 | self.embedding_tensor = output 314 | ypred = self.pred_model(output) 315 | # print(output.size()) 316 | return ypred, adj_att_tensor 317 | 318 | def loss(self, pred, label, type="softmax"): 319 | # softmax + CE 320 | if type == "softmax": 321 | return F.cross_entropy(pred, label, size_average=True) 322 | elif type == "margin": 323 | batch_size = pred.size()[0] 324 | label_onehot = torch.zeros(batch_size, self.label_dim).long().cuda() 325 | label_onehot.scatter_(1, label.view(-1, 1), 1) 326 | return torch.nn.MultiLabelMarginLoss()(pred, label_onehot) 327 | 328 | # return F.binary_cross_entropy(F.sigmoid(pred[:,0]), label.float()) 329 | 330 | 331 | class GcnEncoderNode(GcnEncoderGraph): 332 | def __init__( 333 | self, 334 | input_dim, 335 | hidden_dim, 336 | embedding_dim, 337 | label_dim, 338 | num_layers, 339 | pred_hidden_dims=[], 340 | concat=True, 341 | bn=True, 342 | dropout=0.0, 343 | args=None, 344 | ): 345 | super(GcnEncoderNode, self).__init__( 346 | input_dim, 347 | hidden_dim, 348 | embedding_dim, 349 | label_dim, 350 | num_layers, 351 | pred_hidden_dims, 352 | concat, 353 | bn, 354 | dropout, 355 | args=args, 356 | ) 357 | if hasattr(args, "loss_weight"): 358 | print("Loss weight: ", args.loss_weight) 359 | self.celoss = nn.CrossEntropyLoss(weight=args.loss_weight) 360 | else: 361 | self.celoss = nn.CrossEntropyLoss() 362 | 363 | def forward(self, x, adj, batch_num_nodes=None, **kwargs): 364 | # mask 365 | max_num_nodes = adj.size()[1] 366 | if batch_num_nodes is not None: 367 | embedding_mask = self.construct_mask(max_num_nodes, batch_num_nodes) 368 | else: 369 | embedding_mask = None 370 | 371 | self.adj_atts = [] 372 | self.embedding_tensor, adj_att = self.gcn_forward( 373 | x, adj, self.conv_first, self.conv_block, self.conv_last, embedding_mask 374 | ) 375 | pred = self.pred_model(self.embedding_tensor) 376 | return pred, adj_att 377 | 378 | def loss(self, pred, label): 379 | pred = torch.transpose(pred, 1, 2) 380 | return self.celoss(pred, label) 381 | 382 | 383 | class SoftPoolingGcnEncoder(GcnEncoderGraph): 384 | def __init__( 385 | self, 386 | max_num_nodes, 387 | input_dim, 388 | hidden_dim, 389 | embedding_dim, 390 | label_dim, 391 | num_layers, 392 | assign_hidden_dim, 393 | assign_ratio=0.25, 394 | assign_num_layers=-1, 395 | num_pooling=1, 396 | pred_hidden_dims=[50], 397 | concat=True, 398 | bn=True, 399 | dropout=0.0, 400 | linkpred=True, 401 | assign_input_dim=-1, 402 | args=None, 403 | ): 404 | """ 405 | Args: 406 | num_layers: number of gc layers before each pooling 407 | num_nodes: number of nodes for each graph in batch 408 | linkpred: flag to turn on link prediction side objective 409 | """ 410 | 411 | super(SoftPoolingGcnEncoder, self).__init__( 412 | input_dim, 413 | hidden_dim, 414 | embedding_dim, 415 | label_dim, 416 | num_layers, 417 | pred_hidden_dims=pred_hidden_dims, 418 | concat=concat, 419 | args=args, 420 | ) 421 | add_self = not concat 422 | self.num_pooling = num_pooling 423 | self.linkpred = linkpred 424 | self.assign_ent = True 425 | 426 | # GC 427 | self.conv_first_after_pool = [] 428 | self.conv_block_after_pool = [] 429 | self.conv_last_after_pool = [] 430 | for i in range(num_pooling): 431 | # use self to register the modules in self.modules() 432 | self.conv_first2, self.conv_block2, self.conv_last2 = self.build_conv_layers( 433 | self.pred_input_dim, 434 | hidden_dim, 435 | embedding_dim, 436 | num_layers, 437 | add_self, 438 | normalize=True, 439 | dropout=dropout, 440 | ) 441 | self.conv_first_after_pool.append(self.conv_first2) 442 | self.conv_block_after_pool.append(self.conv_block2) 443 | self.conv_last_after_pool.append(self.conv_last2) 444 | 445 | # assignment 446 | assign_dims = [] 447 | if assign_num_layers == -1: 448 | assign_num_layers = num_layers 449 | if assign_input_dim == -1: 450 | assign_input_dim = input_dim 451 | 452 | self.assign_conv_first_modules = [] 453 | self.assign_conv_block_modules = [] 454 | self.assign_conv_last_modules = [] 455 | self.assign_pred_modules = [] 456 | assign_dim = int(max_num_nodes * assign_ratio) 457 | for i in range(num_pooling): 458 | assign_dims.append(assign_dim) 459 | self.assign_conv_first, self.assign_conv_block, self.assign_conv_last = self.build_conv_layers( 460 | assign_input_dim, 461 | assign_hidden_dim, 462 | assign_dim, 463 | assign_num_layers, 464 | add_self, 465 | normalize=True, 466 | ) 467 | assign_pred_input_dim = ( 468 | assign_hidden_dim * (num_layers - 1) + assign_dim 469 | if concat 470 | else assign_dim 471 | ) 472 | self.assign_pred = self.build_pred_layers( 473 | assign_pred_input_dim, [], assign_dim, num_aggs=1 474 | ) 475 | 476 | # next pooling layer 477 | assign_input_dim = embedding_dim 478 | assign_dim = int(assign_dim * assign_ratio) 479 | 480 | self.assign_conv_first_modules.append(self.assign_conv_first) 481 | self.assign_conv_block_modules.append(self.assign_conv_block) 482 | self.assign_conv_last_modules.append(self.assign_conv_last) 483 | self.assign_pred_modules.append(self.assign_pred) 484 | 485 | self.pred_model = self.build_pred_layers( 486 | self.pred_input_dim * (num_pooling + 1), 487 | pred_hidden_dims, 488 | label_dim, 489 | num_aggs=self.num_aggs, 490 | ) 491 | 492 | for m in self.modules(): 493 | if isinstance(m, GraphConv): 494 | m.weight.data = init.xavier_uniform( 495 | m.weight.data, gain=nn.init.calculate_gain("relu") 496 | ) 497 | if m.bias is not None: 498 | m.bias.data = init.constant(m.bias.data, 0.0) 499 | 500 | def forward(self, x, adj, batch_num_nodes, **kwargs): 501 | if "assign_x" in kwargs: 502 | x_a = kwargs["assign_x"] 503 | else: 504 | x_a = x 505 | 506 | # mask 507 | max_num_nodes = adj.size()[1] 508 | if batch_num_nodes is not None: 509 | embedding_mask = self.construct_mask(max_num_nodes, batch_num_nodes) 510 | else: 511 | embedding_mask = None 512 | 513 | out_all = [] 514 | 515 | # self.assign_tensor = self.gcn_forward(x_a, adj, 516 | # self.assign_conv_first_modules[0], self.assign_conv_block_modules[0], self.assign_conv_last_modules[0], 517 | # embedding_mask) 518 | ## [batch_size x num_nodes x next_lvl_num_nodes] 519 | # self.assign_tensor = nn.Softmax(dim=-1)(self.assign_pred(self.assign_tensor)) 520 | # if embedding_mask is not None: 521 | # self.assign_tensor = self.assign_tensor * embedding_mask 522 | # [batch_size x num_nodes x embedding_dim] 523 | embedding_tensor = self.gcn_forward( 524 | x, adj, self.conv_first, self.conv_block, self.conv_last, embedding_mask 525 | ) 526 | 527 | out, _ = torch.max(embedding_tensor, dim=1) 528 | out_all.append(out) 529 | if self.num_aggs == 2: 530 | out = torch.sum(embedding_tensor, dim=1) 531 | out_all.append(out) 532 | 533 | for i in range(self.num_pooling): 534 | if batch_num_nodes is not None and i == 0: 535 | embedding_mask = self.construct_mask(max_num_nodes, batch_num_nodes) 536 | else: 537 | embedding_mask = None 538 | 539 | self.assign_tensor = self.gcn_forward( 540 | x_a, 541 | adj, 542 | self.assign_conv_first_modules[i], 543 | self.assign_conv_block_modules[i], 544 | self.assign_conv_last_modules[i], 545 | embedding_mask, 546 | ) 547 | # [batch_size x num_nodes x next_lvl_num_nodes] 548 | self.assign_tensor = nn.Softmax(dim=-1)( 549 | self.assign_pred(self.assign_tensor) 550 | ) 551 | if embedding_mask is not None: 552 | self.assign_tensor = self.assign_tensor * embedding_mask 553 | 554 | # update pooled features and adj matrix 555 | x = torch.matmul( 556 | torch.transpose(self.assign_tensor, 1, 2), embedding_tensor 557 | ) 558 | adj = torch.transpose(self.assign_tensor, 1, 2) @ adj @ self.assign_tensor 559 | x_a = x 560 | 561 | embedding_tensor = self.gcn_forward( 562 | x, 563 | adj, 564 | self.conv_first_after_pool[i], 565 | self.conv_block_after_pool[i], 566 | self.conv_last_after_pool[i], 567 | ) 568 | 569 | out, _ = torch.max(embedding_tensor, dim=1) 570 | out_all.append(out) 571 | if self.num_aggs == 2: 572 | # out = torch.mean(embedding_tensor, dim=1) 573 | out = torch.sum(embedding_tensor, dim=1) 574 | out_all.append(out) 575 | 576 | if self.concat: 577 | output = torch.cat(out_all, dim=1) 578 | else: 579 | output = out 580 | ypred = self.pred_model(output) 581 | return ypred 582 | 583 | def loss(self, pred, label, adj=None, batch_num_nodes=None, adj_hop=1): 584 | """ 585 | Args: 586 | batch_num_nodes: numpy array of number of nodes in each graph in the minibatch. 587 | """ 588 | eps = 1e-7 589 | loss = super(SoftPoolingGcnEncoder, self).loss(pred, label) 590 | if self.linkpred: 591 | max_num_nodes = adj.size()[1] 592 | pred_adj0 = self.assign_tensor @ torch.transpose(self.assign_tensor, 1, 2) 593 | tmp = pred_adj0 594 | pred_adj = pred_adj0 595 | for adj_pow in range(adj_hop - 1): 596 | tmp = tmp @ pred_adj0 597 | pred_adj = pred_adj + tmp 598 | pred_adj = torch.min(pred_adj, torch.Tensor(1).cuda()) 599 | # print('adj1', torch.sum(pred_adj0) / torch.numel(pred_adj0)) 600 | # print('adj2', torch.sum(pred_adj) / torch.numel(pred_adj)) 601 | # self.link_loss = F.nll_loss(torch.log(pred_adj), adj) 602 | self.link_loss = -adj * torch.log(pred_adj + eps) - (1 - adj) * torch.log( 603 | 1 - pred_adj + eps 604 | ) 605 | if batch_num_nodes is None: 606 | num_entries = max_num_nodes * max_num_nodes * adj.size()[0] 607 | print("Warning: calculating link pred loss without masking") 608 | else: 609 | num_entries = np.sum(batch_num_nodes * batch_num_nodes) 610 | embedding_mask = self.construct_mask(max_num_nodes, batch_num_nodes) 611 | adj_mask = embedding_mask @ torch.transpose(embedding_mask, 1, 2) 612 | self.link_loss[1 - adj_mask.byte()] = 0.0 613 | 614 | self.link_loss = torch.sum(self.link_loss) / float(num_entries) 615 | # print('linkloss: ', self.link_loss) 616 | return loss + self.link_loss 617 | return loss 618 | 619 | -------------------------------------------------------------------------------- /models_pyg.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torch_geometric.transforms as T 4 | from torch_geometric.nn import GCNConv, GATConv 5 | 6 | class GCNNet(torch.nn.Module): 7 | def __init__(self, input_dim, hidden_dim, label_dim, num_layers, 8 | pred_hidden_dims=[], concat=True, bn=True, dropout=0.0, add_self=False, args=None): 9 | super(GCNNet, self).__init__() 10 | self.input_dim = input_dim 11 | print ('GCNNet input_dim:', self.input_dim) 12 | self.hidden_dim = hidden_dim 13 | print ('GCNNet hidden_dim:', self.hidden_dim) 14 | self.label_dim = label_dim 15 | print ('GCNNet label_dim:', self.label_dim) 16 | self.num_layers = num_layers 17 | print ('GCNNet num_layers:', self.num_layers) 18 | 19 | # self.concat = concat 20 | # self.bn = bn 21 | # self.add_self = add_self 22 | self.args = args 23 | self.dropout = dropout 24 | self.act = F.relu 25 | 26 | self.convs = torch.nn.ModuleList() 27 | self.convs.append(GCNConv(self.input_dim, self.hidden_dim)) 28 | for layer in range(self.num_layers - 2): 29 | self.convs.append(GCNConv(self.hidden_dim, self.hidden_dim)) 30 | self.convs.append(GCNConv(self.hidden_dim, self.label_dim)) 31 | print ('len(self.convs):', len(self.convs)) 32 | 33 | def forward(self, data): 34 | x, edge_index, batch = data.feat, data.edge_index, data.batch 35 | 36 | for i in range(self.num_layers): 37 | x = self.convs[i](x, edge_index) 38 | x = F.relu(x) 39 | x = F.dropout(x, p=self.dropout, training=self.training) 40 | return F.log_softmax(x, dim=1) 41 | 42 | def loss(self, pred, label): 43 | return F.nll_loss(pred, label) 44 | -------------------------------------------------------------------------------- /notebook/GNN-Explainer-Viz-Interactive.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# GNN Explainer\n", 8 | "\n", 9 | "This notebook is designed to visualize the results of the GNN Explainer.\n", 10 | "\n", 11 | "Use it after one has trained the model using train.py, and has run the explainer optimization (explainer_main.py).\n", 12 | "The main purpose is to visualize the trained mask by interactively tuning the threshold. In many scientific applications, the explanation size is unknown a priori. This tool can help user visualize the selected subgraph, with respect to different values of the thresholds, and find the right size for a good explanation." 13 | ] 14 | }, 15 | { 16 | "cell_type": "code", 17 | "execution_count": 1, 18 | "metadata": {}, 19 | "outputs": [], 20 | "source": [ 21 | "from ipywidgets import interact, interactive, fixed, interact_manual\n", 22 | "import ipywidgets as widgets" 23 | ] 24 | }, 25 | { 26 | "cell_type": "code", 27 | "execution_count": 121, 28 | "metadata": {}, 29 | "outputs": [], 30 | "source": [ 31 | "import numpy as np\n", 32 | "import os\n", 33 | "import networkx as nx\n", 34 | "import matplotlib.pyplot as plt\n", 35 | "import json" 36 | ] 37 | }, 38 | { 39 | "cell_type": "code", 40 | "execution_count": 109, 41 | "metadata": {}, 42 | "outputs": [], 43 | "source": [ 44 | "%matplotlib inline" 45 | ] 46 | }, 47 | { 48 | "cell_type": "markdown", 49 | "metadata": {}, 50 | "source": [ 51 | "Configuring the experiment you want to visualize. These values should match the configuration:\n", 52 | "\n", 53 | "> TODO: Unify configuration of experiments in yaml" 54 | ] 55 | }, 56 | { 57 | "cell_type": "code", 58 | "execution_count": 3, 59 | "metadata": {}, 60 | "outputs": [], 61 | "source": [ 62 | "logdir = '../log/'\n", 63 | "expdir = 'syn2_base_h20_o20_explain'" 64 | ] 65 | }, 66 | { 67 | "cell_type": "code", 68 | "execution_count": 4, 69 | "metadata": {}, 70 | "outputs": [], 71 | "source": [ 72 | "# Load the produced masks" 73 | ] 74 | }, 75 | { 76 | "cell_type": "code", 77 | "execution_count": 74, 78 | "metadata": {}, 79 | "outputs": [], 80 | "source": [ 81 | "dirs = os.listdir(os.path.join(logdir, expdir))" 82 | ] 83 | }, 84 | { 85 | "cell_type": "code", 86 | "execution_count": 48, 87 | "metadata": {}, 88 | "outputs": [ 89 | { 90 | "name": "stdout", 91 | "output_type": "stream", 92 | "text": [ 93 | "masked_adjsyn2_base_h20_o20_explain.npy\n", 94 | "masked_adj.npy\n" 95 | ] 96 | } 97 | ], 98 | "source": [ 99 | "masks = []\n", 100 | "# This would print all the files and directories\n", 101 | "for file in dirs:\n", 102 | " if file.split('.')[-1] == 'npy':\n", 103 | " print(file)\n", 104 | " masks.append(file)" 105 | ] 106 | }, 107 | { 108 | "cell_type": "markdown", 109 | "metadata": {}, 110 | "source": [ 111 | "Utility to save masks:" 112 | ] 113 | }, 114 | { 115 | "cell_type": "code", 116 | "execution_count": 127, 117 | "metadata": {}, 118 | "outputs": [], 119 | "source": [ 120 | "from networkx.readwrite import json_graph\n", 121 | "\n", 122 | "def save_mask(G, fname, fmt='json', suffix=''):\n", 123 | " pth = os.path.join(logdir, expdir, fname+'-filt-'+suffix+'.'+fmt)\n", 124 | " if fmt == 'json':\n", 125 | " dt = json_graph.node_link_data(G)\n", 126 | " with open(pth, 'w') as f:\n", 127 | " json.dump(dt, f)\n", 128 | " elif fmt == 'pdf':\n", 129 | " plt.savefig(pth)\n", 130 | " elif fmt == 'npy':\n", 131 | " np.save(pth, nx.to_numpy_array(G))" 132 | ] 133 | }, 134 | { 135 | "cell_type": "markdown", 136 | "metadata": {}, 137 | "source": [ 138 | "Plotting utilities:" 139 | ] 140 | }, 141 | { 142 | "cell_type": "code", 143 | "execution_count": 54, 144 | "metadata": {}, 145 | "outputs": [], 146 | "source": [ 147 | "def show_adjacency_full(mask, ax=None):\n", 148 | " adj = np.load(os.path.join(logdir, expdir, mask), allow_pickle=True)\n", 149 | " if ax is None:\n", 150 | " plt.figure()\n", 151 | " plt.imshow(adj);\n", 152 | " else:\n", 153 | " ax.imshow(adj)\n", 154 | " return adj" 155 | ] 156 | }, 157 | { 158 | "cell_type": "code", 159 | "execution_count": 55, 160 | "metadata": {}, 161 | "outputs": [], 162 | "source": [ 163 | "def read_adjacency_full(mask, ax=None):\n", 164 | " adj = np.load(os.path.join(logdir, expdir, mask), allow_pickle=True)\n", 165 | " return adj" 166 | ] 167 | }, 168 | { 169 | "cell_type": "code", 170 | "execution_count": 57, 171 | "metadata": {}, 172 | "outputs": [ 173 | { 174 | "data": { 175 | "application/vnd.jupyter.widget-view+json": { 176 | "model_id": "d87c030bb64e4c4b940f655300d411b2", 177 | "version_major": 2, 178 | "version_minor": 0 179 | }, 180 | "text/plain": [ 181 | "interactive(children=(FloatSlider(value=0.5, description='thresh', max=1.5, min=-0.5), Output()), _dom_classes…" 182 | ] 183 | }, 184 | "metadata": {}, 185 | "output_type": "display_data" 186 | } 187 | ], 188 | "source": [ 189 | "filt_adj = read_adjacency_full(masks[0])\n", 190 | "@interact\n", 191 | "def filter_adj(thresh=0.5):\n", 192 | " filt_adj[filt_adj best_val_result["acc"] - 1e-7: 216 | best_val_result["acc"] = val_result["acc"] 217 | best_val_result["epoch"] = epoch 218 | best_val_result["loss"] = avg_loss 219 | if test_dataset is not None: 220 | test_result = evaluate(test_dataset, model, args, name="Test") 221 | test_result["epoch"] = epoch 222 | if writer is not None: 223 | writer.add_scalar("acc/train_acc", result["acc"], epoch) 224 | writer.add_scalar("acc/val_acc", val_result["acc"], epoch) 225 | writer.add_scalar("loss/best_val_loss", best_val_result["loss"], epoch) 226 | if test_dataset is not None: 227 | writer.add_scalar("acc/test_acc", test_result["acc"], epoch) 228 | 229 | print("Best val result: ", best_val_result) 230 | best_val_epochs.append(best_val_result["epoch"]) 231 | best_val_accs.append(best_val_result["acc"]) 232 | if test_dataset is not None: 233 | print("Test result: ", test_result) 234 | test_epochs.append(test_result["epoch"]) 235 | test_accs.append(test_result["acc"]) 236 | 237 | matplotlib.style.use("seaborn") 238 | plt.switch_backend("agg") 239 | plt.figure() 240 | plt.plot(train_epochs, math_utils.exp_moving_avg(train_accs, 0.85), "-", lw=1) 241 | if test_dataset is not None: 242 | plt.plot(best_val_epochs, best_val_accs, "bo", test_epochs, test_accs, "go") 243 | plt.legend(["train", "val", "test"]) 244 | else: 245 | plt.plot(best_val_epochs, best_val_accs, "bo") 246 | plt.legend(["train", "val"]) 247 | plt.savefig(io_utils.gen_train_plt_name(args), dpi=600) 248 | plt.close() 249 | matplotlib.style.use("default") 250 | 251 | print(all_adjs.shape, all_feats.shape, all_labels.shape) 252 | 253 | cg_data = { 254 | "adj": all_adjs, 255 | "feat": all_feats, 256 | "label": all_labels, 257 | "pred": np.expand_dims(predictions, axis=0), 258 | "train_idx": list(range(len(dataset))), 259 | } 260 | io_utils.save_checkpoint(model, optimizer, args, num_epochs=-1, cg_dict=cg_data) 261 | return model, val_accs 262 | 263 | 264 | def train_node_classifier(G, labels, model, args, writer=None): 265 | # train/test split only for nodes 266 | num_nodes = G.number_of_nodes() 267 | num_train = int(num_nodes * args.train_ratio) 268 | idx = [i for i in range(num_nodes)] 269 | 270 | np.random.shuffle(idx) 271 | train_idx = idx[:num_train] 272 | test_idx = idx[num_train:] 273 | 274 | data = gengraph.preprocess_input_graph(G, labels) 275 | labels_train = torch.tensor(data["labels"][:, train_idx], dtype=torch.long) 276 | adj = torch.tensor(data["adj"], dtype=torch.float) 277 | x = torch.tensor(data["feat"], requires_grad=True, dtype=torch.float) 278 | scheduler, optimizer = train_utils.build_optimizer( 279 | args, model.parameters(), weight_decay=args.weight_decay 280 | ) 281 | model.train() 282 | ypred = None 283 | for epoch in range(args.num_epochs): 284 | begin_time = time.time() 285 | model.zero_grad() 286 | 287 | if args.gpu: 288 | ypred, adj_att = model(x.cuda(), adj.cuda()) 289 | else: 290 | ypred, adj_att = model(x, adj) 291 | ypred_train = ypred[:, train_idx, :] 292 | if args.gpu: 293 | loss = model.loss(ypred_train, labels_train.cuda()) 294 | else: 295 | loss = model.loss(ypred_train, labels_train) 296 | loss.backward() 297 | nn.utils.clip_grad_norm(model.parameters(), args.clip) 298 | 299 | optimizer.step() 300 | #for param_group in optimizer.param_groups: 301 | # print(param_group["lr"]) 302 | elapsed = time.time() - begin_time 303 | 304 | result_train, result_test = evaluate_node( 305 | ypred.cpu(), data["labels"], train_idx, test_idx 306 | ) 307 | if writer is not None: 308 | writer.add_scalar("loss/avg_loss", loss, epoch) 309 | writer.add_scalars( 310 | "prec", 311 | {"train": result_train["prec"], "test": result_test["prec"]}, 312 | epoch, 313 | ) 314 | writer.add_scalars( 315 | "recall", 316 | {"train": result_train["recall"], "test": result_test["recall"]}, 317 | epoch, 318 | ) 319 | writer.add_scalars( 320 | "acc", {"train": result_train["acc"], "test": result_test["acc"]}, epoch 321 | ) 322 | 323 | if epoch % 10 == 0: 324 | print( 325 | "epoch: ", 326 | epoch, 327 | "; loss: ", 328 | loss.item(), 329 | "; train_acc: ", 330 | result_train["acc"], 331 | "; test_acc: ", 332 | result_test["acc"], 333 | "; train_prec: ", 334 | result_train["prec"], 335 | "; test_prec: ", 336 | result_test["prec"], 337 | "; epoch time: ", 338 | "{0:0.2f}".format(elapsed), 339 | ) 340 | 341 | if scheduler is not None: 342 | scheduler.step() 343 | print(result_train["conf_mat"]) 344 | print(result_test["conf_mat"]) 345 | 346 | # computation graph 347 | model.eval() 348 | if args.gpu: 349 | ypred, _ = model(x.cuda(), adj.cuda()) 350 | else: 351 | ypred, _ = model(x, adj) 352 | cg_data = { 353 | "adj": data["adj"], 354 | "feat": data["feat"], 355 | "label": data["labels"], 356 | "pred": ypred.cpu().detach().numpy(), 357 | "train_idx": train_idx, 358 | } 359 | # import pdb 360 | # pdb.set_trace() 361 | io_utils.save_checkpoint(model, optimizer, args, num_epochs=-1, cg_dict=cg_data) 362 | 363 | 364 | def train_node_classifier_multigraph(G_list, labels, model, args, writer=None): 365 | train_idx_all, test_idx_all = [], [] 366 | # train/test split only for nodes 367 | num_nodes = G_list[0].number_of_nodes() 368 | num_train = int(num_nodes * args.train_ratio) 369 | idx = [i for i in range(num_nodes)] 370 | np.random.shuffle(idx) 371 | train_idx = idx[:num_train] 372 | train_idx_all.append(train_idx) 373 | test_idx = idx[num_train:] 374 | test_idx_all.append(test_idx) 375 | 376 | data = gengraph.preprocess_input_graph(G_list[0], labels[0]) 377 | all_labels = data["labels"] 378 | labels_train = torch.tensor(data["labels"][:, train_idx], dtype=torch.long) 379 | adj = torch.tensor(data["adj"], dtype=torch.float) 380 | x = torch.tensor(data["feat"], requires_grad=True, dtype=torch.float) 381 | 382 | for i in range(1, len(G_list)): 383 | np.random.shuffle(idx) 384 | train_idx = idx[:num_train] 385 | train_idx_all.append(train_idx) 386 | test_idx = idx[num_train:] 387 | test_idx_all.append(test_idx) 388 | data = gengraph.preprocess_input_graph(G_list[i], labels[i]) 389 | all_labels = np.concatenate((all_labels, data["labels"]), axis=0) 390 | labels_train = torch.cat( 391 | [ 392 | labels_train, 393 | torch.tensor(data["labels"][:, train_idx], dtype=torch.long), 394 | ], 395 | dim=0, 396 | ) 397 | adj = torch.cat([adj, torch.tensor(data["adj"], dtype=torch.float)]) 398 | x = torch.cat( 399 | [x, torch.tensor(data["feat"], requires_grad=True, dtype=torch.float)] 400 | ) 401 | 402 | scheduler, optimizer = train_utils.build_optimizer( 403 | args, model.parameters(), weight_decay=args.weight_decay 404 | ) 405 | model.train() 406 | ypred = None 407 | for epoch in range(args.num_epochs): 408 | begin_time = time.time() 409 | model.zero_grad() 410 | 411 | if args.gpu: 412 | ypred = model(x.cuda(), adj.cuda()) 413 | else: 414 | ypred = model(x, adj) 415 | # normal indexing 416 | ypred_train = ypred[:, train_idx, :] 417 | # in multigraph setting we can't directly access all dimensions so we need to gather all the training instances 418 | all_train_idx = [item for sublist in train_idx_all for item in sublist] 419 | ypred_train_cmp = torch.cat( 420 | [ypred[i, train_idx_all[i], :] for i in range(10)], dim=0 421 | ).reshape(10, 146, 6) 422 | if args.gpu: 423 | loss = model.loss(ypred_train_cmp, labels_train.cuda()) 424 | else: 425 | loss = model.loss(ypred_train_cmp, labels_train) 426 | loss.backward() 427 | nn.utils.clip_grad_norm(model.parameters(), args.clip) 428 | 429 | optimizer.step() 430 | #for param_group in optimizer.param_groups: 431 | # print(param_group["lr"]) 432 | elapsed = time.time() - begin_time 433 | 434 | result_train, result_test = evaluate_node( 435 | ypred.cpu(), all_labels, train_idx_all, test_idx_all 436 | ) 437 | if writer is not None: 438 | writer.add_scalar("loss/avg_loss", loss, epoch) 439 | writer.add_scalars( 440 | "prec", 441 | {"train": result_train["prec"], "test": result_test["prec"]}, 442 | epoch, 443 | ) 444 | writer.add_scalars( 445 | "recall", 446 | {"train": result_train["recall"], "test": result_test["recall"]}, 447 | epoch, 448 | ) 449 | writer.add_scalars( 450 | "acc", {"train": result_train["acc"], "test": result_test["acc"]}, epoch 451 | ) 452 | 453 | print( 454 | "epoch: ", 455 | epoch, 456 | "; loss: ", 457 | loss.item(), 458 | "; train_acc: ", 459 | result_train["acc"], 460 | "; test_acc: ", 461 | result_test["acc"], 462 | "; epoch time: ", 463 | "{0:0.2f}".format(elapsed), 464 | ) 465 | 466 | if scheduler is not None: 467 | scheduler.step() 468 | print(result_train["conf_mat"]) 469 | print(result_test["conf_mat"]) 470 | 471 | # computation graph 472 | model.eval() 473 | if args.gpu: 474 | ypred = model(x.cuda(), adj.cuda()) 475 | else: 476 | ypred = model(x, adj) 477 | cg_data = { 478 | "adj": adj.cpu().detach().numpy(), 479 | "feat": x.cpu().detach().numpy(), 480 | "label": all_labels, 481 | "pred": ypred.cpu().detach().numpy(), 482 | "train_idx": train_idx_all, 483 | } 484 | io_utils.save_checkpoint(model, optimizer, args, num_epochs=-1, cg_dict=cg_data) 485 | 486 | 487 | 488 | ############################# 489 | # 490 | # Evaluate Trained Model 491 | # 492 | ############################# 493 | def evaluate(dataset, model, args, name="Validation", max_num_examples=None): 494 | model.eval() 495 | 496 | labels = [] 497 | preds = [] 498 | for batch_idx, data in enumerate(dataset): 499 | adj = Variable(data["adj"].float(), requires_grad=False).cuda() 500 | h0 = Variable(data["feats"].float()).cuda() 501 | labels.append(data["label"].long().numpy()) 502 | batch_num_nodes = data["num_nodes"].int().numpy() 503 | assign_input = Variable( 504 | data["assign_feats"].float(), requires_grad=False 505 | ).cuda() 506 | 507 | ypred, att_adj = model(h0, adj, batch_num_nodes, assign_x=assign_input) 508 | _, indices = torch.max(ypred, 1) 509 | preds.append(indices.cpu().data.numpy()) 510 | 511 | if max_num_examples is not None: 512 | if (batch_idx + 1) * args.batch_size > max_num_examples: 513 | break 514 | 515 | labels = np.hstack(labels) 516 | preds = np.hstack(preds) 517 | 518 | result = { 519 | "prec": metrics.precision_score(labels, preds, average="macro"), 520 | "recall": metrics.recall_score(labels, preds, average="macro"), 521 | "acc": metrics.accuracy_score(labels, preds), 522 | } 523 | print(name, " accuracy:", result["acc"]) 524 | return result 525 | 526 | 527 | def evaluate_node(ypred, labels, train_idx, test_idx): 528 | _, pred_labels = torch.max(ypred, 2) 529 | pred_labels = pred_labels.numpy() 530 | 531 | pred_train = np.ravel(pred_labels[:, train_idx]) 532 | pred_test = np.ravel(pred_labels[:, test_idx]) 533 | labels_train = np.ravel(labels[:, train_idx]) 534 | labels_test = np.ravel(labels[:, test_idx]) 535 | 536 | result_train = { 537 | "prec": metrics.precision_score(labels_train, pred_train, average="macro"), 538 | "recall": metrics.recall_score(labels_train, pred_train, average="macro"), 539 | "acc": metrics.accuracy_score(labels_train, pred_train), 540 | "conf_mat": metrics.confusion_matrix(labels_train, pred_train), 541 | } 542 | result_test = { 543 | "prec": metrics.precision_score(labels_test, pred_test, average="macro"), 544 | "recall": metrics.recall_score(labels_test, pred_test, average="macro"), 545 | "acc": metrics.accuracy_score(labels_test, pred_test), 546 | "conf_mat": metrics.confusion_matrix(labels_test, pred_test), 547 | } 548 | return result_train, result_test 549 | 550 | 551 | 552 | ############################# 553 | # 554 | # Run Experiments 555 | # 556 | ############################# 557 | def ppi_essential_task(args, writer=None): 558 | feat_file = "G-MtfPathways_gene-motifs.csv" 559 | # G = io_utils.read_biosnap('data/ppi_essential', 'PP-Pathways_ppi.csv', 'G-HumanEssential.tsv', 560 | # feat_file=feat_file) 561 | G = io_utils.read_biosnap( 562 | "data/ppi_essential", 563 | "hi-union-ppi.tsv", 564 | "G-HumanEssential.tsv", 565 | feat_file=feat_file, 566 | ) 567 | labels = np.array([G.nodes[u]["label"] for u in G.nodes()]) 568 | num_classes = max(labels) + 1 569 | input_dim = G.nodes[next(iter(G.nodes()))]["feat"].shape[0] 570 | 571 | if args.method == "attn": 572 | print("Method: attn") 573 | else: 574 | print("Method:", args.method) 575 | args.loss_weight = torch.tensor([1, 5.0], dtype=torch.float).cuda() 576 | model = models.GcnEncoderNode( 577 | input_dim, 578 | args.hidden_dim, 579 | args.output_dim, 580 | num_classes, 581 | args.num_gc_layers, 582 | bn=args.bn, 583 | args=args, 584 | ) 585 | if args.gpu: 586 | model = model.cuda() 587 | 588 | train_node_classifier(G, labels, model, args, writer=writer) 589 | 590 | 591 | def syn_task1(args, writer=None): 592 | # data 593 | G, labels, name = gengraph.gen_syn1( 594 | feature_generator=featgen.ConstFeatureGen(np.ones(args.input_dim, dtype=float)) 595 | ) 596 | num_classes = max(labels) + 1 597 | 598 | if args.method == "att": 599 | print("Method: att") 600 | model = models.GcnEncoderNode( 601 | args.input_dim, 602 | args.hidden_dim, 603 | args.output_dim, 604 | num_classes, 605 | args.num_gc_layers, 606 | bn=args.bn, 607 | args=args, 608 | ) 609 | else: 610 | print("Method:", args.method) 611 | model = models.GcnEncoderNode( 612 | args.input_dim, 613 | args.hidden_dim, 614 | args.output_dim, 615 | num_classes, 616 | args.num_gc_layers, 617 | bn=args.bn, 618 | args=args, 619 | ) 620 | if args.gpu: 621 | model = model.cuda() 622 | 623 | train_node_classifier(G, labels, model, args, writer=writer) 624 | 625 | 626 | def syn_task2(args, writer=None): 627 | # data 628 | G, labels, name = gengraph.gen_syn2() 629 | input_dim = len(G.nodes[0]["feat"]) 630 | num_classes = max(labels) + 1 631 | 632 | if args.method == "attn": 633 | print("Method: attn") 634 | else: 635 | print("Method:", args.method) 636 | model = models.GcnEncoderNode( 637 | input_dim, 638 | args.hidden_dim, 639 | args.output_dim, 640 | num_classes, 641 | args.num_gc_layers, 642 | bn=args.bn, 643 | args=args, 644 | ) 645 | if args.gpu: 646 | model = model.cuda() 647 | 648 | train_node_classifier(G, labels, model, args, writer=writer) 649 | 650 | 651 | def syn_task3(args, writer=None): 652 | # data 653 | G, labels, name = gengraph.gen_syn3( 654 | feature_generator=featgen.ConstFeatureGen(np.ones(args.input_dim, dtype=float)) 655 | ) 656 | print(labels) 657 | num_classes = max(labels) + 1 658 | 659 | if args.method == "attn": 660 | print("Method: attn") 661 | else: 662 | print("Method:", args.method) 663 | model = models.GcnEncoderNode( 664 | args.input_dim, 665 | args.hidden_dim, 666 | args.output_dim, 667 | num_classes, 668 | args.num_gc_layers, 669 | bn=args.bn, 670 | args=args, 671 | ) 672 | if args.gpu: 673 | model = model.cuda() 674 | 675 | train_node_classifier(G, labels, model, args, writer=writer) 676 | 677 | 678 | def syn_task4(args, writer=None): 679 | # data 680 | G, labels, name = gengraph.gen_syn4( 681 | feature_generator=featgen.ConstFeatureGen(np.ones(args.input_dim, dtype=float)) 682 | ) 683 | print(labels) 684 | num_classes = max(labels) + 1 685 | 686 | if args.method == "attn": 687 | print("Method: attn") 688 | else: 689 | print("Method:", args.method) 690 | model = models.GcnEncoderNode( 691 | args.input_dim, 692 | args.hidden_dim, 693 | args.output_dim, 694 | num_classes, 695 | args.num_gc_layers, 696 | bn=args.bn, 697 | args=args, 698 | ) 699 | 700 | if args.gpu: 701 | model = model.cuda() 702 | 703 | train_node_classifier(G, labels, model, args, writer=writer) 704 | 705 | 706 | def syn_task5(args, writer=None): 707 | # data 708 | G, labels, name = gengraph.gen_syn5( 709 | feature_generator=featgen.ConstFeatureGen(np.ones(args.input_dim, dtype=float)) 710 | ) 711 | print(labels) 712 | print("Number of nodes: ", G.number_of_nodes()) 713 | num_classes = max(labels) + 1 714 | 715 | if args.method == "attn": 716 | print("Method: attn") 717 | else: 718 | print("Method: base") 719 | model = models.GcnEncoderNode( 720 | args.input_dim, 721 | args.hidden_dim, 722 | args.output_dim, 723 | num_classes, 724 | args.num_gc_layers, 725 | bn=args.bn, 726 | args=args, 727 | ) 728 | 729 | if args.gpu: 730 | model = model.cuda() 731 | 732 | train_node_classifier(G, labels, model, args, writer=writer) 733 | 734 | 735 | def pkl_task(args, feat=None): 736 | with open(os.path.join(args.datadir, args.pkl_fname), "rb") as pkl_file: 737 | data = pickle.load(pkl_file) 738 | graphs = data[0] 739 | labels = data[1] 740 | test_graphs = data[2] 741 | test_labels = data[3] 742 | 743 | for i in range(len(graphs)): 744 | graphs[i].graph["label"] = labels[i] 745 | for i in range(len(test_graphs)): 746 | test_graphs[i].graph["label"] = test_labels[i] 747 | 748 | if feat is None: 749 | featgen_const = featgen.ConstFeatureGen(np.ones(args.input_dim, dtype=float)) 750 | for G in graphs: 751 | featgen_const.gen_node_features(G) 752 | for G in test_graphs: 753 | featgen_const.gen_node_features(G) 754 | 755 | train_dataset, test_dataset, max_num_nodes = prepare_data( 756 | graphs, args, test_graphs=test_graphs 757 | ) 758 | model = models.GcnEncoderGraph( 759 | args.input_dim, 760 | args.hidden_dim, 761 | args.output_dim, 762 | args.num_classes, 763 | args.num_gc_layers, 764 | bn=args.bn, 765 | ).cuda() 766 | train(train_dataset, model, args, test_dataset=test_dataset) 767 | evaluate(test_dataset, model, args, "Validation") 768 | 769 | 770 | def enron_task_multigraph(args, idx=None, writer=None): 771 | labels_dict = { 772 | "None": 5, 773 | "Employee": 0, 774 | "Vice President": 1, 775 | "Manager": 2, 776 | "Trader": 3, 777 | "CEO+Managing Director+Director+President": 4, 778 | } 779 | max_enron_id = 183 780 | if idx is None: 781 | G_list = [] 782 | labels_list = [] 783 | for i in range(10): 784 | net = pickle.load( 785 | open("data/gnn-explainer-enron/enron_slice_{}.pkl".format(i), "rb") 786 | ) 787 | net.add_nodes_from(range(max_enron_id)) 788 | labels = [n[1].get("role", "None") for n in net.nodes(data=True)] 789 | labels_num = [labels_dict[l] for l in labels] 790 | featgen_const = featgen.ConstFeatureGen( 791 | np.ones(args.input_dim, dtype=float) 792 | ) 793 | featgen_const.gen_node_features(net) 794 | G_list.append(net) 795 | labels_list.append(labels_num) 796 | # train_dataset, test_dataset, max_num_nodes = prepare_data(G_list, args) 797 | model = models.GcnEncoderNode( 798 | args.input_dim, 799 | args.hidden_dim, 800 | args.output_dim, 801 | args.num_classes, 802 | args.num_gc_layers, 803 | bn=args.bn, 804 | args=args, 805 | ) 806 | if args.gpu: 807 | model = model.cuda() 808 | print(labels_num) 809 | train_node_classifier_multigraph( 810 | G_list, labels_list, model, args, writer=writer 811 | ) 812 | else: 813 | print("Running Enron full task") 814 | 815 | 816 | def enron_task(args, idx=None, writer=None): 817 | labels_dict = { 818 | "None": 5, 819 | "Employee": 0, 820 | "Vice President": 1, 821 | "Manager": 2, 822 | "Trader": 3, 823 | "CEO+Managing Director+Director+President": 4, 824 | } 825 | max_enron_id = 183 826 | if idx is None: 827 | G_list = [] 828 | labels_list = [] 829 | for i in range(10): 830 | net = pickle.load( 831 | open("data/gnn-explainer-enron/enron_slice_{}.pkl".format(i), "rb") 832 | ) 833 | # net.add_nodes_from(range(max_enron_id)) 834 | # labels=[n[1].get('role', 'None') for n in net.nodes(data=True)] 835 | # labels_num = [labels_dict[l] for l in labels] 836 | featgen_const = featgen.ConstFeatureGen( 837 | np.ones(args.input_dim, dtype=float) 838 | ) 839 | featgen_const.gen_node_features(net) 840 | G_list.append(net) 841 | print(net.number_of_nodes()) 842 | # labels_list.append(labels_num) 843 | 844 | G = nx.disjoint_union_all(G_list) 845 | model = models.GcnEncoderNode( 846 | args.input_dim, 847 | args.hidden_dim, 848 | args.output_dim, 849 | len(labels_dict), 850 | args.num_gc_layers, 851 | bn=args.bn, 852 | args=args, 853 | ) 854 | labels = [n[1].get("role", "None") for n in G.nodes(data=True)] 855 | labels_num = [labels_dict[l] for l in labels] 856 | for i in range(5): 857 | print("Label ", i, ": ", labels_num.count(i)) 858 | 859 | print("Total num nodes: ", len(labels_num)) 860 | print(labels_num) 861 | 862 | if args.gpu: 863 | model = model.cuda() 864 | train_node_classifier(G, labels_num, model, args, writer=writer) 865 | else: 866 | print("Running Enron full task") 867 | 868 | 869 | def benchmark_task(args, writer=None, feat="node-label"): 870 | graphs = io_utils.read_graphfile( 871 | args.datadir, args.bmname, max_nodes=args.max_nodes 872 | ) 873 | print(max([G.graph["label"] for G in graphs])) 874 | 875 | if feat == "node-feat" and "feat_dim" in graphs[0].graph: 876 | print("Using node features") 877 | input_dim = graphs[0].graph["feat_dim"] 878 | elif feat == "node-label" and "label" in graphs[0].nodes[0]: 879 | print("Using node labels") 880 | for G in graphs: 881 | for u in G.nodes(): 882 | G.nodes[u]["feat"] = np.array(G.nodes[u]["label"]) 883 | # make it -1/1 instead of 0/1 884 | # feat = np.array(G.nodes[u]['label']) 885 | # G.nodes[u]['feat'] = feat * 2 - 1 886 | else: 887 | print("Using constant labels") 888 | featgen_const = featgen.ConstFeatureGen(np.ones(args.input_dim, dtype=float)) 889 | for G in graphs: 890 | featgen_const.gen_node_features(G) 891 | 892 | train_dataset, val_dataset, test_dataset, max_num_nodes, input_dim, assign_input_dim = prepare_data( 893 | graphs, args, max_nodes=args.max_nodes 894 | ) 895 | if args.method == "soft-assign": 896 | print("Method: soft-assign") 897 | model = models.SoftPoolingGcnEncoder( 898 | max_num_nodes, 899 | input_dim, 900 | args.hidden_dim, 901 | args.output_dim, 902 | args.num_classes, 903 | args.num_gc_layers, 904 | args.hidden_dim, 905 | assign_ratio=args.assign_ratio, 906 | num_pooling=args.num_pool, 907 | bn=args.bn, 908 | dropout=args.dropout, 909 | linkpred=args.linkpred, 910 | args=args, 911 | assign_input_dim=assign_input_dim, 912 | ).cuda() 913 | else: 914 | print("Method: base") 915 | model = models.GcnEncoderGraph( 916 | input_dim, 917 | args.hidden_dim, 918 | args.output_dim, 919 | args.num_classes, 920 | args.num_gc_layers, 921 | bn=args.bn, 922 | dropout=args.dropout, 923 | args=args, 924 | ).cuda() 925 | 926 | train( 927 | train_dataset, 928 | model, 929 | args, 930 | val_dataset=val_dataset, 931 | test_dataset=test_dataset, 932 | writer=writer, 933 | ) 934 | evaluate(test_dataset, model, args, "Validation") 935 | 936 | 937 | def benchmark_task_val(args, writer=None, feat="node-label"): 938 | all_vals = [] 939 | graphs = io_utils.read_graphfile( 940 | args.datadir, args.bmname, max_nodes=args.max_nodes 941 | ) 942 | 943 | if feat == "node-feat" and "feat_dim" in graphs[0].graph: 944 | print("Using node features") 945 | input_dim = graphs[0].graph["feat_dim"] 946 | elif feat == "node-label" and "label" in graphs[0].nodes[0]: 947 | print("Using node labels") 948 | for G in graphs: 949 | for u in G.nodes(): 950 | G.nodes[u]["feat"] = np.array(G.nodes[u]["label"]) 951 | else: 952 | print("Using constant labels") 953 | featgen_const = featgen.ConstFeatureGen(np.ones(args.input_dim, dtype=float)) 954 | for G in graphs: 955 | featgen_const.gen_node_features(G) 956 | 957 | # 10 splits 958 | for i in range(10): 959 | train_dataset, val_dataset, max_num_nodes, input_dim, assign_input_dim = cross_val.prepare_val_data( 960 | graphs, args, i, max_nodes=args.max_nodes 961 | ) 962 | print("Method: base") 963 | model = models.GcnEncoderGraph( 964 | input_dim, 965 | args.hidden_dim, 966 | args.output_dim, 967 | args.num_classes, 968 | args.num_gc_layers, 969 | bn=args.bn, 970 | dropout=args.dropout, 971 | args=args, 972 | ).cuda() 973 | 974 | _, val_accs = train( 975 | train_dataset, 976 | model, 977 | args, 978 | val_dataset=val_dataset, 979 | test_dataset=None, 980 | writer=writer, 981 | ) 982 | all_vals.append(np.array(val_accs)) 983 | all_vals = np.vstack(all_vals) 984 | all_vals = np.mean(all_vals, axis=0) 985 | print(all_vals) 986 | print(np.max(all_vals)) 987 | print(np.argmax(all_vals)) 988 | 989 | 990 | def arg_parse(): 991 | parser = argparse.ArgumentParser(description="GraphPool arguments.") 992 | io_parser = parser.add_mutually_exclusive_group(required=False) 993 | io_parser.add_argument("--dataset", dest="dataset", help="Input dataset.") 994 | benchmark_parser = io_parser.add_argument_group() 995 | benchmark_parser.add_argument( 996 | "--bmname", dest="bmname", help="Name of the benchmark dataset" 997 | ) 998 | io_parser.add_argument("--pkl", dest="pkl_fname", help="Name of the pkl data file") 999 | 1000 | softpool_parser = parser.add_argument_group() 1001 | softpool_parser.add_argument( 1002 | "--assign-ratio", 1003 | dest="assign_ratio", 1004 | type=float, 1005 | help="ratio of number of nodes in consecutive layers", 1006 | ) 1007 | softpool_parser.add_argument( 1008 | "--num-pool", dest="num_pool", type=int, help="number of pooling layers" 1009 | ) 1010 | parser.add_argument( 1011 | "--linkpred", 1012 | dest="linkpred", 1013 | action="store_const", 1014 | const=True, 1015 | default=False, 1016 | help="Whether link prediction side objective is used", 1017 | ) 1018 | 1019 | parser_utils.parse_optimizer(parser) 1020 | 1021 | parser.add_argument( 1022 | "--datadir", dest="datadir", help="Directory where benchmark is located" 1023 | ) 1024 | parser.add_argument("--logdir", dest="logdir", help="Tensorboard log directory") 1025 | parser.add_argument("--ckptdir", dest="ckptdir", help="Model checkpoint directory") 1026 | parser.add_argument("--cuda", dest="cuda", help="CUDA.") 1027 | parser.add_argument( 1028 | "--gpu", 1029 | dest="gpu", 1030 | action="store_const", 1031 | const=True, 1032 | default=False, 1033 | help="whether to use GPU.", 1034 | ) 1035 | parser.add_argument( 1036 | "--max-nodes", 1037 | dest="max_nodes", 1038 | type=int, 1039 | help="Maximum number of nodes (ignore graghs with nodes exceeding the number.", 1040 | ) 1041 | parser.add_argument("--batch-size", dest="batch_size", type=int, help="Batch size.") 1042 | parser.add_argument( 1043 | "--epochs", dest="num_epochs", type=int, help="Number of epochs to train." 1044 | ) 1045 | parser.add_argument( 1046 | "--train-ratio", 1047 | dest="train_ratio", 1048 | type=float, 1049 | help="Ratio of number of graphs training set to all graphs.", 1050 | ) 1051 | parser.add_argument( 1052 | "--num_workers", 1053 | dest="num_workers", 1054 | type=int, 1055 | help="Number of workers to load data.", 1056 | ) 1057 | parser.add_argument( 1058 | "--feature", 1059 | dest="feature_type", 1060 | help="Feature used for encoder. Can be: id, deg", 1061 | ) 1062 | parser.add_argument( 1063 | "--input-dim", dest="input_dim", type=int, help="Input feature dimension" 1064 | ) 1065 | parser.add_argument( 1066 | "--hidden-dim", dest="hidden_dim", type=int, help="Hidden dimension" 1067 | ) 1068 | parser.add_argument( 1069 | "--output-dim", dest="output_dim", type=int, help="Output dimension" 1070 | ) 1071 | parser.add_argument( 1072 | "--num-classes", dest="num_classes", type=int, help="Number of label classes" 1073 | ) 1074 | parser.add_argument( 1075 | "--num-gc-layers", 1076 | dest="num_gc_layers", 1077 | type=int, 1078 | help="Number of graph convolution layers before each pooling", 1079 | ) 1080 | parser.add_argument( 1081 | "--bn", 1082 | dest="bn", 1083 | action="store_const", 1084 | const=True, 1085 | default=False, 1086 | help="Whether batch normalization is used", 1087 | ) 1088 | parser.add_argument("--dropout", dest="dropout", type=float, help="Dropout rate.") 1089 | parser.add_argument( 1090 | "--nobias", 1091 | dest="bias", 1092 | action="store_const", 1093 | const=False, 1094 | default=True, 1095 | help="Whether to add bias. Default to True.", 1096 | ) 1097 | parser.add_argument( 1098 | "--weight-decay", 1099 | dest="weight_decay", 1100 | type=float, 1101 | help="Weight decay regularization constant.", 1102 | ) 1103 | 1104 | parser.add_argument( 1105 | "--method", dest="method", help="Method. Possible values: base, " 1106 | ) 1107 | parser.add_argument( 1108 | "--name-suffix", dest="name_suffix", help="suffix added to the output filename" 1109 | ) 1110 | 1111 | parser.set_defaults( 1112 | datadir="data", # io_parser 1113 | logdir="log", 1114 | ckptdir="ckpt", 1115 | dataset="syn1", 1116 | opt="adam", # opt_parser 1117 | opt_scheduler="none", 1118 | max_nodes=100, 1119 | cuda="1", 1120 | feature_type="default", 1121 | lr=0.001, 1122 | clip=2.0, 1123 | batch_size=20, 1124 | num_epochs=1000, 1125 | train_ratio=0.8, 1126 | test_ratio=0.1, 1127 | num_workers=1, 1128 | input_dim=10, 1129 | hidden_dim=20, 1130 | output_dim=20, 1131 | num_classes=2, 1132 | num_gc_layers=3, 1133 | dropout=0.0, 1134 | weight_decay=0.005, 1135 | method="base", 1136 | name_suffix="", 1137 | assign_ratio=0.1, 1138 | ) 1139 | return parser.parse_args() 1140 | 1141 | 1142 | def main(): 1143 | prog_args = configs.arg_parse() 1144 | 1145 | path = os.path.join(prog_args.logdir, io_utils.gen_prefix(prog_args)) 1146 | writer = SummaryWriter(path) 1147 | 1148 | if prog_args.gpu: 1149 | os.environ["CUDA_VISIBLE_DEVICES"] = prog_args.cuda 1150 | print("CUDA", prog_args.cuda) 1151 | else: 1152 | print("Using CPU") 1153 | 1154 | # use --bmname=[dataset_name] for Reddit-Binary, Mutagenicity 1155 | if prog_args.bmname is not None: 1156 | benchmark_task(prog_args, writer=writer) 1157 | elif prog_args.pkl_fname is not None: 1158 | pkl_task(prog_args) 1159 | elif prog_args.dataset is not None: 1160 | if prog_args.dataset == "syn1": 1161 | syn_task1(prog_args, writer=writer) 1162 | elif prog_args.dataset == "syn2": 1163 | syn_task2(prog_args, writer=writer) 1164 | elif prog_args.dataset == "syn3": 1165 | syn_task3(prog_args, writer=writer) 1166 | elif prog_args.dataset == "syn4": 1167 | syn_task4(prog_args, writer=writer) 1168 | elif prog_args.dataset == "syn5": 1169 | syn_task5(prog_args, writer=writer) 1170 | elif prog_args.dataset == "enron": 1171 | enron_task(prog_args, writer=writer) 1172 | elif prog_args.dataset == "ppi_essential": 1173 | ppi_essential_task(prog_args, writer=writer) 1174 | 1175 | writer.close() 1176 | 1177 | 1178 | if __name__ == "__main__": 1179 | main() 1180 | 1181 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RexYing/gnn-model-explainer/bc984829f4f4829e93c760e9bbdc8e73f96e2cc1/utils/__init__.py -------------------------------------------------------------------------------- /utils/featgen.py: -------------------------------------------------------------------------------- 1 | """ featgen.py 2 | 3 | Node feature generators. 4 | 5 | """ 6 | import networkx as nx 7 | import numpy as np 8 | import random 9 | 10 | import abc 11 | 12 | 13 | class FeatureGen(metaclass=abc.ABCMeta): 14 | """Feature Generator base class.""" 15 | @abc.abstractmethod 16 | def gen_node_features(self, G): 17 | pass 18 | 19 | 20 | class ConstFeatureGen(FeatureGen): 21 | """Constant Feature class.""" 22 | def __init__(self, val): 23 | self.val = val 24 | 25 | def gen_node_features(self, G): 26 | feat_dict = {i:{'feat': np.array(self.val, dtype=np.float32)} for i in G.nodes()} 27 | print ('feat_dict[0]["feat"]:', feat_dict[0]['feat'].dtype) 28 | nx.set_node_attributes(G, feat_dict) 29 | print ('G.nodes[0]["feat"]:', G.nodes[0]['feat'].dtype) 30 | 31 | 32 | class GaussianFeatureGen(FeatureGen): 33 | """Gaussian Feature class.""" 34 | def __init__(self, mu, sigma): 35 | self.mu = mu 36 | if sigma.ndim < 2: 37 | self.sigma = np.diag(sigma) 38 | else: 39 | self.sigma = sigma 40 | 41 | def gen_node_features(self, G): 42 | feat = np.random.multivariate_normal(self.mu, self.sigma, G.number_of_nodes()) 43 | feat_dict = { 44 | i: {"feat": feat[i]} for i in range(feat.shape[0]) 45 | } 46 | nx.set_node_attributes(G, feat_dict) 47 | 48 | 49 | class GridFeatureGen(FeatureGen): 50 | """Grid Feature class.""" 51 | def __init__(self, mu, sigma, com_choices): 52 | self.mu = mu # Mean 53 | self.sigma = sigma # Variance 54 | self.com_choices = com_choices # List of possible community labels 55 | 56 | def gen_node_features(self, G): 57 | # Generate community assignment 58 | community_dict = { 59 | n: self.com_choices[0] if G.degree(n) < 4 else self.com_choices[1] 60 | for n in G.nodes() 61 | } 62 | 63 | # Generate random variable 64 | s = np.random.normal(self.mu, self.sigma, G.number_of_nodes()) 65 | 66 | # Generate features 67 | feat_dict = { 68 | n: {"feat": np.asarray([community_dict[n], s[i]])} 69 | for i, n in enumerate(G.nodes()) 70 | } 71 | 72 | nx.set_node_attributes(G, feat_dict) 73 | return community_dict 74 | 75 | -------------------------------------------------------------------------------- /utils/graph_utils.py: -------------------------------------------------------------------------------- 1 | """graph_utils.py 2 | 3 | Utility for sampling graphs from a dataset. 4 | """ 5 | import networkx as nx 6 | import numpy as np 7 | import torch 8 | import torch.utils.data 9 | 10 | 11 | class GraphSampler(torch.utils.data.Dataset): 12 | """ Sample graphs and nodes in graph 13 | """ 14 | 15 | def __init__( 16 | self, 17 | G_list, 18 | features="default", 19 | normalize=True, 20 | assign_feat="default", 21 | max_num_nodes=0, 22 | ): 23 | self.adj_all = [] 24 | self.len_all = [] 25 | self.feature_all = [] 26 | self.label_all = [] 27 | 28 | self.assign_feat_all = [] 29 | 30 | if max_num_nodes == 0: 31 | self.max_num_nodes = max([G.number_of_nodes() for G in G_list]) 32 | else: 33 | self.max_num_nodes = max_num_nodes 34 | 35 | existing_node = list(G_list[0].nodes())[-1] 36 | self.feat_dim = G_list[0].nodes[existing_node]["feat"].shape[0] 37 | 38 | for G in G_list: 39 | adj = np.array(nx.to_numpy_matrix(G)) 40 | if normalize: 41 | sqrt_deg = np.diag( 42 | 1.0 / np.sqrt(np.sum(adj, axis=0, dtype=float).squeeze()) 43 | ) 44 | adj = np.matmul(np.matmul(sqrt_deg, adj), sqrt_deg) 45 | self.adj_all.append(adj) 46 | self.len_all.append(G.number_of_nodes()) 47 | self.label_all.append(G.graph["label"]) 48 | # feat matrix: max_num_nodes x feat_dim 49 | if features == "default": 50 | f = np.zeros((self.max_num_nodes, self.feat_dim), dtype=float) 51 | for i, u in enumerate(G.nodes()): 52 | f[i, :] = G.nodes[u]["feat"] 53 | self.feature_all.append(f) 54 | elif features == "id": 55 | self.feature_all.append(np.identity(self.max_num_nodes)) 56 | elif features == "deg-num": 57 | degs = np.sum(np.array(adj), 1) 58 | degs = np.expand_dims( 59 | np.pad(degs, [0, self.max_num_nodes - G.number_of_nodes()], 0), 60 | axis=1, 61 | ) 62 | self.feature_all.append(degs) 63 | elif features == "deg": 64 | self.max_deg = 10 65 | degs = np.sum(np.array(adj), 1).astype(int) 66 | degs[degs > self.max_deg] = self.max_deg 67 | feat = np.zeros((len(degs), self.max_deg + 1)) 68 | feat[np.arange(len(degs)), degs] = 1 69 | feat = np.pad( 70 | feat, 71 | ((0, self.max_num_nodes - G.number_of_nodes()), (0, 0)), 72 | "constant", 73 | constant_values=0, 74 | ) 75 | 76 | f = np.zeros((self.max_num_nodes, self.feat_dim), dtype=float) 77 | for i, u in enumerate(G.nodes()): 78 | f[i, :] = G.nodes[u]["feat"] 79 | 80 | feat = np.concatenate((feat, f), axis=1) 81 | 82 | self.feature_all.append(feat) 83 | elif features == "struct": 84 | self.max_deg = 10 85 | degs = np.sum(np.array(adj), 1).astype(int) 86 | degs[degs > 10] = 10 87 | feat = np.zeros((len(degs), self.max_deg + 1)) 88 | feat[np.arange(len(degs)), degs] = 1 89 | degs = np.pad( 90 | feat, 91 | ((0, self.max_num_nodes - G.number_of_nodes()), (0, 0)), 92 | "constant", 93 | constant_values=0, 94 | ) 95 | 96 | clusterings = np.array(list(nx.clustering(G).values())) 97 | clusterings = np.expand_dims( 98 | np.pad( 99 | clusterings, 100 | [0, self.max_num_nodes - G.number_of_nodes()], 101 | "constant", 102 | ), 103 | axis=1, 104 | ) 105 | g_feat = np.hstack([degs, clusterings]) 106 | if "feat" in G.nodes[0]: 107 | node_feats = np.array( 108 | [G.nodes[i]["feat"] for i in range(G.number_of_nodes())] 109 | ) 110 | node_feats = np.pad( 111 | node_feats, 112 | ((0, self.max_num_nodes - G.number_of_nodes()), (0, 0)), 113 | "constant", 114 | ) 115 | g_feat = np.hstack([g_feat, node_feats]) 116 | 117 | self.feature_all.append(g_feat) 118 | 119 | if assign_feat == "id": 120 | self.assign_feat_all.append( 121 | np.hstack((np.identity(self.max_num_nodes), self.feature_all[-1])) 122 | ) 123 | else: 124 | self.assign_feat_all.append(self.feature_all[-1]) 125 | 126 | self.feat_dim = self.feature_all[0].shape[1] 127 | self.assign_feat_dim = self.assign_feat_all[0].shape[1] 128 | 129 | def __len__(self): 130 | return len(self.adj_all) 131 | 132 | def __getitem__(self, idx): 133 | adj = self.adj_all[idx] 134 | num_nodes = adj.shape[0] 135 | adj_padded = np.zeros((self.max_num_nodes, self.max_num_nodes)) 136 | adj_padded[:num_nodes, :num_nodes] = adj 137 | 138 | # use all nodes for aggregation (baseline) 139 | return { 140 | "adj": adj_padded, 141 | "feats": self.feature_all[idx].copy(), 142 | "label": self.label_all[idx], 143 | "num_nodes": num_nodes, 144 | "assign_feats": self.assign_feat_all[idx].copy(), 145 | } 146 | 147 | def neighborhoods(adj, n_hops, use_cuda): 148 | """Returns the n_hops degree adjacency matrix adj.""" 149 | adj = torch.tensor(adj, dtype=torch.float) 150 | if use_cuda: 151 | adj = adj.cuda() 152 | hop_adj = power_adj = adj 153 | for i in range(n_hops - 1): 154 | power_adj = power_adj @ adj 155 | prev_hop_adj = hop_adj 156 | hop_adj = hop_adj + power_adj 157 | hop_adj = (hop_adj > 0).float() 158 | return hop_adj.cpu().numpy().astype(int) -------------------------------------------------------------------------------- /utils/io_utils.py: -------------------------------------------------------------------------------- 1 | """ io_utils.py 2 | 3 | Utilities for reading and writing logs. 4 | """ 5 | import os 6 | import statistics 7 | import re 8 | import csv 9 | 10 | import numpy as np 11 | import pandas as pd 12 | import scipy as sc 13 | 14 | 15 | import matplotlib 16 | import matplotlib.pyplot as plt 17 | 18 | import numpy as np 19 | import torch 20 | import networkx as nx 21 | import tensorboardX 22 | 23 | import cv2 24 | 25 | import torch 26 | import torch.nn as nn 27 | from torch.autograd import Variable 28 | 29 | # Only necessary to rebuild the Chemistry example 30 | # from rdkit import Chem 31 | 32 | import utils.featgen as featgen 33 | 34 | use_cuda = torch.cuda.is_available() 35 | 36 | 37 | def gen_prefix(args): 38 | '''Generate label prefix for a graph model. 39 | ''' 40 | if args.bmname is not None: 41 | name = args.bmname 42 | else: 43 | name = args.dataset 44 | name += "_" + args.method 45 | 46 | name += "_h" + str(args.hidden_dim) + "_o" + str(args.output_dim) 47 | if not args.bias: 48 | name += "_nobias" 49 | if len(args.name_suffix) > 0: 50 | name += "_" + args.name_suffix 51 | return name 52 | 53 | 54 | def gen_explainer_prefix(args): 55 | '''Generate label prefix for a graph explainer model. 56 | ''' 57 | name = gen_prefix(args) + "_explain" 58 | if len(args.explainer_suffix) > 0: 59 | name += "_" + args.explainer_suffix 60 | return name 61 | 62 | 63 | def create_filename(save_dir, args, isbest=False, num_epochs=-1): 64 | """ 65 | Args: 66 | args : the arguments parsed in the parser 67 | isbest : whether the saved model is the best-performing one 68 | num_epochs : epoch number of the model (when isbest=False) 69 | """ 70 | filename = os.path.join(save_dir, gen_prefix(args)) 71 | os.makedirs(filename, exist_ok=True) 72 | 73 | if isbest: 74 | filename = os.path.join(filename, "best") 75 | elif num_epochs > 0: 76 | filename = os.path.join(filename, str(num_epochs)) 77 | 78 | return filename + ".pth.tar" 79 | 80 | 81 | def save_checkpoint(model, optimizer, args, num_epochs=-1, isbest=False, cg_dict=None): 82 | """Save pytorch model checkpoint. 83 | 84 | Args: 85 | - model : The PyTorch model to save. 86 | - optimizer : The optimizer used to train the model. 87 | - args : A dict of meta-data about the model. 88 | - num_epochs : Number of training epochs. 89 | - isbest : True if the model has the highest accuracy so far. 90 | - cg_dict : A dictionary of the sampled computation graphs. 91 | """ 92 | filename = create_filename(args.ckptdir, args, isbest, num_epochs=num_epochs) 93 | torch.save( 94 | { 95 | "epoch": num_epochs, 96 | "model_type": args.method, 97 | "optimizer": optimizer, 98 | "model_state": model.state_dict(), 99 | "optimizer_state": optimizer.state_dict(), 100 | "cg": cg_dict, 101 | }, 102 | filename, 103 | ) 104 | 105 | 106 | def load_ckpt(args, isbest=False): 107 | '''Load a pre-trained pytorch model from checkpoint. 108 | ''' 109 | print("loading model") 110 | filename = create_filename(args.ckptdir, args, isbest) 111 | print(filename) 112 | if os.path.isfile(filename): 113 | print("=> loading checkpoint '{}'".format(filename)) 114 | ckpt = torch.load(filename) 115 | else: 116 | print("Checkpoint does not exist!") 117 | print("Checked path -- {}".format(filename)) 118 | print("Make sure you have provided the correct path!") 119 | print("You may have forgotten to train a model for this dataset.") 120 | print() 121 | print("To train one of the paper's models, run the following") 122 | print(">> python train.py --dataset=DATASET_NAME") 123 | print() 124 | raise Exception("File not found.") 125 | return ckpt 126 | 127 | def preprocess_cg(cg): 128 | """Pre-process computation graph.""" 129 | if use_cuda: 130 | preprocessed_cg_tensor = torch.from_numpy(cg).cuda() 131 | else: 132 | preprocessed_cg_tensor = torch.from_numpy(cg) 133 | 134 | preprocessed_cg_tensor.unsqueeze_(0) 135 | return Variable(preprocessed_cg_tensor, requires_grad=False) 136 | 137 | def load_model(path): 138 | """Load a pytorch model.""" 139 | model = torch.load(path) 140 | model.eval() 141 | if use_cuda: 142 | model.cuda() 143 | 144 | for p in model.features.parameters(): 145 | p.requires_grad = False 146 | for p in model.classifier.parameters(): 147 | p.requires_grad = False 148 | 149 | return model 150 | 151 | 152 | def load_cg(path): 153 | """Load a computation graph.""" 154 | cg = pickle.load(open(path)) 155 | return cg 156 | 157 | 158 | def save(mask_cg): 159 | """Save a rendering of the computation graph mask.""" 160 | mask = mask_cg.cpu().data.numpy()[0] 161 | mask = np.transpose(mask, (1, 2, 0)) 162 | 163 | mask = (mask - np.min(mask)) / np.max(mask) 164 | mask = 1 - mask 165 | 166 | cv2.imwrite("mask.png", np.uint8(255 * mask)) 167 | 168 | def log_matrix(writer, mat, name, epoch, fig_size=(8, 6), dpi=200): 169 | """Save an image of a matrix to disk. 170 | 171 | Args: 172 | - writer : A file writer. 173 | - mat : The matrix to write. 174 | - name : Name of the file to save. 175 | - epoch : Epoch number. 176 | - fig_size : Size to of the figure to save. 177 | - dpi : Resolution. 178 | """ 179 | plt.switch_backend("agg") 180 | fig = plt.figure(figsize=fig_size, dpi=dpi) 181 | mat = mat.cpu().detach().numpy() 182 | if mat.ndim == 1: 183 | mat = mat[:, np.newaxis] 184 | plt.imshow(mat, cmap=plt.get_cmap("BuPu")) 185 | cbar = plt.colorbar() 186 | cbar.solids.set_edgecolor("face") 187 | 188 | plt.tight_layout() 189 | fig.canvas.draw() 190 | writer.add_image(name, tensorboardX.utils.figure_to_image(fig), epoch) 191 | 192 | 193 | def denoise_graph(adj, node_idx, feat=None, label=None, threshold=None, threshold_num=None, max_component=True): 194 | """Cleaning a graph by thresholding its node values. 195 | 196 | Args: 197 | - adj : Adjacency matrix. 198 | - node_idx : Index of node to highlight (TODO ?) 199 | - feat : An array of node features. 200 | - label : A list of node labels. 201 | - threshold : The weight threshold. 202 | - theshold_num : The maximum number of nodes to threshold. 203 | - max_component : TODO 204 | """ 205 | num_nodes = adj.shape[-1] 206 | G = nx.Graph() 207 | G.add_nodes_from(range(num_nodes)) 208 | G.nodes[node_idx]["self"] = 1 209 | if feat is not None: 210 | for node in G.nodes(): 211 | G.nodes[node]["feat"] = feat[node] 212 | if label is not None: 213 | for node in G.nodes(): 214 | G.nodes[node]["label"] = label[node] 215 | 216 | if threshold_num is not None: 217 | # this is for symmetric graphs: edges are repeated twice in adj 218 | adj_threshold_num = threshold_num * 2 219 | #adj += np.random.rand(adj.shape[0], adj.shape[1]) * 1e-4 220 | neigh_size = len(adj[adj > 0]) 221 | threshold_num = min(neigh_size, adj_threshold_num) 222 | threshold = np.sort(adj[adj > 0])[-threshold_num] 223 | 224 | if threshold is not None: 225 | weighted_edge_list = [ 226 | (i, j, adj[i, j]) 227 | for i in range(num_nodes) 228 | for j in range(num_nodes) 229 | if adj[i, j] >= threshold 230 | ] 231 | else: 232 | weighted_edge_list = [ 233 | (i, j, adj[i, j]) 234 | for i in range(num_nodes) 235 | for j in range(num_nodes) 236 | if adj[i, j] > 1e-6 237 | ] 238 | G.add_weighted_edges_from(weighted_edge_list) 239 | if max_component: 240 | largest_cc = max(nx.connected_components(G), key=len) 241 | G = G.subgraph(largest_cc).copy() 242 | else: 243 | # remove zero degree nodes 244 | G.remove_nodes_from(list(nx.isolates(G))) 245 | return G 246 | 247 | # TODO: unify log_graph and log_graph2 248 | def log_graph( 249 | writer, 250 | Gc, 251 | name, 252 | identify_self=True, 253 | nodecolor="label", 254 | epoch=0, 255 | fig_size=(4, 3), 256 | dpi=300, 257 | label_node_feat=False, 258 | edge_vmax=None, 259 | args=None, 260 | ): 261 | """ 262 | Args: 263 | nodecolor: the color of node, can be determined by 'label', or 'feat'. For feat, it needs to 264 | be one-hot' 265 | """ 266 | cmap = plt.get_cmap("Set1") 267 | plt.switch_backend("agg") 268 | fig = plt.figure(figsize=fig_size, dpi=dpi) 269 | 270 | node_colors = [] 271 | # edge_colors = [min(max(w, 0.0), 1.0) for (u,v,w) in Gc.edges.data('weight', default=1)] 272 | edge_colors = [w for (u, v, w) in Gc.edges.data("weight", default=1)] 273 | 274 | # maximum value for node color 275 | vmax = 8 276 | for i in Gc.nodes(): 277 | if nodecolor == "feat" and "feat" in Gc.nodes[i]: 278 | num_classes = Gc.nodes[i]["feat"].size()[0] 279 | if num_classes >= 10: 280 | cmap = plt.get_cmap("tab20") 281 | vmax = 19 282 | elif num_classes >= 8: 283 | cmap = plt.get_cmap("tab10") 284 | vmax = 9 285 | break 286 | 287 | feat_labels = {} 288 | for i in Gc.nodes(): 289 | if identify_self and "self" in Gc.nodes[i]: 290 | node_colors.append(0) 291 | elif nodecolor == "label" and "label" in Gc.nodes[i]: 292 | node_colors.append(Gc.nodes[i]["label"] + 1) 293 | elif nodecolor == "feat" and "feat" in Gc.nodes[i]: 294 | # print(Gc.nodes[i]['feat']) 295 | feat = Gc.nodes[i]["feat"].detach().numpy() 296 | # idx with pos val in 1D array 297 | feat_class = 0 298 | for j in range(len(feat)): 299 | if feat[j] == 1: 300 | feat_class = j 301 | break 302 | node_colors.append(feat_class) 303 | feat_labels[i] = feat_class 304 | else: 305 | node_colors.append(1) 306 | if not label_node_feat: 307 | feat_labels = None 308 | 309 | plt.switch_backend("agg") 310 | fig = plt.figure(figsize=fig_size, dpi=dpi) 311 | 312 | if Gc.number_of_nodes() == 0: 313 | raise Exception("empty graph") 314 | if Gc.number_of_edges() == 0: 315 | raise Exception("empty edge") 316 | # remove_nodes = [] 317 | # for u in Gc.nodes(): 318 | # if Gc 319 | pos_layout = nx.kamada_kawai_layout(Gc, weight=None) 320 | # pos_layout = nx.spring_layout(Gc, weight=None) 321 | 322 | weights = [d for (u, v, d) in Gc.edges(data="weight", default=1)] 323 | if edge_vmax is None: 324 | edge_vmax = statistics.median_high( 325 | [d for (u, v, d) in Gc.edges(data="weight", default=1)] 326 | ) 327 | min_color = min([d for (u, v, d) in Gc.edges(data="weight", default=1)]) 328 | # color range: gray to black 329 | edge_vmin = 2 * min_color - edge_vmax 330 | nx.draw( 331 | Gc, 332 | pos=pos_layout, 333 | with_labels=False, 334 | font_size=4, 335 | labels=feat_labels, 336 | node_color=node_colors, 337 | vmin=0, 338 | vmax=vmax, 339 | cmap=cmap, 340 | edge_color=edge_colors, 341 | edge_cmap=plt.get_cmap("Greys"), 342 | edge_vmin=edge_vmin, 343 | edge_vmax=edge_vmax, 344 | width=1.0, 345 | node_size=50, 346 | alpha=0.8, 347 | ) 348 | fig.axes[0].xaxis.set_visible(False) 349 | fig.canvas.draw() 350 | 351 | logdir = "log" if not hasattr(args, "logdir") or not args.logdir else str(args.logdir) 352 | if nodecolor != "feat": 353 | name += gen_explainer_prefix(args) 354 | save_path = os.path.join(logdir, name + "_" + str(epoch) + ".pdf") 355 | print(logdir + "/" + name + gen_explainer_prefix(args) + "_" + str(epoch) + ".pdf") 356 | os.makedirs(os.path.dirname(save_path), exist_ok=True) 357 | plt.savefig(save_path, format="pdf") 358 | 359 | img = tensorboardX.utils.figure_to_image(fig) 360 | writer.add_image(name, img, epoch) 361 | 362 | 363 | def plot_cmap(cmap, ncolor): 364 | """ 365 | A convenient function to plot colors of a matplotlib cmap 366 | Credit goes to http://gvallver.perso.univ-pau.fr/?p=712 367 | 368 | Args: 369 | ncolor (int): number of color to show 370 | cmap: a cmap object or a matplotlib color name 371 | """ 372 | 373 | if isinstance(cmap, str): 374 | name = cmap 375 | try: 376 | cm = plt.get_cmap(cmap) 377 | except ValueError: 378 | print("WARNINGS :", cmap, " is not a known colormap") 379 | cm = plt.cm.gray 380 | else: 381 | cm = cmap 382 | name = cm.name 383 | 384 | with matplotlib.rc_context(matplotlib.rcParamsDefault): 385 | fig = plt.figure(figsize=(12, 1), frameon=False) 386 | ax = fig.add_subplot(111) 387 | ax.pcolor(np.linspace(1, ncolor, ncolor).reshape(1, ncolor), cmap=cm) 388 | ax.set_title(name) 389 | xt = ax.set_xticks([]) 390 | yt = ax.set_yticks([]) 391 | return fig 392 | 393 | 394 | def plot_cmap_tb(writer, cmap, ncolor, name): 395 | """Plot the color map used for plot.""" 396 | fig = plot_cmap(cmap, ncolor) 397 | img = tensorboardX.utils.figure_to_image(fig) 398 | writer.add_image(name, img, 0) 399 | 400 | 401 | def sparse_mx_to_torch_sparse_tensor(sparse_mx): 402 | """Convert a scipy sparse matrix to a torch sparse tensor.""" 403 | sparse_mx = sparse_mx.tocoo().astype(np.float32) 404 | indices = torch.from_numpy( 405 | np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64) 406 | ) 407 | values = torch.from_numpy(sparse_mx.data) 408 | shape = torch.Size(sparse_mx.shape) 409 | return torch.sparse.FloatTensor(indices, values, shape) 410 | 411 | def numpy_to_torch(img, requires_grad=True): 412 | if len(img.shape) < 3: 413 | output = np.float32([img]) 414 | else: 415 | output = np.transpose(img, (2, 0, 1)) 416 | 417 | output = torch.from_numpy(output) 418 | if use_cuda: 419 | output = output.cuda() 420 | 421 | output.unsqueeze_(0) 422 | v = Variable(output, requires_grad=requires_grad) 423 | return v 424 | 425 | 426 | def read_graphfile(datadir, dataname, max_nodes=None, edge_labels=False): 427 | """ Read data from https://ls11-www.cs.tu-dortmund.de/staff/morris/graphkerneldatasets 428 | graph index starts with 1 in file 429 | 430 | Returns: 431 | List of networkx objects with graph and node labels 432 | """ 433 | prefix = os.path.join(datadir, dataname, dataname) 434 | filename_graph_indic = prefix + "_graph_indicator.txt" 435 | # index of graphs that a given node belongs to 436 | graph_indic = {} 437 | with open(filename_graph_indic) as f: 438 | i = 1 439 | for line in f: 440 | line = line.strip("\n") 441 | graph_indic[i] = int(line) 442 | i += 1 443 | 444 | filename_nodes = prefix + "_node_labels.txt" 445 | node_labels = [] 446 | min_label_val = None 447 | try: 448 | with open(filename_nodes) as f: 449 | for line in f: 450 | line = line.strip("\n") 451 | l = int(line) 452 | node_labels += [l] 453 | if min_label_val is None or min_label_val > l: 454 | min_label_val = l 455 | # assume that node labels are consecutive 456 | num_unique_node_labels = max(node_labels) - min_label_val + 1 457 | node_labels = [l - min_label_val for l in node_labels] 458 | except IOError: 459 | print("No node labels") 460 | 461 | filename_node_attrs = prefix + "_node_attributes.txt" 462 | node_attrs = [] 463 | try: 464 | with open(filename_node_attrs) as f: 465 | for line in f: 466 | line = line.strip("\s\n") 467 | attrs = [ 468 | float(attr) for attr in re.split("[,\s]+", line) if not attr == "" 469 | ] 470 | node_attrs.append(np.array(attrs)) 471 | except IOError: 472 | print("No node attributes") 473 | 474 | label_has_zero = False 475 | filename_graphs = prefix + "_graph_labels.txt" 476 | graph_labels = [] 477 | 478 | label_vals = [] 479 | with open(filename_graphs) as f: 480 | for line in f: 481 | line = line.strip("\n") 482 | val = int(line) 483 | if val not in label_vals: 484 | label_vals.append(val) 485 | graph_labels.append(val) 486 | 487 | label_map_to_int = {val: i for i, val in enumerate(label_vals)} 488 | graph_labels = np.array([label_map_to_int[l] for l in graph_labels]) 489 | 490 | if edge_labels: 491 | # For Tox21_AHR we want to know edge labels 492 | filename_edges = prefix + "_edge_labels.txt" 493 | edge_labels = [] 494 | 495 | edge_label_vals = [] 496 | with open(filename_edges) as f: 497 | for line in f: 498 | line = line.strip("\n") 499 | val = int(line) 500 | if val not in edge_label_vals: 501 | edge_label_vals.append(val) 502 | edge_labels.append(val) 503 | 504 | edge_label_map_to_int = {val: i for i, val in enumerate(edge_label_vals)} 505 | 506 | filename_adj = prefix + "_A.txt" 507 | adj_list = {i: [] for i in range(1, len(graph_labels) + 1)} 508 | # edge_label_list={i:[] for i in range(1,len(graph_labels)+1)} 509 | index_graph = {i: [] for i in range(1, len(graph_labels) + 1)} 510 | num_edges = 0 511 | with open(filename_adj) as f: 512 | for line in f: 513 | line = line.strip("\n").split(",") 514 | e0, e1 = (int(line[0].strip(" ")), int(line[1].strip(" "))) 515 | adj_list[graph_indic[e0]].append((e0, e1)) 516 | index_graph[graph_indic[e0]] += [e0, e1] 517 | # edge_label_list[graph_indic[e0]].append(edge_labels[num_edges]) 518 | num_edges += 1 519 | for k in index_graph.keys(): 520 | index_graph[k] = [u - 1 for u in set(index_graph[k])] 521 | 522 | graphs = [] 523 | for i in range(1, 1 + len(adj_list)): 524 | # indexed from 1 here 525 | G = nx.from_edgelist(adj_list[i]) 526 | 527 | if max_nodes is not None and G.number_of_nodes() > max_nodes: 528 | continue 529 | 530 | # add features and labels 531 | G.graph["label"] = graph_labels[i - 1] 532 | 533 | # Special label for aromaticity experiment 534 | # aromatic_edge = 2 535 | # G.graph['aromatic'] = aromatic_edge in edge_label_list[i] 536 | 537 | for u in G.nodes(): 538 | if len(node_labels) > 0: 539 | node_label_one_hot = [0] * num_unique_node_labels 540 | node_label = node_labels[u - 1] 541 | node_label_one_hot[node_label] = 1 542 | G.nodes[u]["label"] = node_label_one_hot 543 | if len(node_attrs) > 0: 544 | G.nodes[u]["feat"] = node_attrs[u - 1] 545 | if len(node_attrs) > 0: 546 | G.graph["feat_dim"] = node_attrs[0].shape[0] 547 | 548 | # relabeling 549 | mapping = {} 550 | it = 0 551 | if float(nx.__version__) < 2.0: 552 | for n in G.nodes(): 553 | mapping[n] = it 554 | it += 1 555 | else: 556 | for n in G.nodes: 557 | mapping[n] = it 558 | it += 1 559 | 560 | # indexed from 0 561 | graphs.append(nx.relabel_nodes(G, mapping)) 562 | return graphs 563 | 564 | 565 | def read_biosnap(datadir, edgelist_file, label_file, feat_file=None, concat=True): 566 | """ Read data from BioSnap 567 | 568 | Returns: 569 | List of networkx objects with graph and node labels 570 | """ 571 | G = nx.Graph() 572 | delimiter = "\t" if "tsv" in edgelist_file else "," 573 | print(delimiter) 574 | df = pd.read_csv( 575 | os.path.join(datadir, edgelist_file), delimiter=delimiter, header=None 576 | ) 577 | data = list(map(tuple, df.values.tolist())) 578 | G.add_edges_from(data) 579 | print("Total nodes: ", G.number_of_nodes()) 580 | 581 | G = max(nx.connected_component_subgraphs(G), key=len) 582 | print("Total nodes in largest connected component: ", G.number_of_nodes()) 583 | 584 | df = pd.read_csv(os.path.join(datadir, label_file), delimiter="\t", usecols=[0, 1]) 585 | data = list(map(tuple, df.values.tolist())) 586 | 587 | missing_node = 0 588 | for line in data: 589 | if int(line[0]) not in G: 590 | missing_node += 1 591 | else: 592 | G.nodes[int(line[0])]["label"] = int(line[1] == "Essential") 593 | 594 | print("missing node: ", missing_node) 595 | 596 | missing_label = 0 597 | remove_nodes = [] 598 | for u in G.nodes(): 599 | if "label" not in G.nodes[u]: 600 | missing_label += 1 601 | remove_nodes.append(u) 602 | G.remove_nodes_from(remove_nodes) 603 | print("missing_label: ", missing_label) 604 | 605 | if feat_file is None: 606 | feature_generator = featgen.ConstFeatureGen(np.ones(10, dtype=float)) 607 | feature_generator.gen_node_features(G) 608 | else: 609 | df = pd.read_csv(os.path.join(datadir, feat_file), delimiter=",") 610 | data = np.array(df.values) 611 | print("Feat shape: ", data.shape) 612 | 613 | for row in data: 614 | if int(row[0]) in G: 615 | if concat: 616 | node = int(row[0]) 617 | onehot = np.zeros(10) 618 | onehot[min(G.degree[node], 10) - 1] = 1.0 619 | G.nodes[node]["feat"] = np.hstack( 620 | (np.log(row[1:] + 0.1), [1.0], onehot) 621 | ) 622 | else: 623 | G.nodes[int(row[0])]["feat"] = np.log(row[1:] + 0.1) 624 | 625 | missing_feat = 0 626 | remove_nodes = [] 627 | for u in G.nodes(): 628 | if "feat" not in G.nodes[u]: 629 | missing_feat += 1 630 | remove_nodes.append(u) 631 | G.remove_nodes_from(remove_nodes) 632 | print("missing feat: ", missing_feat) 633 | 634 | return G 635 | 636 | 637 | def build_aromaticity_dataset(): 638 | filename = "data/tox21_10k_data_all.sdf" 639 | basename = filename.split(".")[0] 640 | collector = [] 641 | sdprovider = Chem.SDMolSupplier(filename) 642 | for i,mol in enumerate(sdprovider): 643 | try: 644 | moldict = {} 645 | moldict['smiles'] = Chem.MolToSmiles(mol) 646 | #Parse Data 647 | for propname in mol.GetPropNames(): 648 | moldict[propname] = mol.GetProp(propname) 649 | nb_bonds = len(mol.GetBonds()) 650 | is_aromatic = False; aromatic_bonds = [] 651 | for j in range(nb_bonds): 652 | if mol.GetBondWithIdx(j).GetIsAromatic(): 653 | aromatic_bonds.append(j) 654 | is_aromatic = True 655 | moldict['aromaticity'] = is_aromatic 656 | moldict['aromatic_bonds'] = aromatic_bonds 657 | collector.append(moldict) 658 | except: 659 | print("Molecule %s failed"%i) 660 | data = pd.DataFrame(collector) 661 | data.to_csv(basename + '_pandas.csv') 662 | 663 | 664 | def gen_train_plt_name(args): 665 | return "results/" + gen_prefix(args) + ".png" 666 | 667 | 668 | def log_assignment(assign_tensor, writer, epoch, batch_idx): 669 | plt.switch_backend("agg") 670 | fig = plt.figure(figsize=(8, 6), dpi=300) 671 | 672 | # has to be smaller than args.batch_size 673 | for i in range(len(batch_idx)): 674 | plt.subplot(2, 2, i + 1) 675 | plt.imshow( 676 | assign_tensor.cpu().data.numpy()[batch_idx[i]], cmap=plt.get_cmap("BuPu") 677 | ) 678 | cbar = plt.colorbar() 679 | cbar.solids.set_edgecolor("face") 680 | plt.tight_layout() 681 | fig.canvas.draw() 682 | 683 | data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="") 684 | data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) 685 | writer.add_image("assignment", data, epoch) 686 | 687 | # TODO: unify log_graph and log_graph2 688 | def log_graph2(adj, batch_num_nodes, writer, epoch, batch_idx, assign_tensor=None): 689 | plt.switch_backend("agg") 690 | fig = plt.figure(figsize=(8, 6), dpi=300) 691 | 692 | for i in range(len(batch_idx)): 693 | ax = plt.subplot(2, 2, i + 1) 694 | num_nodes = batch_num_nodes[batch_idx[i]] 695 | adj_matrix = adj[batch_idx[i], :num_nodes, :num_nodes].cpu().data.numpy() 696 | G = nx.from_numpy_matrix(adj_matrix) 697 | nx.draw( 698 | G, 699 | pos=nx.spring_layout(G), 700 | with_labels=True, 701 | node_color="#336699", 702 | edge_color="grey", 703 | width=0.5, 704 | node_size=300, 705 | alpha=0.7, 706 | ) 707 | ax.xaxis.set_visible(False) 708 | 709 | plt.tight_layout() 710 | fig.canvas.draw() 711 | 712 | data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="") 713 | data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) 714 | writer.add_image("graphs", data, epoch) 715 | 716 | # log a label-less version 717 | # fig = plt.figure(figsize=(8,6), dpi=300) 718 | # for i in range(len(batch_idx)): 719 | # ax = plt.subplot(2, 2, i+1) 720 | # num_nodes = batch_num_nodes[batch_idx[i]] 721 | # adj_matrix = adj[batch_idx[i], :num_nodes, :num_nodes].cpu().data.numpy() 722 | # G = nx.from_numpy_matrix(adj_matrix) 723 | # nx.draw(G, pos=nx.spring_layout(G), with_labels=False, node_color='#336699', 724 | # edge_color='grey', width=0.5, node_size=25, 725 | # alpha=0.8) 726 | 727 | # plt.tight_layout() 728 | # fig.canvas.draw() 729 | 730 | # data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='') 731 | # data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) 732 | # writer.add_image('graphs_no_label', data, epoch) 733 | 734 | # colored according to assignment 735 | assignment = assign_tensor.cpu().data.numpy() 736 | fig = plt.figure(figsize=(8, 6), dpi=300) 737 | 738 | num_clusters = assignment.shape[2] 739 | all_colors = np.array(range(num_clusters)) 740 | 741 | for i in range(len(batch_idx)): 742 | ax = plt.subplot(2, 2, i + 1) 743 | num_nodes = batch_num_nodes[batch_idx[i]] 744 | adj_matrix = adj[batch_idx[i], :num_nodes, :num_nodes].cpu().data.numpy() 745 | 746 | label = np.argmax(assignment[batch_idx[i]], axis=1).astype(int) 747 | label = label[: batch_num_nodes[batch_idx[i]]] 748 | node_colors = all_colors[label] 749 | 750 | G = nx.from_numpy_matrix(adj_matrix) 751 | nx.draw( 752 | G, 753 | pos=nx.spring_layout(G), 754 | with_labels=False, 755 | node_color=node_colors, 756 | edge_color="grey", 757 | width=0.4, 758 | node_size=50, 759 | cmap=plt.get_cmap("Set1"), 760 | vmin=0, 761 | vmax=num_clusters - 1, 762 | alpha=0.8, 763 | ) 764 | 765 | plt.tight_layout() 766 | fig.canvas.draw() 767 | 768 | data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="") 769 | data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) 770 | writer.add_image("graphs_colored", data, epoch) 771 | -------------------------------------------------------------------------------- /utils/math_utils.py: -------------------------------------------------------------------------------- 1 | """ math_utils.py 2 | 3 | Math utilities. 4 | """ 5 | 6 | import torch 7 | 8 | def exp_moving_avg(x, decay=0.9): 9 | '''Exponentially decaying moving average. 10 | ''' 11 | shadow = x[0] 12 | a = [shadow] 13 | for v in x[1:]: 14 | shadow -= (1-decay) * (shadow-v) 15 | a.append(shadow) 16 | return a 17 | 18 | def tv_norm(input, tv_beta): 19 | '''Total variation norm 20 | ''' 21 | img = input[0, 0, :] 22 | row_grad = torch.mean(torch.abs((img[:-1, :] - img[1:, :])).pow(tv_beta)) 23 | col_grad = torch.mean(torch.abs((img[:, :-1] - img[:, 1:])).pow(tv_beta)) 24 | return row_grad + col_grad -------------------------------------------------------------------------------- /utils/parser_utils.py: -------------------------------------------------------------------------------- 1 | """ parser_utils.py 2 | 3 | Parsing utilities. 4 | """ 5 | import argparse 6 | 7 | def parse_optimizer(parser): 8 | '''Set optimizer parameters''' 9 | opt_parser = parser.add_argument_group() 10 | opt_parser.add_argument('--opt', dest='opt', type=str, 11 | help='Type of optimizer') 12 | opt_parser.add_argument('--opt-scheduler', dest='opt_scheduler', type=str, 13 | help='Type of optimizer scheduler. By default none') 14 | opt_parser.add_argument('--opt-restart', dest='opt_restart', type=int, 15 | help='Number of epochs before restart (by default set to 0 which means no restart)') 16 | opt_parser.add_argument('--opt-decay-step', dest='opt_decay_step', type=int, 17 | help='Number of epochs before decay') 18 | opt_parser.add_argument('--opt-decay-rate', dest='opt_decay_rate', type=float, 19 | help='Learning rate decay ratio') 20 | opt_parser.add_argument('--lr', dest='lr', type=float, 21 | help='Learning rate.') 22 | opt_parser.add_argument('--clip', dest='clip', type=float, 23 | help='Gradient clipping.') 24 | 25 | -------------------------------------------------------------------------------- /utils/synthetic_structsim.py: -------------------------------------------------------------------------------- 1 | """synthetic_structsim.py 2 | 3 | Utilities for generating certain graph shapes. 4 | """ 5 | import math 6 | 7 | import networkx as nx 8 | import numpy as np 9 | 10 | # Following GraphWave's representation of structural similarity 11 | 12 | 13 | def clique(start, nb_nodes, nb_to_remove=0, role_start=0): 14 | """ Defines a clique (complete graph on nb_nodes nodes, 15 | with nb_to_remove edges that will have to be removed), 16 | index of nodes starting at start 17 | and role_ids at role_start 18 | INPUT: 19 | ------------- 20 | start : starting index for the shape 21 | nb_nodes : int correspondingraph to the nb of nodes in the clique 22 | role_start : starting index for the roles 23 | nb_to_remove: int-- numb of edges to remove (unif at RDM) 24 | OUTPUT: 25 | ------------- 26 | graph : a house shape graph, with ids beginning at start 27 | roles : list of the roles of the nodes (indexed starting at 28 | role_start) 29 | """ 30 | a = np.ones((nb_nodes, nb_nodes)) 31 | np.fill_diagonal(a, 0) 32 | graph = nx.from_numpy_matrix(a) 33 | edge_list = graph.edges().keys() 34 | roles = [role_start] * nb_nodes 35 | if nb_to_remove > 0: 36 | lst = np.random.choice(len(edge_list), nb_to_remove, replace=False) 37 | print(edge_list, lst) 38 | to_delete = [edge_list[e] for e in lst] 39 | graph.remove_edges_from(to_delete) 40 | for e in lst: 41 | print(edge_list[e][0]) 42 | print(len(roles)) 43 | roles[edge_list[e][0]] += 1 44 | roles[edge_list[e][1]] += 1 45 | mapping_graph = {k: (k + start) for k in range(nb_nodes)} 46 | graph = nx.relabel_nodes(graph, mapping_graph) 47 | return graph, roles 48 | 49 | 50 | def cycle(start, len_cycle, role_start=0): 51 | """Builds a cycle graph, with index of nodes starting at start 52 | and role_ids at role_start 53 | INPUT: 54 | ------------- 55 | start : starting index for the shape 56 | role_start : starting index for the roles 57 | OUTPUT: 58 | ------------- 59 | graph : a house shape graph, with ids beginning at start 60 | roles : list of the roles of the nodes (indexed starting at 61 | role_start) 62 | """ 63 | graph = nx.Graph() 64 | graph.add_nodes_from(range(start, start + len_cycle)) 65 | for i in range(len_cycle - 1): 66 | graph.add_edges_from([(start + i, start + i + 1)]) 67 | graph.add_edges_from([(start + len_cycle - 1, start)]) 68 | roles = [role_start] * len_cycle 69 | return graph, roles 70 | 71 | 72 | def diamond(start, role_start=0): 73 | """Builds a diamond graph, with index of nodes starting at start 74 | and role_ids at role_start 75 | INPUT: 76 | ------------- 77 | start : starting index for the shape 78 | role_start : starting index for the roles 79 | OUTPUT: 80 | ------------- 81 | graph : a house shape graph, with ids beginning at start 82 | roles : list of the roles of the nodes (indexed starting at 83 | role_start) 84 | """ 85 | graph = nx.Graph() 86 | graph.add_nodes_from(range(start, start + 6)) 87 | graph.add_edges_from( 88 | [ 89 | (start, start + 1), 90 | (start + 1, start + 2), 91 | (start + 2, start + 3), 92 | (start + 3, start), 93 | ] 94 | ) 95 | graph.add_edges_from( 96 | [ 97 | (start + 4, start), 98 | (start + 4, start + 1), 99 | (start + 4, start + 2), 100 | (start + 4, start + 3), 101 | ] 102 | ) 103 | graph.add_edges_from( 104 | [ 105 | (start + 5, start), 106 | (start + 5, start + 1), 107 | (start + 5, start + 2), 108 | (start + 5, start + 3), 109 | ] 110 | ) 111 | roles = [role_start] * 6 112 | return graph, roles 113 | 114 | 115 | def tree(start, height, r=2, role_start=0): 116 | """Builds a balanced r-tree of height h 117 | INPUT: 118 | ------------- 119 | start : starting index for the shape 120 | height : int height of the tree 121 | r : int number of branches per node 122 | role_start : starting index for the roles 123 | OUTPUT: 124 | ------------- 125 | graph : a tree shape graph, with ids beginning at start 126 | roles : list of the roles of the nodes (indexed starting at role_start) 127 | """ 128 | graph = nx.balanced_tree(r, height) 129 | roles = [0] * graph.number_of_nodes() 130 | return graph, roles 131 | 132 | 133 | def fan(start, nb_branches, role_start=0): 134 | """Builds a fan-like graph, with index of nodes starting at start 135 | and role_ids at role_start 136 | INPUT: 137 | ------------- 138 | nb_branches : int correspondingraph to the nb of fan branches 139 | start : starting index for the shape 140 | role_start : starting index for the roles 141 | OUTPUT: 142 | ------------- 143 | graph : a house shape graph, with ids beginning at start 144 | roles : list of the roles of the nodes (indexed starting at 145 | role_start) 146 | """ 147 | graph, roles = star(start, nb_branches, role_start=role_start) 148 | for k in range(1, nb_branches - 1): 149 | roles[k] += 1 150 | roles[k + 1] += 1 151 | graph.add_edges_from([(start + k, start + k + 1)]) 152 | return graph, roles 153 | 154 | 155 | def ba(start, width, role_start=0, m=5): 156 | """Builds a BA preferential attachment graph, with index of nodes starting at start 157 | and role_ids at role_start 158 | INPUT: 159 | ------------- 160 | start : starting index for the shape 161 | width : int size of the graph 162 | role_start : starting index for the roles 163 | OUTPUT: 164 | ------------- 165 | graph : a house shape graph, with ids beginning at start 166 | roles : list of the roles of the nodes (indexed starting at 167 | role_start) 168 | """ 169 | graph = nx.barabasi_albert_graph(width, m) 170 | graph.add_nodes_from(range(start, start + width)) 171 | nids = sorted(graph) 172 | mapping = {nid: start + i for i, nid in enumerate(nids)} 173 | graph = nx.relabel_nodes(graph, mapping) 174 | roles = [role_start for i in range(width)] 175 | return graph, roles 176 | 177 | 178 | def house(start, role_start=0): 179 | """Builds a house-like graph, with index of nodes starting at start 180 | and role_ids at role_start 181 | INPUT: 182 | ------------- 183 | start : starting index for the shape 184 | role_start : starting index for the roles 185 | OUTPUT: 186 | ------------- 187 | graph : a house shape graph, with ids beginning at start 188 | roles : list of the roles of the nodes (indexed starting at 189 | role_start) 190 | """ 191 | graph = nx.Graph() 192 | graph.add_nodes_from(range(start, start + 5)) 193 | graph.add_edges_from( 194 | [ 195 | (start, start + 1), 196 | (start + 1, start + 2), 197 | (start + 2, start + 3), 198 | (start + 3, start), 199 | ] 200 | ) 201 | # graph.add_edges_from([(start, start + 2), (start + 1, start + 3)]) 202 | graph.add_edges_from([(start + 4, start), (start + 4, start + 1)]) 203 | roles = [role_start, role_start, role_start + 1, role_start + 1, role_start + 2] 204 | return graph, roles 205 | 206 | 207 | def grid(start, dim=2, role_start=0): 208 | """ Builds a 2by2 grid 209 | """ 210 | grid_G = nx.grid_graph([dim, dim]) 211 | grid_G = nx.convert_node_labels_to_integers(grid_G, first_label=start) 212 | roles = [role_start for i in grid_G.nodes()] 213 | return grid_G, roles 214 | 215 | 216 | def star(start, nb_branches, role_start=0): 217 | """Builds a star graph, with index of nodes starting at start 218 | and role_ids at role_start 219 | INPUT: 220 | ------------- 221 | nb_branches : int correspondingraph to the nb of star branches 222 | start : starting index for the shape 223 | role_start : starting index for the roles 224 | OUTPUT: 225 | ------------- 226 | graph : a house shape graph, with ids beginning at start 227 | roles : list of the roles of the nodes (indexed starting at 228 | role_start) 229 | """ 230 | graph = nx.Graph() 231 | graph.add_nodes_from(range(start, start + nb_branches + 1)) 232 | for k in range(1, nb_branches + 1): 233 | graph.add_edges_from([(start, start + k)]) 234 | roles = [role_start + 1] * (nb_branches + 1) 235 | roles[0] = role_start 236 | return graph, roles 237 | 238 | 239 | def path(start, width, role_start=0): 240 | """Builds a path graph, with index of nodes starting at start 241 | and role_ids at role_start 242 | INPUT: 243 | ------------- 244 | start : starting index for the shape 245 | width : int length of the path 246 | role_start : starting index for the roles 247 | OUTPUT: 248 | ------------- 249 | graph : a house shape graph, with ids beginning at start 250 | roles : list of the roles of the nodes (indexed starting at 251 | role_start) 252 | """ 253 | graph = nx.Graph() 254 | graph.add_nodes_from(range(start, start + width)) 255 | for i in range(width - 1): 256 | graph.add_edges_from([(start + i, start + i + 1)]) 257 | roles = [role_start] * width 258 | roles[0] = role_start + 1 259 | roles[-1] = role_start + 1 260 | return graph, roles 261 | 262 | 263 | def build_graph( 264 | width_basis, 265 | basis_type, 266 | list_shapes, 267 | start=0, 268 | rdm_basis_plugins=False, 269 | add_random_edges=0, 270 | m=5, 271 | ): 272 | """This function creates a basis (scale-free, path, or cycle) 273 | and attaches elements of the type in the list randomly along the basis. 274 | Possibility to add random edges afterwards. 275 | INPUT: 276 | -------------------------------------------------------------------------------------- 277 | width_basis : width (in terms of number of nodes) of the basis 278 | basis_type : (torus, string, or cycle) 279 | shapes : list of shape list (1st arg: type of shape, 280 | next args:args for building the shape, 281 | except for the start) 282 | start : initial nb for the first node 283 | rdm_basis_plugins: boolean. Should the shapes be randomly placed 284 | along the basis (True) or regularly (False)? 285 | add_random_edges : nb of edges to randomly add on the structure 286 | m : number of edges to attach to existing node (for BA graph) 287 | OUTPUT: 288 | -------------------------------------------------------------------------------------- 289 | basis : a nx graph with the particular shape 290 | role_ids : labels for each role 291 | plugins : node ids with the attached shapes 292 | """ 293 | if basis_type == "ba": 294 | basis, role_id = eval(basis_type)(start, width_basis, m=m) 295 | else: 296 | basis, role_id = eval(basis_type)(start, width_basis) 297 | 298 | n_basis, n_shapes = nx.number_of_nodes(basis), len(list_shapes) 299 | start += n_basis # indicator of the id of the next node 300 | 301 | # Sample (with replacement) where to attach the new motifs 302 | if rdm_basis_plugins is True: 303 | plugins = np.random.choice(n_basis, n_shapes, replace=False) 304 | else: 305 | spacing = math.floor(n_basis / n_shapes) 306 | plugins = [int(k * spacing) for k in range(n_shapes)] 307 | seen_shapes = {"basis": [0, n_basis]} 308 | 309 | for shape_id, shape in enumerate(list_shapes): 310 | shape_type = shape[0] 311 | args = [start] 312 | if len(shape) > 1: 313 | args += shape[1:] 314 | args += [0] 315 | graph_s, roles_graph_s = eval(shape_type)(*args) 316 | n_s = nx.number_of_nodes(graph_s) 317 | try: 318 | col_start = seen_shapes[shape_type][0] 319 | except: 320 | col_start = np.max(role_id) + 1 321 | seen_shapes[shape_type] = [col_start, n_s] 322 | # Attach the shape to the basis 323 | basis.add_nodes_from(graph_s.nodes()) 324 | basis.add_edges_from(graph_s.edges()) 325 | basis.add_edges_from([(start, plugins[shape_id])]) 326 | if shape_type == "cycle": 327 | if np.random.random() > 0.5: 328 | a = np.random.randint(1, 4) 329 | b = np.random.randint(1, 4) 330 | basis.add_edges_from([(a + start, b + plugins[shape_id])]) 331 | temp_labels = [r + col_start for r in roles_graph_s] 332 | # temp_labels[0] += 100 * seen_shapes[shape_type][0] 333 | role_id += temp_labels 334 | start += n_s 335 | 336 | if add_random_edges > 0: 337 | # add random edges between nodes: 338 | for p in range(add_random_edges): 339 | src, dest = np.random.choice(nx.number_of_nodes(basis), 2, replace=False) 340 | print(src, dest) 341 | basis.add_edges_from([(src, dest)]) 342 | 343 | return basis, role_id, plugins 344 | -------------------------------------------------------------------------------- /utils/train_utils.py: -------------------------------------------------------------------------------- 1 | '''train_utils.py 2 | 3 | Some training utilities. 4 | ''' 5 | import torch.optim as optim 6 | 7 | def build_optimizer(args, params, weight_decay=0.0): 8 | filter_fn = filter(lambda p : p.requires_grad, params) 9 | if args.opt == 'adam': 10 | optimizer = optim.Adam(filter_fn, lr=args.lr, weight_decay=weight_decay) 11 | elif args.opt == 'sgd': 12 | optimizer = optim.SGD(filter_fn, lr=args.lr, momentum=0.95, weight_decay=weight_decay) 13 | elif args.opt == 'rmsprop': 14 | optimizer = optim.RMSprop(filter_fn, lr=args.lr, weight_decay=weight_decay) 15 | elif args.opt == 'adagrad': 16 | optimizer = optim.Adagrad(filter_fn, lr=args.lr, weight_decay=weight_decay) 17 | if args.opt_scheduler == 'none': 18 | return None, optimizer 19 | elif args.opt_scheduler == 'step': 20 | scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=args.opt_decay_step, gamma=args.opt_decay_rate) 21 | elif args.opt_scheduler == 'cos': 22 | scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.opt_restart) 23 | return scheduler, optimizer 24 | --------------------------------------------------------------------------------