├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── __init__.py ├── best_model.py ├── data └── README.md ├── launch.py ├── models ├── __init__.py ├── base_models.py ├── edge_models.py ├── node_edge_models.py └── node_models.py ├── third_party ├── __init__.py └── gcn │ ├── .gitignore │ ├── LICENSE │ ├── README.md │ ├── __init__.py │ ├── gcn │ ├── __init__.py │ ├── data │ │ ├── ind.citeseer.allx │ │ ├── ind.citeseer.ally │ │ ├── ind.citeseer.graph │ │ ├── ind.citeseer.test.index │ │ ├── ind.citeseer.tx │ │ ├── ind.citeseer.ty │ │ ├── ind.citeseer.x │ │ ├── ind.citeseer.y │ │ ├── ind.cora.allx │ │ ├── ind.cora.ally │ │ ├── ind.cora.graph │ │ ├── ind.cora.test.index │ │ ├── ind.cora.tx │ │ ├── ind.cora.ty │ │ ├── ind.cora.x │ │ ├── ind.cora.y │ │ ├── ind.pubmed.allx │ │ ├── ind.pubmed.ally │ │ ├── ind.pubmed.graph │ │ ├── ind.pubmed.test.index │ │ ├── ind.pubmed.tx │ │ ├── ind.pubmed.ty │ │ ├── ind.pubmed.x │ │ └── ind.pubmed.y │ ├── inits.py │ ├── layers.py │ ├── metrics.py │ ├── models.py │ ├── train.py │ └── utils.py │ └── setup.py ├── train.py └── utils ├── __init__.py ├── data_utils.py ├── link_prediction_utils.py ├── model_utils.py └── train_utils.py /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | We'd love to accept your patches and contributions to this project. There are 4 | just a few small guidelines you need to follow. 5 | 6 | ## Contributor License Agreement 7 | 8 | Contributions to this project must be accompanied by a Contributor License 9 | Agreement. You (or your employer) retain the copyright to your contribution; 10 | this simply gives us permission to use and redistribute your contributions as 11 | part of the project. Head over to to see 12 | your current agreements on file or to sign a new one. 13 | 14 | You generally only need to submit a CLA once, so if you've already submitted one 15 | (even if it was for a different project), you probably don't need to do it 16 | again. 17 | 18 | ## Code reviews 19 | 20 | All submissions, including submissions by project members, require review. We 21 | use GitHub pull requests for this purpose. Consult 22 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more 23 | information on using pull requests. 24 | 25 | ## Community Guidelines 26 | 27 | This project follows [Google's Open Source Community 28 | Guidelines](https://opensource.google.com/conduct/). 29 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Graph Convolutional Neural Networks (GCNN) models 2 | 3 | This repository contains a tensorflow implementation of GCNN models for node 4 | classification, link predicition and joint node classification and link 5 | prediction to supplement the survey paper by Chami et al. 6 | 7 | NOTE: This is not an officially supported Google product. 8 | 9 | ## Code organization 10 | 11 | * `train.py`: trains a model with FLAGS parameters. `train --helpshort` for more information. 12 | . 13 | * `launch.py`: trains several model with varied combinations of parameters. Specify parameters in `launch.py` file. `launch --helpshort` for more information. 14 | 15 | 16 | * `best_model.py`: Parse the logs for multiple training with `launch.py` and finds best model parameters based on validation accuracy. `best_model --helpshort` for more information. 17 | 18 | * `models/` 19 | * `base_models.py`: base model functionnalities (data utils, loss function, metrics etc) 20 | 21 | * `node_models.py`: forward pass implementation of node classification models (including Gat, Gcn, Mlp and SemiEmb) 22 | 23 | * `edge_models.py`: forward pass implementation of link prediction models (including Gae and Vgae) 24 | 25 | * `node_edge_models.py`: forward pass implementation of joint node classification and link prediction 26 | 27 | * `utils/` 28 | 29 | * `model_utils.py`: layers implementation. 30 | 31 | * `link_prediction_utils.py`: implementation of some link prediction heuristics such as common neighbours or adamic adar 32 | 33 | * `data_utils.py`: data processing utils functions 34 | 35 | * `train_utils.py` train utils functions 36 | 37 | * `data/`: contains data files for citation data (cora, citeseer, pubmed) and PPI 38 | 39 | ## Code usage 40 | 41 | 0. Install required libraries. 42 | 43 | 1. Set environment variables 44 | `GCNN_HOME=$(pwd)` 45 | `export PATH="$GCNN_HOME:$PATH"` 46 | 47 | 2. Put datasets the data folder. 48 | 49 | 3. Train GAT on cora with default parameters 50 | 51 | `SAVE_DIRECTORY="/tmp/models/cora/Gat"` 52 | `python train.py --save_dir=$SAVE_DIRECTORY --dataset=cora --model_name=Gat` 53 | 54 | 4. Check results 55 | 56 | `cat $SAVE_DIRECTORY/*.log` 57 | 58 | This model should give approximately 83% test accuracy. 59 | 60 | 5. Launch multiple experiments 61 | 62 | To launch multiple experiments for hyper-parameter search use the `launch.py` script. Update the parameters to search over in the `launch.py` file. For instance to train Gcn on cora with multiple parameters: 63 | 64 | `LAUNCH_DIR="/tmp/launch"` 65 | 66 | `python launch.py --launch_save_dir=$LAUNCH_DIR --launch_model_name=Gcn --launch_dataset=cora --launch_n_runs=3` 67 | 68 | This will create subdirectories `$LAUNCH_DIR/dataset_name/prop_edges_removed` where the log files will be saved. 69 | 70 | 6. Retrieve best model parameters 71 | 72 | `python best_model.py --dir=$LAUNCH_DIR --models=Gcn --target=node_acc --datasets=cora` 73 | 74 | This will create a `best_params` file in `$LAUNCH_DIR` with the best parameters for each (dataset-model-proportion_edges_dropped) combination based on validation metrics. 75 | 76 | `cat $LAUNCH_DIR/best_params` 77 | 78 | ## More examples 79 | 80 | * Reproduce Gat results on cora (83.5% average test accuracy): 81 | 82 | `python train.py --model_name=Gat --lr=0.005 --node_l2_reg=0.0005 --dataset=cora --p_drop_node=0.6 --n_att_node=8,1 --n_hidden_node=8 --save_dir=/tmp/models/cora/gat 83 | --epochs=10000 --patience=100 --normalize_adj=False --sparse_features=True` 84 | 85 | * Reproduce Gcn results on cora (81.5% average test accuracy): 86 | 87 | `python train.py --model_name=Gcn --epochs=200 --patience=10 --lr=0.01 --node_l2_reg=0.0005 88 | --dataset=cora --p_drop_node=0.5 --n_hidden_node=16 89 | --save_dir=/tmp/models/cora/gcn --normalize_adj=True --sparse_features=True` 90 | 91 | * Better Gcn results on cora (83.1% average test accuracy): 92 | 93 | `python train.py --model_name=Gcn --epochs=10000 --patience=100 --lr=0.005 --node_l2_reg=0.0005 94 | --dataset=cora --p_drop_node=0.6 --input_dim=1433 --n_hidden_node=128 95 | --save_dir=/tmp/models/cora/gcn_best --normalize_adj=True --sparse_features=True` 96 | 97 | * Train Gae on Cora with 10% of edges removed 98 | 99 | `python train.py --model_name=Gae --epochs=10000 --patience=50 --lr=0.005 --p_drop_edge=0. --n_hidden_edge=256-128 --save_dir=/tmp/models/cora/Gae --edge_l2_reg=0 --att_mechanism=dot --normalize_adj=True --edge_loss=w_sigmoid_ce --dataset=cora --sparse_features=True --drop_edge_prop=10` 100 | 101 | ## Implementing a new model 102 | 103 | To add a new model: 104 | 105 | * Create a model class inheriting from one of the base class (NodeModel, EdgeModel or NodeEdgeModel) and implement the inference step in the correspoding file (`node_models.py`, `edge_models.py` or `node_edge_models.py`) 106 | 107 | * Add the model name to the list of models in `train.py` 108 | 109 | ## Adding another dataset 110 | 111 | To add another dataset: 112 | 113 | * Write a `load_${dataset_str}_data()` function and add it to the load_data(dataset_str, data_path) function. the dataset_str will be the FLAG for this dataset. 114 | 115 | * Save the data files in the `data/` folder. 116 | 117 | ## References 118 | 119 | [GAT original code](https://github.com/PetarV-/GAT) 120 | 121 | [GCN original code](https://github.com/tkipf/gcn/tree/master/gcn) 122 | 123 | [GAE original code](https://github.com/tkipf/gae/blob/master/gae) 124 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | #Copyright 2018 Google LLC 2 | # 3 | #Licensed under the Apache License, Version 2.0 (the "License"); 4 | #you may not use this file except in compliance with the License. 5 | #You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | #Unless required by applicable law or agreed to in writing, software 10 | #distributed under the License is distributed on an "AS IS" BASIS, 11 | #WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | #See the License for the specific language governing permissions and 13 | #limitations under the License. 14 | -------------------------------------------------------------------------------- /best_model.py: -------------------------------------------------------------------------------- 1 | #Copyright 2018 Google LLC 2 | # 3 | #Licensed under the Apache License, Version 2.0 (the "License"); 4 | #you may not use this file except in compliance with the License. 5 | #You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | #Unless required by applicable law or agreed to in writing, software 10 | #distributed under the License is distributed on an "AS IS" BASIS, 11 | #WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | #See the License for the specific language governing permissions and 13 | #limitations under the License. 14 | 15 | 16 | """Averages validation metric over multiple runs and returns best model.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | import os 22 | from absl import app 23 | from absl import flags 24 | import numpy as np 25 | import scipy.stats as stats 26 | import tensorflow as tf 27 | 28 | flags.DEFINE_string('dir', '/tmp/launch', 'path were models are saved.') 29 | flags.DEFINE_string('target', 'node_acc', 'target metric to use.') 30 | flags.DEFINE_string('datasets', 'cora', 'datasets to use.') 31 | flags.DEFINE_string('drop_prop', '0-10-20-30-40-50-60-70-80-90', 32 | 'proportion of edges dropped') 33 | flags.DEFINE_string('save_file', 'best_params', 'name of files to same the' 34 | 'results.') 35 | flags.DEFINE_string('models', 'Gcn', 'name of model directories to parse.') 36 | FLAGS = flags.FLAGS 37 | 38 | 39 | def get_val_test_acc(data): 40 | """Parses log file to retrieve test and val accuracy.""" 41 | data = [x.split() for x in data if len(x.split()) > 1] 42 | val_acc_idx = data[-4].index('val_{}'.format(FLAGS.target)) 43 | test_acc_idx = data[-3].index('test_{}'.format(FLAGS.target)) 44 | val_acc = data[-4][val_acc_idx + 2] 45 | test_acc = data[-3][test_acc_idx + 2] 46 | return float(val_acc) * 100, float(test_acc) * 100 47 | 48 | 49 | def main(_): 50 | log_file = tf.gfile.Open(os.path.join(FLAGS.dir, FLAGS.save_file), 'w') 51 | for dataset in FLAGS.datasets.split('-'): 52 | for prop in FLAGS.drop_prop.split('-'): 53 | dir_path = os.path.join(FLAGS.dir, dataset, prop) 54 | if tf.gfile.IsDirectory(dir_path): 55 | print(dir_path) 56 | for model_name in tf.gfile.ListDirectory(dir_path): 57 | if model_name in FLAGS.models.split('-'): 58 | model_dir = os.path.join(dir_path, model_name) 59 | train_log_files = [ 60 | filename for filename in tf.gfile.ListDirectory(model_dir) 61 | if 'log' in filename 62 | ] 63 | eval_stats = {} 64 | for filename in train_log_files: 65 | data = tf.gfile.Open(os.path.join(model_dir, 66 | filename)).readlines() 67 | nb_lines = len(data) 68 | if nb_lines > 0: 69 | if 'Training done' in data[-1]: 70 | val_acc, test_acc = get_val_test_acc(data) 71 | params = '-'.join(filename.split('-')[:-1]) 72 | if params in eval_stats: 73 | eval_stats[params]['val'].append(val_acc) 74 | eval_stats[params]['test'].append(test_acc) 75 | else: 76 | eval_stats[params] = {'val': [val_acc], 'test': [test_acc]} 77 | best_val_metric = -1 78 | best_params = None 79 | for params in eval_stats: 80 | val_metric = np.mean(eval_stats[params]['val']) 81 | if val_metric > best_val_metric: 82 | best_val_metric = val_metric 83 | best_params = params 84 | # print(eval_stats) 85 | log_file.write('\n' + model_dir + '\n') 86 | log_file.write('Best params: {}\n'.format(best_params)) 87 | log_file.write('val_{}: {} +- {}\n'.format( 88 | FLAGS.target, round(np.mean(eval_stats[best_params]['val']), 2), 89 | round(stats.sem(eval_stats[best_params]['val']), 2))) 90 | log_file.write('test_{}: {} +- {}\n'.format( 91 | FLAGS.target, round( 92 | np.mean(eval_stats[best_params]['test']), 2), 93 | round(stats.sem(eval_stats[best_params]['test']), 2))) 94 | 95 | 96 | if __name__ == '__main__': 97 | app.run(main) 98 | -------------------------------------------------------------------------------- /data/README.md: -------------------------------------------------------------------------------- 1 | # Data Directory 2 | 3 | Place datasets here. By default, the following datasets are supported: 4 | 5 | ## Cora 6 | ## Citeseer 7 | ## PPI 8 | 9 | See third_party/gcn/gcn/data for the relevant files. 10 | 11 | ## Pubmed 12 | -------------------------------------------------------------------------------- /launch.py: -------------------------------------------------------------------------------- 1 | #Copyright 2018 Google LLC 2 | # 3 | #Licensed under the Apache License, Version 2.0 (the "License"); 4 | #you may not use this file except in compliance with the License. 5 | #You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | #Unless required by applicable law or agreed to in writing, software 10 | #distributed under the License is distributed on an "AS IS" BASIS, 11 | #WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | #See the License for the specific language governing permissions and 13 | #limitations under the License. 14 | 15 | 16 | """Train models with different combinations of parameters.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | from itertools import product 22 | 23 | import os 24 | from absl import app 25 | from absl import flags 26 | 27 | from train import Config 28 | from train import TrainTest 29 | 30 | flags.DEFINE_string('launch_save_dir', '/tmp/launch', 31 | 'Where to save the results.') 32 | flags.DEFINE_string('launch_model_name', 'Gcn', 'Model to train.') 33 | flags.DEFINE_string('launch_dataset', 'cora', 'Dataset to use.') 34 | flags.DEFINE_string('launch_datapath', 35 | 'data/', 36 | 'Path to data folder.') 37 | flags.DEFINE_boolean('launch_sparse_features', True, 38 | 'True if node features are sparse.') 39 | flags.DEFINE_boolean('launch_normalize_adj', True, 40 | 'True to normalize adjacency matrix') 41 | flags.DEFINE_integer('launch_n_runs', 5, 42 | 'number of runs for each combination of parameters.') 43 | FLAGS = flags.FLAGS 44 | 45 | 46 | def get_params(): 47 | ############################### CHANGE PARAMS HERE ########################## 48 | return { 49 | # training parameters 50 | 'lr': [0.01], 51 | 'epochs': [10000], 52 | 'patience': [10], 53 | 'node_l2_reg': [0.001, 0.0005], 54 | 'edge_l2_reg': [0.], 55 | 'edge_reg': [0], 56 | 'p_drop_node': [0.5], 57 | 'p_drop_edge': [0], 58 | 59 | # model parameters 60 | 'n_hidden_node': ['128', '64'], 61 | 'n_att_node': ['8-8'], 62 | 'n_hidden_edge': ['128-64'], 63 | 'n_att_edge': ['8-1'], 64 | 'topk': [0], 65 | 'att_mechanism': ['l2'], 66 | 'edge_loss': ['w_sigmoid_ce'], 67 | 'cheby_k_loc': [1], 68 | 'semi_emb_k': [-1], 69 | 70 | # data parameters 71 | 'drop_edge_prop': [0, 50], 72 | 'normalize_adj': [True] 73 | } 74 | ############################################################################# 75 | 76 | 77 | def get_config(run_params, data): 78 | """Parse configuration parameters for training.""" 79 | config = Config() 80 | for param in run_params: 81 | if 'n_hidden' in param or 'n_att' in param: 82 | # Number of layers and att are defined as string so we parse 83 | # them differently 84 | setattr(config, param, list(map(int, run_params[param].split('-')))) 85 | else: 86 | setattr(config, param, run_params[param]) 87 | config.set_num_nodes_edges(data) 88 | return config 89 | 90 | 91 | def main(_): 92 | params = get_params() 93 | trainer = TrainTest(FLAGS.launch_model_name) 94 | print('Loading dataset...') 95 | trainer.load_dataset(FLAGS.launch_dataset, FLAGS.launch_sparse_features, 96 | FLAGS.launch_datapath) 97 | print('Dataset loaded!') 98 | # iterate over all combination of parameters 99 | all_params = product(*params.values()) 100 | for run_params in all_params: 101 | run_params = dict(zip(params, run_params)) 102 | # load the dataset and process adjacency and node features 103 | trainer.mask_edges(trainer.data['adj_true'], run_params['drop_edge_prop']) 104 | trainer.process_adj(FLAGS.launch_normalize_adj) 105 | config = get_config(run_params, trainer.data) 106 | # multilple runs 107 | save_dir = os.path.join(FLAGS.launch_save_dir, FLAGS.launch_dataset, 108 | str(run_params['drop_edge_prop']), 109 | FLAGS.launch_model_name) 110 | for run_id in range(FLAGS.launch_n_runs): 111 | filename_suffix = config.get_filename_suffix(run_id) 112 | trainer.run(config, save_dir, filename_suffix) 113 | 114 | 115 | if __name__ == '__main__': 116 | app.run(main) 117 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | #Copyright 2018 Google LLC 2 | # 3 | #Licensed under the Apache License, Version 2.0 (the "License"); 4 | #you may not use this file except in compliance with the License. 5 | #You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | #Unless required by applicable law or agreed to in writing, software 10 | #distributed under the License is distributed on an "AS IS" BASIS, 11 | #WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | #See the License for the specific language governing permissions and 13 | #limitations under the License. 14 | -------------------------------------------------------------------------------- /models/base_models.py: -------------------------------------------------------------------------------- 1 | #Copyright 2018 Google LLC 2 | # 3 | #Licensed under the Apache License, Version 2.0 (the "License"); 4 | #you may not use this file except in compliance with the License. 5 | #You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | #Unless required by applicable law or agreed to in writing, software 10 | #distributed under the License is distributed on an "AS IS" BASIS, 11 | #WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | #See the License for the specific language governing permissions and 13 | #limitations under the License. 14 | 15 | """Base models class. 16 | 17 | Main functionnalities for node classification models, link prediction 18 | models and joint node classification and link prediction models. 19 | """ 20 | 21 | from __future__ import absolute_import 22 | from __future__ import division 23 | from __future__ import print_function 24 | 25 | import tensorflow as tf 26 | 27 | 28 | class BaseModel(object): 29 | """Base model class. Defines basic functionnalities for all models.""" 30 | 31 | def __init__(self, config): 32 | """Initialize base model. 33 | 34 | Args: 35 | config: object of Config class defined in train.py, 36 | stores configuration parameters to build and train the model 37 | """ 38 | self.input_dim = config.input_dim 39 | self.lr = config.lr 40 | self.edge_reg = config.edge_reg 41 | self.edge_l2_reg = config.edge_l2_reg 42 | self.node_l2_reg = config.node_l2_reg 43 | self.nb_nodes = config.nb_nodes 44 | self.nb_edges = config.nb_edges 45 | self.sparse_features = config.sparse_features 46 | self.edge_loss = config.edge_loss 47 | self.att_mechanism = config.att_mechanism 48 | self.multilabel = config.multilabel 49 | 50 | def _create_placeholders(self): 51 | raise NotImplementedError 52 | 53 | def compute_inference(self, features, adj_matrix, is_training): 54 | raise NotImplementedError 55 | 56 | def build_graph(self): 57 | raise NotImplementedError 58 | 59 | def _create_optimizer(self, loss): 60 | """Create train operation.""" 61 | opt = tf.train.AdamOptimizer(learning_rate=self.lr) 62 | train_op = opt.minimize(loss) 63 | return train_op 64 | 65 | def _compute_node_loss(self, logits, labels): 66 | """Node classification loss with sigmoid cross entropy.""" 67 | if self.multilabel: 68 | loss = tf.nn.sigmoid_cross_entropy_with_logits( 69 | labels=labels, logits=logits) 70 | else: 71 | loss = tf.nn.softmax_cross_entropy_with_logits( 72 | labels=labels, logits=logits) 73 | return tf.reduce_mean(loss) 74 | 75 | def _compute_node_l2_loss(self): 76 | """L2 regularization loss for parameters in node classification model.""" 77 | all_variables = tf.trainable_variables() 78 | non_reg = ['bias', 'embeddings', 'beta', 'edge-model'] 79 | node_l2_loss = tf.add_n([ 80 | tf.nn.l2_loss(v) 81 | for v in all_variables 82 | if all([var_name not in v.name for var_name in non_reg]) 83 | ]) 84 | return node_l2_loss 85 | 86 | def _compute_edge_l2_loss(self): 87 | """L2 regularization loss for parameters in link prediction model.""" 88 | all_variables = tf.trainable_variables() 89 | edge_l2_loss = tf.add_n([tf.nn.l2_loss(v) for v in all_variables if \ 90 | 'edge-model' in v.name]) 91 | return edge_l2_loss 92 | 93 | def _compute_edge_loss_neg_sampling(self, adj_pred, adj_true): 94 | """Link prediction CE loss with negative sampling.""" 95 | keep_prob = self.nb_edges / (self.nb_nodes**2 - self.nb_edges) 96 | loss_mask = tf.nn.dropout( 97 | 1 - adj_true, keep_prob=keep_prob) * keep_prob 98 | loss_mask += adj_true 99 | boolean_mask = tf.greater(loss_mask, 0.) 100 | masked_pred = tf.boolean_mask(adj_pred, boolean_mask) 101 | masked_true = tf.boolean_mask(adj_true, boolean_mask) 102 | edge_loss = tf.nn.sigmoid_cross_entropy_with_logits( 103 | labels=masked_true, 104 | logits=masked_pred, 105 | ) 106 | return tf.reduce_mean(edge_loss) 107 | 108 | def _compute_edge_loss_weighted_ce(self, adj_pred, adj_true): 109 | """Link prediction loss with weighted sigmoid cross entropy.""" 110 | pos_weight = float((self.nb_nodes**2) - self.nb_edges) / self.nb_edges 111 | edge_loss = tf.nn.weighted_cross_entropy_with_logits( 112 | targets=adj_true, 113 | logits=adj_pred, 114 | pos_weight=pos_weight) 115 | return tf.reduce_mean(edge_loss) 116 | 117 | def _compute_edge_loss(self, adj_pred, adj_true): 118 | if self.edge_loss == 'weighted': 119 | return self._compute_edge_loss_weighted_ce(adj_pred, adj_true) 120 | else: 121 | return self._compute_edge_loss_neg_sampling(adj_pred, adj_true) 122 | 123 | 124 | class NodeModel(BaseModel): 125 | """Base model class for semi-supevised node classification.""" 126 | 127 | def __init__(self, config): 128 | """Initializes NodeModel for semi-supervised node classification. 129 | 130 | Args: 131 | config: object of Config class defined in train.py, 132 | stores configuration parameters to build and train the model 133 | """ 134 | super(NodeModel, self).__init__(config) 135 | self.p_drop = config.p_drop_node 136 | self.n_att = config.n_att_node 137 | self.n_hidden = config.n_hidden_node 138 | 139 | def _create_placeholders(self): 140 | """Create placeholders.""" 141 | with tf.name_scope('input'): 142 | self.placeholders = { 143 | 'adj_train': 144 | tf.sparse_placeholder(tf.float32), # normalized 145 | 'node_labels': 146 | tf.placeholder(tf.float32, shape=[None, self.n_hidden[-1]]), 147 | 'node_mask': 148 | tf.placeholder(tf.float32, shape=[ 149 | None, 150 | ]), 151 | 'is_training': 152 | tf.placeholder(tf.bool), 153 | } 154 | if self.sparse_features: 155 | self.placeholders['features'] = tf.sparse_placeholder(tf.float32) 156 | else: 157 | self.placeholders['features'] = tf.placeholder( 158 | tf.float32, shape=[None, self.input_dim]) 159 | 160 | def make_feed_dict(self, data, split, is_training): 161 | """Build feed dictionnary to train the model.""" 162 | feed_dict = { 163 | self.placeholders['adj_train']: data['adj_train_norm'], 164 | self.placeholders['features']: data['features'], 165 | self.placeholders['node_labels']: data['node_labels'], 166 | self.placeholders['node_mask']: data[split]['node_mask'], 167 | self.placeholders['is_training']: is_training 168 | } 169 | return feed_dict 170 | 171 | def build_graph(self): 172 | """Build tensorflow graph and create training, testing ops.""" 173 | self._create_placeholders() 174 | logits = self.compute_inference(self.placeholders['features'], 175 | self.placeholders['adj_train'], 176 | self.placeholders['is_training']) 177 | boolean_mask = tf.greater(self.placeholders['node_mask'], 0.) 178 | masked_pred = tf.boolean_mask(logits, boolean_mask) 179 | masked_true = tf.boolean_mask(self.placeholders['node_labels'], 180 | boolean_mask) 181 | loss = self._compute_node_loss(masked_pred, masked_true) 182 | loss += self.node_l2_reg * self._compute_node_l2_loss() 183 | train_op = self._create_optimizer(loss) 184 | metric_op, metric_update_op = self._create_metrics( 185 | masked_pred, masked_true) 186 | return loss, train_op, metric_op, metric_update_op 187 | 188 | def _create_metrics(self, logits, node_labels): 189 | """Create evaluation metrics for node classification.""" 190 | with tf.name_scope('metrics'): 191 | metrics = {} 192 | if self.multilabel: 193 | predictions = tf.cast( 194 | tf.greater(tf.nn.sigmoid(logits), 0.5), tf.float32) 195 | metrics['recall'], rec_op = tf.metrics.recall( 196 | labels=node_labels, predictions=predictions) 197 | metrics['precision'], prec_op = tf.metrics.precision( 198 | labels=node_labels, predictions=predictions) 199 | metrics['f1'] = 2 * metrics['precision'] * metrics['recall'] / ( 200 | metrics['precision'] + metrics['recall'] 201 | ) 202 | update_ops = [rec_op, prec_op] 203 | else: 204 | metrics['node_acc'], acc_op = tf.metrics.accuracy( 205 | labels=tf.argmax(node_labels, 1), predictions=tf.argmax(logits, 1)) 206 | update_ops = [acc_op] 207 | return metrics, update_ops 208 | 209 | 210 | class EdgeModel(BaseModel): 211 | """Base model class for link prediction.""" 212 | 213 | def __init__(self, config): 214 | """Initializes Edge model for link prediction. 215 | 216 | Args: 217 | config: object of Config class defined in train.py, 218 | stores configuration parameters to build and train the model 219 | """ 220 | super(EdgeModel, self).__init__(config) 221 | self.p_drop = config.p_drop_edge 222 | self.n_att = config.n_att_edge 223 | self.n_hidden = config.n_hidden_edge 224 | 225 | def _create_placeholders(self): 226 | """Create placeholders.""" 227 | with tf.name_scope('input'): 228 | self.placeholders = { 229 | # to compute metrics 230 | 'adj_true': tf.placeholder(tf.float32, shape=[None, None]), 231 | # to compute loss 232 | 'adj_train': tf.placeholder(tf.float32, shape=[None, None]), 233 | # for inference step 234 | 'adj_train_norm': tf.sparse_placeholder(tf.float32), # normalized 235 | 'edge_mask': tf.sparse_placeholder(tf.float32), 236 | 'is_training': tf.placeholder(tf.bool), 237 | } 238 | if self.sparse_features: 239 | self.placeholders['features'] = tf.sparse_placeholder(tf.float32) 240 | else: 241 | self.placeholders['features'] = tf.placeholder( 242 | tf.float32, shape=[None, self.input_dim]) 243 | 244 | def make_feed_dict(self, data, split, is_training): 245 | """Build feed dictionnary to train the model.""" 246 | feed_dict = { 247 | self.placeholders['features']: data['features'], 248 | self.placeholders['adj_true']: data['adj_true'], 249 | self.placeholders['adj_train']: data['adj_train'], 250 | self.placeholders['adj_train_norm']: data['adj_train_norm'], 251 | self.placeholders['edge_mask']: data[split]['edge_mask'], 252 | self.placeholders['is_training']: is_training 253 | } 254 | return feed_dict 255 | 256 | def build_graph(self): 257 | """Build tensorflow graph and create training, testing ops.""" 258 | self._create_placeholders() 259 | adj_pred = self.compute_inference(self.placeholders['features'], 260 | self.placeholders['adj_train_norm'], 261 | self.placeholders['is_training']) 262 | adj_train = tf.reshape(self.placeholders['adj_train'], (-1,)) 263 | loss = self._compute_edge_loss(tf.reshape(adj_pred, (-1,)), adj_train) 264 | loss += self.edge_l2_reg * self._compute_edge_l2_loss() 265 | train_op = self._create_optimizer(loss) 266 | masked_true = tf.reshape(tf.gather_nd( 267 | self.placeholders['adj_true'], self.placeholders['edge_mask'].indices), 268 | (-1,)) 269 | masked_pred = tf.reshape(tf.gather_nd( 270 | adj_pred, self.placeholders['edge_mask'].indices), (-1,)) 271 | metric_op, metric_update_op = self._create_metrics(masked_pred, masked_true) 272 | return loss, train_op, metric_op, metric_update_op 273 | 274 | def _create_metrics(self, adj_pred, adj_true): 275 | """Create evaluation metrics for node classification.""" 276 | with tf.name_scope('metrics'): 277 | metrics = {} 278 | metrics['edge_roc_auc'], roc_op = tf.metrics.auc( 279 | labels=adj_true, 280 | predictions=tf.sigmoid(adj_pred), 281 | curve='ROC' 282 | ) 283 | metrics['edge_pr_auc'], pr_op = tf.metrics.auc( 284 | labels=adj_true, 285 | predictions=tf.sigmoid(adj_pred), 286 | curve='PR' 287 | ) 288 | update_ops = [roc_op, pr_op] 289 | return metrics, update_ops 290 | 291 | 292 | class NodeEdgeModel(BaseModel): 293 | """Model class for semi-supevised node classification and link prediction.""" 294 | 295 | def __init__(self, config): 296 | """Initializes model. 297 | 298 | Args: 299 | config: object of Config class defined in train.py, 300 | stores configuration parameters to build and train the model 301 | """ 302 | super(NodeEdgeModel, self).__init__(config) 303 | self.n_att_edge = config.n_att_edge 304 | self.n_hidden_edge = config.n_hidden_edge 305 | self.p_drop_edge = config.p_drop_edge 306 | self.n_att_node = config.n_att_node 307 | self.n_hidden_node = config.n_hidden_node 308 | self.p_drop_node = config.p_drop_node 309 | self.topk = config.topk 310 | 311 | def _create_placeholders(self): 312 | """Create placeholders.""" 313 | with tf.name_scope('input'): 314 | self.placeholders = { 315 | 'adj_true': tf.placeholder(tf.float32, shape=[None, None]), 316 | # to compute loss 317 | 'adj_train': tf.placeholder(tf.float32, shape=[None, None]), 318 | # for inference step 319 | 'adj_train_norm': tf.sparse_placeholder(tf.float32), # normalized 320 | 'edge_mask': tf.sparse_placeholder(tf.float32), 321 | 'node_labels': 322 | tf.placeholder(tf.float32, shape=[None, self.n_hidden_node[-1]]), 323 | 'node_mask': 324 | tf.placeholder(tf.float32, shape=[ 325 | None, 326 | ]), 327 | 'is_training': 328 | tf.placeholder(tf.bool), 329 | } 330 | if self.sparse_features: 331 | self.placeholders['features'] = tf.sparse_placeholder(tf.float32) 332 | else: 333 | self.placeholders['features'] = tf.placeholder( 334 | tf.float32, shape=[None, self.input_dim]) 335 | 336 | def make_feed_dict(self, data, split, is_training): 337 | """Build feed dictionnary to train the model.""" 338 | feed_dict = { 339 | self.placeholders['features']: data['features'], 340 | self.placeholders['adj_true']: data['adj_true'], 341 | self.placeholders['adj_train']: data['adj_train'], 342 | self.placeholders['adj_train_norm']: data['adj_train_norm'], 343 | self.placeholders['edge_mask']: data[split]['edge_mask'], 344 | self.placeholders['node_labels']: data['node_labels'], 345 | self.placeholders['node_mask']: data[split]['node_mask'], 346 | self.placeholders['is_training']: is_training 347 | } 348 | return feed_dict 349 | 350 | def build_graph(self): 351 | """Build tensorflow graph and create training, testing ops.""" 352 | self._create_placeholders() 353 | logits, adj_pred = self.compute_inference( 354 | self.placeholders['features'], 355 | self.placeholders['adj_train_norm'], 356 | self.placeholders['is_training']) 357 | adj_train = tf.reshape(self.placeholders['adj_train'], (-1,)) 358 | boolean_node_mask = tf.greater(self.placeholders['node_mask'], 0.) 359 | masked_node_pred = tf.boolean_mask(logits, boolean_node_mask) 360 | masked_node_true = tf.boolean_mask(self.placeholders['node_labels'], 361 | boolean_node_mask) 362 | loss = self._compute_node_loss(masked_node_pred, 363 | masked_node_true) 364 | loss += self.node_l2_reg * self._compute_node_l2_loss() 365 | loss += self.edge_reg * self._compute_edge_loss( 366 | tf.reshape(adj_pred, (-1,)), adj_train) 367 | loss += self.edge_l2_reg * self._compute_edge_l2_loss() 368 | self.grad = tf.gradients(loss, self.adj_matrix_pred) 369 | train_op = self._create_optimizer(loss) 370 | masked_adj_true = tf.reshape(tf.gather_nd( 371 | self.placeholders['adj_true'], 372 | self.placeholders['edge_mask'].indices), (-1,)) 373 | masked_adj_pred = tf.reshape(tf.gather_nd( 374 | adj_pred, self.placeholders['edge_mask'].indices), (-1,)) 375 | metric_op, metric_update_op = self._create_metrics( 376 | masked_adj_pred, masked_adj_true, masked_node_pred, masked_node_true) 377 | return loss, train_op, metric_op, metric_update_op 378 | 379 | def _create_metrics(self, adj_pred, adj_true, node_pred, node_labels): 380 | """Create evaluation metrics for node classification.""" 381 | with tf.name_scope('metrics'): 382 | metrics = {} 383 | metrics['edge_roc_auc'], roc_op = tf.metrics.auc( 384 | labels=adj_true, 385 | predictions=tf.sigmoid(adj_pred), 386 | curve='ROC' 387 | ) 388 | metrics['edge_pr_auc'], pr_op = tf.metrics.auc( 389 | labels=adj_true, 390 | predictions=tf.sigmoid(adj_pred), 391 | curve='PR' 392 | ) 393 | metrics['node_acc'], acc_op = tf.metrics.accuracy( 394 | labels=tf.argmax(node_labels, 1), 395 | predictions=tf.argmax(node_pred, 1)) 396 | update_ops = [roc_op, pr_op, acc_op] 397 | return metrics, update_ops 398 | -------------------------------------------------------------------------------- /models/edge_models.py: -------------------------------------------------------------------------------- 1 | #Copyright 2018 Google LLC 2 | # 3 | #Licensed under the Apache License, Version 2.0 (the "License"); 4 | #you may not use this file except in compliance with the License. 5 | #You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | #Unless required by applicable law or agreed to in writing, software 10 | #distributed under the License is distributed on an "AS IS" BASIS, 11 | #WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | #See the License for the specific language governing permissions and 13 | #limitations under the License. 14 | 15 | 16 | """Inference step for link prediction models.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | from models.base_models import EdgeModel 23 | import tensorflow as tf 24 | from utils.model_utils import compute_adj 25 | from utils.model_utils import gat_module 26 | from utils.model_utils import gcn_module 27 | from utils.model_utils import mlp_module 28 | 29 | 30 | class Gae(EdgeModel): 31 | """Graph Auto-Encoder (GAE) (Kipf & al) for link prediction. 32 | 33 | arXiv link: https://arxiv.org/abs/1611.07308 34 | """ 35 | 36 | def compute_inference(self, node_features, adj_matrix, is_training): 37 | """Forward step for GAE model.""" 38 | sparse = self.sparse_features 39 | in_dim = self.input_dim 40 | with tf.variable_scope('edge-model'): 41 | h0 = gcn_module(node_features, adj_matrix, self.n_hidden, self.p_drop, 42 | is_training, in_dim, sparse) 43 | adj_matrix_pred = compute_adj(h0, self.att_mechanism, self.p_drop, 44 | is_training) 45 | self.adj_matrix_pred = tf.nn.sigmoid(adj_matrix_pred) 46 | return adj_matrix_pred 47 | 48 | 49 | class Egat(EdgeModel): 50 | """Edge-GAT for link prediction.""" 51 | 52 | def compute_inference(self, node_features, adj_matrix, is_training): 53 | """Forward step for GAE model.""" 54 | sparse = self.sparse_features 55 | in_dim = self.input_dim 56 | with tf.variable_scope('edge-model'): 57 | h0 = gat_module( 58 | node_features, 59 | adj_matrix, 60 | self.n_hidden, 61 | self.n_att, 62 | self.p_drop, 63 | is_training, 64 | in_dim, 65 | sparse, 66 | average_last=True) 67 | adj_matrix_pred = compute_adj(h0, self.att_mechanism, self.p_drop, 68 | is_training) 69 | self.adj_matrix_pred = tf.nn.sigmoid(adj_matrix_pred) 70 | return adj_matrix_pred 71 | 72 | 73 | class Vgae(EdgeModel): 74 | """Variational Graph Auto-Encoder (VGAE) (Kipf & al) for link prediction. 75 | 76 | arXiv link: https://arxiv.org/abs/1611.07308 77 | """ 78 | 79 | def compute_inference(self, node_features, adj_matrix, is_training): 80 | """Forward step for GAE model.""" 81 | sparse = self.sparse_features 82 | in_dim = self.input_dim 83 | with tf.variable_scope('edge-model'): 84 | h0 = gcn_module(node_features, adj_matrix, self.n_hidden[:-1], 85 | self.p_drop, is_training, in_dim, sparse) 86 | # N x F 87 | with tf.variable_scope('mean'): 88 | z_mean = gcn_module(h0, adj_matrix, self.n_hidden[-1:], self.p_drop, 89 | is_training, self.n_hidden[-2], False) 90 | self.z_mean = z_mean 91 | with tf.variable_scope('std'): 92 | # N x F 93 | z_log_std = gcn_module(h0, adj_matrix, self.n_hidden[-1:], self.p_drop, 94 | is_training, self.n_hidden[-2], False) 95 | self.z_log_std = z_log_std 96 | # add noise during training 97 | noise = tf.random_normal([self.nb_nodes, self.n_hidden[-1] 98 | ]) * tf.exp(z_log_std) 99 | z = tf.cond(is_training, lambda: tf.add(z_mean, noise), 100 | lambda: z_mean) 101 | # N x N 102 | adj_matrix_pred = compute_adj(z, self.att_mechanism, self.p_drop, 103 | is_training) 104 | self.adj_matrix_pred = tf.nn.sigmoid(adj_matrix_pred) 105 | return adj_matrix_pred 106 | 107 | def _compute_edge_loss(self, adj_pred, adj_train): 108 | """Overrides _compute_edge_loss to add Variational Inference objective.""" 109 | log_lik = super(Vgae, self)._compute_edge_loss(adj_pred, adj_train) 110 | norm = self.nb_nodes**2 / float((self.nb_nodes**2 - self.nb_edges) * 2) 111 | kl_mat = 0.5 * tf.reduce_sum( 112 | 1 + 2 * self.z_log_std - tf.square(self.z_mean) - tf.square( 113 | tf.exp(self.z_log_std)), 1) 114 | kl = tf.reduce_mean(kl_mat) / self.nb_nodes 115 | edge_loss = norm * log_lik - kl 116 | return edge_loss 117 | 118 | 119 | class Emlp(EdgeModel): 120 | """Simple baseline for link prediction. 121 | 122 | Creates a tensorflow graph to train and evaluate EMLP on graph data. 123 | """ 124 | 125 | def compute_inference(self, node_features, _, is_training): 126 | """Forward step for GAE model.""" 127 | sparse = self.sparse_features 128 | in_dim = self.input_dim 129 | with tf.variable_scope('edge-model'): 130 | h0 = mlp_module( 131 | node_features, 132 | self.n_hidden, 133 | self.p_drop, 134 | is_training, 135 | in_dim, 136 | sparse, 137 | use_bias=False) 138 | adj_matrix_pred = compute_adj(h0, self.att_mechanism, self.p_drop, 139 | is_training) 140 | self.adj_matrix_pred = tf.nn.sigmoid(adj_matrix_pred) 141 | return adj_matrix_pred 142 | -------------------------------------------------------------------------------- /models/node_edge_models.py: -------------------------------------------------------------------------------- 1 | #Copyright 2018 Google LLC 2 | # 3 | #Licensed under the Apache License, Version 2.0 (the "License"); 4 | #you may not use this file except in compliance with the License. 5 | #You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | #Unless required by applicable law or agreed to in writing, software 10 | #distributed under the License is distributed on an "AS IS" BASIS, 11 | #WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | #See the License for the specific language governing permissions and 13 | #limitations under the License. 14 | 15 | 16 | """Inference step for joint node classification and link prediction models.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | 23 | from models.base_models import NodeEdgeModel 24 | from models.edge_models import Gae 25 | from models.node_models import Gat 26 | from models.node_models import Gcn 27 | import tensorflow as tf 28 | from utils.model_utils import compute_adj 29 | from utils.model_utils import gat_module 30 | from utils.model_utils import gcn_module 31 | from utils.model_utils import get_sp_topk 32 | from utils.model_utils import mask_edges 33 | 34 | 35 | class GaeGat(NodeEdgeModel): 36 | """GAE for link prediction and GAT for node classification.""" 37 | 38 | def __init__(self, config): 39 | """Initializes EGCNGAT model.""" 40 | super(GaeGat, self).__init__(config) 41 | self.edge_model = Gae(config) 42 | self.node_model = Gat(config) 43 | 44 | def compute_inference(self, node_features_in, sp_adj_matrix, is_training): 45 | adj_matrix_pred = self.edge_model.compute_inference( 46 | node_features_in, sp_adj_matrix, is_training) 47 | self.adj_matrix_pred = adj_matrix_pred 48 | adj_mask = get_sp_topk(adj_matrix_pred, sp_adj_matrix, self.nb_nodes, 49 | self.topk) 50 | self.adj_mask = adj_mask 51 | # masked_adj_matrix_pred = tf.multiply(adj_mask, 52 | # tf.nn.sigmoid(adj_matrix_pred)) 53 | masked_adj_matrix_pred = mask_edges(tf.nn.sigmoid(adj_matrix_pred), 54 | adj_mask) 55 | sp_adj_pred = tf.contrib.layers.dense_to_sparse(masked_adj_matrix_pred) 56 | logits = self.node_model.compute_inference(node_features_in, sp_adj_pred, 57 | is_training) 58 | return logits, adj_matrix_pred 59 | 60 | 61 | class GaeGcn(NodeEdgeModel): 62 | """GAE for link prediction and GCN for node classification.""" 63 | 64 | def __init__(self, config): 65 | """Initializes EGCNGCN model.""" 66 | super(GaeGcn, self).__init__(config) 67 | self.edge_model = Gae(config) 68 | self.node_model = Gcn(config) 69 | 70 | def compute_inference(self, node_features_in, sp_adj_matrix, is_training): 71 | adj_matrix_pred = self.edge_model.compute_inference( 72 | node_features_in, sp_adj_matrix, is_training) 73 | self.adj_matrix_pred = adj_matrix_pred 74 | adj_mask = get_sp_topk(adj_matrix_pred, sp_adj_matrix, self.nb_nodes, 75 | self.topk) 76 | sp_adj_pred = tf.contrib.layers.dense_to_sparse( 77 | tf.multiply(adj_mask, tf.nn.leaky_relu(adj_matrix_pred))) 78 | sp_adj_pred = tf.sparse_softmax(sp_adj_pred) 79 | logits = self.node_model.compute_inference(node_features_in, sp_adj_pred, 80 | is_training) 81 | return logits, adj_matrix_pred 82 | 83 | 84 | ############################ EXPERIMENTAL MODELS ############################# 85 | 86 | 87 | class GatGraphite(NodeEdgeModel): 88 | """Gae for link prediction and GCN for node classification.""" 89 | 90 | def compute_inference(self, node_features_in, sp_adj_matrix, is_training): 91 | with tf.variable_scope('edge-model'): 92 | z_latent = gat_module( 93 | node_features_in, 94 | sp_adj_matrix, 95 | self.n_hidden_edge, 96 | self.n_att_edge, 97 | self.p_drop_edge, 98 | is_training, 99 | self.input_dim, 100 | self.sparse_features, 101 | average_last=False) 102 | adj_matrix_pred = compute_adj(z_latent, self.att_mechanism, 103 | self.p_drop_edge, is_training) 104 | self.adj_matrix_pred = adj_matrix_pred 105 | with tf.variable_scope('node-model'): 106 | concat = True 107 | if concat: 108 | z_latent = tf.sparse_concat( 109 | axis=1, 110 | sp_inputs=[ 111 | tf.contrib.layers.dense_to_sparse(z_latent), node_features_in 112 | ], 113 | ) 114 | sparse_features = True 115 | input_dim = self.n_hidden_edge[-1] * self.n_att_edge[ 116 | -1] + self.input_dim 117 | else: 118 | sparse_features = False 119 | input_dim = self.n_hidden_edge[-1] * self.n_att_edge[-1] 120 | logits = gat_module( 121 | z_latent, 122 | sp_adj_matrix, 123 | self.n_hidden_node, 124 | self.n_att_node, 125 | self.p_drop_node, 126 | is_training, 127 | input_dim, 128 | sparse_features=sparse_features, 129 | average_last=False) 130 | 131 | return logits, adj_matrix_pred 132 | 133 | 134 | class GaeGatConcat(NodeEdgeModel): 135 | """EGCN for link prediction and GCN for node classification.""" 136 | 137 | def __init__(self, config): 138 | """Initializes EGCN_GAT model.""" 139 | super(GaeGatConcat, self).__init__(config) 140 | self.edge_model = Gae(config) 141 | self.node_model = Gat(config) 142 | 143 | def compute_inference(self, node_features_in, sp_adj_matrix, is_training): 144 | with tf.variable_scope('edge-model'): 145 | z_latent = gcn_module(node_features_in, sp_adj_matrix, self.n_hidden_edge, 146 | self.p_drop_edge, is_training, self.input_dim, 147 | self.sparse_features) 148 | adj_matrix_pred = compute_adj(z_latent, self.att_mechanism, 149 | self.p_drop_edge, is_training) 150 | self.adj_matrix_pred = adj_matrix_pred 151 | with tf.variable_scope('node-model'): 152 | z_latent = tf.sparse_concat( 153 | axis=1, 154 | sp_inputs=[ 155 | tf.contrib.layers.dense_to_sparse(z_latent), node_features_in 156 | ]) 157 | sparse_features = True 158 | input_dim = self.n_hidden_edge[-1] + self.input_dim 159 | sp_adj_train = tf.SparseTensor( 160 | indices=sp_adj_matrix.indices, 161 | values=tf.ones_like(sp_adj_matrix.values), 162 | dense_shape=sp_adj_matrix.dense_shape) 163 | logits = gat_module( 164 | z_latent, 165 | sp_adj_train, 166 | self.n_hidden_node, 167 | self.n_att_node, 168 | self.p_drop_node, 169 | is_training, 170 | input_dim, 171 | sparse_features=sparse_features, 172 | average_last=True) 173 | return logits, adj_matrix_pred 174 | 175 | 176 | class GaeGcnConcat(NodeEdgeModel): 177 | """EGCN for link prediction and GCN for node classification.""" 178 | 179 | def compute_inference(self, node_features_in, sp_adj_matrix, is_training): 180 | with tf.variable_scope('edge-model'): 181 | z_latent = gcn_module(node_features_in, sp_adj_matrix, self.n_hidden_edge, 182 | self.p_drop_edge, is_training, self.input_dim, 183 | self.sparse_features) 184 | adj_matrix_pred = compute_adj(z_latent, self.att_mechanism, 185 | self.p_drop_edge, is_training) 186 | self.adj_matrix_pred = adj_matrix_pred 187 | with tf.variable_scope('node-model'): 188 | z_latent = tf.sparse_concat( 189 | axis=1, 190 | sp_inputs=[ 191 | tf.contrib.layers.dense_to_sparse(z_latent), node_features_in 192 | ]) 193 | sparse_features = True 194 | input_dim = self.n_hidden_edge[-1] + self.input_dim 195 | logits = gcn_module( 196 | z_latent, 197 | sp_adj_matrix, 198 | self.n_hidden_node, 199 | self.p_drop_node, 200 | is_training, 201 | input_dim, 202 | sparse_features=sparse_features) 203 | return logits, adj_matrix_pred 204 | 205 | 206 | class Gcat(NodeEdgeModel): 207 | """1 iteration Graph Convolution Attention Model.""" 208 | 209 | def __init__(self, config): 210 | """Initializes GCAT model.""" 211 | super(Gcat, self).__init__(config) 212 | self.edge_model = Gae(config) 213 | self.node_model = Gcn(config) 214 | 215 | def compute_inference(self, node_features_in, sp_adj_matrix, is_training): 216 | """Forward pass for GAT model.""" 217 | adj_matrix_pred = self.edge_model.compute_inference( 218 | node_features_in, sp_adj_matrix, is_training) 219 | sp_adj_mask = tf.SparseTensor( 220 | indices=sp_adj_matrix.indices, 221 | values=tf.ones_like(sp_adj_matrix.values), 222 | dense_shape=sp_adj_matrix.dense_shape) 223 | sp_adj_att = sp_adj_mask * adj_matrix_pred 224 | sp_adj_att = tf.SparseTensor( 225 | indices=sp_adj_att.indices, 226 | values=tf.nn.leaky_relu(sp_adj_att.values), 227 | dense_shape=sp_adj_att.dense_shape) 228 | sp_adj_att = tf.sparse_softmax(sp_adj_att) 229 | logits = self.node_model.compute_inference(node_features_in, sp_adj_att, 230 | is_training) 231 | return logits, adj_matrix_pred 232 | -------------------------------------------------------------------------------- /models/node_models.py: -------------------------------------------------------------------------------- 1 | #Copyright 2018 Google LLC 2 | # 3 | #Licensed under the Apache License, Version 2.0 (the "License"); 4 | #you may not use this file except in compliance with the License. 5 | #You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | #Unless required by applicable law or agreed to in writing, software 10 | #distributed under the License is distributed on an "AS IS" BASIS, 11 | #WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | #See the License for the specific language governing permissions and 13 | #limitations under the License. 14 | 15 | 16 | """Inference step for node classification models.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | from models.base_models import NodeModel 23 | import tensorflow as tf 24 | from utils.model_utils import cheby_module 25 | from utils.model_utils import compute_adj 26 | from utils.model_utils import gat_module 27 | from utils.model_utils import gcn_module 28 | from utils.model_utils import gcn_pool_layer 29 | from utils.model_utils import mlp_module 30 | from utils.model_utils import sp_gat_layer 31 | from utils.model_utils import sp_gcn_layer 32 | 33 | 34 | class Gat(NodeModel): 35 | """Graph Attention (GAT) Model (Velickovic & al). 36 | 37 | arXiv link: https://arxiv.org/abs/1710.10903 38 | """ 39 | 40 | def compute_inference(self, node_features, adj_matrix, is_training): 41 | """Forward step for GAT model.""" 42 | sparse = self.sparse_features 43 | in_dim = self.input_dim 44 | average_last = True 45 | with tf.variable_scope('node-model'): 46 | logits = gat_module(node_features, adj_matrix, self.n_hidden, self.n_att, 47 | self.p_drop, is_training, in_dim, sparse, 48 | average_last) 49 | return logits 50 | 51 | 52 | class Gcn(NodeModel): 53 | """Graph convolution network (Kipf & al). 54 | 55 | arXiv link: https://arxiv.org/abs/1609.02907 56 | """ 57 | 58 | def compute_inference(self, node_features, adj_matrix, is_training): 59 | """Forward step for graph convolution model.""" 60 | with tf.variable_scope('node-model'): 61 | logits = gcn_module(node_features, adj_matrix, self.n_hidden, self.p_drop, 62 | is_training, self.input_dim, self.sparse_features) 63 | return logits 64 | 65 | 66 | class Mlp(NodeModel): 67 | """Multi-layer perceptron model.""" 68 | 69 | def compute_inference(self, node_features, adj_matrix, is_training): 70 | """Forward step for graph convolution model.""" 71 | with tf.variable_scope('node-model'): 72 | logits = mlp_module(node_features, self.n_hidden, self.p_drop, 73 | is_training, self.input_dim, self.sparse_features, 74 | use_bias=True) 75 | return logits 76 | 77 | 78 | class SemiEmb(NodeModel): 79 | """Deep Learning via Semi-Supervised Embedding (Weston & al). 80 | 81 | paper: http://icml2008.cs.helsinki.fi/papers/340.pdf 82 | """ 83 | 84 | def __init__(self, config): 85 | super(SemiEmb, self).__init__(config) 86 | self.semi_emb_k = config.semi_emb_k 87 | 88 | def compute_inference(self, node_features, adj_matrix, is_training): 89 | with tf.variable_scope('node-model'): 90 | hidden_repr = mlp_module(node_features, self.n_hidden, self.p_drop, 91 | is_training, self.input_dim, 92 | self.sparse_features, use_bias=True, 93 | return_hidden=True) 94 | logits = hidden_repr[-1] 95 | hidden_repr_reg = hidden_repr[self.semi_emb_k] 96 | l2_scores = compute_adj(hidden_repr_reg, self.att_mechanism, self.p_drop, 97 | is_training=False) 98 | self.l2_scores = tf.gather_nd(l2_scores, adj_matrix.indices) 99 | return logits 100 | 101 | def _compute_node_loss(self, logits, labels): 102 | supervised_loss = super(SemiEmb, self)._compute_node_loss(logits, labels) 103 | # supervised_loss = tf.nn.softmax_cross_entropy_with_logits( 104 | # labels=labels, logits=logits) 105 | # supervised_loss = tf.reduce_sum(supervised_loss) / self.nb_nodes 106 | reg_loss = tf.reduce_mean(self.l2_scores) 107 | return supervised_loss + self.edge_reg * reg_loss 108 | 109 | 110 | class Cheby(NodeModel): 111 | """Chebyshev polynomials for Spectral Graph Convolutions (Defferrard & al). 112 | 113 | arXiv link: https://arxiv.org/abs/1606.09375 114 | """ 115 | 116 | def __init__(self, config): 117 | super(Cheby, self).__init__(config) 118 | self.cheby_k_loc = config.cheby_k_loc 119 | 120 | def compute_inference(self, node_features, normalized_laplacian, is_training): 121 | with tf.variable_scope('node-model'): 122 | dense_normalized_laplacian = tf.sparse_to_dense( 123 | sparse_indices=normalized_laplacian.indices, 124 | output_shape=normalized_laplacian.dense_shape, 125 | sparse_values=normalized_laplacian.values) 126 | cheby_polynomials = [tf.eye(self.nb_nodes), dense_normalized_laplacian] 127 | self.cheby = cheby_polynomials 128 | for _ in range(2, self.cheby_k_loc+1): 129 | cheby_polynomials.append(2 * tf.sparse_tensor_dense_matmul( 130 | normalized_laplacian, cheby_polynomials[-1]) - cheby_polynomials[-2] 131 | ) 132 | logits = cheby_module(node_features, cheby_polynomials, self.n_hidden, 133 | self.p_drop, is_training, self.input_dim, 134 | self.sparse_features) 135 | return logits 136 | 137 | 138 | ############################ EXPERIMENTAL MODELS ############################# 139 | 140 | 141 | class Hgat(NodeModel): 142 | """Hierarchical Graph Attention (GAT) Model.""" 143 | 144 | def compute_inference(self, node_features, adj_matrix, is_training): 145 | """Forward step for GAT model.""" 146 | in_dim = self.input_dim 147 | att = [] 148 | for j in range(4): 149 | with tf.variable_scope('gat-layer1-att{}'.format(j)): 150 | att.append( 151 | sp_gat_layer(node_features, adj_matrix, in_dim, 8, self.p_drop, 152 | is_training, True)) 153 | hidden_2 = [] 154 | hidden_2.append(tf.nn.elu(tf.concat(att[:2], axis=-1))) 155 | hidden_2.append(tf.nn.elu(tf.concat(att[2:], axis=-1))) 156 | att = [] 157 | for j in range(2): 158 | with tf.variable_scope('gat-layer2-att{}'.format(j)): 159 | att.append( 160 | sp_gat_layer(hidden_2[j], adj_matrix, 16, 7, self.p_drop, 161 | is_training, False)) 162 | return tf.add_n(att) / 2. 163 | 164 | 165 | class Pgcn(NodeModel): 166 | """Pooling Graph Convolution Network.""" 167 | 168 | def compute_inference(self, node_features, adj_matrix, is_training): 169 | adj_matrix_dense = tf.sparse_to_dense( 170 | sparse_indices=adj_matrix.indices, 171 | output_shape=adj_matrix.dense_shape, 172 | sparse_values=adj_matrix.values, 173 | validate_indices=False) 174 | adj_matrix_dense = tf.cast(tf.greater(adj_matrix_dense, 0), tf.float32) 175 | adj_matrix_dense = tf.expand_dims(adj_matrix_dense, -1) # N x N x 1 176 | in_dim = self.input_dim 177 | sparse = self.sparse_features 178 | for i, out_dim in enumerate(self.n_hidden[:-1]): 179 | if i > 0: 180 | sparse = False 181 | with tf.variable_scope('gcn-pool-{}'.format(i)): 182 | node_features = gcn_pool_layer( 183 | node_features, 184 | adj_matrix_dense, 185 | in_dim=in_dim, 186 | out_dim=out_dim, 187 | sparse=sparse, 188 | is_training=is_training, 189 | p_drop=self.p_drop) 190 | node_features = tf.reshape(node_features, (-1, out_dim)) 191 | node_features = tf.contrib.layers.bias_add(node_features) 192 | node_features = tf.nn.elu(node_features) 193 | in_dim = out_dim 194 | with tf.variable_scope('gcn-layer-last'): 195 | logits = sp_gcn_layer(node_features, adj_matrix, in_dim, 196 | self.n_hidden[-1], self.p_drop, is_training, False) 197 | return logits 198 | 199 | -------------------------------------------------------------------------------- /third_party/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/gcnn-survey-paper/591af8d6c4374378831cab2cdec79575e2540d79/third_party/__init__.py -------------------------------------------------------------------------------- /third_party/gcn/.gitignore: -------------------------------------------------------------------------------- 1 | # Custom 2 | *.idea 3 | *.png 4 | *.pdf 5 | tmp/ 6 | *.txt 7 | 8 | # Byte-compiled / optimized / DLL files 9 | __pycache__/ 10 | *.py[cod] 11 | *$py.class 12 | 13 | # C extensions 14 | *.so 15 | 16 | # Distribution / packaging 17 | .Python 18 | env/ 19 | build/ 20 | develop-eggs/ 21 | dist/ 22 | downloads/ 23 | eggs/ 24 | .eggs/ 25 | lib/ 26 | lib64/ 27 | parts/ 28 | sdist/ 29 | var/ 30 | *.egg-info/ 31 | .installed.cfg 32 | *.egg 33 | 34 | # PyInstaller 35 | # Usually these files are written by a python script from a template 36 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 37 | *.manifest 38 | *.spec 39 | 40 | # Installer logs 41 | pip-log.txt 42 | pip-delete-this-directory.txt 43 | 44 | # Unit test / coverage reports 45 | htmlcov/ 46 | .tox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *,cover 53 | .hypothesis/ 54 | 55 | # Translations 56 | *.mo 57 | *.pot 58 | 59 | # Django stuff: 60 | *.log 61 | local_settings.py 62 | 63 | # Flask stuff: 64 | instance/ 65 | .webassets-cache 66 | 67 | # Scrapy stuff: 68 | .scrapy 69 | 70 | # Sphinx documentation 71 | docs/_build/ 72 | 73 | # PyBuilder 74 | target/ 75 | 76 | # IPython Notebook 77 | .ipynb_checkpoints 78 | 79 | # pyenv 80 | .python-version 81 | 82 | # celery beat schedule file 83 | celerybeat-schedule 84 | 85 | # dotenv 86 | .env 87 | 88 | # virtualenv 89 | venv/ 90 | ENV/ 91 | 92 | # Spyder project settings 93 | .spyderproject 94 | 95 | # Rope project settings 96 | .ropeproject 97 | 98 | *.pickle 99 | -------------------------------------------------------------------------------- /third_party/gcn/LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License 2 | 3 | Copyright (c) 2016 Thomas Kipf 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in 13 | all copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 21 | THE SOFTWARE. -------------------------------------------------------------------------------- /third_party/gcn/README.md: -------------------------------------------------------------------------------- 1 | # Graph Convolutional Networks 2 | 3 | This is a TensorFlow implementation of Graph Convolutional Networks for the task of (semi-supervised) classification of nodes in a graph, as described in our paper: 4 | 5 | Thomas N. Kipf, Max Welling, [Semi-Supervised Classification with Graph Convolutional Networks](http://arxiv.org/abs/1609.02907) (ICLR 2017) 6 | 7 | For a high-level explanation, have a look at our blog post: 8 | 9 | Thomas Kipf, [Graph Convolutional Networks](http://tkipf.github.io/graph-convolutional-networks/) (2016) 10 | 11 | ## Installation 12 | 13 | ```bash 14 | python setup.py install 15 | ``` 16 | 17 | ## Requirements 18 | * tensorflow (>0.12) 19 | * networkx 20 | 21 | ## Run the demo 22 | 23 | ```bash 24 | python train.py 25 | ``` 26 | 27 | ## Data 28 | 29 | In order to use your own data, you have to provide 30 | * an N by N adjacency matrix (N is the number of nodes), 31 | * an N by D feature matrix (D is the number of features per node), and 32 | * an N by E binary label matrix (E is the number of classes). 33 | 34 | Have a look at the `load_data()` function in `utils.py` for an example. 35 | 36 | In this example, we load citation network data (Cora, Citeseer or Pubmed). The original datasets can be found here: http://linqs.cs.umd.edu/projects/projects/lbc/. In our version (see `data` folder) we use dataset splits provided by https://github.com/kimiyoung/planetoid (Zhilin Yang, William W. Cohen, Ruslan Salakhutdinov, [Revisiting Semi-Supervised Learning with Graph Embeddings](https://arxiv.org/abs/1603.08861), ICML 2016). 37 | 38 | You can specify a dataset as follows: 39 | 40 | ```bash 41 | python train.py --dataset citeseer 42 | ``` 43 | 44 | (or by editing `train.py`) 45 | 46 | ## Models 47 | 48 | You can choose between the following models: 49 | * `gcn`: Graph convolutional network (Thomas N. Kipf, Max Welling, [Semi-Supervised Classification with Graph Convolutional Networks](http://arxiv.org/abs/1609.02907), 2016) 50 | * `gcn_cheby`: Chebyshev polynomial version of graph convolutional network as described in (Michaël Defferrard, Xavier Bresson, Pierre Vandergheynst, [Convolutional Neural Networks on Graphs with Fast Localized Spectral Filtering](https://arxiv.org/abs/1606.09375), NIPS 2016) 51 | * `dense`: Basic multi-layer perceptron that supports sparse inputs 52 | 53 | ## Graph classification 54 | 55 | Our framework also supports batch-wise classification of multiple graph instances (of potentially different size) with an adjacency matrix each. It is best to concatenate respective feature matrices and build a (sparse) block-diagonal matrix where each block corresponds to the adjacency matrix of one graph instance. For pooling (in case of graph-level outputs as opposed to node-level outputs) it is best to specify a simple pooling matrix that collects features from their respective graph instances, as illustrated below: 56 | 57 | ![graph_classification](https://user-images.githubusercontent.com/7347296/34198790-eb5bec96-e56b-11e7-90d5-157800e042de.png) 58 | 59 | 60 | ## Cite 61 | 62 | Please cite our paper if you use this code in your own work: 63 | 64 | ``` 65 | @inproceedings{kipf2017semi, 66 | title={Semi-Supervised Classification with Graph Convolutional Networks}, 67 | author={Kipf, Thomas N. and Welling, Max}, 68 | booktitle={International Conference on Learning Representations (ICLR)}, 69 | year={2017} 70 | } 71 | ``` 72 | -------------------------------------------------------------------------------- /third_party/gcn/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/gcnn-survey-paper/591af8d6c4374378831cab2cdec79575e2540d79/third_party/gcn/__init__.py -------------------------------------------------------------------------------- /third_party/gcn/gcn/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from __future__ import division 3 | -------------------------------------------------------------------------------- /third_party/gcn/gcn/data/ind.citeseer.allx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/gcnn-survey-paper/591af8d6c4374378831cab2cdec79575e2540d79/third_party/gcn/gcn/data/ind.citeseer.allx -------------------------------------------------------------------------------- /third_party/gcn/gcn/data/ind.citeseer.ally: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/gcnn-survey-paper/591af8d6c4374378831cab2cdec79575e2540d79/third_party/gcn/gcn/data/ind.citeseer.ally -------------------------------------------------------------------------------- /third_party/gcn/gcn/data/ind.citeseer.graph: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/gcnn-survey-paper/591af8d6c4374378831cab2cdec79575e2540d79/third_party/gcn/gcn/data/ind.citeseer.graph -------------------------------------------------------------------------------- /third_party/gcn/gcn/data/ind.citeseer.test.index: -------------------------------------------------------------------------------- 1 | 2488 2 | 2644 3 | 3261 4 | 2804 5 | 3176 6 | 2432 7 | 3310 8 | 2410 9 | 2812 10 | 2520 11 | 2994 12 | 3282 13 | 2680 14 | 2848 15 | 2670 16 | 3005 17 | 2977 18 | 2592 19 | 2967 20 | 2461 21 | 3184 22 | 2852 23 | 2768 24 | 2905 25 | 2851 26 | 3129 27 | 3164 28 | 2438 29 | 2793 30 | 2763 31 | 2528 32 | 2954 33 | 2347 34 | 2640 35 | 3265 36 | 2874 37 | 2446 38 | 2856 39 | 3149 40 | 2374 41 | 3097 42 | 3301 43 | 2664 44 | 2418 45 | 2655 46 | 2464 47 | 2596 48 | 3262 49 | 3278 50 | 2320 51 | 2612 52 | 2614 53 | 2550 54 | 2626 55 | 2772 56 | 3007 57 | 2733 58 | 2516 59 | 2476 60 | 2798 61 | 2561 62 | 2839 63 | 2685 64 | 2391 65 | 2705 66 | 3098 67 | 2754 68 | 3251 69 | 2767 70 | 2630 71 | 2727 72 | 2513 73 | 2701 74 | 3264 75 | 2792 76 | 2821 77 | 3260 78 | 2462 79 | 3307 80 | 2639 81 | 2900 82 | 3060 83 | 2672 84 | 3116 85 | 2731 86 | 3316 87 | 2386 88 | 2425 89 | 2518 90 | 3151 91 | 2586 92 | 2797 93 | 2479 94 | 3117 95 | 2580 96 | 3182 97 | 2459 98 | 2508 99 | 3052 100 | 3230 101 | 3215 102 | 2803 103 | 2969 104 | 2562 105 | 2398 106 | 3325 107 | 2343 108 | 3030 109 | 2414 110 | 2776 111 | 2383 112 | 3173 113 | 2850 114 | 2499 115 | 3312 116 | 2648 117 | 2784 118 | 2898 119 | 3056 120 | 2484 121 | 3179 122 | 3132 123 | 2577 124 | 2563 125 | 2867 126 | 3317 127 | 2355 128 | 3207 129 | 3178 130 | 2968 131 | 3319 132 | 2358 133 | 2764 134 | 3001 135 | 2683 136 | 3271 137 | 2321 138 | 2567 139 | 2502 140 | 3246 141 | 2715 142 | 3066 143 | 2390 144 | 2381 145 | 3162 146 | 2741 147 | 2498 148 | 2790 149 | 3038 150 | 3321 151 | 2481 152 | 3050 153 | 3161 154 | 3122 155 | 2801 156 | 2957 157 | 3177 158 | 2965 159 | 2621 160 | 3208 161 | 2921 162 | 2802 163 | 2357 164 | 2677 165 | 2519 166 | 2860 167 | 2696 168 | 2368 169 | 3241 170 | 2858 171 | 2419 172 | 2762 173 | 2875 174 | 3222 175 | 3064 176 | 2827 177 | 3044 178 | 2471 179 | 3062 180 | 2982 181 | 2736 182 | 2322 183 | 2709 184 | 2766 185 | 2424 186 | 2602 187 | 2970 188 | 2675 189 | 3299 190 | 2554 191 | 2964 192 | 2597 193 | 2753 194 | 2979 195 | 2523 196 | 2912 197 | 2896 198 | 2317 199 | 3167 200 | 2813 201 | 2482 202 | 2557 203 | 3043 204 | 3244 205 | 2985 206 | 2460 207 | 2363 208 | 3272 209 | 3045 210 | 3192 211 | 2453 212 | 2656 213 | 2834 214 | 2443 215 | 3202 216 | 2926 217 | 2711 218 | 2633 219 | 2384 220 | 2752 221 | 3285 222 | 2817 223 | 2483 224 | 2919 225 | 2924 226 | 2661 227 | 2698 228 | 2361 229 | 2662 230 | 2819 231 | 3143 232 | 2316 233 | 3196 234 | 2739 235 | 2345 236 | 2578 237 | 2822 238 | 3229 239 | 2908 240 | 2917 241 | 2692 242 | 3200 243 | 2324 244 | 2522 245 | 3322 246 | 2697 247 | 3163 248 | 3093 249 | 3233 250 | 2774 251 | 2371 252 | 2835 253 | 2652 254 | 2539 255 | 2843 256 | 3231 257 | 2976 258 | 2429 259 | 2367 260 | 3144 261 | 2564 262 | 3283 263 | 3217 264 | 3035 265 | 2962 266 | 2433 267 | 2415 268 | 2387 269 | 3021 270 | 2595 271 | 2517 272 | 2468 273 | 3061 274 | 2673 275 | 2348 276 | 3027 277 | 2467 278 | 3318 279 | 2959 280 | 3273 281 | 2392 282 | 2779 283 | 2678 284 | 3004 285 | 2634 286 | 2974 287 | 3198 288 | 2342 289 | 2376 290 | 3249 291 | 2868 292 | 2952 293 | 2710 294 | 2838 295 | 2335 296 | 2524 297 | 2650 298 | 3186 299 | 2743 300 | 2545 301 | 2841 302 | 2515 303 | 2505 304 | 3181 305 | 2945 306 | 2738 307 | 2933 308 | 3303 309 | 2611 310 | 3090 311 | 2328 312 | 3010 313 | 3016 314 | 2504 315 | 2936 316 | 3266 317 | 3253 318 | 2840 319 | 3034 320 | 2581 321 | 2344 322 | 2452 323 | 2654 324 | 3199 325 | 3137 326 | 2514 327 | 2394 328 | 2544 329 | 2641 330 | 2613 331 | 2618 332 | 2558 333 | 2593 334 | 2532 335 | 2512 336 | 2975 337 | 3267 338 | 2566 339 | 2951 340 | 3300 341 | 2869 342 | 2629 343 | 2747 344 | 3055 345 | 2831 346 | 3105 347 | 3168 348 | 3100 349 | 2431 350 | 2828 351 | 2684 352 | 3269 353 | 2910 354 | 2865 355 | 2693 356 | 2884 357 | 3228 358 | 2783 359 | 3247 360 | 2770 361 | 3157 362 | 2421 363 | 2382 364 | 2331 365 | 3203 366 | 3240 367 | 2351 368 | 3114 369 | 2986 370 | 2688 371 | 2439 372 | 2996 373 | 3079 374 | 3103 375 | 3296 376 | 2349 377 | 2372 378 | 3096 379 | 2422 380 | 2551 381 | 3069 382 | 2737 383 | 3084 384 | 3304 385 | 3022 386 | 2542 387 | 3204 388 | 2949 389 | 2318 390 | 2450 391 | 3140 392 | 2734 393 | 2881 394 | 2576 395 | 3054 396 | 3089 397 | 3125 398 | 2761 399 | 3136 400 | 3111 401 | 2427 402 | 2466 403 | 3101 404 | 3104 405 | 3259 406 | 2534 407 | 2961 408 | 3191 409 | 3000 410 | 3036 411 | 2356 412 | 2800 413 | 3155 414 | 3224 415 | 2646 416 | 2735 417 | 3020 418 | 2866 419 | 2426 420 | 2448 421 | 3226 422 | 3219 423 | 2749 424 | 3183 425 | 2906 426 | 2360 427 | 2440 428 | 2946 429 | 2313 430 | 2859 431 | 2340 432 | 3008 433 | 2719 434 | 3058 435 | 2653 436 | 3023 437 | 2888 438 | 3243 439 | 2913 440 | 3242 441 | 3067 442 | 2409 443 | 3227 444 | 2380 445 | 2353 446 | 2686 447 | 2971 448 | 2847 449 | 2947 450 | 2857 451 | 3263 452 | 3218 453 | 2861 454 | 3323 455 | 2635 456 | 2966 457 | 2604 458 | 2456 459 | 2832 460 | 2694 461 | 3245 462 | 3119 463 | 2942 464 | 3153 465 | 2894 466 | 2555 467 | 3128 468 | 2703 469 | 2323 470 | 2631 471 | 2732 472 | 2699 473 | 2314 474 | 2590 475 | 3127 476 | 2891 477 | 2873 478 | 2814 479 | 2326 480 | 3026 481 | 3288 482 | 3095 483 | 2706 484 | 2457 485 | 2377 486 | 2620 487 | 2526 488 | 2674 489 | 3190 490 | 2923 491 | 3032 492 | 2334 493 | 3254 494 | 2991 495 | 3277 496 | 2973 497 | 2599 498 | 2658 499 | 2636 500 | 2826 501 | 3148 502 | 2958 503 | 3258 504 | 2990 505 | 3180 506 | 2538 507 | 2748 508 | 2625 509 | 2565 510 | 3011 511 | 3057 512 | 2354 513 | 3158 514 | 2622 515 | 3308 516 | 2983 517 | 2560 518 | 3169 519 | 3059 520 | 2480 521 | 3194 522 | 3291 523 | 3216 524 | 2643 525 | 3172 526 | 2352 527 | 2724 528 | 2485 529 | 2411 530 | 2948 531 | 2445 532 | 2362 533 | 2668 534 | 3275 535 | 3107 536 | 2496 537 | 2529 538 | 2700 539 | 2541 540 | 3028 541 | 2879 542 | 2660 543 | 3324 544 | 2755 545 | 2436 546 | 3048 547 | 2623 548 | 2920 549 | 3040 550 | 2568 551 | 3221 552 | 3003 553 | 3295 554 | 2473 555 | 3232 556 | 3213 557 | 2823 558 | 2897 559 | 2573 560 | 2645 561 | 3018 562 | 3326 563 | 2795 564 | 2915 565 | 3109 566 | 3086 567 | 2463 568 | 3118 569 | 2671 570 | 2909 571 | 2393 572 | 2325 573 | 3029 574 | 2972 575 | 3110 576 | 2870 577 | 3284 578 | 2816 579 | 2647 580 | 2667 581 | 2955 582 | 2333 583 | 2960 584 | 2864 585 | 2893 586 | 2458 587 | 2441 588 | 2359 589 | 2327 590 | 3256 591 | 3099 592 | 3073 593 | 3138 594 | 2511 595 | 2666 596 | 2548 597 | 2364 598 | 2451 599 | 2911 600 | 3237 601 | 3206 602 | 3080 603 | 3279 604 | 2934 605 | 2981 606 | 2878 607 | 3130 608 | 2830 609 | 3091 610 | 2659 611 | 2449 612 | 3152 613 | 2413 614 | 2722 615 | 2796 616 | 3220 617 | 2751 618 | 2935 619 | 3238 620 | 2491 621 | 2730 622 | 2842 623 | 3223 624 | 2492 625 | 3074 626 | 3094 627 | 2833 628 | 2521 629 | 2883 630 | 3315 631 | 2845 632 | 2907 633 | 3083 634 | 2572 635 | 3092 636 | 2903 637 | 2918 638 | 3039 639 | 3286 640 | 2587 641 | 3068 642 | 2338 643 | 3166 644 | 3134 645 | 2455 646 | 2497 647 | 2992 648 | 2775 649 | 2681 650 | 2430 651 | 2932 652 | 2931 653 | 2434 654 | 3154 655 | 3046 656 | 2598 657 | 2366 658 | 3015 659 | 3147 660 | 2944 661 | 2582 662 | 3274 663 | 2987 664 | 2642 665 | 2547 666 | 2420 667 | 2930 668 | 2750 669 | 2417 670 | 2808 671 | 3141 672 | 2997 673 | 2995 674 | 2584 675 | 2312 676 | 3033 677 | 3070 678 | 3065 679 | 2509 680 | 3314 681 | 2396 682 | 2543 683 | 2423 684 | 3170 685 | 2389 686 | 3289 687 | 2728 688 | 2540 689 | 2437 690 | 2486 691 | 2895 692 | 3017 693 | 2853 694 | 2406 695 | 2346 696 | 2877 697 | 2472 698 | 3210 699 | 2637 700 | 2927 701 | 2789 702 | 2330 703 | 3088 704 | 3102 705 | 2616 706 | 3081 707 | 2902 708 | 3205 709 | 3320 710 | 3165 711 | 2984 712 | 3185 713 | 2707 714 | 3255 715 | 2583 716 | 2773 717 | 2742 718 | 3024 719 | 2402 720 | 2718 721 | 2882 722 | 2575 723 | 3281 724 | 2786 725 | 2855 726 | 3014 727 | 2401 728 | 2535 729 | 2687 730 | 2495 731 | 3113 732 | 2609 733 | 2559 734 | 2665 735 | 2530 736 | 3293 737 | 2399 738 | 2605 739 | 2690 740 | 3133 741 | 2799 742 | 2533 743 | 2695 744 | 2713 745 | 2886 746 | 2691 747 | 2549 748 | 3077 749 | 3002 750 | 3049 751 | 3051 752 | 3087 753 | 2444 754 | 3085 755 | 3135 756 | 2702 757 | 3211 758 | 3108 759 | 2501 760 | 2769 761 | 3290 762 | 2465 763 | 3025 764 | 3019 765 | 2385 766 | 2940 767 | 2657 768 | 2610 769 | 2525 770 | 2941 771 | 3078 772 | 2341 773 | 2916 774 | 2956 775 | 2375 776 | 2880 777 | 3009 778 | 2780 779 | 2370 780 | 2925 781 | 2332 782 | 3146 783 | 2315 784 | 2809 785 | 3145 786 | 3106 787 | 2782 788 | 2760 789 | 2493 790 | 2765 791 | 2556 792 | 2890 793 | 2400 794 | 2339 795 | 3201 796 | 2818 797 | 3248 798 | 3280 799 | 2570 800 | 2569 801 | 2937 802 | 3174 803 | 2836 804 | 2708 805 | 2820 806 | 3195 807 | 2617 808 | 3197 809 | 2319 810 | 2744 811 | 2615 812 | 2825 813 | 2603 814 | 2914 815 | 2531 816 | 3193 817 | 2624 818 | 2365 819 | 2810 820 | 3239 821 | 3159 822 | 2537 823 | 2844 824 | 2758 825 | 2938 826 | 3037 827 | 2503 828 | 3297 829 | 2885 830 | 2608 831 | 2494 832 | 2712 833 | 2408 834 | 2901 835 | 2704 836 | 2536 837 | 2373 838 | 2478 839 | 2723 840 | 3076 841 | 2627 842 | 2369 843 | 2669 844 | 3006 845 | 2628 846 | 2788 847 | 3276 848 | 2435 849 | 3139 850 | 3235 851 | 2527 852 | 2571 853 | 2815 854 | 2442 855 | 2892 856 | 2978 857 | 2746 858 | 3150 859 | 2574 860 | 2725 861 | 3188 862 | 2601 863 | 2378 864 | 3075 865 | 2632 866 | 2794 867 | 3270 868 | 3071 869 | 2506 870 | 3126 871 | 3236 872 | 3257 873 | 2824 874 | 2989 875 | 2950 876 | 2428 877 | 2405 878 | 3156 879 | 2447 880 | 2787 881 | 2805 882 | 2720 883 | 2403 884 | 2811 885 | 2329 886 | 2474 887 | 2785 888 | 2350 889 | 2507 890 | 2416 891 | 3112 892 | 2475 893 | 2876 894 | 2585 895 | 2487 896 | 3072 897 | 3082 898 | 2943 899 | 2757 900 | 2388 901 | 2600 902 | 3294 903 | 2756 904 | 3142 905 | 3041 906 | 2594 907 | 2998 908 | 3047 909 | 2379 910 | 2980 911 | 2454 912 | 2862 913 | 3175 914 | 2588 915 | 3031 916 | 3012 917 | 2889 918 | 2500 919 | 2791 920 | 2854 921 | 2619 922 | 2395 923 | 2807 924 | 2740 925 | 2412 926 | 3131 927 | 3013 928 | 2939 929 | 2651 930 | 2490 931 | 2988 932 | 2863 933 | 3225 934 | 2745 935 | 2714 936 | 3160 937 | 3124 938 | 2849 939 | 2676 940 | 2872 941 | 3287 942 | 3189 943 | 2716 944 | 3115 945 | 2928 946 | 2871 947 | 2591 948 | 2717 949 | 2546 950 | 2777 951 | 3298 952 | 2397 953 | 3187 954 | 2726 955 | 2336 956 | 3268 957 | 2477 958 | 2904 959 | 2846 960 | 3121 961 | 2899 962 | 2510 963 | 2806 964 | 2963 965 | 3313 966 | 2679 967 | 3302 968 | 2663 969 | 3053 970 | 2469 971 | 2999 972 | 3311 973 | 2470 974 | 2638 975 | 3120 976 | 3171 977 | 2689 978 | 2922 979 | 2607 980 | 2721 981 | 2993 982 | 2887 983 | 2837 984 | 2929 985 | 2829 986 | 3234 987 | 2649 988 | 2337 989 | 2759 990 | 2778 991 | 2771 992 | 2404 993 | 2589 994 | 3123 995 | 3209 996 | 2729 997 | 3252 998 | 2606 999 | 2579 1000 | 2552 1001 | -------------------------------------------------------------------------------- /third_party/gcn/gcn/data/ind.citeseer.tx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/gcnn-survey-paper/591af8d6c4374378831cab2cdec79575e2540d79/third_party/gcn/gcn/data/ind.citeseer.tx -------------------------------------------------------------------------------- /third_party/gcn/gcn/data/ind.citeseer.ty: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/gcnn-survey-paper/591af8d6c4374378831cab2cdec79575e2540d79/third_party/gcn/gcn/data/ind.citeseer.ty -------------------------------------------------------------------------------- /third_party/gcn/gcn/data/ind.citeseer.x: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/gcnn-survey-paper/591af8d6c4374378831cab2cdec79575e2540d79/third_party/gcn/gcn/data/ind.citeseer.x -------------------------------------------------------------------------------- /third_party/gcn/gcn/data/ind.citeseer.y: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/gcnn-survey-paper/591af8d6c4374378831cab2cdec79575e2540d79/third_party/gcn/gcn/data/ind.citeseer.y -------------------------------------------------------------------------------- /third_party/gcn/gcn/data/ind.cora.allx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/gcnn-survey-paper/591af8d6c4374378831cab2cdec79575e2540d79/third_party/gcn/gcn/data/ind.cora.allx -------------------------------------------------------------------------------- /third_party/gcn/gcn/data/ind.cora.ally: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/gcnn-survey-paper/591af8d6c4374378831cab2cdec79575e2540d79/third_party/gcn/gcn/data/ind.cora.ally -------------------------------------------------------------------------------- /third_party/gcn/gcn/data/ind.cora.graph: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/gcnn-survey-paper/591af8d6c4374378831cab2cdec79575e2540d79/third_party/gcn/gcn/data/ind.cora.graph -------------------------------------------------------------------------------- /third_party/gcn/gcn/data/ind.cora.test.index: -------------------------------------------------------------------------------- 1 | 2692 2 | 2532 3 | 2050 4 | 1715 5 | 2362 6 | 2609 7 | 2622 8 | 1975 9 | 2081 10 | 1767 11 | 2263 12 | 1725 13 | 2588 14 | 2259 15 | 2357 16 | 1998 17 | 2574 18 | 2179 19 | 2291 20 | 2382 21 | 1812 22 | 1751 23 | 2422 24 | 1937 25 | 2631 26 | 2510 27 | 2378 28 | 2589 29 | 2345 30 | 1943 31 | 1850 32 | 2298 33 | 1825 34 | 2035 35 | 2507 36 | 2313 37 | 1906 38 | 1797 39 | 2023 40 | 2159 41 | 2495 42 | 1886 43 | 2122 44 | 2369 45 | 2461 46 | 1925 47 | 2565 48 | 1858 49 | 2234 50 | 2000 51 | 1846 52 | 2318 53 | 1723 54 | 2559 55 | 2258 56 | 1763 57 | 1991 58 | 1922 59 | 2003 60 | 2662 61 | 2250 62 | 2064 63 | 2529 64 | 1888 65 | 2499 66 | 2454 67 | 2320 68 | 2287 69 | 2203 70 | 2018 71 | 2002 72 | 2632 73 | 2554 74 | 2314 75 | 2537 76 | 1760 77 | 2088 78 | 2086 79 | 2218 80 | 2605 81 | 1953 82 | 2403 83 | 1920 84 | 2015 85 | 2335 86 | 2535 87 | 1837 88 | 2009 89 | 1905 90 | 2636 91 | 1942 92 | 2193 93 | 2576 94 | 2373 95 | 1873 96 | 2463 97 | 2509 98 | 1954 99 | 2656 100 | 2455 101 | 2494 102 | 2295 103 | 2114 104 | 2561 105 | 2176 106 | 2275 107 | 2635 108 | 2442 109 | 2704 110 | 2127 111 | 2085 112 | 2214 113 | 2487 114 | 1739 115 | 2543 116 | 1783 117 | 2485 118 | 2262 119 | 2472 120 | 2326 121 | 1738 122 | 2170 123 | 2100 124 | 2384 125 | 2152 126 | 2647 127 | 2693 128 | 2376 129 | 1775 130 | 1726 131 | 2476 132 | 2195 133 | 1773 134 | 1793 135 | 2194 136 | 2581 137 | 1854 138 | 2524 139 | 1945 140 | 1781 141 | 1987 142 | 2599 143 | 1744 144 | 2225 145 | 2300 146 | 1928 147 | 2042 148 | 2202 149 | 1958 150 | 1816 151 | 1916 152 | 2679 153 | 2190 154 | 1733 155 | 2034 156 | 2643 157 | 2177 158 | 1883 159 | 1917 160 | 1996 161 | 2491 162 | 2268 163 | 2231 164 | 2471 165 | 1919 166 | 1909 167 | 2012 168 | 2522 169 | 1865 170 | 2466 171 | 2469 172 | 2087 173 | 2584 174 | 2563 175 | 1924 176 | 2143 177 | 1736 178 | 1966 179 | 2533 180 | 2490 181 | 2630 182 | 1973 183 | 2568 184 | 1978 185 | 2664 186 | 2633 187 | 2312 188 | 2178 189 | 1754 190 | 2307 191 | 2480 192 | 1960 193 | 1742 194 | 1962 195 | 2160 196 | 2070 197 | 2553 198 | 2433 199 | 1768 200 | 2659 201 | 2379 202 | 2271 203 | 1776 204 | 2153 205 | 1877 206 | 2027 207 | 2028 208 | 2155 209 | 2196 210 | 2483 211 | 2026 212 | 2158 213 | 2407 214 | 1821 215 | 2131 216 | 2676 217 | 2277 218 | 2489 219 | 2424 220 | 1963 221 | 1808 222 | 1859 223 | 2597 224 | 2548 225 | 2368 226 | 1817 227 | 2405 228 | 2413 229 | 2603 230 | 2350 231 | 2118 232 | 2329 233 | 1969 234 | 2577 235 | 2475 236 | 2467 237 | 2425 238 | 1769 239 | 2092 240 | 2044 241 | 2586 242 | 2608 243 | 1983 244 | 2109 245 | 2649 246 | 1964 247 | 2144 248 | 1902 249 | 2411 250 | 2508 251 | 2360 252 | 1721 253 | 2005 254 | 2014 255 | 2308 256 | 2646 257 | 1949 258 | 1830 259 | 2212 260 | 2596 261 | 1832 262 | 1735 263 | 1866 264 | 2695 265 | 1941 266 | 2546 267 | 2498 268 | 2686 269 | 2665 270 | 1784 271 | 2613 272 | 1970 273 | 2021 274 | 2211 275 | 2516 276 | 2185 277 | 2479 278 | 2699 279 | 2150 280 | 1990 281 | 2063 282 | 2075 283 | 1979 284 | 2094 285 | 1787 286 | 2571 287 | 2690 288 | 1926 289 | 2341 290 | 2566 291 | 1957 292 | 1709 293 | 1955 294 | 2570 295 | 2387 296 | 1811 297 | 2025 298 | 2447 299 | 2696 300 | 2052 301 | 2366 302 | 1857 303 | 2273 304 | 2245 305 | 2672 306 | 2133 307 | 2421 308 | 1929 309 | 2125 310 | 2319 311 | 2641 312 | 2167 313 | 2418 314 | 1765 315 | 1761 316 | 1828 317 | 2188 318 | 1972 319 | 1997 320 | 2419 321 | 2289 322 | 2296 323 | 2587 324 | 2051 325 | 2440 326 | 2053 327 | 2191 328 | 1923 329 | 2164 330 | 1861 331 | 2339 332 | 2333 333 | 2523 334 | 2670 335 | 2121 336 | 1921 337 | 1724 338 | 2253 339 | 2374 340 | 1940 341 | 2545 342 | 2301 343 | 2244 344 | 2156 345 | 1849 346 | 2551 347 | 2011 348 | 2279 349 | 2572 350 | 1757 351 | 2400 352 | 2569 353 | 2072 354 | 2526 355 | 2173 356 | 2069 357 | 2036 358 | 1819 359 | 1734 360 | 1880 361 | 2137 362 | 2408 363 | 2226 364 | 2604 365 | 1771 366 | 2698 367 | 2187 368 | 2060 369 | 1756 370 | 2201 371 | 2066 372 | 2439 373 | 1844 374 | 1772 375 | 2383 376 | 2398 377 | 1708 378 | 1992 379 | 1959 380 | 1794 381 | 2426 382 | 2702 383 | 2444 384 | 1944 385 | 1829 386 | 2660 387 | 2497 388 | 2607 389 | 2343 390 | 1730 391 | 2624 392 | 1790 393 | 1935 394 | 1967 395 | 2401 396 | 2255 397 | 2355 398 | 2348 399 | 1931 400 | 2183 401 | 2161 402 | 2701 403 | 1948 404 | 2501 405 | 2192 406 | 2404 407 | 2209 408 | 2331 409 | 1810 410 | 2363 411 | 2334 412 | 1887 413 | 2393 414 | 2557 415 | 1719 416 | 1732 417 | 1986 418 | 2037 419 | 2056 420 | 1867 421 | 2126 422 | 1932 423 | 2117 424 | 1807 425 | 1801 426 | 1743 427 | 2041 428 | 1843 429 | 2388 430 | 2221 431 | 1833 432 | 2677 433 | 1778 434 | 2661 435 | 2306 436 | 2394 437 | 2106 438 | 2430 439 | 2371 440 | 2606 441 | 2353 442 | 2269 443 | 2317 444 | 2645 445 | 2372 446 | 2550 447 | 2043 448 | 1968 449 | 2165 450 | 2310 451 | 1985 452 | 2446 453 | 1982 454 | 2377 455 | 2207 456 | 1818 457 | 1913 458 | 1766 459 | 1722 460 | 1894 461 | 2020 462 | 1881 463 | 2621 464 | 2409 465 | 2261 466 | 2458 467 | 2096 468 | 1712 469 | 2594 470 | 2293 471 | 2048 472 | 2359 473 | 1839 474 | 2392 475 | 2254 476 | 1911 477 | 2101 478 | 2367 479 | 1889 480 | 1753 481 | 2555 482 | 2246 483 | 2264 484 | 2010 485 | 2336 486 | 2651 487 | 2017 488 | 2140 489 | 1842 490 | 2019 491 | 1890 492 | 2525 493 | 2134 494 | 2492 495 | 2652 496 | 2040 497 | 2145 498 | 2575 499 | 2166 500 | 1999 501 | 2434 502 | 1711 503 | 2276 504 | 2450 505 | 2389 506 | 2669 507 | 2595 508 | 1814 509 | 2039 510 | 2502 511 | 1896 512 | 2168 513 | 2344 514 | 2637 515 | 2031 516 | 1977 517 | 2380 518 | 1936 519 | 2047 520 | 2460 521 | 2102 522 | 1745 523 | 2650 524 | 2046 525 | 2514 526 | 1980 527 | 2352 528 | 2113 529 | 1713 530 | 2058 531 | 2558 532 | 1718 533 | 1864 534 | 1876 535 | 2338 536 | 1879 537 | 1891 538 | 2186 539 | 2451 540 | 2181 541 | 2638 542 | 2644 543 | 2103 544 | 2591 545 | 2266 546 | 2468 547 | 1869 548 | 2582 549 | 2674 550 | 2361 551 | 2462 552 | 1748 553 | 2215 554 | 2615 555 | 2236 556 | 2248 557 | 2493 558 | 2342 559 | 2449 560 | 2274 561 | 1824 562 | 1852 563 | 1870 564 | 2441 565 | 2356 566 | 1835 567 | 2694 568 | 2602 569 | 2685 570 | 1893 571 | 2544 572 | 2536 573 | 1994 574 | 1853 575 | 1838 576 | 1786 577 | 1930 578 | 2539 579 | 1892 580 | 2265 581 | 2618 582 | 2486 583 | 2583 584 | 2061 585 | 1796 586 | 1806 587 | 2084 588 | 1933 589 | 2095 590 | 2136 591 | 2078 592 | 1884 593 | 2438 594 | 2286 595 | 2138 596 | 1750 597 | 2184 598 | 1799 599 | 2278 600 | 2410 601 | 2642 602 | 2435 603 | 1956 604 | 2399 605 | 1774 606 | 2129 607 | 1898 608 | 1823 609 | 1938 610 | 2299 611 | 1862 612 | 2420 613 | 2673 614 | 1984 615 | 2204 616 | 1717 617 | 2074 618 | 2213 619 | 2436 620 | 2297 621 | 2592 622 | 2667 623 | 2703 624 | 2511 625 | 1779 626 | 1782 627 | 2625 628 | 2365 629 | 2315 630 | 2381 631 | 1788 632 | 1714 633 | 2302 634 | 1927 635 | 2325 636 | 2506 637 | 2169 638 | 2328 639 | 2629 640 | 2128 641 | 2655 642 | 2282 643 | 2073 644 | 2395 645 | 2247 646 | 2521 647 | 2260 648 | 1868 649 | 1988 650 | 2324 651 | 2705 652 | 2541 653 | 1731 654 | 2681 655 | 2707 656 | 2465 657 | 1785 658 | 2149 659 | 2045 660 | 2505 661 | 2611 662 | 2217 663 | 2180 664 | 1904 665 | 2453 666 | 2484 667 | 1871 668 | 2309 669 | 2349 670 | 2482 671 | 2004 672 | 1965 673 | 2406 674 | 2162 675 | 1805 676 | 2654 677 | 2007 678 | 1947 679 | 1981 680 | 2112 681 | 2141 682 | 1720 683 | 1758 684 | 2080 685 | 2330 686 | 2030 687 | 2432 688 | 2089 689 | 2547 690 | 1820 691 | 1815 692 | 2675 693 | 1840 694 | 2658 695 | 2370 696 | 2251 697 | 1908 698 | 2029 699 | 2068 700 | 2513 701 | 2549 702 | 2267 703 | 2580 704 | 2327 705 | 2351 706 | 2111 707 | 2022 708 | 2321 709 | 2614 710 | 2252 711 | 2104 712 | 1822 713 | 2552 714 | 2243 715 | 1798 716 | 2396 717 | 2663 718 | 2564 719 | 2148 720 | 2562 721 | 2684 722 | 2001 723 | 2151 724 | 2706 725 | 2240 726 | 2474 727 | 2303 728 | 2634 729 | 2680 730 | 2055 731 | 2090 732 | 2503 733 | 2347 734 | 2402 735 | 2238 736 | 1950 737 | 2054 738 | 2016 739 | 1872 740 | 2233 741 | 1710 742 | 2032 743 | 2540 744 | 2628 745 | 1795 746 | 2616 747 | 1903 748 | 2531 749 | 2567 750 | 1946 751 | 1897 752 | 2222 753 | 2227 754 | 2627 755 | 1856 756 | 2464 757 | 2241 758 | 2481 759 | 2130 760 | 2311 761 | 2083 762 | 2223 763 | 2284 764 | 2235 765 | 2097 766 | 1752 767 | 2515 768 | 2527 769 | 2385 770 | 2189 771 | 2283 772 | 2182 773 | 2079 774 | 2375 775 | 2174 776 | 2437 777 | 1993 778 | 2517 779 | 2443 780 | 2224 781 | 2648 782 | 2171 783 | 2290 784 | 2542 785 | 2038 786 | 1855 787 | 1831 788 | 1759 789 | 1848 790 | 2445 791 | 1827 792 | 2429 793 | 2205 794 | 2598 795 | 2657 796 | 1728 797 | 2065 798 | 1918 799 | 2427 800 | 2573 801 | 2620 802 | 2292 803 | 1777 804 | 2008 805 | 1875 806 | 2288 807 | 2256 808 | 2033 809 | 2470 810 | 2585 811 | 2610 812 | 2082 813 | 2230 814 | 1915 815 | 1847 816 | 2337 817 | 2512 818 | 2386 819 | 2006 820 | 2653 821 | 2346 822 | 1951 823 | 2110 824 | 2639 825 | 2520 826 | 1939 827 | 2683 828 | 2139 829 | 2220 830 | 1910 831 | 2237 832 | 1900 833 | 1836 834 | 2197 835 | 1716 836 | 1860 837 | 2077 838 | 2519 839 | 2538 840 | 2323 841 | 1914 842 | 1971 843 | 1845 844 | 2132 845 | 1802 846 | 1907 847 | 2640 848 | 2496 849 | 2281 850 | 2198 851 | 2416 852 | 2285 853 | 1755 854 | 2431 855 | 2071 856 | 2249 857 | 2123 858 | 1727 859 | 2459 860 | 2304 861 | 2199 862 | 1791 863 | 1809 864 | 1780 865 | 2210 866 | 2417 867 | 1874 868 | 1878 869 | 2116 870 | 1961 871 | 1863 872 | 2579 873 | 2477 874 | 2228 875 | 2332 876 | 2578 877 | 2457 878 | 2024 879 | 1934 880 | 2316 881 | 1841 882 | 1764 883 | 1737 884 | 2322 885 | 2239 886 | 2294 887 | 1729 888 | 2488 889 | 1974 890 | 2473 891 | 2098 892 | 2612 893 | 1834 894 | 2340 895 | 2423 896 | 2175 897 | 2280 898 | 2617 899 | 2208 900 | 2560 901 | 1741 902 | 2600 903 | 2059 904 | 1747 905 | 2242 906 | 2700 907 | 2232 908 | 2057 909 | 2147 910 | 2682 911 | 1792 912 | 1826 913 | 2120 914 | 1895 915 | 2364 916 | 2163 917 | 1851 918 | 2391 919 | 2414 920 | 2452 921 | 1803 922 | 1989 923 | 2623 924 | 2200 925 | 2528 926 | 2415 927 | 1804 928 | 2146 929 | 2619 930 | 2687 931 | 1762 932 | 2172 933 | 2270 934 | 2678 935 | 2593 936 | 2448 937 | 1882 938 | 2257 939 | 2500 940 | 1899 941 | 2478 942 | 2412 943 | 2107 944 | 1746 945 | 2428 946 | 2115 947 | 1800 948 | 1901 949 | 2397 950 | 2530 951 | 1912 952 | 2108 953 | 2206 954 | 2091 955 | 1740 956 | 2219 957 | 1976 958 | 2099 959 | 2142 960 | 2671 961 | 2668 962 | 2216 963 | 2272 964 | 2229 965 | 2666 966 | 2456 967 | 2534 968 | 2697 969 | 2688 970 | 2062 971 | 2691 972 | 2689 973 | 2154 974 | 2590 975 | 2626 976 | 2390 977 | 1813 978 | 2067 979 | 1952 980 | 2518 981 | 2358 982 | 1789 983 | 2076 984 | 2049 985 | 2119 986 | 2013 987 | 2124 988 | 2556 989 | 2105 990 | 2093 991 | 1885 992 | 2305 993 | 2354 994 | 2135 995 | 2601 996 | 1770 997 | 1995 998 | 2504 999 | 1749 1000 | 2157 1001 | -------------------------------------------------------------------------------- /third_party/gcn/gcn/data/ind.cora.tx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/gcnn-survey-paper/591af8d6c4374378831cab2cdec79575e2540d79/third_party/gcn/gcn/data/ind.cora.tx -------------------------------------------------------------------------------- /third_party/gcn/gcn/data/ind.cora.ty: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/gcnn-survey-paper/591af8d6c4374378831cab2cdec79575e2540d79/third_party/gcn/gcn/data/ind.cora.ty -------------------------------------------------------------------------------- /third_party/gcn/gcn/data/ind.cora.x: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/gcnn-survey-paper/591af8d6c4374378831cab2cdec79575e2540d79/third_party/gcn/gcn/data/ind.cora.x -------------------------------------------------------------------------------- /third_party/gcn/gcn/data/ind.cora.y: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/gcnn-survey-paper/591af8d6c4374378831cab2cdec79575e2540d79/third_party/gcn/gcn/data/ind.cora.y -------------------------------------------------------------------------------- /third_party/gcn/gcn/data/ind.pubmed.allx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/gcnn-survey-paper/591af8d6c4374378831cab2cdec79575e2540d79/third_party/gcn/gcn/data/ind.pubmed.allx -------------------------------------------------------------------------------- /third_party/gcn/gcn/data/ind.pubmed.ally: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/gcnn-survey-paper/591af8d6c4374378831cab2cdec79575e2540d79/third_party/gcn/gcn/data/ind.pubmed.ally -------------------------------------------------------------------------------- /third_party/gcn/gcn/data/ind.pubmed.graph: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/gcnn-survey-paper/591af8d6c4374378831cab2cdec79575e2540d79/third_party/gcn/gcn/data/ind.pubmed.graph -------------------------------------------------------------------------------- /third_party/gcn/gcn/data/ind.pubmed.test.index: -------------------------------------------------------------------------------- 1 | 18747 2 | 19392 3 | 19181 4 | 18843 5 | 19221 6 | 18962 7 | 19560 8 | 19097 9 | 18966 10 | 19014 11 | 18756 12 | 19313 13 | 19000 14 | 19569 15 | 19359 16 | 18854 17 | 18970 18 | 19073 19 | 19661 20 | 19180 21 | 19377 22 | 18750 23 | 19401 24 | 18788 25 | 19224 26 | 19447 27 | 19017 28 | 19241 29 | 18890 30 | 18908 31 | 18965 32 | 19001 33 | 18849 34 | 19641 35 | 18852 36 | 19222 37 | 19172 38 | 18762 39 | 19156 40 | 19162 41 | 18856 42 | 18763 43 | 19318 44 | 18826 45 | 19712 46 | 19192 47 | 19695 48 | 19030 49 | 19523 50 | 19249 51 | 19079 52 | 19232 53 | 19455 54 | 18743 55 | 18800 56 | 19071 57 | 18885 58 | 19593 59 | 19394 60 | 19390 61 | 18832 62 | 19445 63 | 18838 64 | 19632 65 | 19548 66 | 19546 67 | 18825 68 | 19498 69 | 19266 70 | 19117 71 | 19595 72 | 19252 73 | 18730 74 | 18913 75 | 18809 76 | 19452 77 | 19520 78 | 19274 79 | 19555 80 | 19388 81 | 18919 82 | 19099 83 | 19637 84 | 19403 85 | 18720 86 | 19526 87 | 18905 88 | 19451 89 | 19408 90 | 18923 91 | 18794 92 | 19322 93 | 19431 94 | 18912 95 | 18841 96 | 19239 97 | 19125 98 | 19258 99 | 19565 100 | 18898 101 | 19482 102 | 19029 103 | 18778 104 | 19096 105 | 19684 106 | 19552 107 | 18765 108 | 19361 109 | 19171 110 | 19367 111 | 19623 112 | 19402 113 | 19327 114 | 19118 115 | 18888 116 | 18726 117 | 19510 118 | 18831 119 | 19490 120 | 19576 121 | 19050 122 | 18729 123 | 18896 124 | 19246 125 | 19012 126 | 18862 127 | 18873 128 | 19193 129 | 19693 130 | 19474 131 | 18953 132 | 19115 133 | 19182 134 | 19269 135 | 19116 136 | 18837 137 | 18872 138 | 19007 139 | 19212 140 | 18798 141 | 19102 142 | 18772 143 | 19660 144 | 19511 145 | 18914 146 | 18886 147 | 19672 148 | 19360 149 | 19213 150 | 18810 151 | 19420 152 | 19512 153 | 18719 154 | 19432 155 | 19350 156 | 19127 157 | 18782 158 | 19587 159 | 18924 160 | 19488 161 | 18781 162 | 19340 163 | 19190 164 | 19383 165 | 19094 166 | 18835 167 | 19487 168 | 19230 169 | 18791 170 | 18882 171 | 18937 172 | 18928 173 | 18755 174 | 18802 175 | 19516 176 | 18795 177 | 18786 178 | 19273 179 | 19349 180 | 19398 181 | 19626 182 | 19130 183 | 19351 184 | 19489 185 | 19446 186 | 18959 187 | 19025 188 | 18792 189 | 18878 190 | 19304 191 | 19629 192 | 19061 193 | 18785 194 | 19194 195 | 19179 196 | 19210 197 | 19417 198 | 19583 199 | 19415 200 | 19443 201 | 18739 202 | 19662 203 | 18904 204 | 18910 205 | 18901 206 | 18960 207 | 18722 208 | 18827 209 | 19290 210 | 18842 211 | 19389 212 | 19344 213 | 18961 214 | 19098 215 | 19147 216 | 19334 217 | 19358 218 | 18829 219 | 18984 220 | 18931 221 | 18742 222 | 19320 223 | 19111 224 | 19196 225 | 18887 226 | 18991 227 | 19469 228 | 18990 229 | 18876 230 | 19261 231 | 19270 232 | 19522 233 | 19088 234 | 19284 235 | 19646 236 | 19493 237 | 19225 238 | 19615 239 | 19449 240 | 19043 241 | 19674 242 | 19391 243 | 18918 244 | 19155 245 | 19110 246 | 18815 247 | 19131 248 | 18834 249 | 19715 250 | 19603 251 | 19688 252 | 19133 253 | 19053 254 | 19166 255 | 19066 256 | 18893 257 | 18757 258 | 19582 259 | 19282 260 | 19257 261 | 18869 262 | 19467 263 | 18954 264 | 19371 265 | 19151 266 | 19462 267 | 19598 268 | 19653 269 | 19187 270 | 19624 271 | 19564 272 | 19534 273 | 19581 274 | 19478 275 | 18985 276 | 18746 277 | 19342 278 | 18777 279 | 19696 280 | 18824 281 | 19138 282 | 18728 283 | 19643 284 | 19199 285 | 18731 286 | 19168 287 | 18948 288 | 19216 289 | 19697 290 | 19347 291 | 18808 292 | 18725 293 | 19134 294 | 18847 295 | 18828 296 | 18996 297 | 19106 298 | 19485 299 | 18917 300 | 18911 301 | 18776 302 | 19203 303 | 19158 304 | 18895 305 | 19165 306 | 19382 307 | 18780 308 | 18836 309 | 19373 310 | 19659 311 | 18947 312 | 19375 313 | 19299 314 | 18761 315 | 19366 316 | 18754 317 | 19248 318 | 19416 319 | 19658 320 | 19638 321 | 19034 322 | 19281 323 | 18844 324 | 18922 325 | 19491 326 | 19272 327 | 19341 328 | 19068 329 | 19332 330 | 19559 331 | 19293 332 | 18804 333 | 18933 334 | 18935 335 | 19405 336 | 18936 337 | 18945 338 | 18943 339 | 18818 340 | 18797 341 | 19570 342 | 19464 343 | 19428 344 | 19093 345 | 19433 346 | 18986 347 | 19161 348 | 19255 349 | 19157 350 | 19046 351 | 19292 352 | 19434 353 | 19298 354 | 18724 355 | 19410 356 | 19694 357 | 19214 358 | 19640 359 | 19189 360 | 18963 361 | 19218 362 | 19585 363 | 19041 364 | 19550 365 | 19123 366 | 19620 367 | 19376 368 | 19561 369 | 18944 370 | 19706 371 | 19056 372 | 19283 373 | 18741 374 | 19319 375 | 19144 376 | 19542 377 | 18821 378 | 19404 379 | 19080 380 | 19303 381 | 18793 382 | 19306 383 | 19678 384 | 19435 385 | 19519 386 | 19566 387 | 19278 388 | 18946 389 | 19536 390 | 19020 391 | 19057 392 | 19198 393 | 19333 394 | 19649 395 | 19699 396 | 19399 397 | 19654 398 | 19136 399 | 19465 400 | 19321 401 | 19577 402 | 18907 403 | 19665 404 | 19386 405 | 19596 406 | 19247 407 | 19473 408 | 19568 409 | 19355 410 | 18925 411 | 19586 412 | 18982 413 | 19616 414 | 19495 415 | 19612 416 | 19023 417 | 19438 418 | 18817 419 | 19692 420 | 19295 421 | 19414 422 | 19676 423 | 19472 424 | 19107 425 | 19062 426 | 19035 427 | 18883 428 | 19409 429 | 19052 430 | 19606 431 | 19091 432 | 19651 433 | 19475 434 | 19413 435 | 18796 436 | 19369 437 | 19639 438 | 19701 439 | 19461 440 | 19645 441 | 19251 442 | 19063 443 | 19679 444 | 19545 445 | 19081 446 | 19363 447 | 18995 448 | 19549 449 | 18790 450 | 18855 451 | 18833 452 | 18899 453 | 19395 454 | 18717 455 | 19647 456 | 18768 457 | 19103 458 | 19245 459 | 18819 460 | 18779 461 | 19656 462 | 19076 463 | 18745 464 | 18971 465 | 19197 466 | 19711 467 | 19074 468 | 19128 469 | 19466 470 | 19139 471 | 19309 472 | 19324 473 | 18814 474 | 19092 475 | 19627 476 | 19060 477 | 18806 478 | 18929 479 | 18737 480 | 18942 481 | 18906 482 | 18858 483 | 19456 484 | 19253 485 | 19716 486 | 19104 487 | 19667 488 | 19574 489 | 18903 490 | 19237 491 | 18864 492 | 19556 493 | 19364 494 | 18952 495 | 19008 496 | 19323 497 | 19700 498 | 19170 499 | 19267 500 | 19345 501 | 19238 502 | 18909 503 | 18892 504 | 19109 505 | 19704 506 | 18902 507 | 19275 508 | 19680 509 | 18723 510 | 19242 511 | 19112 512 | 19169 513 | 18956 514 | 19343 515 | 19650 516 | 19541 517 | 19698 518 | 19521 519 | 19087 520 | 18976 521 | 19038 522 | 18775 523 | 18968 524 | 19671 525 | 19412 526 | 19407 527 | 19573 528 | 19027 529 | 18813 530 | 19357 531 | 19460 532 | 19673 533 | 19481 534 | 19036 535 | 19614 536 | 18787 537 | 19195 538 | 18732 539 | 18884 540 | 19613 541 | 19657 542 | 19575 543 | 19226 544 | 19589 545 | 19234 546 | 19617 547 | 19707 548 | 19484 549 | 18740 550 | 19424 551 | 18784 552 | 19419 553 | 19159 554 | 18865 555 | 19105 556 | 19315 557 | 19480 558 | 19664 559 | 19378 560 | 18803 561 | 19605 562 | 18870 563 | 19042 564 | 19426 565 | 18848 566 | 19223 567 | 19509 568 | 19532 569 | 18752 570 | 19691 571 | 18718 572 | 19209 573 | 19362 574 | 19090 575 | 19492 576 | 19567 577 | 19687 578 | 19018 579 | 18830 580 | 19530 581 | 19554 582 | 19119 583 | 19442 584 | 19558 585 | 19527 586 | 19427 587 | 19291 588 | 19543 589 | 19422 590 | 19142 591 | 18897 592 | 18950 593 | 19425 594 | 19002 595 | 19588 596 | 18978 597 | 19551 598 | 18930 599 | 18736 600 | 19101 601 | 19215 602 | 19150 603 | 19263 604 | 18949 605 | 18974 606 | 18759 607 | 19335 608 | 19200 609 | 19129 610 | 19328 611 | 19437 612 | 18988 613 | 19429 614 | 19368 615 | 19406 616 | 19049 617 | 18811 618 | 19296 619 | 19256 620 | 19385 621 | 19602 622 | 18770 623 | 19337 624 | 19580 625 | 19476 626 | 19045 627 | 19132 628 | 19089 629 | 19120 630 | 19265 631 | 19483 632 | 18767 633 | 19227 634 | 18934 635 | 19069 636 | 18820 637 | 19006 638 | 19459 639 | 18927 640 | 19037 641 | 19280 642 | 19441 643 | 18823 644 | 19015 645 | 19114 646 | 19618 647 | 18957 648 | 19176 649 | 18853 650 | 19648 651 | 19201 652 | 19444 653 | 19279 654 | 18751 655 | 19302 656 | 19505 657 | 18733 658 | 19601 659 | 19533 660 | 18863 661 | 19708 662 | 19387 663 | 19346 664 | 19152 665 | 19206 666 | 18851 667 | 19338 668 | 19681 669 | 19380 670 | 19055 671 | 18766 672 | 19085 673 | 19591 674 | 19547 675 | 18958 676 | 19146 677 | 18840 678 | 19051 679 | 19021 680 | 19207 681 | 19235 682 | 19086 683 | 18979 684 | 19300 685 | 18939 686 | 19100 687 | 19619 688 | 19287 689 | 18980 690 | 19277 691 | 19326 692 | 19108 693 | 18920 694 | 19625 695 | 19374 696 | 19078 697 | 18734 698 | 19634 699 | 19339 700 | 18877 701 | 19423 702 | 19652 703 | 19683 704 | 19044 705 | 18983 706 | 19330 707 | 19529 708 | 19714 709 | 19468 710 | 19075 711 | 19540 712 | 18839 713 | 19022 714 | 19286 715 | 19537 716 | 19175 717 | 19463 718 | 19167 719 | 19705 720 | 19562 721 | 19244 722 | 19486 723 | 19611 724 | 18801 725 | 19178 726 | 19590 727 | 18846 728 | 19450 729 | 19205 730 | 19381 731 | 18941 732 | 19670 733 | 19185 734 | 19504 735 | 19633 736 | 18997 737 | 19113 738 | 19397 739 | 19636 740 | 19709 741 | 19289 742 | 19264 743 | 19353 744 | 19584 745 | 19126 746 | 18938 747 | 19669 748 | 18964 749 | 19276 750 | 18774 751 | 19173 752 | 19231 753 | 18973 754 | 18769 755 | 19064 756 | 19040 757 | 19668 758 | 18738 759 | 19082 760 | 19655 761 | 19236 762 | 19352 763 | 19609 764 | 19628 765 | 18951 766 | 19384 767 | 19122 768 | 18875 769 | 18992 770 | 18753 771 | 19379 772 | 19254 773 | 19301 774 | 19506 775 | 19135 776 | 19010 777 | 19682 778 | 19400 779 | 19579 780 | 19316 781 | 19553 782 | 19208 783 | 19635 784 | 19644 785 | 18891 786 | 19024 787 | 18989 788 | 19250 789 | 18850 790 | 19317 791 | 18915 792 | 19607 793 | 18799 794 | 18881 795 | 19479 796 | 19031 797 | 19365 798 | 19164 799 | 18744 800 | 18760 801 | 19502 802 | 19058 803 | 19517 804 | 18735 805 | 19448 806 | 19243 807 | 19453 808 | 19285 809 | 18857 810 | 19439 811 | 19016 812 | 18975 813 | 19503 814 | 18998 815 | 18981 816 | 19186 817 | 18994 818 | 19240 819 | 19631 820 | 19070 821 | 19174 822 | 18900 823 | 19065 824 | 19220 825 | 19229 826 | 18880 827 | 19308 828 | 19372 829 | 19496 830 | 18771 831 | 19325 832 | 19538 833 | 19033 834 | 18874 835 | 19077 836 | 19211 837 | 18764 838 | 19458 839 | 19571 840 | 19121 841 | 19019 842 | 19059 843 | 19497 844 | 18969 845 | 19666 846 | 19297 847 | 19219 848 | 19622 849 | 19184 850 | 18977 851 | 19702 852 | 19539 853 | 19329 854 | 19095 855 | 19675 856 | 18972 857 | 19514 858 | 19703 859 | 19188 860 | 18866 861 | 18812 862 | 19314 863 | 18822 864 | 18845 865 | 19494 866 | 19411 867 | 18916 868 | 19686 869 | 18967 870 | 19294 871 | 19143 872 | 19204 873 | 18805 874 | 19689 875 | 19233 876 | 18758 877 | 18748 878 | 19011 879 | 19685 880 | 19336 881 | 19608 882 | 19454 883 | 19124 884 | 18868 885 | 18807 886 | 19544 887 | 19621 888 | 19228 889 | 19154 890 | 19141 891 | 19145 892 | 19153 893 | 18860 894 | 19163 895 | 19393 896 | 19268 897 | 19160 898 | 19305 899 | 19259 900 | 19471 901 | 19524 902 | 18783 903 | 19396 904 | 18894 905 | 19430 906 | 19690 907 | 19348 908 | 19597 909 | 19592 910 | 19677 911 | 18889 912 | 19331 913 | 18773 914 | 19137 915 | 19009 916 | 18932 917 | 19599 918 | 18816 919 | 19054 920 | 19067 921 | 19477 922 | 19191 923 | 18921 924 | 18940 925 | 19578 926 | 19183 927 | 19004 928 | 19072 929 | 19710 930 | 19005 931 | 19610 932 | 18955 933 | 19457 934 | 19148 935 | 18859 936 | 18993 937 | 19642 938 | 19047 939 | 19418 940 | 19535 941 | 19600 942 | 19312 943 | 19039 944 | 19028 945 | 18879 946 | 19003 947 | 19026 948 | 19013 949 | 19149 950 | 19177 951 | 19217 952 | 18987 953 | 19354 954 | 19525 955 | 19202 956 | 19084 957 | 19032 958 | 18749 959 | 18867 960 | 19048 961 | 18999 962 | 19260 963 | 19630 964 | 18727 965 | 19356 966 | 19083 967 | 18926 968 | 18789 969 | 19370 970 | 18861 971 | 19311 972 | 19557 973 | 19531 974 | 19436 975 | 19140 976 | 19310 977 | 19501 978 | 18721 979 | 19604 980 | 19713 981 | 19262 982 | 19563 983 | 19507 984 | 19440 985 | 19572 986 | 19513 987 | 19515 988 | 19518 989 | 19421 990 | 19470 991 | 19499 992 | 19663 993 | 19508 994 | 18871 995 | 19528 996 | 19500 997 | 19307 998 | 19288 999 | 19594 1000 | 19271 1001 | -------------------------------------------------------------------------------- /third_party/gcn/gcn/data/ind.pubmed.tx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/gcnn-survey-paper/591af8d6c4374378831cab2cdec79575e2540d79/third_party/gcn/gcn/data/ind.pubmed.tx -------------------------------------------------------------------------------- /third_party/gcn/gcn/data/ind.pubmed.ty: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/gcnn-survey-paper/591af8d6c4374378831cab2cdec79575e2540d79/third_party/gcn/gcn/data/ind.pubmed.ty -------------------------------------------------------------------------------- /third_party/gcn/gcn/data/ind.pubmed.x: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/gcnn-survey-paper/591af8d6c4374378831cab2cdec79575e2540d79/third_party/gcn/gcn/data/ind.pubmed.x -------------------------------------------------------------------------------- /third_party/gcn/gcn/data/ind.pubmed.y: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/gcnn-survey-paper/591af8d6c4374378831cab2cdec79575e2540d79/third_party/gcn/gcn/data/ind.pubmed.y -------------------------------------------------------------------------------- /third_party/gcn/gcn/inits.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | 4 | 5 | def uniform(shape, scale=0.05, name=None): 6 | """Uniform init.""" 7 | initial = tf.random_uniform(shape, minval=-scale, maxval=scale, dtype=tf.float32) 8 | return tf.Variable(initial, name=name) 9 | 10 | 11 | def glorot(shape, name=None): 12 | """Glorot & Bengio (AISTATS 2010) init.""" 13 | init_range = np.sqrt(6.0/(shape[0]+shape[1])) 14 | initial = tf.random_uniform(shape, minval=-init_range, maxval=init_range, dtype=tf.float32) 15 | return tf.Variable(initial, name=name) 16 | 17 | 18 | def zeros(shape, name=None): 19 | """All zeros.""" 20 | initial = tf.zeros(shape, dtype=tf.float32) 21 | return tf.Variable(initial, name=name) 22 | 23 | 24 | def ones(shape, name=None): 25 | """All ones.""" 26 | initial = tf.ones(shape, dtype=tf.float32) 27 | return tf.Variable(initial, name=name) -------------------------------------------------------------------------------- /third_party/gcn/gcn/layers.py: -------------------------------------------------------------------------------- 1 | from gcn.inits import * 2 | import tensorflow as tf 3 | 4 | flags = tf.app.flags 5 | FLAGS = flags.FLAGS 6 | 7 | # global unique layer ID dictionary for layer name assignment 8 | _LAYER_UIDS = {} 9 | 10 | 11 | def get_layer_uid(layer_name=''): 12 | """Helper function, assigns unique layer IDs.""" 13 | if layer_name not in _LAYER_UIDS: 14 | _LAYER_UIDS[layer_name] = 1 15 | return 1 16 | else: 17 | _LAYER_UIDS[layer_name] += 1 18 | return _LAYER_UIDS[layer_name] 19 | 20 | 21 | def sparse_dropout(x, keep_prob, noise_shape): 22 | """Dropout for sparse tensors.""" 23 | random_tensor = keep_prob 24 | random_tensor += tf.random_uniform(noise_shape) 25 | dropout_mask = tf.cast(tf.floor(random_tensor), dtype=tf.bool) 26 | pre_out = tf.sparse_retain(x, dropout_mask) 27 | return pre_out * (1./keep_prob) 28 | 29 | 30 | def dot(x, y, sparse=False): 31 | """Wrapper for tf.matmul (sparse vs dense).""" 32 | if sparse: 33 | res = tf.sparse_tensor_dense_matmul(x, y) 34 | else: 35 | res = tf.matmul(x, y) 36 | return res 37 | 38 | 39 | class Layer(object): 40 | """Base layer class. Defines basic API for all layer objects. 41 | Implementation inspired by keras (http://keras.io). 42 | 43 | # Properties 44 | name: String, defines the variable scope of the layer. 45 | logging: Boolean, switches Tensorflow histogram logging on/off 46 | 47 | # Methods 48 | _call(inputs): Defines computation graph of layer 49 | (i.e. takes input, returns output) 50 | __call__(inputs): Wrapper for _call() 51 | _log_vars(): Log all variables 52 | """ 53 | 54 | def __init__(self, **kwargs): 55 | allowed_kwargs = {'name', 'logging'} 56 | for kwarg in kwargs.keys(): 57 | assert kwarg in allowed_kwargs, 'Invalid keyword argument: ' + kwarg 58 | name = kwargs.get('name') 59 | if not name: 60 | layer = self.__class__.__name__.lower() 61 | name = layer + '_' + str(get_layer_uid(layer)) 62 | self.name = name 63 | self.vars = {} 64 | logging = kwargs.get('logging', False) 65 | self.logging = logging 66 | self.sparse_inputs = False 67 | 68 | def _call(self, inputs): 69 | return inputs 70 | 71 | def __call__(self, inputs): 72 | with tf.name_scope(self.name): 73 | if self.logging and not self.sparse_inputs: 74 | tf.summary.histogram(self.name + '/inputs', inputs) 75 | outputs = self._call(inputs) 76 | if self.logging: 77 | tf.summary.histogram(self.name + '/outputs', outputs) 78 | return outputs 79 | 80 | def _log_vars(self): 81 | for var in self.vars: 82 | tf.summary.histogram(self.name + '/vars/' + var, self.vars[var]) 83 | 84 | 85 | class Dense(Layer): 86 | """Dense layer.""" 87 | def __init__(self, input_dim, output_dim, placeholders, dropout=0., sparse_inputs=False, 88 | act=tf.nn.relu, bias=False, featureless=False, **kwargs): 89 | super(Dense, self).__init__(**kwargs) 90 | 91 | if dropout: 92 | self.dropout = placeholders['dropout'] 93 | else: 94 | self.dropout = 0. 95 | 96 | self.act = act 97 | self.sparse_inputs = sparse_inputs 98 | self.featureless = featureless 99 | self.bias = bias 100 | 101 | # helper variable for sparse dropout 102 | self.num_features_nonzero = placeholders['num_features_nonzero'] 103 | 104 | with tf.variable_scope(self.name + '_vars'): 105 | self.vars['weights'] = glorot([input_dim, output_dim], 106 | name='weights') 107 | if self.bias: 108 | self.vars['bias'] = zeros([output_dim], name='bias') 109 | 110 | if self.logging: 111 | self._log_vars() 112 | 113 | def _call(self, inputs): 114 | x = inputs 115 | 116 | # dropout 117 | if self.sparse_inputs: 118 | x = sparse_dropout(x, 1-self.dropout, self.num_features_nonzero) 119 | else: 120 | x = tf.nn.dropout(x, 1-self.dropout) 121 | 122 | # transform 123 | output = dot(x, self.vars['weights'], sparse=self.sparse_inputs) 124 | 125 | # bias 126 | if self.bias: 127 | output += self.vars['bias'] 128 | 129 | return self.act(output) 130 | 131 | 132 | class GraphConvolution(Layer): 133 | """Graph convolution layer.""" 134 | def __init__(self, input_dim, output_dim, placeholders, dropout=0., 135 | sparse_inputs=False, act=tf.nn.relu, bias=False, 136 | featureless=False, **kwargs): 137 | super(GraphConvolution, self).__init__(**kwargs) 138 | 139 | if dropout: 140 | self.dropout = placeholders['dropout'] 141 | else: 142 | self.dropout = 0. 143 | 144 | self.act = act 145 | self.support = placeholders['support'] 146 | self.sparse_inputs = sparse_inputs 147 | self.featureless = featureless 148 | self.bias = bias 149 | 150 | # helper variable for sparse dropout 151 | self.num_features_nonzero = placeholders['num_features_nonzero'] 152 | 153 | with tf.variable_scope(self.name + '_vars'): 154 | for i in range(len(self.support)): 155 | self.vars['weights_' + str(i)] = glorot([input_dim, output_dim], 156 | name='weights_' + str(i)) 157 | if self.bias: 158 | self.vars['bias'] = zeros([output_dim], name='bias') 159 | 160 | if self.logging: 161 | self._log_vars() 162 | 163 | def _call(self, inputs): 164 | x = inputs 165 | 166 | # dropout 167 | if self.sparse_inputs: 168 | x = sparse_dropout(x, 1-self.dropout, self.num_features_nonzero) 169 | else: 170 | x = tf.nn.dropout(x, 1-self.dropout) 171 | 172 | # convolve 173 | supports = list() 174 | for i in range(len(self.support)): 175 | if not self.featureless: 176 | pre_sup = dot(x, self.vars['weights_' + str(i)], 177 | sparse=self.sparse_inputs) 178 | else: 179 | pre_sup = self.vars['weights_' + str(i)] 180 | support = dot(self.support[i], pre_sup, sparse=True) 181 | supports.append(support) 182 | output = tf.add_n(supports) 183 | 184 | # bias 185 | if self.bias: 186 | output += self.vars['bias'] 187 | 188 | return self.act(output) 189 | -------------------------------------------------------------------------------- /third_party/gcn/gcn/metrics.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | def masked_softmax_cross_entropy(preds, labels, mask): 5 | """Softmax cross-entropy loss with masking.""" 6 | loss = tf.nn.softmax_cross_entropy_with_logits(logits=preds, labels=labels) 7 | mask = tf.cast(mask, dtype=tf.float32) 8 | mask /= tf.reduce_mean(mask) 9 | loss *= mask 10 | return tf.reduce_mean(loss) 11 | 12 | 13 | def masked_accuracy(preds, labels, mask): 14 | """Accuracy with masking.""" 15 | correct_prediction = tf.equal(tf.argmax(preds, 1), tf.argmax(labels, 1)) 16 | accuracy_all = tf.cast(correct_prediction, tf.float32) 17 | mask = tf.cast(mask, dtype=tf.float32) 18 | mask /= tf.reduce_mean(mask) 19 | accuracy_all *= mask 20 | return tf.reduce_mean(accuracy_all) 21 | -------------------------------------------------------------------------------- /third_party/gcn/gcn/models.py: -------------------------------------------------------------------------------- 1 | from gcn.layers import * 2 | from gcn.metrics import * 3 | 4 | flags = tf.app.flags 5 | FLAGS = flags.FLAGS 6 | 7 | 8 | class Model(object): 9 | def __init__(self, **kwargs): 10 | allowed_kwargs = {'name', 'logging'} 11 | for kwarg in kwargs.keys(): 12 | assert kwarg in allowed_kwargs, 'Invalid keyword argument: ' + kwarg 13 | name = kwargs.get('name') 14 | if not name: 15 | name = self.__class__.__name__.lower() 16 | self.name = name 17 | 18 | logging = kwargs.get('logging', False) 19 | self.logging = logging 20 | 21 | self.vars = {} 22 | self.placeholders = {} 23 | 24 | self.layers = [] 25 | self.activations = [] 26 | 27 | self.inputs = None 28 | self.outputs = None 29 | 30 | self.loss = 0 31 | self.accuracy = 0 32 | self.optimizer = None 33 | self.opt_op = None 34 | 35 | def _build(self): 36 | raise NotImplementedError 37 | 38 | def build(self): 39 | """ Wrapper for _build() """ 40 | with tf.variable_scope(self.name): 41 | self._build() 42 | 43 | # Build sequential layer model 44 | self.activations.append(self.inputs) 45 | for layer in self.layers: 46 | hidden = layer(self.activations[-1]) 47 | self.activations.append(hidden) 48 | self.outputs = self.activations[-1] 49 | 50 | # Store model variables for easy access 51 | variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=self.name) 52 | self.vars = {var.name: var for var in variables} 53 | 54 | # Build metrics 55 | self._loss() 56 | self._accuracy() 57 | 58 | self.opt_op = self.optimizer.minimize(self.loss) 59 | 60 | def predict(self): 61 | pass 62 | 63 | def _loss(self): 64 | raise NotImplementedError 65 | 66 | def _accuracy(self): 67 | raise NotImplementedError 68 | 69 | def save(self, sess=None): 70 | if not sess: 71 | raise AttributeError("TensorFlow session not provided.") 72 | saver = tf.train.Saver(self.vars) 73 | save_path = saver.save(sess, "tmp/%s.ckpt" % self.name) 74 | print("Model saved in file: %s" % save_path) 75 | 76 | def load(self, sess=None): 77 | if not sess: 78 | raise AttributeError("TensorFlow session not provided.") 79 | saver = tf.train.Saver(self.vars) 80 | save_path = "tmp/%s.ckpt" % self.name 81 | saver.restore(sess, save_path) 82 | print("Model restored from file: %s" % save_path) 83 | 84 | 85 | class MLP(Model): 86 | def __init__(self, placeholders, input_dim, **kwargs): 87 | super(MLP, self).__init__(**kwargs) 88 | 89 | self.inputs = placeholders['features'] 90 | self.input_dim = input_dim 91 | # self.input_dim = self.inputs.get_shape().as_list()[1] # To be supported in future Tensorflow versions 92 | self.output_dim = placeholders['labels'].get_shape().as_list()[1] 93 | self.placeholders = placeholders 94 | 95 | self.optimizer = tf.train.AdamOptimizer(learning_rate=FLAGS.learning_rate) 96 | 97 | self.build() 98 | 99 | def _loss(self): 100 | # Weight decay loss 101 | for var in self.layers[0].vars.values(): 102 | self.loss += FLAGS.weight_decay * tf.nn.l2_loss(var) 103 | 104 | # Cross entropy error 105 | self.loss += masked_softmax_cross_entropy(self.outputs, self.placeholders['labels'], 106 | self.placeholders['labels_mask']) 107 | 108 | def _accuracy(self): 109 | self.accuracy = masked_accuracy(self.outputs, self.placeholders['labels'], 110 | self.placeholders['labels_mask']) 111 | 112 | def _build(self): 113 | self.layers.append(Dense(input_dim=self.input_dim, 114 | output_dim=FLAGS.hidden1, 115 | placeholders=self.placeholders, 116 | act=tf.nn.relu, 117 | dropout=True, 118 | sparse_inputs=True, 119 | logging=self.logging)) 120 | 121 | self.layers.append(Dense(input_dim=FLAGS.hidden1, 122 | output_dim=self.output_dim, 123 | placeholders=self.placeholders, 124 | act=lambda x: x, 125 | dropout=True, 126 | logging=self.logging)) 127 | 128 | def predict(self): 129 | return tf.nn.softmax(self.outputs) 130 | 131 | 132 | class GCN(Model): 133 | def __init__(self, placeholders, input_dim, **kwargs): 134 | super(GCN, self).__init__(**kwargs) 135 | 136 | self.inputs = placeholders['features'] 137 | self.input_dim = input_dim 138 | # self.input_dim = self.inputs.get_shape().as_list()[1] # To be supported in future Tensorflow versions 139 | self.output_dim = placeholders['labels'].get_shape().as_list()[1] 140 | self.placeholders = placeholders 141 | 142 | self.optimizer = tf.train.AdamOptimizer(learning_rate=FLAGS.learning_rate) 143 | 144 | self.build() 145 | 146 | def _loss(self): 147 | # Weight decay loss 148 | for var in self.layers[0].vars.values(): 149 | self.loss += FLAGS.weight_decay * tf.nn.l2_loss(var) 150 | 151 | # Cross entropy error 152 | self.loss += masked_softmax_cross_entropy(self.outputs, self.placeholders['labels'], 153 | self.placeholders['labels_mask']) 154 | 155 | def _accuracy(self): 156 | self.accuracy = masked_accuracy(self.outputs, self.placeholders['labels'], 157 | self.placeholders['labels_mask']) 158 | 159 | def _build(self): 160 | 161 | self.layers.append(GraphConvolution(input_dim=self.input_dim, 162 | output_dim=FLAGS.hidden1, 163 | placeholders=self.placeholders, 164 | act=tf.nn.relu, 165 | dropout=True, 166 | sparse_inputs=True, 167 | logging=self.logging)) 168 | 169 | self.layers.append(GraphConvolution(input_dim=FLAGS.hidden1, 170 | output_dim=self.output_dim, 171 | placeholders=self.placeholders, 172 | act=lambda x: x, 173 | dropout=True, 174 | logging=self.logging)) 175 | 176 | def predict(self): 177 | return tf.nn.softmax(self.outputs) 178 | -------------------------------------------------------------------------------- /third_party/gcn/gcn/train.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from __future__ import print_function 3 | 4 | import time 5 | import tensorflow as tf 6 | 7 | from gcn.utils import * 8 | from gcn.models import GCN, MLP 9 | 10 | # Set random seed 11 | seed = 123 12 | np.random.seed(seed) 13 | tf.set_random_seed(seed) 14 | 15 | # Settings 16 | flags = tf.app.flags 17 | FLAGS = flags.FLAGS 18 | flags.DEFINE_string('dataset', 'cora', 'Dataset string.') # 'cora', 'citeseer', 'pubmed' 19 | flags.DEFINE_string('model', 'gcn', 'Model string.') # 'gcn', 'gcn_cheby', 'dense' 20 | flags.DEFINE_float('learning_rate', 0.01, 'Initial learning rate.') 21 | flags.DEFINE_integer('epochs', 200, 'Number of epochs to train.') 22 | flags.DEFINE_integer('hidden1', 16, 'Number of units in hidden layer 1.') 23 | flags.DEFINE_float('dropout', 0.5, 'Dropout rate (1 - keep probability).') 24 | flags.DEFINE_float('weight_decay', 5e-4, 'Weight for L2 loss on embedding matrix.') 25 | flags.DEFINE_integer('early_stopping', 10, 'Tolerance for early stopping (# of epochs).') 26 | flags.DEFINE_integer('max_degree', 3, 'Maximum Chebyshev polynomial degree.') 27 | 28 | # Load data 29 | adj, features, y_train, y_val, y_test, train_mask, val_mask, test_mask = load_data(FLAGS.dataset) 30 | 31 | # Some preprocessing 32 | features = preprocess_features(features) 33 | if FLAGS.model == 'gcn': 34 | support = [preprocess_adj(adj)] 35 | num_supports = 1 36 | model_func = GCN 37 | elif FLAGS.model == 'gcn_cheby': 38 | support = chebyshev_polynomials(adj, FLAGS.max_degree) 39 | num_supports = 1 + FLAGS.max_degree 40 | model_func = GCN 41 | elif FLAGS.model == 'dense': 42 | support = [preprocess_adj(adj)] # Not used 43 | num_supports = 1 44 | model_func = MLP 45 | else: 46 | raise ValueError('Invalid argument for model: ' + str(FLAGS.model)) 47 | 48 | # Define placeholders 49 | placeholders = { 50 | 'support': [tf.sparse_placeholder(tf.float32) for _ in range(num_supports)], 51 | 'features': tf.sparse_placeholder(tf.float32, shape=tf.constant(features[2], dtype=tf.int64)), 52 | 'labels': tf.placeholder(tf.float32, shape=(None, y_train.shape[1])), 53 | 'labels_mask': tf.placeholder(tf.int32), 54 | 'dropout': tf.placeholder_with_default(0., shape=()), 55 | 'num_features_nonzero': tf.placeholder(tf.int32) # helper variable for sparse dropout 56 | } 57 | 58 | # Create model 59 | model = model_func(placeholders, input_dim=features[2][1], logging=True) 60 | 61 | # Initialize session 62 | sess = tf.Session() 63 | 64 | 65 | # Define model evaluation function 66 | def evaluate(features, support, labels, mask, placeholders): 67 | t_test = time.time() 68 | feed_dict_val = construct_feed_dict(features, support, labels, mask, placeholders) 69 | outs_val = sess.run([model.loss, model.accuracy], feed_dict=feed_dict_val) 70 | return outs_val[0], outs_val[1], (time.time() - t_test) 71 | 72 | 73 | # Init variables 74 | sess.run(tf.global_variables_initializer()) 75 | 76 | cost_val = [] 77 | 78 | # Train model 79 | for epoch in range(FLAGS.epochs): 80 | 81 | t = time.time() 82 | # Construct feed dictionary 83 | feed_dict = construct_feed_dict(features, support, y_train, train_mask, placeholders) 84 | feed_dict.update({placeholders['dropout']: FLAGS.dropout}) 85 | 86 | # Training step 87 | outs = sess.run([model.opt_op, model.loss, model.accuracy], feed_dict=feed_dict) 88 | 89 | # Validation 90 | cost, acc, duration = evaluate(features, support, y_val, val_mask, placeholders) 91 | cost_val.append(cost) 92 | 93 | # Print results 94 | print("Epoch:", '%04d' % (epoch + 1), "train_loss=", "{:.5f}".format(outs[1]), 95 | "train_acc=", "{:.5f}".format(outs[2]), "val_loss=", "{:.5f}".format(cost), 96 | "val_acc=", "{:.5f}".format(acc), "time=", "{:.5f}".format(time.time() - t)) 97 | 98 | if epoch > FLAGS.early_stopping and cost_val[-1] > np.mean(cost_val[-(FLAGS.early_stopping+1):-1]): 99 | print("Early stopping...") 100 | break 101 | 102 | print("Optimization Finished!") 103 | 104 | # Testing 105 | test_cost, test_acc, test_duration = evaluate(features, support, y_test, test_mask, placeholders) 106 | print("Test set results:", "cost=", "{:.5f}".format(test_cost), 107 | "accuracy=", "{:.5f}".format(test_acc), "time=", "{:.5f}".format(test_duration)) 108 | -------------------------------------------------------------------------------- /third_party/gcn/gcn/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pickle as pkl 3 | import networkx as nx 4 | import scipy.sparse as sp 5 | from scipy.sparse.linalg.eigen.arpack import eigsh 6 | import sys 7 | 8 | 9 | def parse_index_file(filename): 10 | """Parse index file.""" 11 | index = [] 12 | for line in open(filename): 13 | index.append(int(line.strip())) 14 | return index 15 | 16 | 17 | def sample_mask(idx, l): 18 | """Create mask.""" 19 | mask = np.zeros(l) 20 | mask[idx] = 1 21 | return np.array(mask, dtype=np.bool) 22 | 23 | 24 | def load_data(dataset_str): 25 | """ 26 | Loads input data from gcn/data directory 27 | 28 | ind.dataset_str.x => the feature vectors of the training instances as scipy.sparse.csr.csr_matrix object; 29 | ind.dataset_str.tx => the feature vectors of the test instances as scipy.sparse.csr.csr_matrix object; 30 | ind.dataset_str.allx => the feature vectors of both labeled and unlabeled training instances 31 | (a superset of ind.dataset_str.x) as scipy.sparse.csr.csr_matrix object; 32 | ind.dataset_str.y => the one-hot labels of the labeled training instances as numpy.ndarray object; 33 | ind.dataset_str.ty => the one-hot labels of the test instances as numpy.ndarray object; 34 | ind.dataset_str.ally => the labels for instances in ind.dataset_str.allx as numpy.ndarray object; 35 | ind.dataset_str.graph => a dict in the format {index: [index_of_neighbor_nodes]} as collections.defaultdict 36 | object; 37 | ind.dataset_str.test.index => the indices of test instances in graph, for the inductive setting as list object. 38 | 39 | All objects above must be saved using python pickle module. 40 | 41 | :param dataset_str: Dataset name 42 | :return: All data input files loaded (as well the training/test data). 43 | """ 44 | names = ['x', 'y', 'tx', 'ty', 'allx', 'ally', 'graph'] 45 | objects = [] 46 | for i in range(len(names)): 47 | with open("data/ind.{}.{}".format(dataset_str, names[i]), 'rb') as f: 48 | if sys.version_info > (3, 0): 49 | objects.append(pkl.load(f, encoding='latin1')) 50 | else: 51 | objects.append(pkl.load(f)) 52 | 53 | x, y, tx, ty, allx, ally, graph = tuple(objects) 54 | test_idx_reorder = parse_index_file("data/ind.{}.test.index".format(dataset_str)) 55 | test_idx_range = np.sort(test_idx_reorder) 56 | 57 | if dataset_str == 'citeseer': 58 | # Fix citeseer dataset (there are some isolated nodes in the graph) 59 | # Find isolated nodes, add them as zero-vecs into the right position 60 | test_idx_range_full = range(min(test_idx_reorder), max(test_idx_reorder)+1) 61 | tx_extended = sp.lil_matrix((len(test_idx_range_full), x.shape[1])) 62 | tx_extended[test_idx_range-min(test_idx_range), :] = tx 63 | tx = tx_extended 64 | ty_extended = np.zeros((len(test_idx_range_full), y.shape[1])) 65 | ty_extended[test_idx_range-min(test_idx_range), :] = ty 66 | ty = ty_extended 67 | 68 | features = sp.vstack((allx, tx)).tolil() 69 | features[test_idx_reorder, :] = features[test_idx_range, :] 70 | adj = nx.adjacency_matrix(nx.from_dict_of_lists(graph)) 71 | 72 | labels = np.vstack((ally, ty)) 73 | labels[test_idx_reorder, :] = labels[test_idx_range, :] 74 | 75 | idx_test = test_idx_range.tolist() 76 | idx_train = range(len(y)) 77 | idx_val = range(len(y), len(y)+500) 78 | 79 | train_mask = sample_mask(idx_train, labels.shape[0]) 80 | val_mask = sample_mask(idx_val, labels.shape[0]) 81 | test_mask = sample_mask(idx_test, labels.shape[0]) 82 | 83 | y_train = np.zeros(labels.shape) 84 | y_val = np.zeros(labels.shape) 85 | y_test = np.zeros(labels.shape) 86 | y_train[train_mask, :] = labels[train_mask, :] 87 | y_val[val_mask, :] = labels[val_mask, :] 88 | y_test[test_mask, :] = labels[test_mask, :] 89 | 90 | return adj, features, y_train, y_val, y_test, train_mask, val_mask, test_mask 91 | 92 | 93 | def sparse_to_tuple(sparse_mx): 94 | """Convert sparse matrix to tuple representation.""" 95 | def to_tuple(mx): 96 | if not sp.isspmatrix_coo(mx): 97 | mx = mx.tocoo() 98 | coords = np.vstack((mx.row, mx.col)).transpose() 99 | values = mx.data 100 | shape = mx.shape 101 | return coords, values, shape 102 | 103 | if isinstance(sparse_mx, list): 104 | for i in range(len(sparse_mx)): 105 | sparse_mx[i] = to_tuple(sparse_mx[i]) 106 | else: 107 | sparse_mx = to_tuple(sparse_mx) 108 | 109 | return sparse_mx 110 | 111 | 112 | def preprocess_features(features): 113 | """Row-normalize feature matrix and convert to tuple representation""" 114 | rowsum = np.array(features.sum(1)) 115 | r_inv = np.power(rowsum, -1).flatten() 116 | r_inv[np.isinf(r_inv)] = 0. 117 | r_mat_inv = sp.diags(r_inv) 118 | features = r_mat_inv.dot(features) 119 | return sparse_to_tuple(features) 120 | 121 | 122 | def normalize_adj(adj): 123 | """Symmetrically normalize adjacency matrix.""" 124 | adj = sp.coo_matrix(adj) 125 | rowsum = np.array(adj.sum(1)) 126 | d_inv_sqrt = np.power(rowsum, -0.5).flatten() 127 | d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0. 128 | d_mat_inv_sqrt = sp.diags(d_inv_sqrt) 129 | return adj.dot(d_mat_inv_sqrt).transpose().dot(d_mat_inv_sqrt).tocoo() 130 | 131 | 132 | def preprocess_adj(adj): 133 | """Preprocessing of adjacency matrix for simple GCN model and conversion to tuple representation.""" 134 | adj_normalized = normalize_adj(adj + sp.eye(adj.shape[0])) 135 | return sparse_to_tuple(adj_normalized) 136 | 137 | 138 | def construct_feed_dict(features, support, labels, labels_mask, placeholders): 139 | """Construct feed dictionary.""" 140 | feed_dict = dict() 141 | feed_dict.update({placeholders['labels']: labels}) 142 | feed_dict.update({placeholders['labels_mask']: labels_mask}) 143 | feed_dict.update({placeholders['features']: features}) 144 | feed_dict.update({placeholders['support'][i]: support[i] for i in range(len(support))}) 145 | feed_dict.update({placeholders['num_features_nonzero']: features[1].shape}) 146 | return feed_dict 147 | 148 | 149 | def chebyshev_polynomials(adj, k): 150 | """Calculate Chebyshev polynomials up to order k. Return a list of sparse matrices (tuple representation).""" 151 | print("Calculating Chebyshev polynomials up to order {}...".format(k)) 152 | 153 | adj_normalized = normalize_adj(adj) 154 | laplacian = sp.eye(adj.shape[0]) - adj_normalized 155 | largest_eigval, _ = eigsh(laplacian, 1, which='LM') 156 | scaled_laplacian = (2. / largest_eigval[0]) * laplacian - sp.eye(adj.shape[0]) 157 | 158 | t_k = list() 159 | t_k.append(sp.eye(adj.shape[0])) 160 | t_k.append(scaled_laplacian) 161 | 162 | def chebyshev_recurrence(t_k_minus_one, t_k_minus_two, scaled_lap): 163 | s_lap = sp.csr_matrix(scaled_lap, copy=True) 164 | return 2 * s_lap.dot(t_k_minus_one) - t_k_minus_two 165 | 166 | for i in range(2, k+1): 167 | t_k.append(chebyshev_recurrence(t_k[-1], t_k[-2], scaled_laplacian)) 168 | 169 | return sparse_to_tuple(t_k) 170 | -------------------------------------------------------------------------------- /third_party/gcn/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | from setuptools import find_packages 3 | 4 | setup(name='gcn', 5 | version='1.0', 6 | description='Graph Convolutional Networks in Tensorflow', 7 | author='Thomas Kipf', 8 | author_email='thomas.kipf@gmail.com', 9 | url='https://tkipf.github.io', 10 | download_url='https://github.com/tkipf/gcn', 11 | license='MIT', 12 | install_requires=['numpy', 13 | 'tensorflow', 14 | 'networkx', 15 | 'scipy' 16 | ], 17 | package_data={'gcn': ['README.md']}, 18 | packages=find_packages()) -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | #Copyright 2018 Google LLC 2 | # 3 | #Licensed under the Apache License, Version 2.0 (the "License"); 4 | #you may not use this file except in compliance with the License. 5 | #You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | #Unless required by applicable law or agreed to in writing, software 10 | #distributed under the License is distributed on an "AS IS" BASIS, 11 | #WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | #See the License for the specific language governing permissions and 13 | #limitations under the License. 14 | 15 | 16 | """Training script for GNN models for link prediction/node classification.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import datetime 23 | import os 24 | 25 | from absl import app 26 | from absl import flags 27 | 28 | import models.edge_models as edge_models 29 | import models.node_edge_models as node_edge_models 30 | import models.node_models as node_models 31 | 32 | import numpy as np 33 | import scipy.sparse as sp 34 | import tensorflow as tf 35 | 36 | from utils.data_utils import load_data 37 | from utils.data_utils import mask_test_edges 38 | from utils.data_utils import mask_val_test_edges 39 | from utils.data_utils import process_adj 40 | from utils.data_utils import sparse_to_tuple 41 | from utils.train_utils import check_improve 42 | from utils.train_utils import format_metrics 43 | from utils.train_utils import format_params 44 | 45 | flags.DEFINE_string('model_name', 'Gat', 'Which model to use.') 46 | flags.DEFINE_integer('epochs', 10000, 'Number of epochs to train for.') 47 | flags.DEFINE_integer('patience', 100, 'Patience for early stopping.') 48 | flags.DEFINE_string('dataset', 'cora', 49 | 'Dataset to use: (cora - citeseer - pubmed).') 50 | flags.DEFINE_string('datapath', 'data/', 51 | 'Path to directory with data files.') 52 | flags.DEFINE_string('save_dir', '/tmp/models/cora/gat', 53 | 'Directory where to save model checkpoints and summaries.') 54 | flags.DEFINE_float('lr', 0.005, 'Learning rate to use.') 55 | flags.DEFINE_string( 56 | 'model_checkpoint', '', 'Model checkpoint to load before' 57 | 'training or for testing. If not specified the model will be trained from ' 58 | 'scratch.') 59 | flags.DEFINE_float('drop_edge_prop', 0, 60 | 'Percentage of edges to remove (0 to keep all edges).') 61 | flags.DEFINE_float('node_l2_reg', 0.0005, 'L2 regularization to use for node ' 62 | 'model parameters.') 63 | flags.DEFINE_float('edge_l2_reg', 0., 'L2 regularization to use for edge ' 64 | 'model parameters.') 65 | flags.DEFINE_float('edge_reg', 0., 'Regularization to use for the edge ' 66 | 'loss.') 67 | flags.DEFINE_integer( 68 | 'cheby_k_loc', 1, 'K for K-localized filters in Chebyshev' 69 | 'polynomials approximation.') 70 | flags.DEFINE_integer( 71 | 'semi_emb_k', -1, 'which layer to regularize for' 72 | 'semi-supervised embedding model.') 73 | flags.DEFINE_float('p_drop_node', 0.6, 'Dropout probability for node model.') 74 | flags.DEFINE_float('p_drop_edge', 0., 'Dopout probability for edge model.') 75 | flags.DEFINE_integer( 76 | 'topk', 1000, 'Top k entries to keep in adjacency for' 77 | ' NodeEdge models.') 78 | flags.DEFINE_string( 79 | 'n_hidden_node', '8', 'Number of hidden units per layer in node model. ' 80 | 'The last layer has as many nodes as the number of classes ' 81 | 'in the dataset.') 82 | flags.DEFINE_string( 83 | 'n_att_node', '8-1', 84 | 'Number of attentions heads per layer in for node model. ' 85 | '(This is only for graph attention models).') 86 | flags.DEFINE_string('n_hidden_edge', '32-16', 87 | 'Number of hidden units per layer in edge model.') 88 | flags.DEFINE_string( 89 | 'n_att_edge', '8-4', 90 | 'Number of attentions heads per layer in for edge model. ' 91 | '(This is only for edge graph attention models)') 92 | flags.DEFINE_string( 93 | 'att_mechanism', 'dot', 94 | 'Attention mehcanism to use: dot product, asymmetric dot ' 95 | 'product or attention (dot - att - asym-dot).') 96 | flags.DEFINE_string( 97 | 'edge_loss', 'w_sigmoid_ce', 'edge loss (w_sigmoid_ce - neg_sampling_ce). ' 98 | 'w_sigmoid_ce for weighted sigmoid cross entropy and neg_sampling_ce for' 99 | 'negative sampling.') 100 | flags.DEFINE_boolean('sparse_features', True, 101 | 'True if node features are sparse.') 102 | flags.DEFINE_boolean( 103 | 'normalize_adj', True, 'Whether to normalize adjaceny or not (True for' 104 | 'GCN models and False for GAT models).') 105 | flags.DEFINE_integer('run_id', 0, 'Run id.') 106 | 107 | FLAGS = flags.FLAGS 108 | NODE_MODELS = ['Gat', 'Gcn', 'Mlp', 'Hgat', 'Pgcn', 'SemiEmb', 'Cheby'] 109 | NODE_EDGE_MODELS = [ 110 | 'GaeGat', 'GaeGcn', 'GatGraphite', 'GaeGatConcat', 'GaeGcnConcat', 'Gcat' 111 | ] 112 | EDGE_MODELS = ['Gae', 'Egat', 'Emlp', 'Vgae'] 113 | 114 | 115 | class Config(object): 116 | """Gets config parameters from flags to train the GNN models.""" 117 | 118 | def __init__(self): 119 | # Model parameters 120 | self.n_hidden_node = list(map(int, FLAGS.n_hidden_node.split('-'))) 121 | self.n_att_node = list(map(int, FLAGS.n_att_node.split('-'))) 122 | self.n_hidden_edge = list(map(int, FLAGS.n_hidden_edge.split('-'))) 123 | self.n_att_edge = list(map(int, FLAGS.n_att_edge.split('-'))) 124 | self.topk = FLAGS.topk 125 | self.att_mechanism = FLAGS.att_mechanism 126 | self.edge_loss = FLAGS.edge_loss 127 | self.cheby_k_loc = FLAGS.cheby_k_loc 128 | self.semi_emb_k = FLAGS.semi_emb_k 129 | 130 | # Dataset parameters 131 | self.sparse_features = FLAGS.sparse_features 132 | 133 | # Training parameters 134 | self.lr = FLAGS.lr 135 | self.epochs = FLAGS.epochs 136 | self.patience = FLAGS.patience 137 | self.node_l2_reg = FLAGS.node_l2_reg 138 | self.edge_l2_reg = FLAGS.edge_l2_reg 139 | self.edge_reg = FLAGS.edge_reg 140 | self.p_drop_node = FLAGS.p_drop_node 141 | self.p_drop_edge = FLAGS.p_drop_edge 142 | 143 | def set_num_nodes_edges(self, data): 144 | if self.sparse_features: 145 | self.nb_nodes, self.input_dim = data['features'][-1] 146 | else: 147 | self.nb_nodes, self.input_dim = data['features'].shape 148 | self.nb_classes = data['node_labels'].shape[-1] 149 | self.n_hidden_node += [int(self.nb_classes)] 150 | self.nb_edges = np.sum(data['adj_train'] > 0) - self.nb_nodes 151 | self.multilabel = np.max(np.sum(data['node_labels'], 1)) > 1 152 | 153 | def get_filename_suffix(self, run_id): 154 | """Formats all params in a string for log file suffix.""" 155 | all_params = [ 156 | self.lr, self.epochs, self.patience, self.node_l2_reg, self.edge_l2_reg, 157 | self.edge_reg, self.p_drop_node, self.p_drop_edge, '.'.join([ 158 | str(x) for x in self.n_hidden_node 159 | ]), '.'.join([str(x) for x in self.n_att_node]), 160 | '.'.join([str(x) for x in self.n_hidden_edge]), '.'.join( 161 | [str(x) for x in self.n_att_edge]), self.topk, self.att_mechanism, 162 | self.edge_loss, self.cheby_k_loc, self.semi_emb_k, run_id 163 | ] 164 | file_suffix = '-'.join([str(x) for x in all_params]) 165 | return file_suffix 166 | 167 | 168 | class TrainTest(object): 169 | """Class to train node and edge classification models""" 170 | 171 | def __init__(self, model_name): 172 | # initialize global step 173 | self.global_step = 0 174 | self.model_name = model_name 175 | self.data = {'train': {}, 'test': {}, 'val': {}} 176 | 177 | def load_dataset(self, dataset, sparse_features, datapath): 178 | """Loads citation dataset.""" 179 | dataset = load_data(dataset, datapath) 180 | adj_true = dataset[0] + sp.eye(dataset[0].shape[0]) 181 | # adj_true to compute link prediction metrics 182 | self.data['adj_true'] = adj_true.todense() 183 | if sparse_features: 184 | self.data['features'] = sparse_to_tuple(dataset[1]) 185 | else: 186 | self.data['features'] = dataset[1] 187 | self.data['node_labels'] = dataset[2] 188 | self.data['train']['node_mask'] = dataset[3] 189 | self.data['val']['node_mask'] = dataset[4] 190 | self.data['test']['node_mask'] = dataset[5] 191 | 192 | def mask_edges(self, adj_true, drop_edge_prop): 193 | """Load edge mask and remove edges for training adjacency.""" 194 | # adj_train to compute loss 195 | if drop_edge_prop > 0: 196 | if self.model_name in NODE_MODELS: 197 | self.data['adj_train'], test_mask = mask_test_edges( 198 | sp.coo_matrix(adj_true), drop_edge_prop * 0.01) 199 | else: 200 | self.data['adj_train'], val_mask, test_mask = mask_val_test_edges( 201 | sp.coo_matrix(adj_true), drop_edge_prop * 0.01) 202 | self.data['val']['edge_mask'] = val_mask 203 | self.data['train']['edge_mask'] = val_mask # unused 204 | self.data['test']['edge_mask'] = test_mask 205 | self.data['adj_train'] += sp.eye(adj_true.shape[0]) 206 | self.data['adj_train'] = self.data['adj_train'].todense() 207 | else: 208 | self.data['adj_train'] = adj_true 209 | 210 | def process_adj(self, norm_adj): 211 | # adj_train_norm for inference 212 | if norm_adj: 213 | adj_train_norm = process_adj(self.data['adj_train'], self.model_name) 214 | else: 215 | adj_train_norm = sp.coo_matrix(self.data['adj_train']) 216 | self.data['adj_train_norm'] = sparse_to_tuple(adj_train_norm) 217 | 218 | def init_global_step(self): 219 | self.global_step = 0 220 | 221 | def create_saver(self, save_dir, filename_suffix): 222 | """Creates saver to save model checkpoints.""" 223 | self.summary_writer = tf.summary.FileWriter( 224 | save_dir, tf.get_default_graph(), filename_suffix=filename_suffix) 225 | self.saver = tf.train.Saver() 226 | # logging file to print metrics and loss 227 | self.log_file = tf.gfile.Open( 228 | os.path.join(save_dir, '{}.log'.format(filename_suffix)), 'w') 229 | 230 | def _create_summary(self, loss, metrics, split): 231 | """Create summaries for tensorboard.""" 232 | with tf.name_scope('{}-summary'.format(split)): 233 | tf.summary.scalar('loss', loss) 234 | for metric in metrics: 235 | tf.summary.scalar(metric, metrics[metric]) 236 | summary_op = tf.summary.merge_all() 237 | return summary_op 238 | 239 | def _make_feed_dict(self, split): 240 | """Creates feed dictionnaries for edge models and node models.""" 241 | if split == 'train': 242 | is_training = True 243 | else: 244 | is_training = False 245 | return self.model.make_feed_dict(self.data, split, is_training) 246 | 247 | def _get_model_and_targets(self, multilabel): 248 | """Define targets to select best model based on val metrics.""" 249 | if self.model_name in NODE_MODELS: 250 | model_class = getattr(node_models, self.model_name) 251 | if multilabel: 252 | target_metrics = {'f1': 1, 'loss': 0} 253 | else: 254 | target_metrics = {'node_acc': 1, 'loss': 0} 255 | # target_metrics = {'node_acc': 1} 256 | elif self.model_name in NODE_EDGE_MODELS: 257 | model_class = getattr(node_edge_models, self.model_name) 258 | target_metrics = {'node_acc': 1} 259 | else: 260 | model_class = getattr(edge_models, self.model_name) 261 | target_metrics = {'edge_pr_auc': 1} #, 'loss': 0} 262 | return model_class, target_metrics 263 | 264 | def build_model(self, config): 265 | """Build model graph.""" 266 | model_class, self.target_metrics = self._get_model_and_targets( 267 | config.multilabel) 268 | self.model = model_class(config) 269 | all_ops = self.model.build_graph() 270 | loss, train_op, metric_op, metric_update_op = all_ops 271 | self.train_ops = [train_op] 272 | self.eval_ops = [loss, metric_update_op] 273 | self.metrics = metric_op 274 | self.train_summary = self._create_summary(loss, metric_op, 'train') 275 | self.val_summary = self._create_summary(loss, metric_op, 'val') 276 | 277 | def _eval_model(self, sess, split): 278 | """Evaluates model.""" 279 | sess.run(tf.local_variables_initializer()) 280 | if split == 'train': 281 | metrics = {} 282 | # tmp way to not eval on train for edge model 283 | metrics['loss'] = sess.run( 284 | self.eval_ops[0], feed_dict=self._make_feed_dict(split)) 285 | else: 286 | loss, _ = sess.run(self.eval_ops, feed_dict=self._make_feed_dict(split)) 287 | metrics = sess.run(self.metrics, feed_dict=self._make_feed_dict(split)) 288 | metrics['loss'] = loss 289 | return metrics 290 | 291 | def _init_best_metrics(self): 292 | best_metrics = {} 293 | for metric in self.target_metrics: 294 | if self.target_metrics[metric] == 1: 295 | best_metrics[metric] = -1 296 | else: 297 | best_metrics[metric] = np.inf 298 | return best_metrics 299 | 300 | def _log(self, message): 301 | """Writes into train.log file.""" 302 | time = datetime.datetime.now().strftime('%d.%b %Y %H:%M:%S') 303 | self.log_file.write(time + ' : ' + message + '\n') 304 | 305 | def init_model_weights(self, sess): 306 | self._log('Initializing model weights...') 307 | sess.run(tf.global_variables_initializer()) 308 | sess.run(tf.local_variables_initializer()) 309 | 310 | def restore_checkpoint(self, sess, model_checkpoint=None): 311 | """Loads model checkpoint if found and computes evaluation metrics.""" 312 | if model_checkpoint is None or not tf.train.checkpoint_exists( 313 | model_checkpoint): 314 | self.init_model_weights(sess) 315 | else: 316 | self._log('Loading existing model saved at {}'.format(model_checkpoint)) 317 | self.saver.restore(sess, model_checkpoint) 318 | self.global_step = int(model_checkpoint.split('-')[-1]) 319 | val_metrics = self._eval_model(sess, 'val') 320 | test_metrics = self._eval_model(sess, 'test') 321 | self._log(format_metrics(val_metrics, 'val')) 322 | self._log(format_metrics(test_metrics, 'test')) 323 | 324 | def train(self, sess, config): 325 | """Trains node classification model or joint node edge model.""" 326 | self._log('Training {} model...'.format(self.model_name)) 327 | self._log('Training parameters : \n ' + format_params(config)) 328 | epochs = config.epochs 329 | lr = config.lr 330 | patience = config.patience 331 | # best_step = self.global_step 332 | # step for patience 333 | curr_step = 0 334 | # best metrics to select model 335 | best_val_metrics = self._init_best_metrics() 336 | best_test_metrics = self._init_best_metrics() 337 | # train the model 338 | for epoch in range(epochs): 339 | self.global_step += 1 340 | sess.run(self.train_ops, feed_dict=self._make_feed_dict('train')) 341 | train_metrics = self._eval_model(sess, 'train') 342 | val_metrics = self._eval_model(sess, 'val') 343 | self._log('Epoch {} : lr = {:.4f} | '.format(epoch, lr) + 344 | format_metrics(train_metrics, 'train') + 345 | format_metrics(val_metrics, 'val')) 346 | # write summaries 347 | train_summary = sess.run(self.train_summary, 348 | self._make_feed_dict('train')) 349 | val_summary = sess.run(self.val_summary, self._make_feed_dict('val')) 350 | self.summary_writer.add_summary( 351 | train_summary, global_step=self.global_step) 352 | self.summary_writer.add_summary(val_summary, global_step=self.global_step) 353 | # save model checkpoint if val acc increased and val loss decreased 354 | comp = check_improve(best_val_metrics, val_metrics, self.target_metrics) 355 | if np.any(comp): 356 | if np.all(comp): 357 | # best_step = self.global_step 358 | # save_path = os.path.join(save_dir, 'model') 359 | # self.saver.save(sess, save_path, global_step=self.global_step) 360 | best_test_metrics = self._eval_model(sess, 'test') 361 | best_val_metrics = val_metrics 362 | curr_step = 0 363 | else: 364 | curr_step += 1 365 | if curr_step == patience: 366 | self._log('Early stopping') 367 | break 368 | 369 | self._log('\n' + '*' * 40 + ' Best model metrics ' + '*' * 40) 370 | # load best model to evaluate on test set 371 | # save_path = os.path.join(save_dir, 'model-{}'.format(best_step)) 372 | # self.restore_checkpoint(sess, save_path) 373 | self._log(format_metrics(best_val_metrics, 'val')) 374 | self._log(format_metrics(best_test_metrics, 'test')) 375 | self._log('\n' + '*' * 40 + ' Training done ' + '*' * 40) 376 | 377 | def run(self, config, save_dir, file_prefix): 378 | """Build and train a model.""" 379 | tf.reset_default_graph() 380 | self.init_global_step() 381 | # build model 382 | self.build_model(config) 383 | # create summary writer and save for model weights 384 | if not os.path.exists(save_dir): 385 | tf.gfile.MakeDirs(save_dir) 386 | self.create_saver(save_dir, file_prefix) 387 | # run sessions 388 | with tf.Session() as sess: 389 | self.init_model_weights(sess) 390 | self.train(sess, config) 391 | sess.close() 392 | self.log_file.close() 393 | 394 | 395 | def main(_): 396 | # parse configuration parameters 397 | trainer = TrainTest(FLAGS.model_name) 398 | print('Loading dataset...') 399 | # load the dataset and process adjacency and node features 400 | trainer.load_dataset(FLAGS.dataset, FLAGS.sparse_features, FLAGS.datapath) 401 | trainer.mask_edges(trainer.data['adj_true'], FLAGS.drop_edge_prop) 402 | trainer.process_adj(FLAGS.normalize_adj) 403 | print('Dataset loaded...') 404 | config = Config() 405 | config.set_num_nodes_edges(trainer.data) 406 | filename_suffix = config.get_filename_suffix(FLAGS.run_id) 407 | trainer.run(config, FLAGS.save_dir, filename_suffix) 408 | 409 | 410 | if __name__ == '__main__': 411 | app.run(main) 412 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | #Copyright 2018 Google LLC 2 | # 3 | #Licensed under the Apache License, Version 2.0 (the "License"); 4 | #you may not use this file except in compliance with the License. 5 | #You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | #Unless required by applicable law or agreed to in writing, software 10 | #distributed under the License is distributed on an "AS IS" BASIS, 11 | #WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | #See the License for the specific language governing permissions and 13 | #limitations under the License. 14 | -------------------------------------------------------------------------------- /utils/data_utils.py: -------------------------------------------------------------------------------- 1 | #Copyright 2018 Google LLC 2 | # 3 | #Licensed under the Apache License, Version 2.0 (the "License"); 4 | #you may not use this file except in compliance with the License. 5 | #You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | #Unless required by applicable law or agreed to in writing, software 10 | #distributed under the License is distributed on an "AS IS" BASIS, 11 | #WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | #See the License for the specific language governing permissions and 13 | #limitations under the License. 14 | 15 | 16 | """Utils functions to load and process citation data.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | import os 22 | import pickle as pkl 23 | import sys 24 | 25 | import networkx as nx 26 | import numpy as np 27 | import scipy.sparse as sp 28 | from scipy.sparse.linalg.eigen.arpack import eigsh 29 | import tensorflow as tf 30 | from third_party.gcn.gcn.utils import normalize_adj 31 | from third_party.gcn.gcn.utils import parse_index_file 32 | from third_party.gcn.gcn.utils import sample_mask 33 | from third_party.gcn.gcn.utils import sparse_to_tuple 34 | from third_party.gcn.gcn.utils import preprocess_features 35 | 36 | 37 | def load_test_edge_mask(dataset_str, data_path, drop_edge_prop): 38 | """Remove test edges by loading edge masks.""" 39 | edge_mask_path = os.path.join( 40 | data_path, 'emask.{}.remove{}.npz'.format(dataset_str, drop_edge_prop)) 41 | with tf.gfile.Open(edge_mask_path) as f: 42 | mask = sp.load_npz(f) 43 | return mask 44 | 45 | 46 | def load_edge_masks(dataset_str, data_path, adj_true, drop_edge_prop): 47 | """Loads adjacency matrix as sparse matrix and masks for val & test links. 48 | 49 | Args: 50 | dataset_str: dataset to use 51 | data_path: path to data folder 52 | adj_true: true adjacency matrix in dense format, 53 | drop_edge_prop: proportion of edges to remove. 54 | 55 | Returns: 56 | adj_matrix: adjacency matrix 57 | train_mask: mask for train edges 58 | val_mask: mask for val edges 59 | test_mask: mask for test edges 60 | """ 61 | edge_mask_path = os.path.join( 62 | data_path, 'emask.{}.remove{}.'.format(dataset_str, drop_edge_prop)) 63 | val_mask = sp.load_npz(edge_mask_path + 'val.npz') 64 | test_mask = sp.load_npz(edge_mask_path + 'test.npz') 65 | train_mask = 1. - val_mask.todense() - test_mask.todense() 66 | # remove val and test edges from true A 67 | adj_train = np.multiply(adj_true, train_mask) 68 | train_mask -= np.eye(train_mask.shape[0]) 69 | return adj_train, sparse_to_tuple(val_mask), sparse_to_tuple( 70 | val_mask), sparse_to_tuple(test_mask) 71 | 72 | 73 | def add_top_k_edges(data, edge_mask_path, gae_scores_path, topk, nb_nodes, 74 | norm_adj): 75 | """Loads GAE scores and adds topK edges to train adjacency.""" 76 | test_mask = sp.load_npz(os.path.join(edge_mask_path, 'test_mask.npz')) 77 | train_mask = 1. - test_mask.todense() 78 | # remove val and test edges from true A 79 | adj_train_curr = np.multiply(data['adj_true'], train_mask) 80 | # Predict test edges using precomputed scores 81 | scores = np.load(os.path.join(gae_scores_path, 'gae_scores.npy')) 82 | # scores_mask = 1 - np.eye(nb_nodes) 83 | scores_mask = np.zeros((nb_nodes, nb_nodes)) 84 | scores_mask[:140, 140:] = 1. 85 | scores_mask[140:, :140] = 1. 86 | scores = np.multiply(scores, scores_mask).reshape((-1,)) 87 | threshold = scores[np.argsort(-scores)[topk]] 88 | adj_train_curr += 1 * (scores > threshold).reshape((nb_nodes, nb_nodes)) 89 | adj_train_curr = 1 * (adj_train_curr > 0) 90 | if norm_adj: 91 | adj_train_norm = normalize_adj(data['adj_train']) 92 | else: 93 | adj_train_norm = sp.coo_matrix(data['adj_train']) 94 | return adj_train_curr, sparse_to_tuple(adj_train_norm) 95 | 96 | 97 | def process_adj(adj, model_name): 98 | """Symmetrically normalize adjacency matrix.""" 99 | if model_name == 'Cheby': 100 | laplacian = sp.eye(adj.shape[0]) - normalize_adj(adj - sp.eye(adj.shape[0])) 101 | # TODO(chamii): compare with 102 | # adj) 103 | largest_eigval, _ = eigsh(laplacian, 1, which='LM') 104 | laplacian_norm = (2. / largest_eigval[0]) * laplacian - sp.eye(adj.shape[0]) 105 | return laplacian_norm 106 | else: 107 | return normalize_adj(adj) 108 | 109 | 110 | def load_data(dataset_str, data_path): 111 | if dataset_str in ['cora', 'citeseer', 'pubmed']: 112 | return load_citation_data(dataset_str, data_path) 113 | else: 114 | return load_ppi_data(data_path) 115 | 116 | 117 | def load_ppi_data(data_path): 118 | """Load PPI dataset.""" 119 | with tf.gfile.Open(os.path.join(data_path, 'ppi.edges.npz')) as f: 120 | adj = sp.load_npz(f) 121 | 122 | with tf.gfile.Open(os.path.join(data_path, 'ppi.features.norm.npy')) as f: 123 | features = np.load(f) 124 | 125 | with tf.gfile.Open(os.path.join(data_path, 'ppi.labels.npz')) as f: 126 | labels = sp.load_npz(f).todense() 127 | 128 | train_mask = np.load( 129 | tf.gfile.Open(os.path.join(data_path, 'ppi.train_mask.npy'))) > 0 130 | val_mask = np.load( 131 | tf.gfile.Open(os.path.join(data_path, 'ppi.test_mask.npy'))) > 0 132 | test_mask = np.load( 133 | tf.gfile.Open(os.path.join(data_path, 'ppi.test_mask.npy'))) > 0 134 | 135 | return adj, features, labels, train_mask, val_mask, test_mask 136 | 137 | 138 | def load_citation_data(dataset_str, data_path): 139 | """Load data.""" 140 | names = ['x', 'y', 'tx', 'ty', 'allx', 'ally', 'graph'] 141 | objects = {} 142 | for name in names: 143 | with tf.gfile.Open( 144 | os.path.join(data_path, 'ind.{}.{}'.format(dataset_str, name)), 145 | 'rb') as f: 146 | if sys.version_info > (3, 0): 147 | objects[name] = pkl.load(f) # , encoding='latin1') comment to pass lint 148 | else: 149 | objects[name] = pkl.load(f) 150 | 151 | test_idx_reorder = parse_index_file( 152 | os.path.join(data_path, 'ind.{}.test.index'.format(dataset_str))) 153 | test_idx_range = np.sort(test_idx_reorder) 154 | 155 | if dataset_str == 'citeseer': 156 | # Fix citeseer dataset (there are some isolated nodes in the graph) 157 | # Find isolated nodes, add them as zero-vecs into the right position 158 | test_idx_range_full = range( 159 | min(test_idx_reorder), 160 | max(test_idx_reorder) + 1) 161 | tx_extended = sp.lil_matrix((len(test_idx_range_full), 162 | objects['x'].shape[1])) 163 | tx_extended[test_idx_range - min(test_idx_range), :] = objects['tx'] 164 | objects['tx'] = tx_extended 165 | ty_extended = np.zeros((len(test_idx_range_full), 166 | objects['y'].shape[1])) 167 | ty_extended[test_idx_range - min(test_idx_range), :] = objects['ty'] 168 | objects['ty'] = ty_extended 169 | 170 | features = sp.vstack((objects['allx'], objects['tx'])).tolil() 171 | features[test_idx_reorder, :] = features[test_idx_range, :] 172 | adj = nx.adjacency_matrix(nx.from_dict_of_lists(objects['graph'])) 173 | 174 | labels = np.vstack((objects['ally'], objects['ty'])) 175 | labels[test_idx_reorder, :] = labels[test_idx_range, :] 176 | 177 | idx_test = test_idx_range.tolist() 178 | idx_train = range(len(objects['y'])) 179 | idx_val = range(len(objects['y']), len(objects['y']) + 500) 180 | 181 | train_mask = sample_mask(idx_train, labels.shape[0]) 182 | val_mask = sample_mask(idx_val, labels.shape[0]) 183 | test_mask = sample_mask(idx_test, labels.shape[0]) 184 | 185 | features = preprocess_features(features) 186 | return adj, features, labels, train_mask, val_mask, test_mask 187 | 188 | 189 | def construct_feed_dict(adj_normalized, adj, features, placeholders): 190 | # construct feed dictionary 191 | feed_dict = dict() 192 | feed_dict.update({placeholders['features']: features}) 193 | feed_dict.update({placeholders['adj']: adj_normalized}) 194 | feed_dict.update({placeholders['adj_orig']: adj}) 195 | return feed_dict 196 | 197 | 198 | def mask_val_test_edges(adj, prop): 199 | """Function to mask test and val edges.""" 200 | # NOTE: Splits are randomized and results might slightly 201 | # deviate from reported numbers in the paper. 202 | 203 | # Remove diagonal elements 204 | adj = adj - sp.dia_matrix( 205 | (adj.diagonal()[np.newaxis, :], [0]), shape=adj.shape) 206 | adj.eliminate_zeros() 207 | # Check that diag is zero: 208 | assert np.diag(adj.todense()).sum() == 0 209 | 210 | adj_triu = sp.triu(adj) 211 | adj_tuple = sparse_to_tuple(adj_triu) 212 | edges = adj_tuple[0] 213 | edges_all = sparse_to_tuple(adj)[0] 214 | num_test = int(np.floor(edges.shape[0] * prop)) 215 | # num_val = int(np.floor(edges.shape[0] * 0.05)) # we keep 5% for validation 216 | # we keep 10% of training edges for validation 217 | num_val = int(np.floor((edges.shape[0] - num_test) * 0.05)) 218 | 219 | all_edge_idx = range(edges.shape[0]) 220 | np.random.shuffle(all_edge_idx) 221 | val_edge_idx = all_edge_idx[:num_val] 222 | test_edge_idx = all_edge_idx[num_val:(num_val + num_test)] 223 | test_edges = edges[test_edge_idx] 224 | val_edges = edges[val_edge_idx] 225 | train_edges = np.delete( 226 | edges, np.hstack([test_edge_idx, val_edge_idx]), axis=0) 227 | 228 | def ismember(a, b, tol=5): 229 | rows_close = np.all(np.round(a - b[:, None], tol) == 0, axis=-1) 230 | return np.any(rows_close) 231 | 232 | test_edges_false = [] 233 | while len(test_edges_false) < len(test_edges): 234 | idx_i = np.random.randint(0, adj.shape[0]) 235 | idx_j = np.random.randint(0, adj.shape[0]) 236 | if idx_i == idx_j: 237 | continue 238 | if ismember([idx_i, idx_j], edges_all): 239 | continue 240 | if test_edges_false: 241 | if ismember([idx_j, idx_i], np.array(test_edges_false)): 242 | continue 243 | if ismember([idx_i, idx_j], np.array(test_edges_false)): 244 | continue 245 | test_edges_false.append([idx_i, idx_j]) 246 | 247 | val_edges_false = [] 248 | while len(val_edges_false) < len(val_edges): 249 | idx_i = np.random.randint(0, adj.shape[0]) 250 | idx_j = np.random.randint(0, adj.shape[0]) 251 | if idx_i == idx_j: 252 | continue 253 | if ismember([idx_i, idx_j], train_edges): 254 | continue 255 | if ismember([idx_j, idx_i], train_edges): 256 | continue 257 | if ismember([idx_i, idx_j], val_edges): 258 | continue 259 | if ismember([idx_j, idx_i], val_edges): 260 | continue 261 | if val_edges_false: 262 | if ismember([idx_j, idx_i], np.array(val_edges_false)): 263 | continue 264 | if ismember([idx_i, idx_j], np.array(val_edges_false)): 265 | continue 266 | val_edges_false.append([idx_i, idx_j]) 267 | 268 | assert ~ismember(test_edges_false, edges_all) 269 | assert ~ismember(val_edges_false, edges_all) 270 | assert ~ismember(val_edges, train_edges) 271 | assert ~ismember(test_edges, train_edges) 272 | assert ~ismember(val_edges, test_edges) 273 | 274 | data = np.ones(train_edges.shape[0]) 275 | 276 | # Re-build adj matrix 277 | adj_train = sp.csr_matrix((data, (train_edges[:, 0], train_edges[:, 1])), 278 | shape=adj.shape) 279 | adj_train = adj_train + adj_train.T 280 | 281 | # NOTE: these edge lists only contain single direction of edge! 282 | num_nodes = adj.shape[0] 283 | val_mask = np.zeros((num_nodes, num_nodes)) 284 | for i, j in val_edges: 285 | val_mask[i, j] = 1 286 | val_mask[j, i] = 1 287 | for i, j in val_edges_false: 288 | val_mask[i, j] = 1 289 | val_mask[j, i] = 1 290 | test_mask = np.zeros((num_nodes, num_nodes)) 291 | for i, j in test_edges: 292 | test_mask[i, j] = 1 293 | test_mask[j, i] = 1 294 | for i, j in test_edges_false: 295 | test_mask[i, j] = 1 296 | test_mask[j, i] = 1 297 | return adj_train, sparse_to_tuple(val_mask), sparse_to_tuple(test_mask) 298 | 299 | 300 | def mask_test_edges(adj, prop): 301 | """Function to mask test edges. 302 | 303 | Args: 304 | adj: scipy sparse matrix 305 | prop: proportion of edges to remove (float in [0, 1]) 306 | 307 | Returns: 308 | adj_train: adjacency with edges removed 309 | test_edges: list of positive and negative test edges 310 | """ 311 | # Remove diagonal elements 312 | adj = adj - sp.dia_matrix( 313 | (adj.diagonal()[np.newaxis, :], [0]), shape=adj.shape) 314 | adj.eliminate_zeros() 315 | # Check that diag is zero: 316 | assert np.diag(adj.todense()).sum() == 0 317 | 318 | adj_triu = sp.triu(adj) 319 | adj_tuple = sparse_to_tuple(adj_triu) 320 | edges = adj_tuple[0] 321 | edges_all = sparse_to_tuple(adj)[0] 322 | num_test = int(np.floor(edges.shape[0] * prop)) 323 | 324 | all_edge_idx = range(edges.shape[0]) 325 | np.random.shuffle(all_edge_idx) 326 | test_edge_idx = all_edge_idx[:num_test] 327 | test_edges = edges[test_edge_idx] 328 | train_edges = np.delete(edges, test_edge_idx, axis=0) 329 | 330 | def ismember(a, b, tol=5): 331 | rows_close = np.all(np.round(a - b[:, None], tol) == 0, axis=-1) 332 | return np.any(rows_close) 333 | 334 | test_edges_false = [] 335 | while len(test_edges_false) < len(test_edges): 336 | idx_i = np.random.randint(0, adj.shape[0]) 337 | idx_j = np.random.randint(0, adj.shape[0]) 338 | if idx_i == idx_j: 339 | continue 340 | if ismember([idx_i, idx_j], edges_all): 341 | continue 342 | if test_edges_false: 343 | if ismember([idx_j, idx_i], np.array(test_edges_false)): 344 | continue 345 | if ismember([idx_i, idx_j], np.array(test_edges_false)): 346 | continue 347 | test_edges_false.append([idx_i, idx_j]) 348 | 349 | assert ~ismember(test_edges_false, edges_all) 350 | assert ~ismember(test_edges, train_edges) 351 | 352 | data = np.ones(train_edges.shape[0]) 353 | 354 | # Re-build adj matrix 355 | adj_train = sp.csr_matrix((data, (train_edges[:, 0], train_edges[:, 1])), 356 | shape=adj.shape) 357 | adj_train = adj_train + adj_train.T 358 | 359 | # NOTE: these edge lists only contain single direction of edge! 360 | num_nodes = adj.shape[0] 361 | test_mask = np.zeros((num_nodes, num_nodes)) 362 | for i, j in test_edges: 363 | test_mask[i, j] = 1 364 | test_mask[j, i] = 1 365 | for i, j in test_edges_false: 366 | test_mask[i, j] = 1 367 | test_mask[j, i] = 1 368 | return adj_train, sparse_to_tuple(test_mask) 369 | -------------------------------------------------------------------------------- /utils/link_prediction_utils.py: -------------------------------------------------------------------------------- 1 | #Copyright 2018 Google LLC 2 | # 3 | #Licensed under the Apache License, Version 2.0 (the "License"); 4 | #you may not use this file except in compliance with the License. 5 | #You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | #Unless required by applicable law or agreed to in writing, software 10 | #distributed under the License is distributed on an "AS IS" BASIS, 11 | #WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | #See the License for the specific language governing permissions and 13 | #limitations under the License. 14 | 15 | 16 | """Heuristics for link prediction.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | from absl import flags 23 | from data_utils import mask_test_edges 24 | import networkx as nx 25 | import numpy as np 26 | import scipy.sparse as sp 27 | import sklearn.metrics as skm 28 | 29 | 30 | flags.DEFINE_string('adj_path', '../data/cora.adj.npz', 'path to graph to use.') 31 | flags.DEFINE_string('prop_drop', '10-30-50', 'proportion of edges to remove.') 32 | flags.DEFINE_string('methods', 'svd-katz-common_neighbours', 33 | 'which methods to use') 34 | FLAGS = flags.FLAGS 35 | 36 | 37 | class LinkPredictionHeuristcs(object): 38 | """Link prediction heuristics.""" 39 | 40 | def __init__(self, adj_matrix): 41 | self.adj_matrix = adj_matrix 42 | 43 | def common_neighbours(self): 44 | """Computes scores for each node pair based on common neighbours.""" 45 | scores = self.adj_matrix.dot(self.adj_matrix) 46 | return scores 47 | 48 | def svd(self, rank=64): 49 | """Computes scores using low rank factorization with SVD.""" 50 | adj_matrix = self.adj_matrix.asfptype() 51 | u, s, v = sp.linalg.svds(A=adj_matrix, k=rank) 52 | adj_low_rank = u.dot(np.diag(s).dot(v)) 53 | return adj_low_rank 54 | 55 | def adamic_adar(self): 56 | """Computes adamic adar scores.""" 57 | graph = nx.from_scipy_sparse_matrix(self.adj_matrix) 58 | scores = nx.adamic_adar_index(graph) 59 | return scores 60 | 61 | def jaccard_coeff(self): 62 | """Computes Jaccard coefficients.""" 63 | graph = nx.from_scipy_sparse_matrix(self.adj_matrix) 64 | coeffs = nx.jaccard_coefficient(graph) 65 | return coeffs 66 | 67 | def katz(self, beta=0.001, steps=25): 68 | """Computes Katz scores.""" 69 | coeff = beta 70 | katz_scores = beta * self.adj_matrix 71 | adj_power = self.adj_matrix 72 | for _ in range(2, steps + 1): 73 | adj_power = adj_power.dot(self.adj_matrix) 74 | katz_scores += coeff * adj_power 75 | coeff *= beta 76 | return katz_scores 77 | 78 | 79 | def get_scores_from_generator(gen, nb_nodes=2708): 80 | """Helper function to get scores in numpy array format from generator.""" 81 | adj = np.zeros((nb_nodes, nb_nodes)) 82 | for i, j, score in gen: 83 | adj[i, j] = score 84 | return adj 85 | 86 | 87 | def compute_lp_metrics(edges, true_adj, pred_adj): 88 | """Computes link prediction scores on test edges.""" 89 | labels = np.array(true_adj[edges]).reshape((-1,)) 90 | scores = np.array(pred_adj[edges]).reshape((-1,)) 91 | roc = skm.roc_auc_score(labels, scores) 92 | ap = skm.average_precision_score(labels, scores) 93 | return roc, ap 94 | 95 | 96 | if __name__ == '__main__': 97 | adj_true = sp.load_npz(FLAGS.adj_path).todense() 98 | lp = LinkPredictionHeuristcs(adj_true) 99 | for delete_prop in FLAGS.prop_drop.split('-'): 100 | for method in FLAGS.methods.split('-'): 101 | lp_func = getattr(lp, method) 102 | adj_train, test_edges = mask_test_edges( 103 | adj_true, float(delete_prop) * 0.01) 104 | adj_scores = lp_func(adj_train).todense() 105 | roc_score, ap_score = compute_lp_metrics(test_edges, adj_true, adj_scores) 106 | print('method={} | prop={} | roc_auc={} ap={}\n'.format( 107 | method, delete_prop, round(roc_score, 4), round(ap_score, 4))) 108 | -------------------------------------------------------------------------------- /utils/model_utils.py: -------------------------------------------------------------------------------- 1 | #Copyright 2018 Google LLC 2 | # 3 | #Licensed under the Apache License, Version 2.0 (the "License"); 4 | #you may not use this file except in compliance with the License. 5 | #You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | #Unless required by applicable law or agreed to in writing, software 10 | #distributed under the License is distributed on an "AS IS" BASIS, 11 | #WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | #See the License for the specific language governing permissions and 13 | #limitations under the License. 14 | 15 | 16 | """Utils functions for GNN models.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import tensorflow as tf 23 | 24 | WEIGHT_INIT = tf.contrib.layers.xavier_initializer() 25 | BIAS_INIT = tf.zeros_initializer() 26 | 27 | 28 | ############################## LAYERS ############################# 29 | 30 | 31 | def sparse_dropout(tensor, p_drop, is_training): 32 | """Dropout with sparse tensor.""" 33 | return tf.SparseTensor( 34 | indices=tensor.indices, 35 | values=tf.layers.dropout( 36 | inputs=tensor.values, 37 | rate=p_drop, 38 | training=is_training), 39 | dense_shape=tensor.dense_shape) 40 | 41 | 42 | def dense(node_features, 43 | in_dim, 44 | out_dim, 45 | p_drop, 46 | is_training, 47 | sparse, 48 | use_bias=False): 49 | """Dense layer with sparse or dense tensor and dropout.""" 50 | w_dense = tf.get_variable( 51 | initializer=WEIGHT_INIT, 52 | dtype=tf.float32, 53 | name='linear', 54 | shape=(in_dim, out_dim)) 55 | if sparse: 56 | node_features = sparse_dropout(node_features, p_drop, is_training) 57 | node_features = tf.sparse_tensor_dense_matmul(node_features, w_dense) 58 | else: 59 | node_features = tf.layers.dropout( 60 | inputs=node_features, rate=p_drop, training=is_training) 61 | node_features = tf.matmul(node_features, w_dense) 62 | if use_bias: 63 | node_features = tf.contrib.layers.bias_add(node_features) 64 | return node_features 65 | 66 | 67 | def sp_gcn_layer(node_features, adj_matrix, in_dim, out_dim, p_drop, 68 | is_training, sparse): 69 | """Single graph convolution layer with sparse tensors AXW. 70 | 71 | Args: 72 | node_features: Tensor of shape (nb_nodes, in_dim) or SparseTensor. 73 | adj_matrix: Sparse Tensor, normalized adjacency matrix. 74 | in_dim: integer specifying the input feature dimension. 75 | out_dim: integer specifying the output feature dimension. 76 | p_drop: dropout probability. 77 | is_training: boolean, True if the model is being trained, False otherwise. 78 | sparse: True if node_features are sparse. 79 | 80 | Returns: 81 | node_features: tensor of shape (nb_nodes, out_dim). New node 82 | features obtained from applying one GCN layer. 83 | 84 | Raises: 85 | """ 86 | node_features = dense(node_features, in_dim, out_dim, p_drop, is_training, 87 | sparse) 88 | node_features = tf.layers.dropout( 89 | inputs=node_features, rate=p_drop, training=is_training) 90 | node_features = tf.sparse_tensor_dense_matmul(adj_matrix, node_features) 91 | return node_features 92 | 93 | 94 | def gcn_layer(node_features, adj_matrix, in_dim, out_dim, p_drop, is_training, 95 | sparse): 96 | """Single graph convolution layer with dense A. 97 | 98 | Args: 99 | node_features: Tensor of shape (nb_nodes, in_dim) or SparseTensor. 100 | adj_matrix: Tensor, normalized adjacency matrix. 101 | in_dim: integer specifying the input feature dimension. 102 | out_dim: integer specifying the output feature dimension. 103 | p_drop: dropout probability. 104 | is_training: boolean, True if the model is being trained, False otherwise. 105 | sparse: True if node_features are sparse. 106 | 107 | Returns: 108 | node_features: tensor of shape (nb_nodes, out_dim). New node 109 | features obtained from applying one GCN layer. 110 | 111 | Raises: 112 | """ 113 | node_features = dense(node_features, in_dim, out_dim, p_drop, is_training, 114 | sparse) 115 | node_features = tf.layers.dropout( 116 | inputs=node_features, rate=p_drop, training=is_training) 117 | node_features = tf.matmul(adj_matrix, node_features) 118 | return node_features 119 | 120 | 121 | def gcn_pool_layer(node_features, adj_matrix, in_dim, out_dim, sparse, 122 | is_training, p_drop): 123 | """GCN with maxpooling over neighbours instead of avreaging.""" 124 | node_features = dense(node_features, in_dim, out_dim, p_drop, is_training, 125 | sparse) 126 | node_features = tf.expand_dims(node_features, 0) # 1 x N x d 127 | # broadcasting (adj in N x N x 1 and features are 1 x N x d) 128 | node_features = tf.multiply(node_features, adj_matrix) 129 | node_features = tf.transpose(node_features, perm=[0, 2, 1]) 130 | node_features = tf.reduce_max(node_features, axis=-1) # N x d 131 | return node_features 132 | 133 | 134 | def sp_gat_layer(node_features, adj_matrix, in_dim, out_dim, p_drop, 135 | is_training, sparse): 136 | """Single graph attention layer using sparse tensors. 137 | 138 | Args: 139 | node_features: Sparse Tensor of shape (nb_nodes, in_dim) or SparseTensor. 140 | adj_matrix: Sparse Tensor. 141 | in_dim: integer specifying the input feature dimension. 142 | out_dim: integer specifying the output feature dimension. 143 | p_drop: dropout probability. 144 | is_training: boolean, True if the model is being trained, False otherwise 145 | sparse: True if node features are sparse. 146 | 147 | Returns: 148 | node_features: tensor of shape (nb_nodes, out_dim). New node 149 | features obtained from applying one head of attention to input. 150 | 151 | Raises: 152 | """ 153 | # Linear transform 154 | node_features = dense(node_features, in_dim, out_dim, p_drop, is_training, 155 | sparse) 156 | # Attention scores 157 | alpha = sp_compute_adj_att(node_features, adj_matrix) 158 | alpha = tf.SparseTensor( 159 | indices=alpha.indices, 160 | values=tf.nn.leaky_relu(alpha.values), 161 | dense_shape=alpha.dense_shape) 162 | alpha = tf.sparse_softmax(alpha) 163 | alpha = sparse_dropout(alpha, p_drop, is_training) 164 | node_features = tf.layers.dropout( 165 | inputs=node_features, rate=p_drop, training=is_training) 166 | # Compute self-attention features 167 | node_features = tf.sparse_tensor_dense_matmul(alpha, node_features) 168 | node_features = tf.contrib.layers.bias_add(node_features) 169 | return node_features 170 | 171 | 172 | def gat_layer(node_features, adj_matrix, out_dim, p_drop, is_training, i, j): 173 | """Single graph attention layer. 174 | 175 | Args: 176 | node_features: Tensor of shape (nb_nodes, feature_dim) 177 | adj_matrix: adjacency matrix. Tensor of shape (nb_nodes, nb_nodes) and type 178 | float. There should be 1 if there is a connection between two nodes and 0 179 | otherwise. 180 | out_dim: integer specifying the output feature dimension. 181 | p_drop: dropout probability. 182 | is_training: boolean, True if the model is being trained, False otherwise 183 | i: layer index, used for naming variables 184 | j: attention mechanism index, used for naming variables 185 | 186 | Returns: 187 | node_features: tensor of shape (nb_nodes, out_dim). New node 188 | features obtained from applying one head of attention to input. 189 | 190 | Raises: 191 | """ 192 | with tf.variable_scope('gat-{}-{}'.format(i, j)): 193 | node_features = tf.layers.dropout( 194 | inputs=node_features, rate=p_drop, training=is_training) 195 | # Linear transform of the features 196 | w_dense = tf.get_variable( 197 | initializer=WEIGHT_INIT, 198 | dtype=tf.float32, 199 | name='linear', 200 | shape=(node_features.shape[1], out_dim)) 201 | node_features = tf.matmul(node_features, w_dense) 202 | alpha = compute_adj_att(node_features) 203 | alpha = tf.nn.leaky_relu(alpha) 204 | # Mask values before activation to inject the graph structure 205 | # Add -infinity to corresponding pairs before normalization 206 | bias_mat = -1e9 * (1. - adj_matrix) 207 | # multiply here if adjacency is weighted 208 | alpha = tf.nn.softmax(alpha + bias_mat, axis=-1) 209 | # alpha = tf.nn.softmax(alpha, axis=-1) 210 | alpha = tf.layers.dropout(inputs=alpha, rate=p_drop, training=is_training) 211 | node_features = tf.layers.dropout( 212 | inputs=node_features, rate=p_drop, training=is_training) 213 | # Compute self-attention features 214 | node_features = tf.matmul(alpha, node_features) 215 | node_features = tf.contrib.layers.bias_add(node_features) 216 | return node_features 217 | 218 | 219 | def sp_egat_layer(node_features, adj_matrix, in_dim, out_dim, p_drop, 220 | is_training, sparse): 221 | """Single graph attention layer using sparse tensors. 222 | 223 | Args: 224 | node_features: Tensor of shape (nb_nodes, in_dim) or SparseTensor. 225 | adj_matrix: Sparse Tensor. 226 | in_dim: integer specifying the input feature dimension. 227 | out_dim: integer specifying the output feature dimension. 228 | p_drop: dropout probability. 229 | is_training: boolean, True if the model is being trained, False otherwise 230 | sparse: True if node features are sparse. 231 | 232 | Returns: 233 | node_features: tensor of shape (nb_nodes, out_dim). New node 234 | features obtained from applying one head of attention to input. 235 | 236 | Raises: 237 | """ 238 | # Linear transform 239 | node_features = dense(node_features, in_dim, out_dim, p_drop, is_training, 240 | sparse) 241 | # Attention scores 242 | alpha = sp_compute_adj_att(node_features, adj_matrix) 243 | alpha = tf.SparseTensor( 244 | indices=alpha.indices, 245 | values=tf.nn.leaky_relu(alpha.values), 246 | dense_shape=alpha.dense_shape) 247 | alpha = tf.sparse_softmax(alpha) 248 | alpha = sparse_dropout(alpha, p_drop, is_training) 249 | node_features = tf.layers.dropout( 250 | inputs=node_features, rate=p_drop, training=is_training) 251 | # Compute self-attention features 252 | node_features = tf.sparse_tensor_dense_matmul(alpha, node_features) 253 | node_features = tf.contrib.layers.bias_add(node_features) 254 | return node_features 255 | 256 | 257 | ############################## MULTI LAYERS ############################# 258 | 259 | 260 | def mlp_module(node_features, n_hidden, p_drop, is_training, in_dim, 261 | sparse_features, use_bias, return_hidden=False): 262 | """MLP.""" 263 | nb_layers = len(n_hidden) 264 | hidden_layers = [node_features] 265 | for i, out_dim in enumerate(n_hidden): 266 | with tf.variable_scope('mlp-{}'.format(i)): 267 | if i > 0: 268 | sparse_features = False 269 | if i == nb_layers - 1: 270 | use_bias = False 271 | h_i = dense(hidden_layers[-1], in_dim, out_dim, p_drop, is_training, 272 | sparse_features, use_bias) 273 | if i < nb_layers - 1: 274 | h_i = tf.nn.relu(h_i) 275 | in_dim = out_dim 276 | hidden_layers.append(h_i) 277 | if return_hidden: 278 | return hidden_layers 279 | else: 280 | return hidden_layers[-1] 281 | 282 | 283 | def gcn_module(node_features, adj_matrix, n_hidden, p_drop, is_training, in_dim, 284 | sparse_features): 285 | """GCN module with multiple layers.""" 286 | nb_layers = len(n_hidden) 287 | for i, out_dim in enumerate(n_hidden): 288 | if i > 0: 289 | sparse_features = False 290 | with tf.variable_scope('gcn-{}'.format(i)): 291 | node_features = sp_gcn_layer(node_features, adj_matrix, in_dim, out_dim, 292 | p_drop, is_training, sparse_features) 293 | if i < nb_layers - 1: 294 | node_features = tf.nn.relu(node_features) 295 | in_dim = out_dim 296 | return node_features 297 | 298 | 299 | def cheby_module(node_features, cheby_poly, n_hidden, p_drop, is_training, 300 | in_dim, sparse_features): 301 | """GCN module with multiple layers.""" 302 | nb_layers = len(n_hidden) 303 | for i, out_dim in enumerate(n_hidden): 304 | if i > 0: 305 | sparse_features = False 306 | feats = [] 307 | for j, poly in enumerate(cheby_poly): 308 | with tf.variable_scope('cheb-{}-{}'.format(i, j)): 309 | sparse_poly = tf.contrib.layers.dense_to_sparse(poly) 310 | feats.append(sp_gcn_layer(node_features, sparse_poly, in_dim, out_dim, 311 | p_drop, is_training, sparse_features)) 312 | node_features = tf.add_n(feats) 313 | if i < nb_layers - 1: 314 | node_features = tf.nn.relu(node_features) 315 | in_dim = out_dim 316 | return node_features 317 | 318 | 319 | def gat_module(node_features, adj_matrix, n_hidden, n_att, p_drop, is_training, 320 | in_dim, sparse_features, average_last): 321 | """GAT module with muli-headed attention and multiple layers.""" 322 | nb_layers = len(n_att) 323 | for i, k in enumerate(n_att): 324 | out_dim = n_hidden[i] 325 | att = [] 326 | if i > 0: 327 | sparse_features = False 328 | for j in range(k): 329 | with tf.variable_scope('gat-layer{}-att{}'.format(i, j)): 330 | att.append( 331 | sp_gat_layer(node_features, adj_matrix, in_dim, out_dim, p_drop, 332 | is_training, sparse_features)) 333 | # intermediate layers, concatenate features 334 | if i < nb_layers - 1: 335 | in_dim = out_dim * k 336 | node_features = tf.nn.elu(tf.concat(att, axis=-1)) 337 | if average_last: 338 | # last layer, average features instead of concatenating 339 | logits = tf.add_n(att) / n_att[-1] 340 | else: 341 | logits = tf.concat(att, axis=-1) 342 | return logits 343 | 344 | 345 | def egat_module(node_features, adj_matrix, n_hidden, n_att, p_drop, is_training, 346 | in_dim, sparse_features, average_last): 347 | """Edge-GAT module with muli-headed attention and multiple layers.""" 348 | nb_layers = len(n_att) 349 | for i, k in enumerate(n_att): 350 | out_dim = n_hidden[i] 351 | att = [] 352 | if i > 0: 353 | sparse_features = False 354 | for j in range(k): 355 | with tf.variable_scope('egat-layer{}-att{}'.format(i, j)): 356 | att.append( 357 | sp_gat_layer(node_features, adj_matrix, in_dim, out_dim, p_drop, 358 | is_training, sparse_features)) 359 | # intermediate layers, concatenate features 360 | if i < nb_layers - 1: 361 | in_dim = out_dim * k 362 | node_features = tf.nn.elu(tf.concat(att, axis=-1)) 363 | if average_last: 364 | # last layer, average features instead of concatenating 365 | logits = tf.add_n(att) / n_att[-1] 366 | else: 367 | logits = tf.concat(att, axis=-1) 368 | return logits 369 | 370 | 371 | ###################### EDGE SCORES FUNCTIONS ############################# 372 | 373 | 374 | def sp_compute_adj_att(node_features, adj_matrix_sp): 375 | """Self-attention for edges as in GAT with sparse adjacency.""" 376 | out_dim = node_features.shape[-1] 377 | # Self-attention mechanism 378 | a_row = tf.get_variable( 379 | initializer=WEIGHT_INIT, 380 | dtype=tf.float32, 381 | name='selfatt-row', 382 | shape=(out_dim, 1)) 383 | a_col = tf.get_variable( 384 | initializer=WEIGHT_INIT, 385 | dtype=tf.float32, 386 | name='selfatt-col', 387 | shape=(out_dim, 1)) 388 | alpha_row = tf.matmul(node_features, a_row) 389 | alpha_col = tf.matmul(node_features, a_col) 390 | # Compute matrix with self-attention scores using broadcasting 391 | alpha = tf.sparse_add(adj_matrix_sp * alpha_row, 392 | adj_matrix_sp * tf.transpose(alpha_col, perm=[1, 0])) 393 | return alpha 394 | 395 | 396 | def compute_adj_att(node_features): 397 | """Self-attention for edges as in GAT.""" 398 | out_dim = node_features.shape[-1] 399 | # Self-attention mechanism 400 | a_row = tf.get_variable( 401 | initializer=WEIGHT_INIT, 402 | dtype=tf.float32, 403 | name='selfatt-row', 404 | shape=(out_dim, 1)) 405 | a_col = tf.get_variable( 406 | initializer=WEIGHT_INIT, 407 | dtype=tf.float32, 408 | name='selfatt-col', 409 | shape=(out_dim, 1)) 410 | alpha_row = tf.matmul(node_features, a_row) 411 | alpha_col = tf.matmul(node_features, a_col) 412 | # Compute matrix with self-attention scores using broadcasting 413 | alpha = alpha_row + tf.transpose(alpha_col, perm=[1, 0]) 414 | # alpha += alpha_col + tf.transpose(alpha_row, perm=[1, 0]) 415 | return alpha 416 | 417 | 418 | def compute_weighted_mat_dot(node_features, nb_dots=1): 419 | """Compute weighted dot with matrix multiplication.""" 420 | adj_scores = [] 421 | in_dim = node_features.shape[-1] 422 | for i in range(nb_dots): 423 | weight_mat = tf.get_variable( 424 | initializer=WEIGHT_INIT, 425 | dtype=tf.float32, 426 | name='w-dot-{}'.format(i), 427 | shape=(in_dim, in_dim)) 428 | adj_scores.append(tf.matmul(node_features, tf.matmul( 429 | weight_mat, tf.transpose(node_features, perm=[1, 0])))) 430 | return tf.add_n(adj_scores) 431 | 432 | 433 | def compute_weighted_dot(node_features, nb_dots=4): 434 | """Compute weighted dot product.""" 435 | adj_scores = [] 436 | in_dim = node_features.shape[-1] 437 | for i in range(nb_dots): 438 | weight_vec = tf.get_variable( 439 | initializer=WEIGHT_INIT, 440 | dtype=tf.float32, 441 | name='w-dot-{}'.format(i), 442 | shape=(1, in_dim)) 443 | weight_vec = tf.nn.softmax(weight_vec, axis=-1) 444 | adj_scores.append(tf.matmul(tf.multiply(weight_vec, node_features), 445 | tf.transpose(node_features, perm=[1, 0]))) 446 | return tf.add_n(adj_scores) 447 | 448 | 449 | def compute_l2_sim_matrix(node_features): 450 | """Compute squared-L2 distance between each pair of nodes.""" 451 | # N x N 452 | # d_scores = tf.matmul(node_features, tf.transpose(node_features,perm=[1, 0])) 453 | # diag = tf.diag_part(d_scores) 454 | # d_scores *= -2. 455 | # d_scores += tf.reshape(diag, (-1, 1)) + tf.reshape(diag, (1, -1)) 456 | l2_norm = tf.reduce_sum(tf.square(node_features), 1) 457 | na = tf.reshape(l2_norm, [-1, 1]) 458 | nb = tf.reshape(l2_norm, [1, -1]) 459 | # return pairwise euclidead difference matrix 460 | l2_scores = tf.maximum( 461 | na - 2*tf.matmul(node_features, node_features, False, True) + nb, 0.0) 462 | return l2_scores 463 | 464 | 465 | def compute_dot_sim_matrix(node_features): 466 | """Compute edge scores with dot product.""" 467 | sim = tf.matmul(node_features, tf.transpose(node_features, perm=[1, 0])) 468 | return sim 469 | 470 | 471 | def compute_dot_norm(features): 472 | """Compute edge scores with normalized dot product.""" 473 | features = tf.nn.l2_normalize(features, axis=-1) 474 | sim = tf.matmul(features, tf.transpose(features, perm=[1, 0])) 475 | return sim 476 | 477 | 478 | def compute_asym_dot(node_features): 479 | """Compute edge scores with asymmetric dot product.""" 480 | feat_left, feat_right = tf.split(node_features, 2, axis=-1) 481 | feat_left = tf.nn.l2_normalize(feat_left, axis=-1) 482 | feat_right = tf.nn.l2_normalize(feat_right, axis=-1) 483 | sim = tf.matmul(feat_left, tf.transpose(feat_right, perm=[1, 0])) 484 | return sim 485 | 486 | 487 | def compute_adj(features, att_mechanism, p_drop, is_training): 488 | """Compute adj matrix given node features.""" 489 | features = tf.layers.dropout( 490 | inputs=features, rate=p_drop, training=is_training) 491 | if att_mechanism == 'dot': 492 | return compute_dot_sim_matrix(features) 493 | elif att_mechanism == 'weighted-mat-dot': 494 | return compute_weighted_mat_dot(features) 495 | elif att_mechanism == 'weighted-dot': 496 | return compute_weighted_dot(features) 497 | elif att_mechanism == 'att': 498 | return compute_adj_att(features) 499 | elif att_mechanism == 'dot-norm': 500 | return compute_dot_norm(features) 501 | elif att_mechanism == 'asym-dot': 502 | return compute_asym_dot(features) 503 | else: 504 | return compute_l2_sim_matrix(features) 505 | 506 | 507 | def get_sp_topk(adj_pred, sp_adj_train, nb_nodes, k): 508 | """Returns binary matrix with topK.""" 509 | _, indices = tf.nn.top_k(tf.reshape(adj_pred, (-1,)), k) 510 | indices = tf.reshape(tf.cast(indices, tf.int64), (-1, 1)) 511 | sp_adj_pred = tf.SparseTensor( 512 | indices=indices, 513 | values=tf.ones(k), 514 | dense_shape=(nb_nodes * nb_nodes,)) 515 | sp_adj_pred = tf.sparse_reshape(sp_adj_pred, 516 | shape=(nb_nodes, nb_nodes, 1)) 517 | sp_adj_train = tf.SparseTensor( 518 | indices=sp_adj_train.indices, 519 | values=tf.ones_like(sp_adj_train.values), 520 | dense_shape=sp_adj_train.dense_shape) 521 | sp_adj_train = tf.sparse_reshape(sp_adj_train, 522 | shape=(nb_nodes, nb_nodes, 1)) 523 | sp_adj_pred = tf.sparse_concat( 524 | sp_inputs=[sp_adj_pred, sp_adj_train], axis=-1) 525 | return tf.sparse_reduce_max(sp_adj_pred, axis=-1) 526 | 527 | 528 | @tf.custom_gradient 529 | def mask_edges(scores, mask): 530 | masked_scores = tf.multiply(scores, mask) 531 | def grad(dy): 532 | return dy, None # tf.multiply(scores, dy) 533 | return masked_scores, grad 534 | -------------------------------------------------------------------------------- /utils/train_utils.py: -------------------------------------------------------------------------------- 1 | #Copyright 2018 Google LLC 2 | # 3 | #Licensed under the Apache License, Version 2.0 (the "License"); 4 | #you may not use this file except in compliance with the License. 5 | #You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | #Unless required by applicable law or agreed to in writing, software 10 | #distributed under the License is distributed on an "AS IS" BASIS, 11 | #WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | #See the License for the specific language governing permissions and 13 | #limitations under the License. 14 | 15 | """Helper functions for training.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | 22 | def format_metrics(metrics, mode): 23 | """Format metrics for logging.""" 24 | result = '' 25 | for metric in metrics: 26 | result += '{}_{} = {:.4f} | '.format(mode, metric, float(metrics[metric])) 27 | return result 28 | 29 | 30 | def format_params(config): 31 | """Format training parameters for logging.""" 32 | result = '' 33 | for key, value in config.__dict__.items(): 34 | result += '{}={} \n '.format(key, str(value)) 35 | return result 36 | 37 | 38 | def check_improve(best_metrics, metrics, targets): 39 | """Checks if any of the target metrics improved.""" 40 | return [ 41 | compare(metrics[target], best_metrics[target], targets[target]) 42 | for target in targets 43 | ] 44 | 45 | 46 | def compare(x1, x2, increasing): 47 | if increasing == 1: 48 | return x1 >= x2 49 | else: 50 | return x1 <= x2 51 | --------------------------------------------------------------------------------