├── LICENSE.md ├── README.md ├── attention_pooling.py ├── chebygin.py ├── checkpoints ├── checkpoint_TU_657452_epoch50_seed0000111.pth.tar ├── checkpoint_colors-3_223919_epoch300_seed0000111.pth.tar ├── checkpoint_colors-3_312570_epoch300_seed0000111.pth.tar ├── checkpoint_colors-3_332172_epoch300_seed0000111.pth.tar ├── checkpoint_colors-3_828931_epoch100_seed0000111.pth.tar ├── checkpoint_mnist-75sp_065802_epoch30_seed0000111.pth.tar ├── checkpoint_mnist-75sp_139255_epoch30_seed0000111.pth.tar ├── checkpoint_mnist-75sp_330394_epoch30_seed0000111.pth.tar ├── checkpoint_mnist-75sp_820601_epoch30_seed0000111.pth.tar ├── checkpoint_triangles_051609_epoch100_seed0000111.pth.tar ├── checkpoint_triangles_230187_epoch100_seed0000111.pth.tar ├── checkpoint_triangles_586710_epoch100_seed0000111.pth.tar ├── checkpoint_triangles_658037_epoch100_seed0000111.pth.tar ├── colors-3_alpha_WS_test_seed111_orig.pkl ├── colors-3_alpha_WS_train_seed111_orig.pkl ├── mnist-75sp_alpha_WS_test_seed111_noisy-c.pkl ├── mnist-75sp_alpha_WS_test_seed111_noisy.pkl ├── mnist-75sp_alpha_WS_test_seed111_orig.pkl ├── mnist-75sp_alpha_WS_train_seed111_orig.pkl ├── triangles_alpha_WS_test_seed111_orig.pkl └── triangles_alpha_WS_train_seed111_orig.pkl ├── data ├── COLORS-3.zip ├── TRIANGLES.zip ├── datasets.png ├── mnist_animation.gif └── triangles_animation.gif ├── extract_superpixels.py ├── generate_data.py ├── graphdata.py ├── logs ├── colors-3_global_max_seed111.log ├── colors-3_sup_seed111.log ├── colors-3_unsup_seed111.log ├── colors-3_weaksup_seed111.log ├── mnist-75sp_global_max_seed111.log ├── mnist-75sp_sup_seed111.log ├── mnist-75sp_unsup_seed111.log ├── mnist-75sp_weaksup_seed111.log ├── prepare_data.log ├── proteins_wsup_seed111.log ├── triangles_global_max_seed111.log ├── triangles_sup_seed111.log ├── triangles_unsup_seed111.log └── triangles_weaksup_seed111.log ├── main.py ├── notebooks ├── MNIST_eval_models.ipynb ├── TRIANGLES_eval_models.ipynb ├── convert2TU.ipynb ├── graphs_visualize.ipynb ├── molecules_social_visualize.ipynb ├── superpixels_visualize.ipynb └── synthetic_graphs_visualize.ipynb ├── scripts ├── colors.sh ├── mnist_75sp.sh ├── prepare_data.sh └── triangles.sh ├── train_test.py └── utils.py /LICENSE.md: -------------------------------------------------------------------------------- 1 | Educational Community License, Version 2.0 (ECL-2.0) 2 | 3 | Version 2.0, April 2007 4 | 5 | http://www.osedu.org/licenses/ 6 | 7 | The Educational Community License version 2.0 ("ECL") consists of the Apache 2.0 license, modified to change the scope of the patent grant in section 3 to be specific to the needs of the education communities using this license. The original Apache 2.0 license can be found at: http://www.apache.org/licenses /LICENSE-2.0 8 | 9 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 10 | 11 | 1. Definitions. 12 | 13 | "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. 14 | 15 | "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. 16 | 17 | "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. 18 | 19 | "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. 20 | 21 | "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. 22 | 23 | "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. 24 | 25 | "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). 26 | 27 | "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. 28 | 29 | "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." 30 | 31 | "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 32 | 33 | 2. Grant of Copyright License. 34 | 35 | Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 36 | 37 | 3. Grant of Patent License. 38 | 39 | Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. Any patent license granted hereby with respect to contributions by an individual employed by an institution or organization is limited to patent claims where the individual that is the author of the Work is also the inventor of the patent claims licensed, and where the organization or institution has the right to grant such license under applicable grant and research funding agreements. No other express or implied licenses are granted. 40 | 41 | 4. Redistribution. 42 | 43 | You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: 44 | 45 | You must give any other recipients of the Work or Derivative Works a copy of this License; and You must cause any modified files to carry prominent notices stating that You changed the files; and You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 46 | 47 | 5. Submission of Contributions. 48 | 49 | Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 50 | 51 | 6. Trademarks. 52 | 53 | This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 54 | 55 | 7. Disclaimer of Warranty. 56 | 57 | Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 58 | 59 | 8. Limitation of Liability. 60 | 61 | In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 62 | 63 | 9. Accepting Warranty or Additional Liability. 64 | 65 | While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. 66 | 67 | END OF TERMS AND CONDITIONS 68 | 69 | APPENDIX: How to apply the Educational Community License to your work 70 | 71 | To apply the Educational Community License to your work, attach 72 | the following boilerplate notice, with the fields enclosed by 73 | brackets "[]" replaced with your own identifying information. 74 | (Don't include the brackets!) The text should be enclosed in the 75 | appropriate comment syntax for the file format. We also recommend 76 | that a file or class name and description of purpose be included on 77 | the same "printed page" as the copyright notice for easier 78 | identification within third-party archives. 79 | 80 | Copyright [yyyy] [name of copyright owner] Licensed under the 81 | Educational Community License, Version 2.0 (the "License"); you may 82 | not use this file except in compliance with the License. You may 83 | obtain a copy of the License at 84 | 85 | http://www.osedu.org/licenses /ECL-2.0 86 | 87 | Unless required by applicable law or agreed to in writing, 88 | software distributed under the License is distributed on an "AS IS" 89 | BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express 90 | or implied. See the License for the specific language governing 91 | permissions and limitations under the License. 92 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Intro 2 | 3 | This repository contains code to generate data and reproduce experiments from our NeurIPS 2019 paper: 4 | 5 | [Boris Knyazev, Graham W. Taylor, Mohamed R. Amer. Understanding Attention and Generalization in Graph Neural Networks](https://arxiv.org/abs/1905.02850). 6 | 7 | See slides [here](https://drive.google.com/open?id=1HcmhSEnf8ll6-BxXK1PiGzcXDa6BbKnC). 8 | 9 | [An earlier short version](https://rlgm.github.io/papers/54.pdf) of our paper was presented as a **contributed talk** at [ICLR Workshop on Representation Learning on Graphs and Manifolds, 2019](https://rlgm.github.io/cfp/). 10 | 11 | **Update:** 12 | 13 | In the code for MNIST, the `dist` variable should have been squared to make it a Gaussian. All figures and results were generated without squaring it. I don't think it's very important in terms of results, but if you square it, `sigma` should be adjusted accordingly. 14 | 15 | 16 | | MNIST | TRIANGLES 17 | |:-------------------------:|:-------------------------:| 18 | |
|
| 19 | 20 | 21 | For MNIST from top to bottom rows: 22 | 23 | - input test images with additive Gaussian noise with standard deviation in the range from 0 to 1.4 with step 0.2 24 | - attention coefficients (alpha) predicted by the **unsupervised** model 25 | - attention coefficients (alpha) predicted by the **supervised** model 26 | - attention coefficients (alpha) predicted by our **weakly-supervised** model 27 | 28 | For TRIANGLES from top to bottom rows: 29 | 30 | - **on the left**: input test graph (with 4-100 nodes) with ground truth attention coefficients, **on the right**: graph obtained by **ground truth** node pooling 31 | - **on the left**: input test graph (with 4-100 nodes) with unsupervised attention coefficients, **on the right**: graph obtained by **unsupervised** node pooling 32 | - **on the left**: input test graph (with 4-100 nodes) with supervised attention coefficients, **on the right**: graph obtained by **supervised** node pooling 33 | - **on the left**: input test graph (with 4-100 nodes) with weakly-supervised attention coefficients, **on the right**: graph obtained by **weakly-supervised** node pooling 34 | 35 | 36 | Note that during training, our MNIST models have not encountered noisy images and our TRIANGLES models have not encountered graphs larger than with N=25 nodes. 37 | 38 | 39 | ## Examples using [PyTorch Geometric](https://github.com/rusty1s/pytorch_geometric) 40 | 41 | 42 | 43 | 44 | 45 | COLORS and TRIANGLES datasets are now also available in the [TU](https://ls11-www.cs.tu-dortmund.de/staff/morris/graphkerneldatasets) format, so that you can use a general TU datareader. See PyTorch Geometric examples for 46 | [COLORS](https://github.com/rusty1s/pytorch_geometric/blob/master/examples/colors_topk_pool.py) and [TRIANGLES](https://github.com/rusty1s/pytorch_geometric/blob/master/examples/triangles_sag_pool.py). 47 | 48 | 49 | ## Example of evaluating a pretrained model on MNIST 50 | 51 | For more examples, see [MNIST_eval_models](notebooks/MNIST_eval_models.ipynb) and 52 | [TRIANGLES_eval_models](notebooks/TRIANGLES_eval_models.ipynb). 53 | 54 | ```python 55 | # Download model checkpoint or 'git clone' this repo 56 | import urllib.request 57 | # Let's use the model with supervised attention (other models can be found in the Table below) 58 | model_name = 'checkpoint_mnist-75sp_139255_epoch30_seed0000111.pth.tar' 59 | model_url = 'https://github.com/bknyaz/graph_attention_pool/raw/master/checkpoints/%s' % model_name 60 | model_path = 'checkpoints/%s' % model_name 61 | urllib.request.urlretrieve(model_url, model_path) 62 | ``` 63 | 64 | ```python 65 | # Load the model 66 | import torch 67 | from chebygin import ChebyGIN 68 | 69 | state = torch.load(model_path) 70 | args = state['args'] 71 | model = ChebyGIN(in_features=5, out_features=10, filters=args.filters, K=args.filter_scale, 72 | n_hidden=args.n_hidden, aggregation=args.aggregation, dropout=args.dropout, 73 | readout=args.readout, pool=args.pool, pool_arch=args.pool_arch) 74 | model.load_state_dict(state['state_dict']) 75 | model = model.eval() 76 | ``` 77 | 78 | ```python 79 | # Load image using standard PyTorch Dataset 80 | from torchvision import datasets 81 | data = datasets.MNIST('./data', train=False, download=True) 82 | images = (data.test_data.numpy() / 255.) 83 | import numpy as np 84 | img = images[0].astype(np.float32) # 28x28 MNIST image 85 | ``` 86 | 87 | ```python 88 | # Extract superpixels and create node features 89 | import scipy.ndimage 90 | from skimage.segmentation import slic 91 | from scipy.spatial.distance import cdist 92 | 93 | # The number (n_segments) of superpixels returned by SLIC is usually smaller than requested, so we request more 94 | superpixels = slic(img, n_segments=95, compactness=0.25, multichannel=False) 95 | sp_indices = np.unique(superpixels) 96 | n_sp = len(sp_indices) # should be 74 with these parameters of slic 97 | 98 | sp_intensity = np.zeros((n_sp, 1), np.float32) 99 | sp_coord = np.zeros((n_sp, 2), np.float32) # row, col 100 | for seg in sp_indices: 101 | mask = superpixels == seg 102 | sp_intensity[seg] = np.mean(img[mask]) 103 | sp_coord[seg] = np.array(scipy.ndimage.measurements.center_of_mass(mask)) 104 | 105 | # The model is invariant to the order of nodes in a graph 106 | # We can shuffle nodes and obtain exactly the same results 107 | ind = np.random.permutation(n_sp) 108 | sp_coord = sp_coord[ind] 109 | sp_intensity = sp_intensity[ind] 110 | ``` 111 | 112 | ```python 113 | # Create edges between nodes in the form of adjacency matrix 114 | sp_coord = sp_coord / images.shape[1] 115 | dist = cdist(sp_coord, sp_coord) # distance between all pairs of nodes 116 | sigma = 0.1 * np.pi # width of a Guassian 117 | A = np.exp(- dist / sigma ** 2) # transform distance to spatial closeness 118 | A[np.diag_indices_from(A)] = 0 # remove self-loops 119 | A = torch.from_numpy(A).float().unsqueeze(0) 120 | ``` 121 | 122 | ```python 123 | # Prepare an input to the model and process it 124 | N_nodes = sp_intensity.shape[0] 125 | mask = torch.ones(1, N_nodes, dtype=torch.uint8) 126 | 127 | # mean and std computed for superpixel features in the training set 128 | mn = torch.tensor([0.11225057, 0.11225057, 0.11225057, 0.44206527, 0.43950436]).view(1, 1, -1) 129 | sd = torch.tensor([0.2721889, 0.2721889, 0.2721889, 0.2987583, 0.30080357]).view(1, 1, -1) 130 | 131 | node_features = (torch.from_numpy(np.pad(np.concatenate((sp_intensity, sp_coord), axis=1), 132 | ((0, 0), (2, 0)), 'edge')).unsqueeze(0) - mn) / sd 133 | 134 | y, other_outputs = model([node_features, A, mask, None, {'N_nodes': torch.zeros(1, 1) + N_nodes}]) 135 | alpha = other_outputs['alpha'][0].data 136 | ``` 137 | 138 | - `y` is a vector with 10 unnormalized class scores. To get a predicted label, we can use ```torch.argmax(y)```. 139 | 140 | - `alpha` is a vector of attention coefficients alpha for each node. 141 | 142 | ## Tasks & Datasets 143 | 144 | 1. We design two synthetic graph tasks, COLORS and TRIANGLES, in which we predict the number of green nodes and the number of triangles respectively. 145 | 146 | 2. We also experiment with the [MNIST](http://yann.lecun.com/exdb/mnist/) image classification dataset, which we preprocess by extracting superpixels - a more natural way to feed images to a graph. We denote this dataset as MNIST-75sp. 147 | 148 | 3. We validate our weakly-supervised approach on three common graph classification benchmarks: [COLLAB, PROTEINS and D&D](https://ls11-www.cs.tu-dortmund.de/staff/morris/graphkerneldatasets). 149 | 150 | For COLORS, TRIANGLES and MNIST we know ground truth attention for nodes, which allows us to study graph neural networks with attention in depth. 151 | 152 |
153 | 154 | 155 | ## Data generation 156 | 157 | To generate all data using a single command: ```./scripts/prepare_data.sh```. 158 | 159 | All generated/downloaded ata will be stored in the local ```./data``` directory. 160 | It can take about 1 hour to prepare all data (see my [log](logs/prepare_data.log)) and all data take about 2 GB. 161 | 162 | Alternatively, you can generate data for each task as described below. 163 | 164 | In case of any issues with running these scripts, data can be downloaded from [here](https://drive.google.com/drive/folders/1Prc-n9Nr8-5z-xphdRScftKKIxU4Olzh?usp=sharing). 165 | 166 | ### COLORS 167 | To generate training, validation and test data for our **Colors** dataset with different dimensionalities: 168 | 169 | ```for dim in 3 8 16 32; do python generate_data.py --dim $dim; done``` 170 | 171 | ### MNIST-75sp 172 | To generate training and test data for our MNIST-75sp dataset using 4 CPU threads: 173 | 174 | ```for split in train test; do python extract_superpixels.py -s $split -t 4; done``` 175 | 176 | ## Data visualization 177 | Once datasets are generated or downloaded, you can use the following IPython notebooks to load and visualize data: 178 | 179 | [COLORS and TRIANGLES](notebooks/synthetic_graphs_visualize.ipynb), [MNIST](notebooks/superpixels_visualize.ipynb) and 180 | [COLLAB, PROTEINS and D&D](notebooks/graphs_visualize.ipynb). 181 | 182 | 183 | # Pretrained ChebyGIN models 184 | 185 | Generalization results on the test sets for three tasks. Other results are available in the paper. 186 | 187 | Click on the result to download a trained model in the PyTorch format. 188 | 189 | | Model | COLORS-Test-LargeC | TRIANGLES-Test-Large | MNIST-75sp-Test-Noisy 190 | | --------------------- |:-------------:|:-------------:|:-------------:| 191 | | Script to train models | [colors.sh](scripts/colors.sh) | [triangles.sh](scripts/triangles.sh) | [mnist_75sp.sh](./scripts/mnist_75sp.sh) | 192 | | Global pooling | [15 ± 7](./checkpoints/checkpoint_colors-3_828931_epoch100_seed0000111.pth.tar) | [30 ± 1](./checkpoints/checkpoint_triangles_658037_epoch100_seed0000111.pth.tar) | [80 ± 12](./checkpoints/checkpoint_mnist-75sp_820601_epoch30_seed0000111.pth.tar) | 193 | | Unsupervised attention | [11 ± 6](./checkpoints/checkpoint_colors-3_223919_epoch300_seed0000111.pth.tar) | [26 ± 2](./checkpoints//checkpoint_triangles_051609_epoch100_seed0000111.pth.tar) | [80 ± 23](./checkpoints/checkpoint_mnist-75sp_330394_epoch30_seed0000111.pth.tar) | 194 | | Supervised attention | [75 ± 17](./checkpoints/checkpoint_colors-3_332172_epoch300_seed0000111.pth.tar) | [48 ± 1](./checkpoints/checkpoint_triangles_586710_epoch100_seed0000111.pth.tar) | [92.3 ± 0.4](./checkpoints/checkpoint_mnist-75sp_139255_epoch30_seed0000111.pth.tar) | 195 | | Weakly-supervised attention | [73 ± 14 ](./checkpoints//checkpoint_colors-3_312570_epoch300_seed0000111.pth.tar) | [30 ± 1](./checkpoints/checkpoint_triangles_230187_epoch100_seed0000111.pth.tar) | [88.8 ± 4](./checkpoints/checkpoint_mnist-75sp_065802_epoch30_seed0000111.pth.tar) | | 196 | 197 | 198 | The scripts to train the models must be run from the main directory, e.g.: ```./scripts/mnist_75sp.sh``` 199 | 200 | Examples of evaluating our trained models can be found in notebooks: [MNIST_eval_models](notebooks/MNIST_eval_models.ipynb) and 201 | [TRIANGLES_eval_models](notebooks/TRIANGLES_eval_models.ipynb). 202 | 203 | 204 | ## Other examples of training models 205 | 206 | To tune hyperparameters on the validation set for COLORS, TRIANGLES and MNIST, use the ```--validation``` flag. 207 | 208 | For COLLAB, PROTEINS and D&D tuning of hyperparameters is included in the training script. Use the `--ax` flag. 209 | 210 | Example of running 10 weakly-supervised experiments on **PROTEINS** with cross-validation of hyperparameters *including initialization parameters (distribution and scale) of the attention model* (the `--tune_init` flag): 211 | 212 | ``` 213 | for i in $(seq 1 1 10); do dataseed=$(( ( RANDOM % 10000 ) + 1 )); for j in $(seq 1 1 10); do seed=$(( ( RANDOM % 10000 ) + 1 )); python main.py --seed $seed -D TU --n_nodes 25 --epochs 50 --lr_decay_step 25,35,45 --test_batch_size 100 -f 64,64,64 -K 1 --readout max --dropout 0.1 --pool attn_sup_threshold_skip_skip_0 --pool_arch fc_prev --results None --data_dir ./data/PROTEINS --seed_data $dataseed --cv --cv_folds 5 --cv_threads 5 --ax --ax_trials 30 --scale None --tune_init | tee logs/proteins_wsup_"$dataseed"_"$seed".log; done; done 214 | ``` 215 | 216 | No initialization tuning on **COLLAB**: 217 | 218 | ``` 219 | for i in $(seq 1 1 10); do dataseed=$(( ( RANDOM % 10000 ) + 1 )); for j in $(seq 1 1 10); do seed=$(( ( RANDOM % 10000 ) + 1 )); python main.py --seed $seed -D TU --n_nodes 35 --epochs 50 --lr_decay_step 25,35,45 --test_batch_size 32 -f 64,64,64 -K 3 --readout max --dropout 0.1 --pool attn_sup_threshold_skip_skip_skip_0 --pool_arch fc_prev --results None --data_dir ./data/COLLAB --seed_data $dataseed --cv --cv_folds 5 --cv_threads 5 --ax --ax_trials 30 --scale None | tee logs/collab_wsup_"$dataseed"_"$seed".log; done; done 220 | ``` 221 | 222 | Note that results can be better if using `--pool_arch gnn_prev`, but we didn't focus on that. 223 | 224 | # Requirements 225 | 226 | Python packages required (can be installed via pip or conda): 227 | 228 | - python >= 3.6.1 229 | - PyTorch >= 0.4.1 230 | - [Ax](https://github.com/facebook/Ax) for hyper-parameter tuning on COLLAB, PROTEINS and D\&D 231 | - networkx 232 | - OpenCV 233 | - SciPy 234 | - scikit-image 235 | - scikit-learn 236 | 237 | # Reference 238 | 239 | Please cite our paper if you use our data or code: 240 | 241 | ``` 242 | @inproceedings{knyazev2019understanding, 243 | title={Understanding attention and generalization in graph neural networks}, 244 | author={Knyazev, Boris and Taylor, Graham W and Amer, Mohamed}, 245 | booktitle={Advances in Neural Information Processing Systems}, 246 | pages={4202--4212}, 247 | year={2019}, 248 | pdf={http://arxiv.org/abs/1905.02850} 249 | } 250 | ``` 251 | -------------------------------------------------------------------------------- /attention_pooling.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.sparse 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from utils import * 7 | 8 | 9 | class AttentionPooling(nn.Module): 10 | ''' 11 | Graph pooling layer implementing top-k and threshold-based pooling. 12 | ''' 13 | def __init__(self, 14 | in_features, # feature dimensionality in the current graph layer 15 | in_features_prev, # feature dimensionality in the previous graph layer 16 | pool_type, 17 | pool_arch, 18 | large_graph, 19 | attn_gnn=None, 20 | kl_weight=None, 21 | drop_nodes=True, 22 | init='normal', 23 | scale=None, 24 | debug=False): 25 | super(AttentionPooling, self).__init__() 26 | self.pool_type = pool_type 27 | self.pool_arch = pool_arch 28 | self.large_graph = large_graph 29 | self.kl_weight = kl_weight 30 | self.proj = None 31 | self.drop_nodes = drop_nodes 32 | self.is_topk = self.pool_type[2].lower() == 'topk' 33 | self.scale =scale 34 | self.init = init 35 | self.debug = debug 36 | self.clamp_value = 60 37 | self.torch = torch.__version__ 38 | if self.is_topk: 39 | self.topk_ratio = float(self.pool_type[3]) # r 40 | assert self.topk_ratio > 0 and self.topk_ratio <= 1, ('invalid top-k ratio', self.topk_ratio, self.pool_type) 41 | else: 42 | self.threshold = float(self.pool_type[3]) # \tilde{alpha} 43 | assert self.threshold >= 0 and self.threshold <= 1, ('invalid pooling threshold', self.threshold, self.pool_type) 44 | 45 | if self.pool_type[1] in ['unsup', 'sup']: 46 | assert self.pool_arch not in [None, 'None'], self.pool_arch 47 | 48 | n_in = in_features_prev if self.pool_arch[1] == 'prev' else in_features 49 | if self.pool_arch[0] == 'fc': 50 | p_optimal = torch.from_numpy(np.pad(np.array([0, 1]), (0, n_in - 2), 'constant')).float().view(1, n_in) 51 | if len(self.pool_arch) == 2: 52 | # single layer projection 53 | self.proj = nn.Linear(n_in, 1, bias=False) 54 | p = self.proj.weight.data 55 | if scale is not None: 56 | if init == 'normal': 57 | p = torch.randn(n_in) # std=1, seed 9753 for optimal initialization 58 | elif init == 'uniform': 59 | p = torch.rand(n_in) * 2 - 1 # [-1,1] 60 | else: 61 | raise NotImplementedError(init) 62 | p *= scale # multiply std for normal or change range for uniform 63 | else: 64 | print('Default PyTorch init is used for layer %s, std=%.3f' % (str(p.shape), p.std())) 65 | self.proj.weight.data = p.view_as(self.proj.weight.data) 66 | p = self.proj.weight.data.view(1, n_in) 67 | else: 68 | # multi-layer projection 69 | filters = list(map(int, self.pool_arch[2:])) 70 | self.proj = [] 71 | for layer in range(len(filters)): 72 | self.proj.append(nn.Linear(in_features=n_in if layer == 0 else filters[layer - 1], 73 | out_features=filters[layer])) 74 | if layer == 0: 75 | p = self.proj[0].weight.data 76 | if scale is not None: 77 | if init == 'normal': 78 | p = torch.randn(filters[layer], n_in) 79 | elif init == 'uniform': 80 | p = torch.rand(filters[layer], n_in) * 2 - 1 # [-1,1] 81 | else: 82 | raise NotImplementedError(init) 83 | p *= scale # multiply std for normal or change range for uniform 84 | else: 85 | print('Default PyTorch init is used for layer %s, std=%.3f' % (str(p.shape), p.std())) 86 | self.proj[0].weight.data = p.view_as(self.proj[0].weight.data) 87 | p = self.proj[0].weight.data.view(-1, n_in) 88 | self.proj.append(nn.ReLU(True)) 89 | 90 | self.proj.append(nn.Linear(filters[-1], 1)) 91 | self.proj = nn.Sequential(*self.proj) 92 | 93 | # Compute cosine similarity with the optimal vector and print values 94 | # ignore the last dimension, because it does not receive gradients during training 95 | # n_in=4 for colors-3 because some of our test subsets have 4 dimensional features 96 | cos_sim = self.cosine_sim(p[:, :-1], p_optimal[:, :-1]) 97 | if p.shape[0] == 1: 98 | print('p values', p[0].data.cpu().numpy()) 99 | print('cos_sim', cos_sim.item()) 100 | else: 101 | for fn in [torch.max, torch.min, torch.mean, torch.std]: 102 | print('cos_sim', fn(cos_sim).item()) 103 | elif self.pool_arch[0] == 'gnn': 104 | self.proj = attn_gnn(n_in) 105 | else: 106 | raise ValueError('invalid pooling layer architecture', self.pool_arch) 107 | 108 | elif self.pool_type[1] == 'gt': 109 | if not self.is_topk and self.threshold > 0: 110 | print('For ground truth attention threshold should be 0, but it is %f' % self.threshold) 111 | else: 112 | raise NotImplementedError(self.pool_type[1]) 113 | 114 | def __repr__(self): 115 | return 'AttentionPooling(pool_type={}, pool_arch={}, topk={}, kl_weight={}, init={}, scale={}, proj={})'.format( 116 | self.pool_type, 117 | self.pool_arch, 118 | self.is_topk, 119 | self.kl_weight, 120 | self.init, 121 | self.scale, 122 | self.proj) 123 | 124 | def cosine_sim(self, a, b): 125 | return torch.mm(a, b.t()) / (torch.norm(a, dim=1, keepdim=True) * torch.norm(b, dim=1, keepdim=True)) 126 | 127 | def mask_out(self, x, mask): 128 | return x.view_as(mask) * mask 129 | 130 | def drop_nodes_edges(self, x, A, mask): 131 | N_nodes = torch.sum(mask, dim=1).long() # B 132 | N_nodes_max = N_nodes.max() 133 | idx = None 134 | if N_nodes_max > 0: 135 | B, N, C = x.shape 136 | # Drop nodes 137 | mask, idx = torch.topk(mask.byte(), N_nodes_max, dim=1, largest=True, sorted=False) 138 | x = torch.gather(x, dim=1, index=idx.unsqueeze(2).expand(-1, -1, C)) 139 | # Drop edges 140 | A = torch.gather(A, dim=1, index=idx.unsqueeze(2).expand(-1, -1, N)) 141 | A = torch.gather(A, dim=2, index=idx.unsqueeze(1).expand(-1, N_nodes_max, -1)) 142 | 143 | return x, A, mask, N_nodes, idx 144 | 145 | def forward(self, data): 146 | 147 | KL_loss = None 148 | x, A, mask, _, params_dict = data[:5] 149 | 150 | mask_float = mask.float() 151 | N_nodes_float = params_dict['N_nodes'].float() 152 | B, N, C = x.shape 153 | A = A.view(B, N, N) 154 | alpha_gt = None 155 | if 'node_attn' in params_dict: 156 | if not isinstance(params_dict['node_attn'], list): 157 | params_dict['node_attn'] = [params_dict['node_attn']] 158 | alpha_gt = params_dict['node_attn'][-1].view(B, N) 159 | if 'node_attn_eval' in params_dict: 160 | if not isinstance(params_dict['node_attn_eval'], list): 161 | params_dict['node_attn_eval'] = [params_dict['node_attn_eval']] 162 | 163 | if (self.pool_type[1] == 'gt' or (self.pool_type[1] == 'sup' and self.training)) and alpha_gt is None: 164 | raise ValueError('ground truth node attention values node_attn required for %s' % self.pool_type) 165 | 166 | if self.pool_type[1] in ['unsup', 'sup']: 167 | attn_input = data[-1] if self.pool_arch[1] == 'prev' else x.clone() 168 | if self.pool_arch[0] == 'fc': 169 | alpha_pre = self.proj(attn_input).view(B, N) 170 | else: 171 | # to support python2 172 | input = [attn_input] 173 | input.extend(data[1:]) 174 | alpha_pre = self.proj(input)[0].view(B, N) 175 | # softmax with masking out dummy nodes 176 | alpha_pre = torch.clamp(alpha_pre, -self.clamp_value, self.clamp_value) 177 | alpha = normalize_batch(self.mask_out(torch.exp(alpha_pre), mask_float).view(B, N)) 178 | if self.pool_type[1] == 'sup' and self.training: 179 | if self.torch.find('1.') == 0: 180 | KL_loss_per_node = self.mask_out(F.kl_div(torch.log(alpha + 1e-14), alpha_gt, reduction='none'), 181 | mask_float.view(B,N)) 182 | else: 183 | KL_loss_per_node = self.mask_out(F.kl_div(torch.log(alpha + 1e-14), alpha_gt, reduce=False), 184 | mask_float.view(B, N)) 185 | KL_loss = self.kl_weight * torch.mean(KL_loss_per_node.sum(dim=1) / (N_nodes_float + 1e-7)) # mean over nodes, then mean over batches 186 | else: 187 | alpha = alpha_gt 188 | 189 | x = x * alpha.view(B, N, 1) 190 | if self.large_graph: 191 | # For large graphs during training, all alpha values can be very small hindering training 192 | x = x * N_nodes_float.view(B, 1, 1) 193 | if self.is_topk: 194 | N_remove = torch.round(N_nodes_float * (1 - self.topk_ratio)).long() # number of nodes to be removed for each graph 195 | idx = torch.sort(alpha, dim=1, descending=False)[1] # indices of alpha in ascending order 196 | mask = mask.clone().view(B, N) 197 | for b in range(B): 198 | idx_b = idx[b, mask[b, idx[b]]] # take indices of non-dummy nodes for current data example 199 | mask[b, idx_b[:N_remove[b]]] = 0 200 | else: 201 | mask = (mask & (alpha.view_as(mask) > self.threshold)).view(B, N) 202 | 203 | if self.drop_nodes: 204 | x, A, mask, N_nodes_pooled, idx = self.drop_nodes_edges(x, A, mask) 205 | if idx is not None and 'node_attn' in params_dict: 206 | # update ground truth (or weakly labeled) attention for a reduced graph 207 | params_dict['node_attn'].append(normalize_batch(self.mask_out(torch.gather(alpha_gt, dim=1, index=idx), mask.float()))) 208 | if idx is not None and 'node_attn_eval' in params_dict: 209 | # update ground truth (or weakly labeled) attention for a reduced graph 210 | params_dict['node_attn_eval'].append(normalize_batch(self.mask_out(torch.gather(params_dict['node_attn_eval'][-1], dim=1, index=idx), mask.float()))) 211 | else: 212 | N_nodes_pooled = torch.sum(mask, dim=1).long() # B 213 | if 'node_attn' in params_dict: 214 | params_dict['node_attn'].append((self.mask_out(params_dict['node_attn'][-1], mask.float()))) 215 | if 'node_attn_eval' in params_dict: 216 | params_dict['node_attn_eval'].append((self.mask_out(params_dict['node_attn_eval'][-1], mask.float()))) 217 | 218 | params_dict['N_nodes'] = N_nodes_pooled 219 | 220 | mask_matrix = mask.unsqueeze(2) & mask.unsqueeze(1) 221 | A = A * mask_matrix.float() # or A[~mask_matrix] = 0 222 | 223 | # Add additional losses regularizing the model 224 | if KL_loss is not None: 225 | if 'reg' not in params_dict: 226 | params_dict['reg'] = [] 227 | params_dict['reg'].append(KL_loss) 228 | 229 | # Keep attention coefficients for evaluation 230 | for key, value in zip(['alpha', 'mask'], [alpha, mask]): 231 | if key not in params_dict: 232 | params_dict[key] = [] 233 | params_dict[key].append(value.detach()) 234 | 235 | if self.debug and alpha_gt is not None: 236 | idx_correct_pool = (alpha_gt > 0) 237 | idx_correct_drop = (alpha_gt == 0) 238 | alpha_correct_pool = alpha[idx_correct_pool].sum() / N_nodes_float.sum() 239 | alpha_correct_drop = alpha[idx_correct_drop].sum() / N_nodes_float.sum() 240 | ratio_avg = (N_nodes_pooled.float() / N_nodes_float).mean() 241 | 242 | for key, values in zip(['alpha_correct_pool_debug', 'alpha_correct_drop_debug', 'ratio_avg_debug'], 243 | [alpha_correct_pool, alpha_correct_drop, ratio_avg]): 244 | if key not in params_dict: 245 | params_dict[key] = [] 246 | params_dict[key].append(values.detach()) 247 | 248 | output = [x, A, mask] 249 | output.extend(data[3:]) 250 | return output 251 | -------------------------------------------------------------------------------- /chebygin.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.sparse 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch.nn.parameter import Parameter 7 | from attention_pooling import * 8 | from utils import * 9 | 10 | 11 | class ChebyGINLayer(nn.Module): 12 | ''' 13 | General Graph Neural Network layer that depending on arguments can be: 14 | 1. Graph Convolution Layer (T. Kipf and M. Welling, ICLR 2017) 15 | 2. Chebyshev Graph Convolution Layer (M. Defferrard et al., NeurIPS 2017) 16 | 3. GIN Layer (K. Xu et al., ICLR 2019) 17 | 4. ChebyGIN Layer (B. Knyazev et al., ICLR 2019 Workshop on Representation Learning on Graphs and Manifolds) 18 | The first three types (1-3) of layers are particular cases of the fourth (4) case. 19 | ''' 20 | def __init__(self, 21 | in_features, 22 | out_features, 23 | K, 24 | n_hidden=0, 25 | aggregation='mean', 26 | activation=nn.ReLU(True), 27 | n_relations=1): 28 | super(ChebyGINLayer, self).__init__() 29 | self.in_features = in_features 30 | self.out_features = out_features 31 | self.n_relations = n_relations 32 | assert K > 0, 'order is assumed to be > 0' 33 | self.K = K 34 | assert n_hidden >= 0, ('invalid n_hidden value', n_hidden) 35 | self.n_hidden = n_hidden 36 | assert aggregation in ['mean', 'sum'], ('invalid aggregation', aggregation) 37 | self.aggregation = aggregation 38 | self.activation = activation 39 | n_in = self.in_features * self.K * n_relations 40 | if self.n_hidden == 0: 41 | fc = [nn.Linear(n_in, self.out_features)] 42 | else: 43 | fc = [nn.Linear(n_in, n_hidden), 44 | nn.ReLU(True), 45 | nn.Linear(n_hidden, self.out_features)] 46 | if activation is not None: 47 | fc.append(activation) 48 | self.fc = nn.Sequential(*fc) 49 | print('ChebyGINLayer', list(self.fc.children())[0].weight.shape, 50 | torch.norm(list(self.fc.children())[0].weight, dim=1)[:10]) 51 | 52 | def __repr__(self): 53 | return 'ChebyGINLayer(in_features={}, out_features={}, K={}, n_hidden={}, aggregation={})\nfc={}'.format( 54 | self.in_features, 55 | self.out_features, 56 | self.K, 57 | self.n_hidden, 58 | self.aggregation, 59 | str(self.fc)) 60 | 61 | def chebyshev_basis(self, L, X, K): 62 | ''' 63 | Return T_k X where T_k are the Chebyshev polynomials of order up to K. 64 | :param L: graph Laplacian, batch (B), nodes (N), nodes (N) 65 | :param X: input of size batch (B), nodes (N), features (F) 66 | :param K: Chebyshev polynomial order, i.e. filter size (number of hopes) 67 | :return: Tensor of size (B,N,K,F) as a result of multiplying T_k(L) by X for each order 68 | ''' 69 | if K > 1: 70 | Xt = [X] 71 | Xt.append(torch.bmm(L, X)) # B,N,F 72 | for k in range(2, K): 73 | Xt.append(2 * torch.bmm(L, Xt[k - 1]) - Xt[k - 2]) # B,N,F 74 | Xt = torch.stack(Xt, 2) # B,N,K,F 75 | return Xt 76 | else: 77 | # GCN 78 | assert K == 1, K 79 | return torch.bmm(L, X).unsqueeze(2) # B,N,1,F 80 | 81 | def laplacian_batch(self, A, add_identity=False): 82 | ''' 83 | Computes normalized Laplacian transformed so that its eigenvalues are in range [-1, 1]. 84 | Note that sum of all eigenvalues = trace(L) = 0. 85 | :param A: Tensor of size (B,N,N) containing batch (B) of adjacency matrices of shape N,N 86 | :return: Normalized Laplacian of size (B,N,N) 87 | ''' 88 | B, N = A.shape[:2] 89 | if add_identity: 90 | A = A + torch.eye(N, device=A.get_device() if A.is_cuda else 'cpu').unsqueeze(0) 91 | D = torch.sum(A, 1) # nodes degree (B,N) 92 | D_hat = (D + 1e-5) ** (-0.5) 93 | L = D_hat.view(B, N, 1) * A * D_hat.view(B, 1, N) # B,N,N 94 | if not add_identity: 95 | L = -L # for ChebyNet to make a valid Chebyshev basis 96 | return D, L 97 | 98 | def forward(self, data): 99 | x, A, mask = data[:3] 100 | B, N, F = x.shape 101 | assert N == A.shape[1] == A.shape[2], ('invalid shape', N, x.shape, A.shape) 102 | 103 | if len(A.shape) == 3: 104 | A = A.unsqueeze(3) 105 | 106 | y_out = [] 107 | for rel in range(A.shape[3]): 108 | D, L = self.laplacian_batch(A[:, :, :, rel], add_identity=self.K == 1) # for the first layer this can be done at the preprocessing stage 109 | y = self.chebyshev_basis(L, x, self.K) # B,N,K,F 110 | 111 | if self.aggregation == 'sum': 112 | # Sum features of neighbors 113 | if self.K == 1: 114 | # GIN 115 | y = y * D.view(B, N, 1, 1) 116 | else: 117 | # ChebyGIN 118 | D_GIN = torch.ones(B, N, self.K, device=x.get_device() if x.is_cuda else 'cpu') 119 | D_GIN[:, :, 1:] = D.view(B, N, 1).expand(-1, -1, self.K - 1) # keep self-loop features the same 120 | y = y * D_GIN.view(B, N, self.K, 1) # apply summation for other scales 121 | 122 | y_out.append(y) 123 | 124 | y = torch.cat(y_out, dim=2) 125 | y = self.fc(y.view(B, N, -1)) # B,N,F 126 | 127 | if len(mask.shape) == 2: 128 | mask = mask.unsqueeze(2) 129 | 130 | y = y * mask.float() 131 | output = [y, A, mask] 132 | output.extend(data[3:] + [x]) # for python2 133 | 134 | return output 135 | 136 | 137 | class GraphReadout(nn.Module): 138 | ''' 139 | Global pooling layer applied after the last graph layer. 140 | ''' 141 | def __init__(self, 142 | pool_type): 143 | super(GraphReadout, self).__init__() 144 | self.pool_type = pool_type 145 | dim = 1 # pooling over nodes 146 | if pool_type == 'max': 147 | self.readout_layer = lambda x, mask: torch.max(x, dim=dim)[0] 148 | elif pool_type in ['avg', 'mean']: 149 | # sum over all nodes, then divide by the number of valid nodes in each sample of the batch 150 | self.readout_layer = lambda x, mask: torch.sum(x, dim=dim) / torch.sum(mask, dim=dim).float() 151 | elif pool_type in ['sum']: 152 | self.readout_layer = lambda x, mask: torch.sum(x, dim=dim) 153 | else: 154 | raise NotImplementedError(pool_type) 155 | 156 | def __repr__(self): 157 | return 'GraphReadout({})'.format(self.pool_type) 158 | 159 | def forward(self, data): 160 | x, A, mask = data[:3] 161 | B, N = x.shape[:2] 162 | x = self.readout_layer(x, mask.view(B, N, 1)) 163 | output = [x] 164 | output.extend(data[1:]) # [x, *data[1:]] doesn't work in Python2 165 | return output 166 | 167 | 168 | class ChebyGIN(nn.Module): 169 | ''' 170 | Graph Neural Network class. 171 | ''' 172 | def __init__(self, 173 | in_features, 174 | out_features, 175 | filters, 176 | K=1, 177 | n_hidden=0, 178 | aggregation='mean', 179 | dropout=0, 180 | readout='max', 181 | pool=None, # Example: 'attn_gt_threshold_0_skip_skip'.split('_'), 182 | pool_arch='fc_prev'.split('_'), 183 | large_graph=False, # > ~500 graphs 184 | kl_weight=None, 185 | graph_layer_fn=None, 186 | init='normal', 187 | scale=None, 188 | debug=False): 189 | super(ChebyGIN, self).__init__() 190 | self.out_features = out_features 191 | assert len(filters) > 0, 'filters must be an iterable object with at least one element' 192 | assert K > 0, 'filter scale must be a positive integer' 193 | self.pool = pool 194 | self.pool_arch = pool_arch 195 | self.debug = debug 196 | n_prev = None 197 | 198 | attn_gnn = None 199 | if graph_layer_fn is None: 200 | graph_layer_fn = lambda n_in, n_out, K_, n_hidden_, activation: ChebyGINLayer(in_features=n_in, 201 | out_features=n_out, 202 | K=K_, 203 | n_hidden=n_hidden_, 204 | aggregation=aggregation, 205 | activation=activation) 206 | if self.pool_arch is not None and self.pool_arch[0] == 'gnn': 207 | attn_gnn = lambda n_in: ChebyGIN(in_features=n_in, 208 | out_features=0, 209 | filters=[32, 32, 1], 210 | K=np.min((K, 2)), 211 | n_hidden=0, 212 | graph_layer_fn=graph_layer_fn) 213 | 214 | graph_layers = [] 215 | 216 | for layer, f in enumerate(filters + [None]): 217 | 218 | n_in = in_features if layer == 0 else filters[layer - 1] 219 | # Pooling layers 220 | # It's a non-standard way to put pooling before convolution, but it's important for our work 221 | if self.pool is not None and len(self.pool) > len(filters) + layer and self.pool[layer + 3] != 'skip': 222 | graph_layers.append(AttentionPooling(in_features=n_in, in_features_prev=n_prev, 223 | pool_type=self.pool[:3] + [self.pool[layer + 3]], 224 | pool_arch=self.pool_arch, 225 | large_graph=large_graph, 226 | kl_weight=kl_weight, 227 | attn_gnn=attn_gnn, 228 | init=init, 229 | scale=scale, 230 | debug=debug)) 231 | 232 | if f is not None: 233 | # Graph "convolution" layers 234 | # no ReLU if the last layer and no fc layer after that 235 | graph_layers.append(graph_layer_fn(n_in, f, K, n_hidden, 236 | None if self.out_features == 0 and layer == len(filters) - 1 else nn.ReLU(True))) 237 | n_prev = n_in 238 | 239 | if self.out_features > 0: 240 | # Global pooling over nodes 241 | graph_layers.append(GraphReadout(readout)) 242 | self.graph_layers = nn.Sequential(*graph_layers) 243 | 244 | if self.out_features > 0: 245 | # Fully connected (classification/regression) layers 246 | self.fc = nn.Sequential(*(([nn.Dropout(p=dropout)] if dropout > 0 else []) + [nn.Linear(filters[-1], out_features)])) 247 | 248 | def forward(self, data): 249 | data = self.graph_layers(data) 250 | if self.out_features > 0: 251 | y = self.fc(data[0]) # B,out_features 252 | else: 253 | y = data[0] # B,N,out_features 254 | return y, data[4] 255 | -------------------------------------------------------------------------------- /checkpoints/checkpoint_TU_657452_epoch50_seed0000111.pth.tar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bknyaz/graph_attention_pool/00a6771cf864048f94e8bec19b5c8225a3ed2b0e/checkpoints/checkpoint_TU_657452_epoch50_seed0000111.pth.tar -------------------------------------------------------------------------------- /checkpoints/checkpoint_colors-3_223919_epoch300_seed0000111.pth.tar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bknyaz/graph_attention_pool/00a6771cf864048f94e8bec19b5c8225a3ed2b0e/checkpoints/checkpoint_colors-3_223919_epoch300_seed0000111.pth.tar -------------------------------------------------------------------------------- /checkpoints/checkpoint_colors-3_312570_epoch300_seed0000111.pth.tar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bknyaz/graph_attention_pool/00a6771cf864048f94e8bec19b5c8225a3ed2b0e/checkpoints/checkpoint_colors-3_312570_epoch300_seed0000111.pth.tar -------------------------------------------------------------------------------- /checkpoints/checkpoint_colors-3_332172_epoch300_seed0000111.pth.tar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bknyaz/graph_attention_pool/00a6771cf864048f94e8bec19b5c8225a3ed2b0e/checkpoints/checkpoint_colors-3_332172_epoch300_seed0000111.pth.tar -------------------------------------------------------------------------------- /checkpoints/checkpoint_colors-3_828931_epoch100_seed0000111.pth.tar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bknyaz/graph_attention_pool/00a6771cf864048f94e8bec19b5c8225a3ed2b0e/checkpoints/checkpoint_colors-3_828931_epoch100_seed0000111.pth.tar -------------------------------------------------------------------------------- /checkpoints/checkpoint_mnist-75sp_065802_epoch30_seed0000111.pth.tar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bknyaz/graph_attention_pool/00a6771cf864048f94e8bec19b5c8225a3ed2b0e/checkpoints/checkpoint_mnist-75sp_065802_epoch30_seed0000111.pth.tar -------------------------------------------------------------------------------- /checkpoints/checkpoint_mnist-75sp_139255_epoch30_seed0000111.pth.tar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bknyaz/graph_attention_pool/00a6771cf864048f94e8bec19b5c8225a3ed2b0e/checkpoints/checkpoint_mnist-75sp_139255_epoch30_seed0000111.pth.tar -------------------------------------------------------------------------------- /checkpoints/checkpoint_mnist-75sp_330394_epoch30_seed0000111.pth.tar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bknyaz/graph_attention_pool/00a6771cf864048f94e8bec19b5c8225a3ed2b0e/checkpoints/checkpoint_mnist-75sp_330394_epoch30_seed0000111.pth.tar -------------------------------------------------------------------------------- /checkpoints/checkpoint_mnist-75sp_820601_epoch30_seed0000111.pth.tar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bknyaz/graph_attention_pool/00a6771cf864048f94e8bec19b5c8225a3ed2b0e/checkpoints/checkpoint_mnist-75sp_820601_epoch30_seed0000111.pth.tar -------------------------------------------------------------------------------- /checkpoints/checkpoint_triangles_051609_epoch100_seed0000111.pth.tar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bknyaz/graph_attention_pool/00a6771cf864048f94e8bec19b5c8225a3ed2b0e/checkpoints/checkpoint_triangles_051609_epoch100_seed0000111.pth.tar -------------------------------------------------------------------------------- /checkpoints/checkpoint_triangles_230187_epoch100_seed0000111.pth.tar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bknyaz/graph_attention_pool/00a6771cf864048f94e8bec19b5c8225a3ed2b0e/checkpoints/checkpoint_triangles_230187_epoch100_seed0000111.pth.tar -------------------------------------------------------------------------------- /checkpoints/checkpoint_triangles_586710_epoch100_seed0000111.pth.tar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bknyaz/graph_attention_pool/00a6771cf864048f94e8bec19b5c8225a3ed2b0e/checkpoints/checkpoint_triangles_586710_epoch100_seed0000111.pth.tar -------------------------------------------------------------------------------- /checkpoints/checkpoint_triangles_658037_epoch100_seed0000111.pth.tar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bknyaz/graph_attention_pool/00a6771cf864048f94e8bec19b5c8225a3ed2b0e/checkpoints/checkpoint_triangles_658037_epoch100_seed0000111.pth.tar -------------------------------------------------------------------------------- /checkpoints/colors-3_alpha_WS_test_seed111_orig.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bknyaz/graph_attention_pool/00a6771cf864048f94e8bec19b5c8225a3ed2b0e/checkpoints/colors-3_alpha_WS_test_seed111_orig.pkl -------------------------------------------------------------------------------- /checkpoints/colors-3_alpha_WS_train_seed111_orig.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bknyaz/graph_attention_pool/00a6771cf864048f94e8bec19b5c8225a3ed2b0e/checkpoints/colors-3_alpha_WS_train_seed111_orig.pkl -------------------------------------------------------------------------------- /checkpoints/mnist-75sp_alpha_WS_test_seed111_noisy-c.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bknyaz/graph_attention_pool/00a6771cf864048f94e8bec19b5c8225a3ed2b0e/checkpoints/mnist-75sp_alpha_WS_test_seed111_noisy-c.pkl -------------------------------------------------------------------------------- /checkpoints/mnist-75sp_alpha_WS_test_seed111_noisy.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bknyaz/graph_attention_pool/00a6771cf864048f94e8bec19b5c8225a3ed2b0e/checkpoints/mnist-75sp_alpha_WS_test_seed111_noisy.pkl -------------------------------------------------------------------------------- /checkpoints/mnist-75sp_alpha_WS_test_seed111_orig.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bknyaz/graph_attention_pool/00a6771cf864048f94e8bec19b5c8225a3ed2b0e/checkpoints/mnist-75sp_alpha_WS_test_seed111_orig.pkl -------------------------------------------------------------------------------- /checkpoints/mnist-75sp_alpha_WS_train_seed111_orig.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bknyaz/graph_attention_pool/00a6771cf864048f94e8bec19b5c8225a3ed2b0e/checkpoints/mnist-75sp_alpha_WS_train_seed111_orig.pkl -------------------------------------------------------------------------------- /checkpoints/triangles_alpha_WS_test_seed111_orig.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bknyaz/graph_attention_pool/00a6771cf864048f94e8bec19b5c8225a3ed2b0e/checkpoints/triangles_alpha_WS_test_seed111_orig.pkl -------------------------------------------------------------------------------- /checkpoints/triangles_alpha_WS_train_seed111_orig.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bknyaz/graph_attention_pool/00a6771cf864048f94e8bec19b5c8225a3ed2b0e/checkpoints/triangles_alpha_WS_train_seed111_orig.pkl -------------------------------------------------------------------------------- /data/COLORS-3.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bknyaz/graph_attention_pool/00a6771cf864048f94e8bec19b5c8225a3ed2b0e/data/COLORS-3.zip -------------------------------------------------------------------------------- /data/TRIANGLES.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bknyaz/graph_attention_pool/00a6771cf864048f94e8bec19b5c8225a3ed2b0e/data/TRIANGLES.zip -------------------------------------------------------------------------------- /data/datasets.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bknyaz/graph_attention_pool/00a6771cf864048f94e8bec19b5c8225a3ed2b0e/data/datasets.png -------------------------------------------------------------------------------- /data/mnist_animation.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bknyaz/graph_attention_pool/00a6771cf864048f94e8bec19b5c8225a3ed2b0e/data/mnist_animation.gif -------------------------------------------------------------------------------- /data/triangles_animation.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bknyaz/graph_attention_pool/00a6771cf864048f94e8bec19b5c8225a3ed2b0e/data/triangles_animation.gif -------------------------------------------------------------------------------- /extract_superpixels.py: -------------------------------------------------------------------------------- 1 | # Compute superpixels for MNIST/CIFAR-10 using SLIC algorithm 2 | # https://scikit-image.org/docs/dev/api/skimage.segmentation.html#skimage.segmentation.slic 3 | 4 | import numpy as np 5 | import random 6 | import os 7 | import scipy 8 | import pickle 9 | from skimage.segmentation import slic 10 | from torchvision import datasets 11 | import multiprocessing as mp 12 | import scipy.ndimage 13 | import scipy.spatial 14 | import argparse 15 | import datetime 16 | 17 | 18 | def parse_args(): 19 | parser = argparse.ArgumentParser(description='Extract SLIC superpixels from images') 20 | parser.add_argument('-D', '--dataset', type=str, default='mnist', choices=['mnist', 'cifar10']) 21 | parser.add_argument('-d', '--data_dir', type=str, default='./data', help='path to the dataset') 22 | parser.add_argument('-o', '--out_dir', type=str, default='./data', help='path where to save superpixels') 23 | parser.add_argument('-s', '--split', type=str, default='train', choices=['train', 'val', 'test']) 24 | parser.add_argument('-t', '--threads', type=int, default=0, help='number of parallel threads') 25 | parser.add_argument('-n', '--n_sp', type=int, default=75, help='max number of superpixels per image') 26 | parser.add_argument('-c', '--compactness', type=int, default=0.25, help='compactness of the SLIC algorithm ' 27 | '(Balances color proximity and space proximity): ' 28 | '0.25 is a good value for MNIST ' 29 | 'and 10 for color images like CIFAR-10') 30 | parser.add_argument('--seed', type=int, default=111, help='seed for shuffling nodes') 31 | args = parser.parse_args() 32 | 33 | for arg in vars(args): 34 | print(arg, getattr(args, arg)) 35 | 36 | return args 37 | 38 | 39 | def process_image(params): 40 | 41 | img, index, n_images, args, to_print, shuffle = params 42 | 43 | assert img.dtype == np.uint8, img.dtype 44 | img = (img / 255.).astype(np.float32) 45 | 46 | n_sp_extracted = args.n_sp + 1 # number of actually extracted superpixels (can be different from requested in SLIC) 47 | n_sp_query = args.n_sp + (20 if args.dataset == 'mnist' else 50) # number of superpixels we ask to extract (larger to extract more superpixels - closer to the desired n_sp) 48 | while n_sp_extracted > args.n_sp: 49 | superpixels = slic(img, n_segments=n_sp_query, compactness=args.compactness, multichannel=len(img.shape) > 2) 50 | sp_indices = np.unique(superpixels) 51 | n_sp_extracted = len(sp_indices) 52 | n_sp_query -= 1 # reducing the number of superpixels until we get <= n superpixels 53 | 54 | assert n_sp_extracted <= args.n_sp and n_sp_extracted > 0, (args.split, index, n_sp_extracted, args.n_sp) 55 | assert n_sp_extracted == np.max(superpixels) + 1, ('superpixel indices', np.unique(superpixels)) # make sure superpixel indices are numbers from 0 to n-1 56 | 57 | if shuffle: 58 | ind = np.random.permutation(n_sp_extracted) 59 | else: 60 | ind = np.arange(n_sp_extracted) 61 | 62 | sp_order = sp_indices[ind].astype(np.int32) 63 | if len(img.shape) == 2: 64 | img = img[:, :, None] 65 | 66 | n_ch = 1 if img.shape[2] == 1 else 3 67 | 68 | sp_intensity, sp_coord = [], [] 69 | for seg in sp_order: 70 | mask = (superpixels == seg).squeeze() 71 | avg_value = np.zeros(n_ch) 72 | for c in range(n_ch): 73 | avg_value[c] = np.mean(img[:, :, c][mask]) 74 | cntr = np.array(scipy.ndimage.measurements.center_of_mass(mask)) # row, col 75 | sp_intensity.append(avg_value) 76 | sp_coord.append(cntr) 77 | sp_intensity = np.array(sp_intensity, np.float32) 78 | sp_coord = np.array(sp_coord, np.float32) 79 | if to_print: 80 | print('image={}/{}, shape={}, min={:.2f}, max={:.2f}, n_sp={}'.format(index + 1, n_images, img.shape, 81 | img.min(), img.max(), sp_intensity.shape[0])) 82 | 83 | return sp_intensity, sp_coord, sp_order, superpixels 84 | 85 | 86 | if __name__ == '__main__': 87 | 88 | dt = datetime.datetime.now() 89 | print('start time:', dt) 90 | 91 | args = parse_args() 92 | 93 | if not os.path.isdir(args.out_dir): 94 | os.mkdir(args.out_dir) 95 | 96 | random.seed(args.seed) 97 | np.random.seed(args.seed) # to make node random permutation reproducible (not tested) 98 | 99 | # Read image data using torchvision 100 | is_train = args.split.lower() == 'train' 101 | if args.dataset == 'mnist': 102 | data = datasets.MNIST(args.data_dir, train=is_train, download=True) 103 | assert args.compactness < 10, ('high compactness can result in bad superpixels on MNIST') 104 | assert args.n_sp > 1 and args.n_sp < 28*28, ( 105 | 'the number of superpixels cannot exceed the total number of pixels or be too small') 106 | elif args.dataset == 'cifar10': 107 | data = datasets.CIFAR10(args.data_dir, train=is_train, download=True) 108 | assert args.compactness > 1, ('low compactness can result in bad superpixels on CIFAR-10') 109 | assert args.n_sp > 1 and args.n_sp < 32*32, ( 110 | 'the number of superpixels cannot exceed the total number of pixels or be too small') 111 | else: 112 | raise NotImplementedError('unsupported dataset: ' + args.dataset) 113 | 114 | images = data.train_data if is_train else data.test_data 115 | labels = data.train_labels if is_train else data.test_labels 116 | if not isinstance(images, np.ndarray): 117 | images = images.numpy() 118 | if isinstance(labels, list): 119 | labels = np.array(labels) 120 | if not isinstance(labels, np.ndarray): 121 | labels = labels.numpy() 122 | 123 | n_images = len(labels) 124 | 125 | if args.threads <= 0: 126 | sp_data = [] 127 | for i in range(n_images): 128 | sp_data.append(process_image((images[i], i, n_images, args, True, True))) 129 | else: 130 | with mp.Pool(processes=args.threads) as pool: 131 | sp_data = pool.map(process_image, [(images[i], i, n_images, args, True, True) for i in range(n_images)]) 132 | 133 | superpixels = [sp_data[i][3] for i in range(n_images)] 134 | sp_data = [sp_data[i][:3] for i in range(n_images)] 135 | with open('%s/%s_%dsp_%s.pkl' % (args.out_dir, args.dataset, args.n_sp, args.split), 'wb') as f: 136 | pickle.dump((labels.astype(np.int32), sp_data), f, protocol=2) 137 | with open('%s/%s_%dsp_%s_superpixels.pkl' % (args.out_dir, args.dataset, args.n_sp, args.split), 'wb') as f: 138 | pickle.dump(superpixels, f, protocol=2) 139 | 140 | print('done in {}'.format(datetime.datetime.now() - dt)) 141 | -------------------------------------------------------------------------------- /generate_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import pickle 4 | import argparse 5 | import networkx as nx 6 | import datetime 7 | import random 8 | import multiprocessing as mp 9 | from utils import * 10 | 11 | 12 | def parse_args(): 13 | parser = argparse.ArgumentParser(description='Generate synthetic graph datasets') 14 | parser.add_argument('-D', '--dataset', type=str, default='colors', choices=['colors', 'triangles']) 15 | parser.add_argument('-o', '--out_dir', type=str, default='./data', help='path where to save superpixels') 16 | parser.add_argument('--N_train', type=int, default=500, help='number of training graphs (500 for colors and 30000 for triangles)') 17 | parser.add_argument('--N_val', type=int, default=2500, help='number of graphs in the validation set (2500 for colors and 5000 for triangles)') 18 | parser.add_argument('--N_test', type=int, default=2500, help='number of graphs in each test subset (2500 for colors and 5000 for triangles)') 19 | parser.add_argument('--label_min', type=int, default=0, 20 | help='smallest label value for a graph (i.e. smallest number of green nodes); 1 for triangles') 21 | parser.add_argument('--label_max', type=int, default=10, 22 | help='largest label value for a graph (i.e. largest number of green nodes)') 23 | parser.add_argument('--N_min', type=int, default=4, help='minimum number of nodes') 24 | parser.add_argument('--N_max', type=int, default=200, help='maximum number of nodes (default: 200 for colors and 100 for triangles') 25 | parser.add_argument('--N_max_train', type=int, default=25, help='maximum number of nodes in the training set') 26 | parser.add_argument('--dim', type=int, default=3, help='node feature dimensionality') 27 | parser.add_argument('--green_ch_index', type=int, default=1, 28 | help='index of non-zero value in a one-hot node feature vector, ' 29 | 'i.e. [0, 1, 0] in case green_channel_index=1 and dim=3') 30 | parser.add_argument('--seed', type=int, default=111, help='seed for shuffling nodes') 31 | parser.add_argument('--threads', type=int, default=0, help='only for triangles') 32 | args = parser.parse_args() 33 | 34 | for arg in vars(args): 35 | print(arg, getattr(args, arg)) 36 | 37 | return args 38 | 39 | 40 | def check_graph_duplicates(Adj_matrices, node_features=None): 41 | n_graphs = len(Adj_matrices) 42 | print('check for duplicates for %d graphs' % n_graphs) 43 | n_duplicates = 0 44 | for i in range(n_graphs): 45 | if node_features is not None: 46 | assert Adj_matrices[i].shape[0] == node_features[i].shape[0], ( 47 | 'invalid data', i, Adj_matrices[i].shape[0], node_features[i].shape[0]) 48 | for j in range(i + 1, n_graphs): 49 | if Adj_matrices[i].shape[0] == Adj_matrices[j].shape[0]: 50 | if np.allclose(Adj_matrices[i], Adj_matrices[j]): # adjacency matrices are the same 51 | # for Colors graphs are not considered duplicates if they have the same adjacency matrix, 52 | # but different node features 53 | if node_features is None or np.allclose(node_features[i], node_features[j]): 54 | n_duplicates += 1 55 | print('duplicates %d/%d' % (n_duplicates, n_graphs * (n_graphs - 1) / 2)) 56 | if n_duplicates > 0: 57 | raise ValueError('%d duplicates found in the dataset' % n_duplicates) 58 | 59 | print('no duplicated graphs') 60 | 61 | 62 | # COLORS 63 | def get_node_features_Colors(N_nodes, N_green, dim, green_ch_index=1, new_colors=False): 64 | node_features = np.zeros((N_nodes, dim)) 65 | 66 | # Generate indices for non-zero values, 67 | # so that the total number of nodes with features having value 1 in the green_ch_index equals N_green 68 | idx_not_green = rnd.randint(0, dim - 1, size=N_nodes - N_green) # for dim=3 generate values 0,1 for non-green nodes 69 | idx_non_zero = np.concatenate((idx_not_green, np.zeros(N_green, np.int) + dim - 1)) # make green_ch_index=2 temporary 70 | idx_non_zero_cp = idx_non_zero.copy() 71 | idx_non_zero[idx_non_zero_cp == dim - 1] = green_ch_index # set idx_non_zero=1 for green nodes 72 | idx_non_zero[idx_non_zero_cp == green_ch_index] = dim - 1 # set idx_non_zero=2 for those nodes that were green temporary 73 | rnd.shuffle(idx_non_zero) # shuffle nodes 74 | node_features[np.arange(N_nodes), idx_non_zero] = 1 75 | 76 | if new_colors: 77 | for ind in np.where(idx_non_zero != green_ch_index)[0]: # for non-green nodes 78 | node_features[ind] = rnd.randint(0, 2, size=dim) 79 | node_features[ind, green_ch_index] = 0 # set value at green_ch_index to 0 to avoid confusion with green nodes 80 | 81 | label = np.sum((np.sum(node_features, 1) == node_features[:, green_ch_index]) & (node_features[:, green_ch_index] == 1)) 82 | 83 | gt_attn = (idx_non_zero == green_ch_index).reshape(-1, 1) 84 | label2 = np.sum(gt_attn) 85 | assert N_green == label == label2, ('invalid node features', N_green, label, label2) 86 | return node_features, idx_non_zero, gt_attn 87 | 88 | 89 | def generate_graphs_Colors(N_graphs, N_min, N_max, dim, args, rnd, new_colors=False): 90 | Adj_matrices, node_features, GT_attn, graph_labels, N_edges = [], [], [], [], [] 91 | n_labels = args.label_max - args.label_min + 1 92 | n_graphs_per_shape = int(np.ceil(N_graphs / (N_max - N_min + 1) / n_labels) * n_labels) 93 | for n_nodes in np.array(range(N_min, N_max + 1)): 94 | c = 0 95 | while True: 96 | labels = np.arange(args.label_min, n_labels) 97 | labels = labels[labels <= n_nodes] 98 | rnd.shuffle(labels) 99 | for lbl in labels: 100 | features, idx_non_zero, gt_attn = get_node_features_Colors(N_nodes=n_nodes, 101 | N_green=lbl, 102 | dim=dim, 103 | green_ch_index=args.green_ch_index, 104 | new_colors=new_colors) 105 | n_edges = int((rnd.rand() + 1) * n_nodes) 106 | A = nx.to_numpy_array(nx.gnm_random_graph(n_nodes, n_edges)) 107 | add = True 108 | for k in range(len(Adj_matrices)): 109 | if A.shape[0] == Adj_matrices[k].shape[0] and np.allclose(A, Adj_matrices[k]): 110 | if np.allclose(node_features[k], features): 111 | add = False 112 | break 113 | if add: 114 | Adj_matrices.append(A.astype(np.bool)) # binary adjacency matrix 115 | graph_labels.append(lbl) 116 | node_features.append(features.astype(np.bool)) # binary features 117 | GT_attn.append(gt_attn) # binary GT attention 118 | N_edges.append(n_edges) 119 | c += 1 120 | if c >= n_graphs_per_shape: 121 | break 122 | if c >= n_graphs_per_shape: 123 | break 124 | graph_labels = np.array(graph_labels, np.int32) 125 | N_edges = np.array(N_edges, np.int32) 126 | print(N_graphs, len(graph_labels)) 127 | 128 | return {'Adj_matrices': Adj_matrices, 129 | 'GT_attn': GT_attn, # not normalized to sum=1 130 | 'graph_labels': graph_labels, 131 | 'node_features': node_features, 132 | 'N_edges': N_edges} 133 | 134 | 135 | # TRIANGLES 136 | def get_gt_atnn_triangles(args): 137 | G, N = args 138 | node_ids = [] 139 | if G is not None: 140 | for clq in nx.enumerate_all_cliques(G): 141 | if len(clq) == 3: 142 | node_ids.extend(clq) 143 | node_ids = np.array(node_ids) 144 | gt_attn = np.zeros((N, 1), np.int32) 145 | for i in np.unique(node_ids): 146 | gt_attn[i] = int(np.sum(node_ids == i)) 147 | return gt_attn # unnormalized (do not sum to 1, i.e. use int32 for storage efficiency) 148 | 149 | 150 | def get_graph_triangles(args): 151 | N_nodes, rnd = args 152 | N_edges = int((rnd.rand() + 1) * N_nodes) 153 | G = nx.dense_gnm_random_graph(N_nodes, N_edges, seed=None) 154 | A = nx.to_numpy_array(G) 155 | A_cube = A.dot(A).dot(A) 156 | label = int(np.trace(A_cube) / 6.) # number of triangles 157 | return A.astype(np.bool), label, N_edges, G 158 | 159 | 160 | def generate_graphs_Triangles(N_graphs, N_min, N_max, args, rnd): 161 | N_nodes = rnd.randint(N_min, N_max + 1, size=int(N_graphs * 10)) 162 | print('generating %d graphs with %d-%d nodes' % (N_graphs * 10, N_min, N_max)) 163 | 164 | if args.threads > 0: 165 | with mp.Pool(processes=args.threads) as pool: 166 | data = pool.map(get_graph_triangles, [(N_nodes[i], rnd) for i in range(len(N_nodes))]) 167 | else: 168 | data = [get_graph_triangles((N_nodes[i], rnd)) for i in range(len(N_nodes))] 169 | labels = np.array([data[i][1] for i in range(len(data))], np.int32) 170 | Adj_matrices, node_features, G, graph_labels, N_edges, node_degrees = [], [], [], [], [], [] 171 | for lbl in range(args.label_min, args.label_max + 1): 172 | idx = np.where(labels == lbl)[0] 173 | c = 0 174 | for i in idx: 175 | add = True 176 | for k in range(len(Adj_matrices)): 177 | if data[i][0].shape[0] == Adj_matrices[k].shape[0] and labels[i] == graph_labels[k] and np.allclose(data[i][0], Adj_matrices[k]): 178 | add = False 179 | break 180 | if add: 181 | Adj_matrices.append(data[i][0]) 182 | graph_labels.append(labels[i]) 183 | G.append(data[i][3]) 184 | N_edges.append(data[i][2]) 185 | node_degrees.append(data[i][0].astype(np.int32).sum(1).max()) 186 | c += 1 187 | if c >= int(N_graphs / (args.label_max - args.label_min + 1)): 188 | break 189 | print('label={}, number of graphs={}/{}, total number of generated graphs={}'.format(lbl, c, len(idx), len(Adj_matrices))) 190 | 191 | assert c == int(N_graphs / (args.label_max - args.label_min + 1)), ( 192 | 'invalid data', c, int(N_graphs / (args.label_max - args.label_min + 1))) 193 | 194 | print('computing GT attention for %d graphs' % len(Adj_matrices)) 195 | if args.threads > 0: 196 | with mp.Pool(processes=args.threads) as pool: 197 | GT_attn = pool.map(get_gt_atnn_triangles, [(G[i], Adj_matrices[i].shape[0]) for i in range(len(Adj_matrices))]) 198 | else: 199 | GT_attn = [get_gt_atnn_triangles((G[i], Adj_matrices[i].shape[0])) for i in range(len(Adj_matrices))] 200 | 201 | graph_labels = np.array(graph_labels, np.int32) 202 | N_edges = np.array(N_edges, np.int32) 203 | 204 | return {'Adj_matrices': Adj_matrices, 205 | 'GT_attn': GT_attn, # not normalized to sum=1 206 | 'graph_labels': graph_labels, 207 | 'N_edges': N_edges, 208 | 'Max_degree': np.max(node_degrees)} 209 | 210 | 211 | if __name__ == '__main__': 212 | 213 | dt = datetime.datetime.now() 214 | print('start time:', dt) 215 | 216 | args = parse_args() 217 | 218 | if not os.path.isdir(args.out_dir): 219 | os.mkdir(args.out_dir) 220 | 221 | random.seed(args.seed) # for networkx 222 | np.random.seed(args.seed) 223 | rnd = np.random.RandomState(args.seed) 224 | 225 | def print_stats(data, split_name): 226 | print('%s: %d graphs' % (split_name, len(data['graph_labels']))) 227 | for lbl in np.unique(data['graph_labels']): 228 | print('%s: label=%d, %d graphs' % (split_name, lbl, np.sum(data['graph_labels'] == lbl))) 229 | 230 | if args.dataset.lower() == 'colors': 231 | 232 | # Generate train and test sets 233 | data_test_combined, Adj_matrices, node_features = [], [], [] 234 | for N_graphs, N_nodes_min, N_nodes_max, dim, name in zip([args.N_train + args.N_val + args.N_test, args.N_test, args.N_test], 235 | [args.N_min, args.N_max_train + 1, args.N_max_train + 1], 236 | [args.N_max_train, args.N_max, args.N_max], 237 | [args.dim, args.dim, args.dim + 1], 238 | ['test orig', 'test large', 'test large-c']): 239 | data = generate_graphs_Colors(N_graphs, N_nodes_min, N_nodes_max, dim, args, rnd, new_colors=dim==args.dim + 1) 240 | 241 | if name.find('orig') >= 0: 242 | idx = rnd.permutation(len(data['graph_labels'])) 243 | data_train = copy_data(data, idx[:args.N_train]) 244 | print_stats(data_train, name.replace('test', 'train')) 245 | node_features += data_train['node_features'] 246 | Adj_matrices += data_train['Adj_matrices'] 247 | data_val = copy_data(data, idx[args.N_train: args.N_train + args.N_val]) 248 | print_stats(data_val, name.replace('test', 'val')) 249 | node_features += data_val['node_features'] 250 | Adj_matrices += data_val['Adj_matrices'] 251 | 252 | data_test = copy_data(data, idx[args.N_train + args.N_val: args.N_train + args.N_val + args.N_test]) 253 | else: 254 | data_test = copy_data(data, rnd.permutation(len(data['graph_labels']))[:args.N_test]) 255 | 256 | Adj_matrices += data_test['Adj_matrices'] 257 | node_features += data_test['node_features'] 258 | data_test_combined.append(data_test) 259 | print_stats(data_test, name) 260 | 261 | # Check for duplicates in the combined train+val+test sets 262 | check_graph_duplicates(Adj_matrices, node_features) 263 | 264 | # Saving 265 | with open('%s/random_graphs_colors_dim%d_train.pkl' % (args.out_dir, args.dim), 'wb') as f: 266 | pickle.dump(data_train, f, protocol=2) 267 | 268 | with open('%s/random_graphs_colors_dim%d_val.pkl' % (args.out_dir, args.dim), 'wb') as f: 269 | pickle.dump(data_val, f, protocol=2) 270 | 271 | with open('%s/random_graphs_colors_dim%d_test.pkl' % (args.out_dir, args.dim), 'wb') as f: 272 | pickle.dump(concat_data(data_test_combined), f, protocol=2) 273 | 274 | elif args.dataset.lower() == 'triangles': 275 | 276 | data = generate_graphs_Triangles((args.N_train + args.N_val + args.N_test), args.N_min, args.N_max_train, args, rnd) 277 | # Create balanced splits 278 | idx_train, idx_val, idx_test = [], [], [] 279 | classes = np.unique(data['graph_labels']) 280 | n_classes = len(classes) 281 | for lbl in classes: 282 | idx = np.where(data['graph_labels'] == lbl)[0] 283 | rnd.shuffle(idx) 284 | n_train = int(args.N_train / n_classes) 285 | n_val = int(args.N_val / n_classes) 286 | n_test = int(args.N_test / n_classes) 287 | idx_train.append(idx[:n_train]) 288 | idx_val.append(idx[n_train: n_train + n_val]) 289 | idx_test.append(idx[n_train + n_val: n_train + n_val + n_test]) 290 | data_train = copy_data(data, np.concatenate(idx_train)) 291 | print_stats(data_train, 'train orig') 292 | data_val = copy_data(data, np.concatenate(idx_val)) 293 | print_stats(data_val, 'val orig') 294 | data_test = copy_data(data, np.concatenate(idx_test)) 295 | print_stats(data_test, 'test orig') 296 | 297 | data = generate_graphs_Triangles(args.N_test, args.N_max_train + 1, args.N_max, args, rnd) 298 | data_test_large = copy_data(data, rnd.permutation(len(data['graph_labels']))[:args.N_test]) 299 | print_stats(data_test_large, 'test large') 300 | 301 | check_graph_duplicates(data_train['Adj_matrices'] + data_val['Adj_matrices'] + 302 | data_test['Adj_matrices'] + data_test_large['Adj_matrices']) 303 | 304 | # Saving 305 | # Max degree is max over all graphs in the training and test sets 306 | max_degree = np.max(np.array([d['Max_degree'] for d in (data_train, data_val, data_test, data_test_large)])) 307 | data_train['Max_degree'] = max_degree 308 | with open('%s/random_graphs_triangles_train.pkl' % args.out_dir, 'wb') as f: 309 | pickle.dump(data_train, f, protocol=2) 310 | 311 | data_val['Max_degree'] = max_degree 312 | with open('%s/random_graphs_triangles_val.pkl' % args.out_dir, 'wb') as f: 313 | pickle.dump(data_val, f, protocol=2) 314 | 315 | data_test = concat_data((data_test, data_test_large)) 316 | data_test['Max_degree'] = max_degree 317 | with open('%s/random_graphs_triangles_test.pkl' % args.out_dir, 'wb') as f: 318 | pickle.dump(data_test, f, protocol=2) 319 | 320 | else: 321 | raise NotImplementedError('unsupported dataset: ' + args.dataset) 322 | 323 | print('done in {}'.format(datetime.datetime.now() - dt)) 324 | -------------------------------------------------------------------------------- /graphdata.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | from os.path import join as pjoin 4 | import pickle 5 | import copy 6 | import torch 7 | import torch.utils 8 | import torch.utils.data 9 | import torch.nn.functional as F 10 | import torchvision 11 | from scipy.spatial.distance import cdist 12 | from utils import * 13 | 14 | 15 | def compute_adjacency_matrix_images(coord, sigma=0.1): 16 | coord = coord.reshape(-1, 2) 17 | dist = cdist(coord, coord) 18 | A = np.exp(- dist / (sigma * np.pi) ** 2) 19 | A[np.diag_indices_from(A)] = 0 20 | return A 21 | 22 | 23 | def precompute_graph_images(img_size): 24 | col, row = np.meshgrid(np.arange(img_size), np.arange(img_size)) 25 | coord = np.stack((col, row), axis=2) / img_size # 28,28,2 26 | A = torch.from_numpy(compute_adjacency_matrix_images(coord)).float().unsqueeze(0) 27 | coord = torch.from_numpy(coord).float().unsqueeze(0).view(1, -1, 2) 28 | mask = torch.ones(1, img_size * img_size, dtype=torch.uint8) 29 | return A, coord, mask 30 | 31 | 32 | def collate_batch_images(batch, A, mask, use_mean_px=True, coord=None, 33 | gt_attn_threshold=0, replicate_features=True): 34 | B = len(batch) 35 | C, H, W = batch[0][0].shape 36 | N_nodes = H * W 37 | params_dict = {'N_nodes': torch.zeros(B, dtype=torch.long) + N_nodes, 'node_attn_eval': None} 38 | has_WS_attn = len(batch[0]) > 2 39 | if has_WS_attn: 40 | WS_attn = torch.from_numpy(np.stack([batch[b][2].reshape(N_nodes) for b in range(B)]).astype(np.float32)).view(B, N_nodes) 41 | WS_attn = normalize_batch(WS_attn) 42 | params_dict.update({'node_attn': WS_attn}) # use these scores for training 43 | 44 | if use_mean_px: 45 | x = torch.stack([batch[b][0].view(C, N_nodes).t() for b in range(B)]).float() 46 | if gt_attn_threshold == 0: 47 | GT_attn = (x > 0).view(B, N_nodes).float() 48 | else: 49 | GT_attn = x.view(B, N_nodes).float().clone() 50 | GT_attn[GT_attn < gt_attn_threshold] = 0 51 | GT_attn = normalize_batch(GT_attn) 52 | 53 | params_dict.update({'node_attn_eval': GT_attn}) # use this for evaluation of attention 54 | if not has_WS_attn: 55 | params_dict.update({'node_attn': GT_attn}) # use this to train attention 56 | else: 57 | raise NotImplementedError('this case is not well supported') 58 | 59 | if coord is not None: 60 | if use_mean_px: 61 | x = torch.cat((x, coord.expand(B, -1, -1)), dim=2) 62 | else: 63 | x = coord.expand(B, -1, -1) 64 | if x is None: 65 | x = torch.ones(B, N_nodes, 1) # dummy features 66 | 67 | if replicate_features: 68 | x = F.pad(x, (2, 0), 'replicate') 69 | 70 | try: 71 | labels = torch.Tensor([batch[b][1] for b in range(B)]).long() 72 | except: 73 | labels = torch.stack([batch[b][1] for b in range(B)]).long() 74 | 75 | return [x, A.expand(B, -1, -1), mask.expand(B, -1), labels, params_dict] 76 | 77 | 78 | def collate_batch(batch): 79 | ''' 80 | Creates a batch of same size graphs by zero-padding node features and adjacency matrices up to 81 | the maximum number of nodes in the current batch rather than in the entire dataset. 82 | ''' 83 | 84 | B = len(batch) 85 | N_nodes = [batch[b][2] for b in range(B)] 86 | C = batch[0][0].shape[1] 87 | N_nodes_max = int(np.max(N_nodes)) 88 | 89 | mask = torch.zeros(B, N_nodes_max, dtype=torch.bool) # use byte for older PyTorch 90 | A = torch.zeros(B, N_nodes_max, N_nodes_max) 91 | x = torch.zeros(B, N_nodes_max, C) 92 | has_GT_attn = len(batch[0]) > 4 and batch[0][4] is not None 93 | if has_GT_attn: 94 | GT_attn = torch.zeros(B, N_nodes_max) 95 | has_WS_attn = len(batch[0]) > 5 and batch[0][5] is not None 96 | if has_WS_attn: 97 | WS_attn = torch.zeros(B, N_nodes_max) 98 | 99 | for b in range(B): 100 | x[b, :N_nodes[b]] = batch[b][0] 101 | A[b, :N_nodes[b], :N_nodes[b]] = batch[b][1] 102 | mask[b][:N_nodes[b]] = 1 # mask with values of 0 for dummy (zero padded) nodes, otherwise 1 103 | if has_GT_attn: 104 | GT_attn[b, :N_nodes[b]] = batch[b][4].squeeze() 105 | if has_WS_attn: 106 | WS_attn[b, :N_nodes[b]] = batch[b][5].squeeze() 107 | 108 | N_nodes = torch.from_numpy(np.array(N_nodes)).long() 109 | 110 | params_dict = {'N_nodes': N_nodes} 111 | if has_WS_attn: 112 | params_dict.update({'node_attn': WS_attn}) # use this to train attention 113 | if has_GT_attn: 114 | params_dict.update({'node_attn_eval': GT_attn}) # use this for evaluation of attention 115 | if not has_WS_attn: 116 | params_dict.update({'node_attn': GT_attn}) # use this to train attention 117 | elif has_WS_attn: 118 | params_dict.update({'node_attn_eval': WS_attn}) # use this for evaluation of attention 119 | 120 | labels = torch.from_numpy(np.array([batch[b][3] for b in range(B)])).long() 121 | return [x, A, mask, labels, params_dict] 122 | 123 | 124 | class MNIST(torchvision.datasets.MNIST): 125 | ''' 126 | Wrapper around MNIST to use predefined attention coefficients 127 | ''' 128 | def __init__(self, root, train=True, transform=None, target_transform=None, download=False, attn_coef=None): 129 | super(MNIST, self).__init__(root, train, transform, target_transform, download) 130 | self.alpha_WS = None 131 | if attn_coef is not None and train: 132 | print('loading weakly-supervised labels from %s' % attn_coef) 133 | with open(attn_coef, 'rb') as f: 134 | self.alpha_WS = pickle.load(f) 135 | print(train, len(self.alpha_WS)) 136 | 137 | def __getitem__(self, index): 138 | img, target = super(MNIST, self).__getitem__(index) 139 | if self.alpha_WS is None: 140 | return img, target 141 | else: 142 | return img, target, self.alpha_WS[index] 143 | 144 | 145 | class MNIST75sp(torch.utils.data.Dataset): 146 | def __init__(self, 147 | data_dir, 148 | split, 149 | use_mean_px=True, 150 | use_coord=True, 151 | gt_attn_threshold=0, 152 | attn_coef=None): 153 | 154 | self.data_dir = data_dir 155 | self.split = split 156 | self.is_test = split.lower() in ['test', 'val'] 157 | with open(pjoin(data_dir, 'mnist_75sp_%s.pkl' % split), 'rb') as f: 158 | self.labels, self.sp_data = pickle.load(f) 159 | 160 | self.use_mean_px = use_mean_px 161 | self.use_coord = use_coord 162 | self.n_samples = len(self.labels) 163 | self.img_size = 28 164 | self.gt_attn_threshold = gt_attn_threshold 165 | 166 | self.alpha_WS = None 167 | if attn_coef is not None and not self.is_test: 168 | with open(attn_coef, 'rb') as f: 169 | self.alpha_WS = pickle.load(f) 170 | print('using weakly-supervised labels from %s (%d samples)' % (attn_coef, len(self.alpha_WS))) 171 | 172 | def train_val_split(self, samples_idx): 173 | self.sp_data = [self.sp_data[i] for i in samples_idx] 174 | self.labels = self.labels[samples_idx] 175 | self.n_samples = len(self.labels) 176 | 177 | def precompute_graph_data(self, replicate_features, threads=0): 178 | print('precompute all data for the %s set...' % self.split.upper()) 179 | self.Adj_matrices, self.node_features, self.GT_attn, self.WS_attn = [], [], [], [] 180 | for index, sample in enumerate(self.sp_data): 181 | mean_px, coord = sample[:2] 182 | coord = coord / self.img_size 183 | A = compute_adjacency_matrix_images(coord) 184 | N_nodes = A.shape[0] 185 | x = None 186 | if self.use_mean_px: 187 | x = mean_px.reshape(N_nodes, -1) 188 | if self.use_coord: 189 | coord = coord.reshape(N_nodes, 2) 190 | if self.use_mean_px: 191 | x = np.concatenate((x, coord), axis=1) 192 | else: 193 | x = coord 194 | if x is None: 195 | x = np.ones(N_nodes, 1) # dummy features 196 | if replicate_features: 197 | x = np.pad(x, ((0, 0), (2, 0)), 'edge') # replicate features to make it possible to test on colored images 198 | if self.gt_attn_threshold == 0: 199 | gt_attn = (mean_px > 0).astype(np.float32) 200 | else: 201 | gt_attn = mean_px.copy() 202 | gt_attn[gt_attn < self.gt_attn_threshold] = 0 203 | self.GT_attn.append(normalize(gt_attn)) 204 | 205 | if self.alpha_WS is not None: 206 | self.WS_attn.append(normalize(self.alpha_WS[index])) 207 | 208 | self.node_features.append(x) 209 | self.Adj_matrices.append(A) 210 | 211 | 212 | def __len__(self): 213 | return self.n_samples 214 | 215 | def __getitem__(self, index): 216 | data = [self.node_features[index], 217 | self.Adj_matrices[index], 218 | self.Adj_matrices[index].shape[0], 219 | self.labels[index], 220 | self.GT_attn[index]] 221 | 222 | if self.alpha_WS is not None: 223 | data.append(self.WS_attn[index]) 224 | 225 | data = list_to_torch(data) # convert to torch 226 | 227 | return data 228 | 229 | 230 | class SyntheticGraphs(torch.utils.data.Dataset): 231 | def __init__(self, 232 | data_dir, 233 | dataset, 234 | split, 235 | degree_feature=True, 236 | attn_coef=None, 237 | threads=12): 238 | 239 | self.is_test = split.lower() in ['test', 'val'] 240 | self.split = split 241 | self.degree_feature = degree_feature 242 | 243 | if dataset.find('colors') >= 0: 244 | dim = int(dataset.split('-')[1]) 245 | data_file = 'random_graphs_colors_dim%d_%s.pkl' % (dim, split) 246 | is_triangles = False 247 | self.feature_dim = dim + 1 248 | if dataset.find('triangles') >= 0: 249 | data_file = 'random_graphs_triangles_%s.pkl' % split 250 | is_triangles = True 251 | else: 252 | NotImplementedError(dataset) 253 | 254 | with open(pjoin(data_dir, data_file), 'rb') as f: 255 | data = pickle.load(f) 256 | 257 | for key in data: 258 | if not isinstance(data[key], list) and not isinstance(data[key], np.ndarray): 259 | print(split, key, data[key]) 260 | else: 261 | print(split, key, len(data[key])) 262 | 263 | self.Node_degrees = [np.sum(A, 1).astype(np.int32) for A in data['Adj_matrices']] 264 | 265 | if is_triangles: 266 | # use one-hot degree features as node features 267 | self.feature_dim = data['Max_degree'] + 1 268 | self.node_features = [] 269 | for i in range(len(data['Adj_matrices'])): 270 | N = data['Adj_matrices'][i].shape[0] 271 | if degree_feature: 272 | D_onehot = np.zeros((N, self.feature_dim )) 273 | D_onehot[np.arange(N), self.Node_degrees[i]] = 1 274 | else: 275 | D_onehot = np.zeros((N, 1)) 276 | self.node_features.append(D_onehot) 277 | if not degree_feature: 278 | self.feature_dim = 1 279 | else: 280 | # Add 1 feature to support new colors at test time 281 | self.node_features = [] 282 | for i in range(len(data['node_features'])): 283 | features = data['node_features'][i] 284 | if features.shape[1] < self.feature_dim: 285 | features = np.pad(features, ((0, 0), (0, 1)), 'constant') 286 | self.node_features.append(features) 287 | 288 | self.alpha_WS = None 289 | if attn_coef is not None and not self.is_test: 290 | with open(attn_coef, 'rb') as f: 291 | self.alpha_WS = pickle.load(f) 292 | print('using weakly-supervised labels from %s (%d samples)' % (attn_coef, len(self.alpha_WS))) 293 | self.WS_attn = [] 294 | for index in range(len(self.alpha_WS)): 295 | self.WS_attn.append(normalize(self.alpha_WS[index])) 296 | 297 | N_nodes = np.array([A.shape[0] for A in data['Adj_matrices']]) 298 | self.Adj_matrices = data['Adj_matrices'] 299 | self.GT_attn = data['GT_attn'] 300 | # Normalizing ground truth attention so that it sums to 1 301 | for i in range(len(self.GT_attn)): 302 | self.GT_attn[i] = normalize(self.GT_attn[i]) 303 | #assert np.sum(self.GT_attn[i]) == 1, (i, np.sum(self.GT_attn[i]), self.GT_attn[i]) 304 | self.labels = data['graph_labels'].astype(np.int32) 305 | self.classes = np.unique(self.labels) 306 | self.n_classes = len(self.classes) 307 | R = np.corrcoef(self.labels, N_nodes)[0, 1] 308 | 309 | degrees = [] 310 | for i in range(len(self.Node_degrees)): 311 | degrees.extend(list(self.Node_degrees[i])) 312 | degrees = np.array(degrees, np.int32) 313 | 314 | print('N nodes avg/std/min/max: \t{:.2f}/{:.2f}/{:d}/{:d}'.format(*stats(N_nodes))) 315 | print('N edges avg/std/min/max: \t{:.2f}/{:.2f}/{:d}/{:d}'.format(*stats(data['N_edges']))) 316 | print('Node degree avg/std/min/max: \t{:.2f}/{:.2f}/{:d}/{:d}'.format(*stats(degrees))) 317 | print('Node features dim: \t\t%d' % self.feature_dim) 318 | print('N classes: \t\t\t%d' % self.n_classes) 319 | print('Correlation of labels with graph size: \t%.2f' % R) 320 | print('Classes: \t\t\t%s' % str(self.classes)) 321 | for lbl in self.classes: 322 | idx = self.labels == lbl 323 | print('Class {}: \t\t\t{} samples, N_nodes: avg/std/min/max: \t{:.2f}/{:.2f}/{:d}/{:d}'.format(lbl, np.sum(idx), *stats(N_nodes[idx]))) 324 | 325 | def __len__(self): 326 | return len(self.Adj_matrices) 327 | 328 | def __getitem__(self, index): 329 | data = [self.node_features[index], 330 | self.Adj_matrices[index], 331 | self.Adj_matrices[index].shape[0], 332 | self.labels[index], 333 | self.GT_attn[index]] 334 | 335 | if self.alpha_WS is not None: 336 | data.append(self.WS_attn[index]) 337 | 338 | data = list_to_torch(data) # convert to torch 339 | 340 | return data 341 | 342 | 343 | class GraphData(torch.utils.data.Dataset): 344 | def __init__(self, 345 | datareader, 346 | fold_id, 347 | split, # train, val, train_val, test 348 | degree_feature=True, 349 | attn_labels=None): 350 | self.fold_id = fold_id 351 | self.split = split 352 | self.w_sup_signal_attn = None 353 | print('''The degree_feature argument is ignored for this dataset. 354 | It will automatically be set to True if nodes do not have any features. Otherwise it will be set to False''') 355 | if attn_labels is not None: 356 | if isinstance(attn_labels, str) and os.path.isfile(attn_labels): 357 | with open(attn_labels, 'rb') as f: 358 | self.w_sup_signal_attn = pickle.load(f) 359 | else: 360 | self.w_sup_signal_attn = attn_labels 361 | for i in range(len(self.w_sup_signal_attn)): 362 | alpha = self.w_sup_signal_attn[i] 363 | alpha[alpha < 1e-3] = 0 # assuming that some nodes should have zero importance 364 | self.w_sup_signal_attn[i] = normalize(alpha) 365 | print(('!!!using weakly supervised labels (%d samples)!!!' % len(self.w_sup_signal_attn)).upper()) 366 | 367 | self.set_fold(datareader.data, fold_id) 368 | 369 | def set_fold(self, data, fold_id): 370 | 371 | self.total = len(data['targets']) 372 | self.N_nodes_max = data['N_nodes_max'] 373 | self.num_classes = data['num_classes'] 374 | self.num_features = data['num_features'] 375 | if self.split in ['train', 'val']: 376 | self.idx = data['splits'][self.split][fold_id] 377 | else: 378 | assert self.split in ['train_val', 'test'], ('unexpected split', self.split) 379 | self.idx = data['splits'][self.split] 380 | 381 | # use deepcopy to make sure we don't alter objects in folds 382 | self.labels = np.array(copy.deepcopy([data['targets'][i] for i in self.idx])) 383 | self.adj_list = copy.deepcopy([data['adj_list'][i] for i in self.idx]) 384 | self.features_onehot = copy.deepcopy([data['features_onehot'][i] for i in self.idx]) 385 | self.N_edges = np.array([A.sum() // 2 for A in self.adj_list]) # assume undirected graph with binary edges 386 | print('%s: %d/%d' % (self.split.upper(), len(self.labels), len(data['targets']))) 387 | classes = np.unique(self.labels) 388 | for lbl in classes: 389 | print('Class %d: \t\t\t%d samples' % (lbl, np.sum(self.labels == lbl))) 390 | 391 | def __len__(self): 392 | return len(self.labels) 393 | 394 | def __getitem__(self, index): 395 | if isinstance(index, str): 396 | # To make data format consistent with SyntheticGraphs 397 | if index == 'Adj_matrices': 398 | return self.adj_list 399 | elif index == 'GT_attn': 400 | print('Ground truth attention is unavailable for this dataset: weakly-supervised labels will be returned') 401 | return self.w_sup_signal_attn 402 | elif index == 'graph_labels': 403 | return self.labels 404 | elif index == 'node_features': 405 | return self.features_onehot 406 | elif index == 'N_edges': 407 | return self.N_edges 408 | else: 409 | raise KeyError(index) 410 | else: 411 | data = [self.features_onehot[index], 412 | self.adj_list[index], 413 | self.adj_list[index].shape[0], 414 | self.labels[index], 415 | None] # no GT attention 416 | if self.w_sup_signal_attn is not None: 417 | data.append(self.w_sup_signal_attn[index]) 418 | data = list_to_torch(data) # convert to torch 419 | 420 | return data 421 | 422 | 423 | class DataReader(): 424 | ''' 425 | Class to read the txt files containing all data of the dataset 426 | Should work for any dataset from https://ls11-www.cs.tu-dortmund.de/staff/morris/graphkerneldatasets 427 | ''' 428 | 429 | def __init__(self, 430 | data_dir, # folder with txt files 431 | N_nodes=None, # maximum number of nodes in the training set 432 | rnd_state=None, 433 | use_cont_node_attr=False, # use or not additional float valued node attributes available in some datasets 434 | folds=10, 435 | fold_id=None): 436 | 437 | self.data_dir = data_dir 438 | self.rnd_state = np.random.RandomState() if rnd_state is None else rnd_state 439 | self.use_cont_node_attr = use_cont_node_attr 440 | self.N_nodes = N_nodes 441 | if os.path.isfile('%s/data.pkl' % data_dir): 442 | print('loading data from %s/data.pkl' % data_dir) 443 | with open('%s/data.pkl' % data_dir, 'rb') as f: 444 | data = pickle.load(f) 445 | else: 446 | files = os.listdir(self.data_dir) 447 | data = {} 448 | nodes, graphs = self.read_graph_nodes_relations( 449 | list(filter(lambda f: f.find('graph_indicator') >= 0, files))[0]) 450 | lst = list(filter(lambda f: f.find('node_labels') >= 0, files)) 451 | if len(lst) > 0: 452 | data['features'] = self.read_node_features(lst[0], nodes, graphs, fn=lambda s: int(s.strip())) 453 | else: 454 | data['features'] = None 455 | data['adj_list'] = self.read_graph_adj(list(filter(lambda f: f.find('_A') >= 0, files))[0], nodes, graphs) 456 | data['targets'] = np.array( 457 | self.parse_txt_file(list(filter(lambda f: f.find('graph_labels') >= 0, files))[0], 458 | line_parse_fn=lambda s: int(float(s.strip())))) 459 | 460 | if self.use_cont_node_attr: 461 | data['attr'] = self.read_node_features(list(filter(lambda f: f.find('node_attributes') >= 0, files))[0], 462 | nodes, graphs, 463 | fn=lambda s: np.array(list(map(float, s.strip().split(','))))) 464 | 465 | features, n_edges, degrees = [], [], [] 466 | for sample_id, adj in enumerate(data['adj_list']): 467 | N = len(adj) # number of nodes 468 | if data['features'] is not None: 469 | assert N == len(data['features'][sample_id]), (N, len(data['features'][sample_id])) 470 | n = np.sum(adj) # total sum of edges 471 | # assert n % 2 == 0, n 472 | n_edges.append(int(n / 2)) # undirected edges, so need to divide by 2 473 | if not np.allclose(adj, adj.T): 474 | print(sample_id, 'not symmetric') 475 | degrees.extend(list(np.sum(adj, 1))) 476 | if data['features'] is not None: 477 | features.append(np.array(data['features'][sample_id])) 478 | 479 | # Create features over graphs as one-hot vectors for each node 480 | if data['features'] is not None: 481 | features_all = np.concatenate(features) 482 | features_min = features_all.min() 483 | num_features = int(features_all.max() - features_min + 1) # number of possible values 484 | 485 | features_onehot = [] 486 | for i, x in enumerate(features): 487 | feature_onehot = np.zeros((len(x), num_features)) 488 | for node, value in enumerate(x): 489 | feature_onehot[node, value - features_min] = 1 490 | if self.use_cont_node_attr: 491 | feature_onehot = np.concatenate((feature_onehot, np.array(data['attr'][i])), axis=1) 492 | features_onehot.append(feature_onehot) 493 | 494 | if self.use_cont_node_attr: 495 | num_features = features_onehot[0].shape[1] 496 | else: 497 | degree_max = int(np.max([np.sum(A, 1).max() for A in data['adj_list']])) 498 | num_features = degree_max + 1 499 | features_onehot = [] 500 | for A in data['adj_list']: 501 | n = A.shape[0] 502 | D = np.sum(A, 1).astype(np.int) 503 | D_onehot = np.zeros((n, num_features)) 504 | D_onehot[np.arange(n), D] = 1 505 | features_onehot.append(D_onehot) 506 | 507 | shapes = [len(adj) for adj in data['adj_list']] 508 | labels = data['targets'] # graph class labels 509 | labels -= np.min(labels) # to start from 0 510 | 511 | classes = np.unique(labels) 512 | num_classes = len(classes) 513 | 514 | if not np.all(np.diff(classes) == 1): 515 | print('making labels sequential, otherwise pytorch might crash') 516 | labels_new = np.zeros(labels.shape, dtype=labels.dtype) - 1 517 | for lbl in range(num_classes): 518 | labels_new[labels == classes[lbl]] = lbl 519 | labels = labels_new 520 | classes = np.unique(labels) 521 | assert len(np.unique(labels)) == num_classes, np.unique(labels) 522 | 523 | def stats(x): 524 | return (np.mean(x), np.std(x), np.min(x), np.max(x)) 525 | 526 | print('N nodes avg/std/min/max: \t%.2f/%.2f/%d/%d' % stats(shapes)) 527 | print('N edges avg/std/min/max: \t%.2f/%.2f/%d/%d' % stats(n_edges)) 528 | print('Node degree avg/std/min/max: \t%.2f/%.2f/%d/%d' % stats(degrees)) 529 | print('Node features dim: \t\t%d' % num_features) 530 | print('N classes: \t\t\t%d' % num_classes) 531 | print('Classes: \t\t\t%s' % str(classes)) 532 | for lbl in classes: 533 | print('Class %d: \t\t\t%d samples' % (lbl, np.sum(labels == lbl))) 534 | 535 | if data['features'] is not None: 536 | for u in np.unique(features_all): 537 | print('feature {}, count {}/{}'.format(u, np.count_nonzero(features_all == u), len(features_all))) 538 | 539 | N_graphs = len(labels) # number of samples (graphs) in data 540 | assert N_graphs == len(data['adj_list']) == len(features_onehot), 'invalid data' 541 | 542 | data['features_onehot'] = features_onehot 543 | data['targets'] = labels 544 | 545 | data['N_nodes_max'] = np.max(shapes) # max number of nodes 546 | data['num_features'] = num_features 547 | data['num_classes'] = num_classes 548 | 549 | # Save preprocessed data for faster loading 550 | with open('%s/data.pkl' % data_dir, 'wb') as f: 551 | pickle.dump(data, f, protocol=2) 552 | 553 | labels = data['targets'] 554 | # Create test sets first 555 | N_graphs = len(labels) 556 | shapes = np.array([len(adj) for adj in data['adj_list']]) 557 | train_ids, val_ids, train_val_ids, test_ids = self.split_ids_shape(np.arange(N_graphs), shapes, N_nodes, folds=folds) 558 | 559 | # Create train sets 560 | splits = {'train': [], 'val': [], 'train_val': train_val_ids, 'test': test_ids} 561 | for fold in range(folds): 562 | splits['train'].append(train_ids[fold]) 563 | splits['val'].append(val_ids[fold]) 564 | 565 | data['splits'] = splits 566 | 567 | self.data = data 568 | 569 | 570 | def split_ids_shape(self, ids_all, shapes, N_nodes, folds=1, fold_id=0): 571 | if N_nodes > 0: 572 | small_graphs_ind = np.where(shapes <= N_nodes)[0] 573 | print('{}/{} graphs with at least {} nodes'.format(len(small_graphs_ind), len(shapes), N_nodes)) 574 | idx = self.rnd_state.permutation(len(small_graphs_ind)) 575 | if len(idx) > 1000: 576 | n = 1000 577 | else: 578 | n = 500 579 | train_val_ids = small_graphs_ind[idx[:n]] 580 | test_ids = small_graphs_ind[idx[n:]] 581 | large_graphs_ind = np.where(shapes > N_nodes)[0] 582 | test_ids = np.concatenate((test_ids, large_graphs_ind)) 583 | else: 584 | idx = self.rnd_state.permutation(len(ids_all)) 585 | n = len(ids_all) // folds # number of test samples 586 | test_ids = ids_all[idx[fold_id * n: (fold_id + 1) * n if fold_id < folds - 1 else -1]] 587 | train_val_ids = [] 588 | for i in ids_all: 589 | if i not in test_ids: 590 | train_val_ids.append(i) 591 | train_val_ids = np.array(train_val_ids) 592 | 593 | assert np.all( 594 | np.unique(np.concatenate((train_val_ids, test_ids))) == sorted(ids_all)), 'some graphs are missing in the test sets' 595 | if folds > 0: 596 | print('generating %d-fold cross-validation splits' % folds) 597 | train_ids, val_ids = self.split_ids(train_val_ids, folds=folds) 598 | # Sanity checks 599 | for fold in range(folds): 600 | ind = np.concatenate((train_ids[fold], val_ids[fold])) 601 | print(fold, len(train_ids[fold]), len(val_ids[fold])) 602 | assert len(train_ids[fold]) + len(val_ids[fold]) == len(np.unique(ind)) == len(ind) == len(train_val_ids), 'invalid splits' 603 | else: 604 | train_ids, val_ids = [], [] 605 | 606 | return train_ids, val_ids, train_val_ids, test_ids 607 | 608 | def split_ids(self, ids, folds=10): 609 | n = len(ids) 610 | stride = int(np.ceil(n / float(folds))) 611 | test_ids = [ids[i: i + stride] for i in range(0, n, stride)] 612 | assert np.all( 613 | np.unique(np.concatenate(test_ids)) == sorted(ids)), 'some graphs are missing in the test sets' 614 | assert len(test_ids) == folds, 'invalid test sets' 615 | train_ids = [] 616 | for fold in range(folds): 617 | train_ids.append(np.array([e for e in ids if e not in test_ids[fold]])) 618 | assert len(train_ids[fold]) + len(test_ids[fold]) == len( 619 | np.unique(list(train_ids[fold]) + list(test_ids[fold]))) == n, 'invalid splits' 620 | 621 | return train_ids, test_ids 622 | 623 | def parse_txt_file(self, fpath, line_parse_fn=None): 624 | with open(pjoin(self.data_dir, fpath), 'r') as f: 625 | lines = f.readlines() 626 | data = [line_parse_fn(s) if line_parse_fn is not None else s for s in lines] 627 | return data 628 | 629 | def read_graph_adj(self, fpath, nodes, graphs): 630 | edges = self.parse_txt_file(fpath, line_parse_fn=lambda s: s.split(',')) 631 | adj_dict = {} 632 | for edge in edges: 633 | node1 = int(edge[0].strip()) - 1 # -1 because of zero-indexing in our code 634 | node2 = int(edge[1].strip()) - 1 635 | graph_id = nodes[node1] 636 | assert graph_id == nodes[node2], ('invalid data', graph_id, nodes[node2]) 637 | if graph_id not in adj_dict: 638 | n = len(graphs[graph_id]) 639 | adj_dict[graph_id] = np.zeros((n, n)) 640 | ind1 = np.where(graphs[graph_id] == node1)[0] 641 | ind2 = np.where(graphs[graph_id] == node2)[0] 642 | assert len(ind1) == len(ind2) == 1, (ind1, ind2) 643 | adj_dict[graph_id][ind1, ind2] = 1 644 | 645 | adj_list = [adj_dict[graph_id] for graph_id in sorted(list(graphs.keys()))] 646 | 647 | return adj_list 648 | 649 | def read_graph_nodes_relations(self, fpath): 650 | graph_ids = self.parse_txt_file(fpath, line_parse_fn=lambda s: int(s.rstrip())) 651 | nodes, graphs = {}, {} 652 | for node_id, graph_id in enumerate(graph_ids): 653 | if graph_id not in graphs: 654 | graphs[graph_id] = [] 655 | graphs[graph_id].append(node_id) 656 | nodes[node_id] = graph_id 657 | graph_ids = np.unique(list(graphs.keys())) 658 | for graph_id in graph_ids: 659 | graphs[graph_id] = np.array(graphs[graph_id]) 660 | return nodes, graphs 661 | 662 | def read_node_features(self, fpath, nodes, graphs, fn): 663 | node_features_all = self.parse_txt_file(fpath, line_parse_fn=fn) 664 | node_features = {} 665 | for node_id, x in enumerate(node_features_all): 666 | graph_id = nodes[node_id] 667 | if graph_id not in node_features: 668 | node_features[graph_id] = [None] * len(graphs[graph_id]) 669 | ind = np.where(graphs[graph_id] == node_id)[0] 670 | assert len(ind) == 1, ind 671 | assert node_features[graph_id][ind[0]] is None, node_features[graph_id][ind[0]] 672 | node_features[graph_id][ind[0]] = x 673 | node_features_lst = [node_features[graph_id] for graph_id in sorted(list(graphs.keys()))] 674 | return node_features_lst 675 | -------------------------------------------------------------------------------- /logs/colors-3_global_max_seed111.log: -------------------------------------------------------------------------------- 1 | start time: 2019-06-21 11:23:37.828931 2 | gpus: 1 3 | dataset colors-3 4 | data_dir ./data 5 | epochs 100 6 | batch_size 32 7 | lr 0.001 8 | lr_decay_step [90] 9 | wdecay 0.0001 10 | dropout 0.0 11 | filters [64, 64] 12 | filter_scale 2 13 | n_hidden 0 14 | aggregation mean 15 | readout sum 16 | kl_weight 100 17 | pool None 18 | pool_arch ['fc', 'prev'] 19 | n_nodes 25 20 | cv_folds 5 21 | img_features ['mean', 'coord'] 22 | img_noise_levels [0.4, 0.6] 23 | validation False 24 | debug False 25 | eval_attn_train True 26 | eval_attn_test True 27 | test_batch_size 100 28 | alpha_ws None 29 | log_interval 400 30 | results ./checkpoints/ 31 | resume None 32 | device cuda 33 | seed 111 34 | threads 0 35 | experiment_ID: 828931 36 | train Adj_matrices 500 37 | train GT_attn 500 38 | train graph_labels 500 39 | train node_features 500 40 | train N_edges 500 41 | N nodes avg/std/min/max: 14.26/6.22/4/25 42 | N edges avg/std/min/max: 20.87/10.49/4/48 43 | Node degree avg/std/min/max: 2.93/1.53/0/9 44 | Node features dim: 4 45 | N classes: 11 46 | Correlation of labels with graph size: 0.18 47 | Classes: [ 0 1 2 3 4 5 6 7 8 9 10] 48 | Class 0: 58 samples, N_nodes: avg/std/min/max: 12.90/6.79/4/25 49 | Class 1: 56 samples, N_nodes: avg/std/min/max: 14.00/6.55/4/25 50 | Class 2: 52 samples, N_nodes: avg/std/min/max: 14.46/7.10/4/25 51 | Class 3: 56 samples, N_nodes: avg/std/min/max: 11.86/5.95/4/25 52 | Class 4: 53 samples, N_nodes: avg/std/min/max: 13.55/6.16/4/25 53 | Class 5: 49 samples, N_nodes: avg/std/min/max: 14.73/6.38/5/25 54 | Class 6: 49 samples, N_nodes: avg/std/min/max: 13.57/4.98/6/23 55 | Class 7: 30 samples, N_nodes: avg/std/min/max: 15.37/5.64/7/25 56 | Class 8: 33 samples, N_nodes: avg/std/min/max: 17.24/4.91/9/24 57 | Class 9: 35 samples, N_nodes: avg/std/min/max: 16.11/4.90/9/25 58 | Class 10: 29 samples, N_nodes: avg/std/min/max: 16.66/4.95/10/25 59 | test Adj_matrices 7500 60 | test GT_attn 7500 61 | test graph_labels 7500 62 | test node_features 7500 63 | test N_edges 7500 64 | N nodes avg/std/min/max: 80.64/62.43/4/200 65 | N edges avg/std/min/max: 120.95/98.42/4/396 66 | Node degree avg/std/min/max: 3.00/1.78/0/14 67 | Node features dim: 4 68 | N classes: 11 69 | Correlation of labels with graph size: 0.06 70 | Classes: [ 0 1 2 3 4 5 6 7 8 9 10] 71 | Class 0: 730 samples, N_nodes: avg/std/min/max: 76.34/62.01/4/200 72 | Class 1: 699 samples, N_nodes: avg/std/min/max: 77.19/63.03/4/200 73 | Class 2: 706 samples, N_nodes: avg/std/min/max: 77.07/62.79/4/200 74 | Class 3: 699 samples, N_nodes: avg/std/min/max: 75.73/62.78/4/200 75 | Class 4: 717 samples, N_nodes: avg/std/min/max: 77.32/62.59/4/200 76 | Class 5: 697 samples, N_nodes: avg/std/min/max: 81.44/63.24/5/200 77 | Class 6: 687 samples, N_nodes: avg/std/min/max: 82.32/61.91/6/200 78 | Class 7: 690 samples, N_nodes: avg/std/min/max: 82.41/61.77/7/200 79 | Class 8: 638 samples, N_nodes: avg/std/min/max: 84.81/62.83/8/200 80 | Class 9: 641 samples, N_nodes: avg/std/min/max: 85.45/61.25/9/200 81 | Class 10: 596 samples, N_nodes: avg/std/min/max: 89.37/60.59/10/200 82 | ChebyGINLayer torch.Size([64, 8]) tensor([0.5559, 0.4973, 0.4914, 0.5325, 0.4516, 0.5541, 0.7513, 0.6097, 0.5125, 83 | 0.5658], grad_fn=) 84 | ChebyGINLayer torch.Size([64, 128]) tensor([0.5931, 0.5994, 0.5544, 0.5491, 0.5944, 0.5986, 0.6081, 0.5848, 0.5648, 85 | 0.5838], grad_fn=) 86 | ChebyGIN( 87 | (graph_layers): Sequential( 88 | (0): ChebyGINLayer(in_features=4, out_features=64, K=2, n_hidden=0, aggregation=mean) 89 | fc=Sequential( 90 | (0): Linear(in_features=8, out_features=64, bias=True) 91 | (1): ReLU(inplace) 92 | ) 93 | (1): ChebyGINLayer(in_features=64, out_features=64, K=2, n_hidden=0, aggregation=mean) 94 | fc=Sequential( 95 | (0): Linear(in_features=128, out_features=64, bias=True) 96 | (1): ReLU(inplace) 97 | ) 98 | (2): GraphReadout(sum) 99 | ) 100 | (fc): Sequential( 101 | (0): Linear(in_features=64, out_features=1, bias=True) 102 | ) 103 | ) 104 | model capacity: 8897 105 | model is checked for nodes shuffling 106 | Train set (epoch 1): [500/500 (100%)] Loss: 4.0388 (avg: 13.2787), other losses: [] Acc metric: 70/500 (14.00%) AttnAUC: [] avg sec/iter: 0.0076 107 | 108 | 109 | saving the model to ./checkpoints//checkpoint_colors-3_828931_epoch1_seed0000111.pth.tar 110 | model is checked for nodes shuffling 111 | lbl: 0, avg acc: 0.00% (0/58) 112 | lbl: 1, avg acc: 21.43% (12/56) 113 | lbl: 2, avg acc: 21.15% (11/52) 114 | lbl: 3, avg acc: 23.21% (13/56) 115 | lbl: 4, avg acc: 32.08% (17/53) 116 | lbl: 5, avg acc: 20.41% (10/49) 117 | lbl: 6, avg acc: 14.29% (7/49) 118 | lbl: 7, avg acc: 10.00% (3/30) 119 | lbl: 8, avg acc: 0.00% (0/33) 120 | lbl: 9, avg acc: 0.00% (0/35) 121 | lbl: 10, avg acc: 0.00% (0/29) 122 | 0 <= N_nodes <= 25 (min=4, max=25), avg acc: 14.60% (73/500) 123 | Train set (epoch 1): Avg loss: 4.7120, Acc metric: 73/500 (14.60%) AttnAUC: [] avg sec/iter: 0.0268 124 | 125 | model is checked for nodes shuffling 126 | lbl: 0, avg acc: 0.00% (0/262) 127 | lbl: 1, avg acc: 31.73% (79/249) 128 | lbl: 2, avg acc: 28.29% (73/258) 129 | lbl: 3, avg acc: 24.32% (63/259) 130 | lbl: 4, avg acc: 29.00% (78/269) 131 | lbl: 5, avg acc: 25.00% (59/236) 132 | lbl: 6, avg acc: 18.40% (39/212) 133 | lbl: 7, avg acc: 4.59% (10/218) 134 | lbl: 8, avg acc: 0.00% (0/196) 135 | lbl: 9, avg acc: 0.00% (0/192) 136 | lbl: 10, avg acc: 0.00% (0/149) 137 | 0 <= N_nodes <= 25 (min=4, max=25), avg acc: 16.04% (401/2500) 138 | lbl: 0, avg acc: 0.00% (0/226) 139 | lbl: 1, avg acc: 0.00% (0/216) 140 | lbl: 2, avg acc: 0.00% (0/223) 141 | lbl: 3, avg acc: 0.00% (0/237) 142 | lbl: 4, avg acc: 0.00% (0/229) 143 | lbl: 5, avg acc: 0.00% (0/231) 144 | lbl: 6, avg acc: 1.25% (3/240) 145 | lbl: 7, avg acc: 2.07% (5/242) 146 | lbl: 8, avg acc: 2.82% (6/213) 147 | lbl: 9, avg acc: 2.71% (6/221) 148 | lbl: 10, avg acc: 1.80% (4/222) 149 | 26 <= N_nodes <= 200 (min=26, max=200), avg acc: 0.96% (24/2500) 150 | lbl: 0, avg acc: 0.00% (0/242) 151 | lbl: 1, avg acc: 0.00% (0/234) 152 | lbl: 2, avg acc: 0.00% (0/225) 153 | lbl: 3, avg acc: 0.00% (0/203) 154 | lbl: 4, avg acc: 0.00% (0/219) 155 | lbl: 5, avg acc: 0.00% (0/230) 156 | lbl: 6, avg acc: 0.43% (1/235) 157 | lbl: 7, avg acc: 3.04% (7/230) 158 | lbl: 8, avg acc: 1.31% (3/229) 159 | lbl: 9, avg acc: 3.07% (7/228) 160 | lbl: 10, avg acc: 2.67% (6/225) 161 | 26 <= N_nodes <= 200 (min=26, max=200), avg acc: 0.96% (24/2500) 162 | Test set (epoch 1): Avg loss: 218.5921, Acc metric: 449/7500 (5.99%) AttnAUC: [] avg sec/iter: 0.0118 163 | 164 | Train set (epoch 2): [500/500 (100%)] Loss: 0.5873 (avg: 2.1991), other losses: [] Acc metric: 133/500 (26.60%) AttnAUC: [] avg sec/iter: 0.0035 165 | 166 | 167 | Train set (epoch 3): [500/500 (100%)] Loss: 0.1954 (avg: 0.2053), other losses: [] Acc metric: 382/500 (76.40%) AttnAUC: [] avg sec/iter: 0.0030 168 | 169 | 170 | Train set (epoch 4): [500/500 (100%)] Loss: 0.1411 (avg: 0.1459), other losses: [] Acc metric: 417/500 (83.40%) AttnAUC: [] avg sec/iter: 0.0030 171 | 172 | 173 | Train set (epoch 5): [500/500 (100%)] Loss: 0.1857 (avg: 0.1294), other losses: [] Acc metric: 427/500 (85.40%) AttnAUC: [] avg sec/iter: 0.0030 174 | 175 | 176 | Train set (epoch 6): [500/500 (100%)] Loss: 0.1421 (avg: 0.1294), other losses: [] Acc metric: 432/500 (86.40%) AttnAUC: [] avg sec/iter: 0.0069 177 | 178 | 179 | Train set (epoch 7): [500/500 (100%)] Loss: 0.0833 (avg: 0.1182), other losses: [] Acc metric: 432/500 (86.40%) AttnAUC: [] avg sec/iter: 0.0036 180 | 181 | 182 | Train set (epoch 8): [500/500 (100%)] Loss: 0.1359 (avg: 0.1069), other losses: [] Acc metric: 437/500 (87.40%) AttnAUC: [] avg sec/iter: 0.0035 183 | 184 | 185 | Train set (epoch 9): [500/500 (100%)] Loss: 0.2092 (avg: 0.1067), other losses: [] Acc metric: 435/500 (87.00%) AttnAUC: [] avg sec/iter: 0.0031 186 | 187 | 188 | Train set (epoch 10): [500/500 (100%)] Loss: 0.0502 (avg: 0.0982), other losses: [] Acc metric: 450/500 (90.00%) AttnAUC: [] avg sec/iter: 0.0032 189 | 190 | 191 | Train set (epoch 11): [500/500 (100%)] Loss: 0.0328 (avg: 0.0914), other losses: [] Acc metric: 450/500 (90.00%) AttnAUC: [] avg sec/iter: 0.0032 192 | 193 | 194 | Train set (epoch 12): [500/500 (100%)] Loss: 0.1351 (avg: 0.0882), other losses: [] Acc metric: 450/500 (90.00%) AttnAUC: [] avg sec/iter: 0.0034 195 | 196 | 197 | Train set (epoch 13): [500/500 (100%)] Loss: 0.1683 (avg: 0.0780), other losses: [] Acc metric: 459/500 (91.80%) AttnAUC: [] avg sec/iter: 0.0034 198 | 199 | 200 | Train set (epoch 14): [500/500 (100%)] Loss: 0.0541 (avg: 0.0779), other losses: [] Acc metric: 457/500 (91.40%) AttnAUC: [] avg sec/iter: 0.0033 201 | 202 | 203 | Train set (epoch 15): [500/500 (100%)] Loss: 0.0494 (avg: 0.0691), other losses: [] Acc metric: 461/500 (92.20%) AttnAUC: [] avg sec/iter: 0.0034 204 | 205 | 206 | Train set (epoch 16): [500/500 (100%)] Loss: 0.0533 (avg: 0.0665), other losses: [] Acc metric: 460/500 (92.00%) AttnAUC: [] avg sec/iter: 0.0041 207 | 208 | 209 | Train set (epoch 17): [500/500 (100%)] Loss: 0.0395 (avg: 0.0572), other losses: [] Acc metric: 470/500 (94.00%) AttnAUC: [] avg sec/iter: 0.0040 210 | 211 | 212 | Train set (epoch 18): [500/500 (100%)] Loss: 0.0702 (avg: 0.0482), other losses: [] Acc metric: 479/500 (95.80%) AttnAUC: [] avg sec/iter: 0.0038 213 | 214 | 215 | Train set (epoch 19): [500/500 (100%)] Loss: 0.0504 (avg: 0.0463), other losses: [] Acc metric: 482/500 (96.40%) AttnAUC: [] avg sec/iter: 0.0035 216 | 217 | 218 | Train set (epoch 20): [500/500 (100%)] Loss: 0.0798 (avg: 0.0495), other losses: [] Acc metric: 474/500 (94.80%) AttnAUC: [] avg sec/iter: 0.0035 219 | 220 | 221 | Train set (epoch 21): [500/500 (100%)] Loss: 0.0540 (avg: 0.0386), other losses: [] Acc metric: 489/500 (97.80%) AttnAUC: [] avg sec/iter: 0.0038 222 | 223 | 224 | Train set (epoch 22): [500/500 (100%)] Loss: 0.0729 (avg: 0.0523), other losses: [] Acc metric: 471/500 (94.20%) AttnAUC: [] avg sec/iter: 0.0038 225 | 226 | 227 | Train set (epoch 23): [500/500 (100%)] Loss: 0.0210 (avg: 0.0301), other losses: [] Acc metric: 493/500 (98.60%) AttnAUC: [] avg sec/iter: 0.0040 228 | 229 | 230 | Train set (epoch 24): [500/500 (100%)] Loss: 0.0591 (avg: 0.0323), other losses: [] Acc metric: 489/500 (97.80%) AttnAUC: [] avg sec/iter: 0.0040 231 | 232 | 233 | Train set (epoch 25): [500/500 (100%)] Loss: 0.0328 (avg: 0.0287), other losses: [] Acc metric: 491/500 (98.20%) AttnAUC: [] avg sec/iter: 0.0039 234 | 235 | 236 | Train set (epoch 26): [500/500 (100%)] Loss: 0.0195 (avg: 0.0281), other losses: [] Acc metric: 495/500 (99.00%) AttnAUC: [] avg sec/iter: 0.0039 237 | 238 | 239 | Train set (epoch 27): [500/500 (100%)] Loss: 0.0114 (avg: 0.0240), other losses: [] Acc metric: 496/500 (99.20%) AttnAUC: [] avg sec/iter: 0.0039 240 | 241 | 242 | Train set (epoch 28): [500/500 (100%)] Loss: 0.0369 (avg: 0.0207), other losses: [] Acc metric: 496/500 (99.20%) AttnAUC: [] avg sec/iter: 0.0044 243 | 244 | 245 | Train set (epoch 29): [500/500 (100%)] Loss: 0.0334 (avg: 0.0266), other losses: [] Acc metric: 495/500 (99.00%) AttnAUC: [] avg sec/iter: 0.0056 246 | 247 | 248 | Train set (epoch 30): [500/500 (100%)] Loss: 0.0079 (avg: 0.0249), other losses: [] Acc metric: 493/500 (98.60%) AttnAUC: [] avg sec/iter: 0.0078 249 | 250 | 251 | Train set (epoch 31): [500/500 (100%)] Loss: 0.0076 (avg: 0.0161), other losses: [] Acc metric: 498/500 (99.60%) AttnAUC: [] avg sec/iter: 0.0060 252 | 253 | 254 | Train set (epoch 32): [500/500 (100%)] Loss: 0.0076 (avg: 0.0218), other losses: [] Acc metric: 497/500 (99.40%) AttnAUC: [] avg sec/iter: 0.0044 255 | 256 | 257 | Train set (epoch 33): [500/500 (100%)] Loss: 0.0301 (avg: 0.0235), other losses: [] Acc metric: 496/500 (99.20%) AttnAUC: [] avg sec/iter: 0.0037 258 | 259 | 260 | Train set (epoch 34): [500/500 (100%)] Loss: 0.0307 (avg: 0.0154), other losses: [] Acc metric: 497/500 (99.40%) AttnAUC: [] avg sec/iter: 0.0033 261 | 262 | 263 | Train set (epoch 35): [500/500 (100%)] Loss: 0.0156 (avg: 0.0122), other losses: [] Acc metric: 499/500 (99.80%) AttnAUC: [] avg sec/iter: 0.0033 264 | 265 | 266 | Train set (epoch 36): [500/500 (100%)] Loss: 0.0095 (avg: 0.0110), other losses: [] Acc metric: 498/500 (99.60%) AttnAUC: [] avg sec/iter: 0.0033 267 | 268 | 269 | Train set (epoch 37): [500/500 (100%)] Loss: 0.0799 (avg: 0.0174), other losses: [] Acc metric: 500/500 (100.00%) AttnAUC: [] avg sec/iter: 0.0033 270 | 271 | 272 | Train set (epoch 38): [500/500 (100%)] Loss: 0.0030 (avg: 0.0129), other losses: [] Acc metric: 499/500 (99.80%) AttnAUC: [] avg sec/iter: 0.0033 273 | 274 | 275 | Train set (epoch 39): [500/500 (100%)] Loss: 0.0798 (avg: 0.0168), other losses: [] Acc metric: 497/500 (99.40%) AttnAUC: [] avg sec/iter: 0.0039 276 | 277 | 278 | Train set (epoch 40): [500/500 (100%)] Loss: 0.0138 (avg: 0.0181), other losses: [] Acc metric: 499/500 (99.80%) AttnAUC: [] avg sec/iter: 0.0041 279 | 280 | 281 | Train set (epoch 41): [500/500 (100%)] Loss: 0.0121 (avg: 0.0092), other losses: [] Acc metric: 499/500 (99.80%) AttnAUC: [] avg sec/iter: 0.0040 282 | 283 | 284 | Train set (epoch 42): [500/500 (100%)] Loss: 0.0390 (avg: 0.0146), other losses: [] Acc metric: 499/500 (99.80%) AttnAUC: [] avg sec/iter: 0.0034 285 | 286 | 287 | Train set (epoch 43): [500/500 (100%)] Loss: 0.0235 (avg: 0.0140), other losses: [] Acc metric: 499/500 (99.80%) AttnAUC: [] avg sec/iter: 0.0038 288 | 289 | 290 | Train set (epoch 44): [500/500 (100%)] Loss: 0.0053 (avg: 0.0102), other losses: [] Acc metric: 500/500 (100.00%) AttnAUC: [] avg sec/iter: 0.0033 291 | 292 | 293 | Train set (epoch 45): [500/500 (100%)] Loss: 0.0100 (avg: 0.0125), other losses: [] Acc metric: 500/500 (100.00%) AttnAUC: [] avg sec/iter: 0.0034 294 | 295 | 296 | Train set (epoch 46): [500/500 (100%)] Loss: 0.0064 (avg: 0.0086), other losses: [] Acc metric: 500/500 (100.00%) AttnAUC: [] avg sec/iter: 0.0033 297 | 298 | 299 | Train set (epoch 47): [500/500 (100%)] Loss: 0.0144 (avg: 0.0266), other losses: [] Acc metric: 498/500 (99.60%) AttnAUC: [] avg sec/iter: 0.0034 300 | 301 | 302 | Train set (epoch 48): [500/500 (100%)] Loss: 0.0031 (avg: 0.0076), other losses: [] Acc metric: 500/500 (100.00%) AttnAUC: [] avg sec/iter: 0.0041 303 | 304 | 305 | Train set (epoch 49): [500/500 (100%)] Loss: 0.0038 (avg: 0.0063), other losses: [] Acc metric: 500/500 (100.00%) AttnAUC: [] avg sec/iter: 0.0037 306 | 307 | 308 | Train set (epoch 50): [500/500 (100%)] Loss: 0.0127 (avg: 0.0132), other losses: [] Acc metric: 500/500 (100.00%) AttnAUC: [] avg sec/iter: 0.0040 309 | 310 | 311 | Train set (epoch 51): [500/500 (100%)] Loss: 0.0026 (avg: 0.0129), other losses: [] Acc metric: 500/500 (100.00%) AttnAUC: [] avg sec/iter: 0.0041 312 | 313 | 314 | Train set (epoch 52): [500/500 (100%)] Loss: 0.0416 (avg: 0.0121), other losses: [] Acc metric: 500/500 (100.00%) AttnAUC: [] avg sec/iter: 0.0059 315 | 316 | 317 | Train set (epoch 53): [500/500 (100%)] Loss: 0.0271 (avg: 0.0148), other losses: [] Acc metric: 499/500 (99.80%) AttnAUC: [] avg sec/iter: 0.0060 318 | 319 | 320 | Train set (epoch 54): [500/500 (100%)] Loss: 0.0023 (avg: 0.0093), other losses: [] Acc metric: 499/500 (99.80%) AttnAUC: [] avg sec/iter: 0.0054 321 | 322 | 323 | Train set (epoch 55): [500/500 (100%)] Loss: 0.0026 (avg: 0.0056), other losses: [] Acc metric: 500/500 (100.00%) AttnAUC: [] avg sec/iter: 0.0051 324 | 325 | 326 | Train set (epoch 56): [500/500 (100%)] Loss: 0.0046 (avg: 0.0093), other losses: [] Acc metric: 500/500 (100.00%) AttnAUC: [] avg sec/iter: 0.0047 327 | 328 | 329 | Train set (epoch 57): [500/500 (100%)] Loss: 0.0070 (avg: 0.0046), other losses: [] Acc metric: 500/500 (100.00%) AttnAUC: [] avg sec/iter: 0.0071 330 | 331 | 332 | Train set (epoch 58): [500/500 (100%)] Loss: 0.0124 (avg: 0.0259), other losses: [] Acc metric: 499/500 (99.80%) AttnAUC: [] avg sec/iter: 0.0071 333 | 334 | 335 | Train set (epoch 59): [500/500 (100%)] Loss: 0.0123 (avg: 0.0149), other losses: [] Acc metric: 500/500 (100.00%) AttnAUC: [] avg sec/iter: 0.0077 336 | 337 | 338 | Train set (epoch 60): [500/500 (100%)] Loss: 0.0111 (avg: 0.0205), other losses: [] Acc metric: 498/500 (99.60%) AttnAUC: [] avg sec/iter: 0.0068 339 | 340 | 341 | Train set (epoch 61): [500/500 (100%)] Loss: 0.0123 (avg: 0.0052), other losses: [] Acc metric: 500/500 (100.00%) AttnAUC: [] avg sec/iter: 0.0059 342 | 343 | 344 | Train set (epoch 62): [500/500 (100%)] Loss: 0.0062 (avg: 0.0051), other losses: [] Acc metric: 500/500 (100.00%) AttnAUC: [] avg sec/iter: 0.0055 345 | 346 | 347 | Train set (epoch 63): [500/500 (100%)] Loss: 0.0290 (avg: 0.0056), other losses: [] Acc metric: 500/500 (100.00%) AttnAUC: [] avg sec/iter: 0.0053 348 | 349 | 350 | Train set (epoch 64): [500/500 (100%)] Loss: 0.0457 (avg: 0.0792), other losses: [] Acc metric: 466/500 (93.20%) AttnAUC: [] avg sec/iter: 0.0057 351 | 352 | 353 | Train set (epoch 65): [500/500 (100%)] Loss: 0.0024 (avg: 0.0074), other losses: [] Acc metric: 500/500 (100.00%) AttnAUC: [] avg sec/iter: 0.0073 354 | 355 | 356 | Train set (epoch 66): [500/500 (100%)] Loss: 0.0019 (avg: 0.0048), other losses: [] Acc metric: 500/500 (100.00%) AttnAUC: [] avg sec/iter: 0.0061 357 | 358 | 359 | Train set (epoch 67): [500/500 (100%)] Loss: 0.0031 (avg: 0.0033), other losses: [] Acc metric: 500/500 (100.00%) AttnAUC: [] avg sec/iter: 0.0056 360 | 361 | 362 | Train set (epoch 68): [500/500 (100%)] Loss: 0.0108 (avg: 0.0169), other losses: [] Acc metric: 500/500 (100.00%) AttnAUC: [] avg sec/iter: 0.0052 363 | 364 | 365 | Train set (epoch 69): [500/500 (100%)] Loss: 0.0166 (avg: 0.0050), other losses: [] Acc metric: 500/500 (100.00%) AttnAUC: [] avg sec/iter: 0.0067 366 | 367 | 368 | Train set (epoch 70): [500/500 (100%)] Loss: 0.0037 (avg: 0.0032), other losses: [] Acc metric: 500/500 (100.00%) AttnAUC: [] avg sec/iter: 0.0067 369 | 370 | 371 | Train set (epoch 71): [500/500 (100%)] Loss: 0.0059 (avg: 0.0060), other losses: [] Acc metric: 500/500 (100.00%) AttnAUC: [] avg sec/iter: 0.0078 372 | 373 | 374 | Train set (epoch 72): [500/500 (100%)] Loss: 0.0019 (avg: 0.0091), other losses: [] Acc metric: 500/500 (100.00%) AttnAUC: [] avg sec/iter: 0.0074 375 | 376 | 377 | Train set (epoch 73): [500/500 (100%)] Loss: 0.0046 (avg: 0.0028), other losses: [] Acc metric: 500/500 (100.00%) AttnAUC: [] avg sec/iter: 0.0075 378 | 379 | 380 | Train set (epoch 74): [500/500 (100%)] Loss: 0.0035 (avg: 0.0040), other losses: [] Acc metric: 500/500 (100.00%) AttnAUC: [] avg sec/iter: 0.0076 381 | 382 | 383 | Train set (epoch 75): [500/500 (100%)] Loss: 0.0064 (avg: 0.0055), other losses: [] Acc metric: 500/500 (100.00%) AttnAUC: [] avg sec/iter: 0.0075 384 | 385 | 386 | Train set (epoch 76): [500/500 (100%)] Loss: 0.0036 (avg: 0.0117), other losses: [] Acc metric: 500/500 (100.00%) AttnAUC: [] avg sec/iter: 0.0070 387 | 388 | 389 | Train set (epoch 77): [500/500 (100%)] Loss: 0.0032 (avg: 0.0028), other losses: [] Acc metric: 500/500 (100.00%) AttnAUC: [] avg sec/iter: 0.0053 390 | 391 | 392 | Train set (epoch 78): [500/500 (100%)] Loss: 0.0055 (avg: 0.0082), other losses: [] Acc metric: 500/500 (100.00%) AttnAUC: [] avg sec/iter: 0.0059 393 | 394 | 395 | Train set (epoch 79): [500/500 (100%)] Loss: 0.0239 (avg: 0.0055), other losses: [] Acc metric: 500/500 (100.00%) AttnAUC: [] avg sec/iter: 0.0068 396 | 397 | 398 | Train set (epoch 80): [500/500 (100%)] Loss: 0.0029 (avg: 0.0249), other losses: [] Acc metric: 499/500 (99.80%) AttnAUC: [] avg sec/iter: 0.0064 399 | 400 | 401 | Train set (epoch 81): [500/500 (100%)] Loss: 0.0073 (avg: 0.0033), other losses: [] Acc metric: 500/500 (100.00%) AttnAUC: [] avg sec/iter: 0.0058 402 | 403 | 404 | Train set (epoch 82): [500/500 (100%)] Loss: 0.0060 (avg: 0.0080), other losses: [] Acc metric: 500/500 (100.00%) AttnAUC: [] avg sec/iter: 0.0055 405 | 406 | 407 | Train set (epoch 83): [500/500 (100%)] Loss: 0.0521 (avg: 0.0741), other losses: [] Acc metric: 455/500 (91.00%) AttnAUC: [] avg sec/iter: 0.0059 408 | 409 | 410 | Train set (epoch 84): [500/500 (100%)] Loss: 0.0029 (avg: 0.0073), other losses: [] Acc metric: 500/500 (100.00%) AttnAUC: [] avg sec/iter: 0.0066 411 | 412 | 413 | Train set (epoch 85): [500/500 (100%)] Loss: 0.0017 (avg: 0.0022), other losses: [] Acc metric: 500/500 (100.00%) AttnAUC: [] avg sec/iter: 0.0067 414 | 415 | 416 | Train set (epoch 86): [500/500 (100%)] Loss: 0.0013 (avg: 0.0020), other losses: [] Acc metric: 500/500 (100.00%) AttnAUC: [] avg sec/iter: 0.0062 417 | 418 | 419 | Train set (epoch 87): [500/500 (100%)] Loss: 0.0027 (avg: 0.0020), other losses: [] Acc metric: 500/500 (100.00%) AttnAUC: [] avg sec/iter: 0.0059 420 | 421 | 422 | Train set (epoch 88): [500/500 (100%)] Loss: 0.0047 (avg: 0.0055), other losses: [] Acc metric: 500/500 (100.00%) AttnAUC: [] avg sec/iter: 0.0060 423 | 424 | 425 | Train set (epoch 89): [500/500 (100%)] Loss: 0.0012 (avg: 0.0061), other losses: [] Acc metric: 500/500 (100.00%) AttnAUC: [] avg sec/iter: 0.0072 426 | 427 | 428 | Train set (epoch 90): [500/500 (100%)] Loss: 0.0054 (avg: 0.0030), other losses: [] Acc metric: 500/500 (100.00%) AttnAUC: [] avg sec/iter: 0.0070 429 | 430 | 431 | Train set (epoch 91): [500/500 (100%)] Loss: 0.0017 (avg: 0.0017), other losses: [] Acc metric: 500/500 (100.00%) AttnAUC: [] avg sec/iter: 0.0072 432 | 433 | 434 | Train set (epoch 92): [500/500 (100%)] Loss: 0.0013 (avg: 0.0012), other losses: [] Acc metric: 500/500 (100.00%) AttnAUC: [] avg sec/iter: 0.0072 435 | 436 | 437 | Train set (epoch 93): [500/500 (100%)] Loss: 0.0020 (avg: 0.0012), other losses: [] Acc metric: 500/500 (100.00%) AttnAUC: [] avg sec/iter: 0.0076 438 | 439 | 440 | Train set (epoch 94): [500/500 (100%)] Loss: 0.0012 (avg: 0.0012), other losses: [] Acc metric: 500/500 (100.00%) AttnAUC: [] avg sec/iter: 0.0074 441 | 442 | 443 | Train set (epoch 95): [500/500 (100%)] Loss: 0.0010 (avg: 0.0012), other losses: [] Acc metric: 500/500 (100.00%) AttnAUC: [] avg sec/iter: 0.0066 444 | 445 | 446 | Train set (epoch 96): [500/500 (100%)] Loss: 0.0010 (avg: 0.0012), other losses: [] Acc metric: 500/500 (100.00%) AttnAUC: [] avg sec/iter: 0.0077 447 | 448 | 449 | Train set (epoch 97): [500/500 (100%)] Loss: 0.0012 (avg: 0.0012), other losses: [] Acc metric: 500/500 (100.00%) AttnAUC: [] avg sec/iter: 0.0078 450 | 451 | 452 | Train set (epoch 98): [500/500 (100%)] Loss: 0.0011 (avg: 0.0012), other losses: [] Acc metric: 500/500 (100.00%) AttnAUC: [] avg sec/iter: 0.0062 453 | 454 | 455 | Train set (epoch 99): [500/500 (100%)] Loss: 0.0019 (avg: 0.0012), other losses: [] Acc metric: 500/500 (100.00%) AttnAUC: [] avg sec/iter: 0.0046 456 | 457 | 458 | Train set (epoch 100): [500/500 (100%)] Loss: 0.0010 (avg: 0.0012), other losses: [] Acc metric: 500/500 (100.00%) AttnAUC: [] avg sec/iter: 0.0037 459 | 460 | 461 | saving the model to ./checkpoints//checkpoint_colors-3_828931_epoch100_seed0000111.pth.tar 462 | testing with evaluation of attention: takes longer time 463 | 100/500 samples processed 464 | 200/500 samples processed 465 | 300/500 samples processed 466 | 400/500 samples processed 467 | 500/500 samples processed 468 | lbl: 0, avg acc: 100.00% (58/58) 469 | lbl: 1, avg acc: 100.00% (56/56) 470 | lbl: 2, avg acc: 100.00% (52/52) 471 | lbl: 3, avg acc: 100.00% (56/56) 472 | lbl: 4, avg acc: 100.00% (53/53) 473 | lbl: 5, avg acc: 100.00% (49/49) 474 | lbl: 6, avg acc: 100.00% (49/49) 475 | lbl: 7, avg acc: 100.00% (30/30) 476 | lbl: 8, avg acc: 100.00% (33/33) 477 | lbl: 9, avg acc: 100.00% (35/35) 478 | lbl: 10, avg acc: 100.00% (29/29) 479 | 0 <= N_nodes <= 25 (min=4, max=25), avg acc: 100.00% (500/500) 480 | Train set (epoch 100): Avg loss: 0.0012, Acc metric: 500/500 (100.00%) AttnAUC: ['97.99'] avg sec/iter: 0.1890 481 | 482 | testing with evaluation of attention: takes longer time 483 | 100/7500 samples processed 484 | 200/7500 samples processed 485 | 300/7500 samples processed 486 | 400/7500 samples processed 487 | 500/7500 samples processed 488 | 600/7500 samples processed 489 | 700/7500 samples processed 490 | 800/7500 samples processed 491 | 900/7500 samples processed 492 | 1000/7500 samples processed 493 | 1100/7500 samples processed 494 | 1200/7500 samples processed 495 | 1300/7500 samples processed 496 | 1400/7500 samples processed 497 | 1500/7500 samples processed 498 | 1600/7500 samples processed 499 | 1700/7500 samples processed 500 | 1800/7500 samples processed 501 | 1900/7500 samples processed 502 | 2000/7500 samples processed 503 | 2100/7500 samples processed 504 | 2200/7500 samples processed 505 | 2300/7500 samples processed 506 | 2400/7500 samples processed 507 | 2500/7500 samples processed 508 | 2600/7500 samples processed 509 | 2700/7500 samples processed 510 | 2800/7500 samples processed 511 | 2900/7500 samples processed 512 | 3000/7500 samples processed 513 | 3100/7500 samples processed 514 | 3200/7500 samples processed 515 | 3300/7500 samples processed 516 | 3400/7500 samples processed 517 | 3500/7500 samples processed 518 | 3600/7500 samples processed 519 | 3700/7500 samples processed 520 | 3800/7500 samples processed 521 | 3900/7500 samples processed 522 | 4000/7500 samples processed 523 | 4100/7500 samples processed 524 | 4200/7500 samples processed 525 | 4300/7500 samples processed 526 | 4400/7500 samples processed 527 | 4500/7500 samples processed 528 | 4600/7500 samples processed 529 | 4700/7500 samples processed 530 | 4800/7500 samples processed 531 | 4900/7500 samples processed 532 | 5000/7500 samples processed 533 | 5100/7500 samples processed 534 | 5200/7500 samples processed 535 | 5300/7500 samples processed 536 | 5400/7500 samples processed 537 | 5500/7500 samples processed 538 | 5600/7500 samples processed 539 | 5700/7500 samples processed 540 | 5800/7500 samples processed 541 | 5900/7500 samples processed 542 | 6000/7500 samples processed 543 | 6100/7500 samples processed 544 | 6200/7500 samples processed 545 | 6300/7500 samples processed 546 | 6400/7500 samples processed 547 | 6500/7500 samples processed 548 | 6600/7500 samples processed 549 | 6700/7500 samples processed 550 | 6800/7500 samples processed 551 | 6900/7500 samples processed 552 | 7000/7500 samples processed 553 | 7100/7500 samples processed 554 | 7200/7500 samples processed 555 | 7300/7500 samples processed 556 | 7400/7500 samples processed 557 | 7500/7500 samples processed 558 | lbl: 0, avg acc: 100.00% (262/262) 559 | lbl: 1, avg acc: 100.00% (249/249) 560 | lbl: 2, avg acc: 100.00% (258/258) 561 | lbl: 3, avg acc: 100.00% (259/259) 562 | lbl: 4, avg acc: 100.00% (269/269) 563 | lbl: 5, avg acc: 100.00% (236/236) 564 | lbl: 6, avg acc: 100.00% (212/212) 565 | lbl: 7, avg acc: 100.00% (218/218) 566 | lbl: 8, avg acc: 100.00% (196/196) 567 | lbl: 9, avg acc: 100.00% (192/192) 568 | lbl: 10, avg acc: 100.00% (149/149) 569 | 0 <= N_nodes <= 25 (min=4, max=25), avg acc: 100.00% (2500/2500) 570 | lbl: 0, avg acc: 54.87% (124/226) 571 | lbl: 1, avg acc: 58.80% (127/216) 572 | lbl: 2, avg acc: 58.30% (130/223) 573 | lbl: 3, avg acc: 60.76% (144/237) 574 | lbl: 4, avg acc: 59.83% (137/229) 575 | lbl: 5, avg acc: 59.31% (137/231) 576 | lbl: 6, avg acc: 65.42% (157/240) 577 | lbl: 7, avg acc: 63.64% (154/242) 578 | lbl: 8, avg acc: 65.26% (139/213) 579 | lbl: 9, avg acc: 69.23% (153/221) 580 | lbl: 10, avg acc: 67.12% (149/222) 581 | 26 <= N_nodes <= 200 (min=26, max=200), avg acc: 62.04% (1551/2500) 582 | lbl: 0, avg acc: 12.81% (31/242) 583 | lbl: 1, avg acc: 14.10% (33/234) 584 | lbl: 2, avg acc: 14.67% (33/225) 585 | lbl: 3, avg acc: 13.30% (27/203) 586 | lbl: 4, avg acc: 12.79% (28/219) 587 | lbl: 5, avg acc: 14.78% (34/230) 588 | lbl: 6, avg acc: 12.77% (30/235) 589 | lbl: 7, avg acc: 16.96% (39/230) 590 | lbl: 8, avg acc: 14.85% (34/229) 591 | lbl: 9, avg acc: 16.67% (38/228) 592 | lbl: 10, avg acc: 19.11% (43/225) 593 | 26 <= N_nodes <= 200 (min=26, max=200), avg acc: 14.80% (370/2500) 594 | Test set (epoch 100): Avg loss: 1.8102, Acc metric: 4421/7500 (58.95%) AttnAUC: ['99.83'] avg sec/iter: 0.5009 595 | 596 | done in 0:00:50.659462 597 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import random 3 | import datetime 4 | from torchvision import transforms 5 | from graphdata import * 6 | from train_test import * 7 | import warnings 8 | 9 | warnings.filterwarnings("once") 10 | 11 | def parse_args(): 12 | parser = argparse.ArgumentParser(description='Run experiments with Graph Neural Networks') 13 | # Dataset 14 | parser.add_argument('-D', '--dataset', type=str, default='colors-3', 15 | choices=['colors-3', 'colors-4', 'colors-8', 'colors-16', 'colors-32', 16 | 'triangles', 'mnist', 'mnist-75sp', 'TU'], 17 | help='colors-n means the colors dataset with n-dimensional features; TU is any dataset from https://ls11-www.cs.tu-dortmund.de/staff/morris/graphkerneldatasets') 18 | parser.add_argument('-d', '--data_dir', type=str, default='./data', help='path to the dataset') 19 | # Hyperparameters 20 | parser.add_argument('--epochs', type=int, default=None, help='# of the epochs') 21 | parser.add_argument('--batch_size', type=int, default=32, help='batch size for training data') 22 | parser.add_argument('--lr', type=float, default=0.001, help='Learning Rate') 23 | parser.add_argument('--lr_decay_step', type=str, default=None, help='number of epochs after which to reduce learning rate') 24 | parser.add_argument('--wdecay', type=float, default=1e-4, help='weight decay') 25 | parser.add_argument('--dropout', type=float, default=0, help='dropout rate') 26 | parser.add_argument('-f', '--filters', type=str, default='64,64,64', help='number of filters in each graph layer') 27 | parser.add_argument('-K', '--filter_scale', type=int, default=1, help='filter scale (receptive field size), must be > 0; 1 for GCN or GIN') 28 | parser.add_argument('--n_hidden', type=int, default=0, help='number of hidden units inside the graph layer') 29 | parser.add_argument('--aggregation', type=str, default='mean', choices=['mean', 'sum'], help='neighbors aggregation inside the graph layer') 30 | parser.add_argument('--readout', type=str, default=None, choices=['mean', 'sum', 'max'], help='type of global pooling over all nodes') 31 | parser.add_argument('--kl_weight', type=float, default=100, help='weight of the KL term in the loss') 32 | parser.add_argument('--pool', type=str, default=None, help='type of pooling between layers, None for global pooling only') 33 | parser.add_argument('--pool_arch', type=str, default=None, help='pooling layers architecture defining whether to use fully-connected layers or GNN and to which layer to attach (e.g.: fc_prev, gnn_prev, fc_curr, gnn_curr, fc_prev_32)') 34 | parser.add_argument('--init', type=str, default='normal', choices=['normal', 'uniform'], help='distribution used for initialization for the attention model') 35 | parser.add_argument('--scale', type=str, default='1', help='initialized weights scale for the attention model, set to None to use PyTorch default init') 36 | parser.add_argument('--degree_feature', action='store_true', default=False, help='use degree features (only for the Triangles dataset)') 37 | # TU datasets arguments 38 | parser.add_argument('--n_nodes', type=int, default=25, help='maximum number of nodes in the training set for collab, proteins and dd (35 for collab, 25 for proteins, 200 or 300 for dd)') 39 | parser.add_argument('--cv_folds', type=int, default=5, help='number of folds for cross-validating hyperparameters for collab, proteins and dd (5 or 10 shows similar results, 5 is faster)') 40 | parser.add_argument('--cv_threads', type=int, default=5, help='number of parallel threads for cross-validation') 41 | parser.add_argument('--tune_init', action='store_true', default=False, help='do not tune initialization hyperparameters') 42 | parser.add_argument('--ax', action='store_true', default=False, help='use AX for hyperparameter optimization (recommended)') 43 | parser.add_argument('--ax_trials', type=int, default=30, help='number of AX trials (hyperparameters optimization steps)') 44 | parser.add_argument('--cv', action='store_true', default=False, help='run in the cross-validation mode') 45 | parser.add_argument('--seed_data', type=int, default=111, help='random seed for data splits') 46 | # Image datasets arguments 47 | parser.add_argument('--img_features', type=str, default='mean,coord', help='image features to use as node features') 48 | parser.add_argument('--img_noise_levels', type=str, default=None, 49 | help='Gaussian noise standard deviations for grayscale and color image features') 50 | # Auxiliary arguments 51 | parser.add_argument('--validation', action='store_true', default=False, help='run in the validation mode') 52 | parser.add_argument('--debug', action='store_true', default=False, help='evaluate on the test set after each epoch (only for visualization purposes)') 53 | parser.add_argument('--eval_attn_train', action='store_true', default=False, help='evaluate attention and save coefficients on the training set for models without learnable attention') 54 | parser.add_argument('--eval_attn_test', action='store_true', default=False, help='evaluate attention and save coefficients on the test set for models without learnable attention') 55 | parser.add_argument('--test_batch_size', type=int, default=100, help='batch size for test data') 56 | parser.add_argument('--alpha_ws', type=str, default=None, help='attention labels that will be used for (weak)supervision') 57 | parser.add_argument('--log_interval', type=int, default=400, help='print interval') 58 | parser.add_argument('--results', type=str, default='./results', help='directory to save model checkpoints and other results, set to None to prevent saving anything') 59 | parser.add_argument('--resume', type=str, default=None, help='checkpoint to load the model and optimzer states from and continue training') 60 | parser.add_argument('--device', type=str, default='cuda', choices=['cuda', 'cpu'], help='cuda/cpu') 61 | parser.add_argument('--seed', type=int, default=111, help='random seed for model parameters') 62 | parser.add_argument('--threads', type=int, default=0, help='number of threads for data loader') 63 | args = parser.parse_args() 64 | 65 | # Set default number of epochs and learning rate schedules and other hyperparameters 66 | if args.readout in [None, 'None']: 67 | args.readout = 'max' # global max pooling for all datasets except for COLORS 68 | set_default_lr_decay_step = args.lr_decay_step in [None, 'None'] 69 | if args.epochs in [None, 'None']: 70 | if args.dataset.find('mnist') >= 0: 71 | args.epochs = 30 72 | if set_default_lr_decay_step: 73 | args.lr_decay_step = '20,25' 74 | elif args.dataset == 'triangles': 75 | args.epochs = 100 76 | if set_default_lr_decay_step: 77 | args.lr_decay_step = '85,95' 78 | elif args.dataset == 'TU': 79 | args.epochs = 50 80 | if set_default_lr_decay_step: 81 | args.lr_decay_step = '25,35,45' 82 | elif args.dataset.find('color') >= 0: 83 | if args.readout in [None, 'None']: 84 | args.readout = 'sum' 85 | if args.pool in [None, 'None']: 86 | args.epochs = 100 87 | if set_default_lr_decay_step: 88 | args.lr_decay_step = '90' 89 | else: 90 | args.epochs = 300 91 | if set_default_lr_decay_step: 92 | args.lr_decay_step = '280' 93 | else: 94 | raise NotImplementedError(args.dataset) 95 | 96 | args.lr_decay_step = list(map(int, args.lr_decay_step.split(','))) 97 | args.filters = list(map(int, args.filters.split(','))) 98 | args.img_features = args.img_features.split(',') 99 | args.img_noise_levels = None if args.img_noise_levels in [None, 'None'] else list(map(float, args.img_noise_levels.split(','))) 100 | args.pool = None if args.pool in [None, 'None'] else args.pool.split('_') 101 | args.pool_arch = None if args.pool_arch in [None, 'None'] else args.pool_arch.split('_') 102 | try: 103 | args.scale = float(args.scale) 104 | except: 105 | args.scale = None 106 | 107 | args.torch = torch.__version__ 108 | 109 | for arg in vars(args): 110 | print(arg, getattr(args, arg)) 111 | 112 | return args 113 | 114 | 115 | def load_synthetic(args): 116 | train_dataset = SyntheticGraphs(args.data_dir, args.dataset, 'train', degree_feature=args.degree_feature, 117 | attn_coef=args.alpha_ws) 118 | test_dataset = SyntheticGraphs(args.data_dir, args.dataset, 'val' if args.validation else 'test', 119 | degree_feature=args.degree_feature) 120 | loss_fn = mse_loss 121 | collate_fn = collate_batch 122 | in_features = train_dataset.feature_dim 123 | out_features = 1 124 | return train_dataset, test_dataset, loss_fn, collate_fn, in_features, out_features 125 | 126 | 127 | def load_mnist(args): 128 | use_mean_px = 'mean' in args.img_features 129 | use_coord = 'coord' in args.img_features 130 | assert use_mean_px, ('this mode is not well supported', use_mean_px) 131 | gt_attn_threshold = 0 if (args.pool is not None and args.pool[1] in ['gt'] and args.filter_scale > 1) else 0.5 132 | if args.dataset == 'mnist': 133 | train_dataset = MNIST(args.data_dir, train=True, download=True, transform=transforms.ToTensor(), 134 | attn_coef=args.alpha_ws) 135 | else: 136 | train_dataset = MNIST75sp(args.data_dir, split='train', use_mean_px=use_mean_px, use_coord=use_coord, 137 | gt_attn_threshold=gt_attn_threshold, attn_coef=args.alpha_ws) 138 | 139 | noises, color_noises = None, None 140 | if args.validation: 141 | n_val = 5000 142 | if args.dataset == 'mnist': 143 | train_dataset.train_data = train_dataset.train_data[:-n_val] 144 | train_dataset.train_labels = train_dataset.train_labels[:-n_val] 145 | test_dataset = MNIST(args.data_dir, train=True, download=True, transform=transforms.ToTensor()) 146 | test_dataset.train_data = train_dataset.train_data[-n_val:] 147 | test_dataset.train_labels = train_dataset.train_labels[-n_val:] 148 | else: 149 | train_dataset.train_val_split(np.arange(0, train_dataset.n_samples - n_val)) 150 | test_dataset = MNIST75sp(args.data_dir, split='train', use_mean_px=use_mean_px, use_coord=use_coord, 151 | gt_attn_threshold=gt_attn_threshold) 152 | test_dataset.train_val_split(np.arange(train_dataset.n_samples - n_val, train_dataset.n_samples)) 153 | else: 154 | noise_file = pjoin(args.data_dir, '%s_noise.pt' % args.dataset.replace('-', '_')) 155 | color_noise_file = pjoin(args.data_dir, '%s_color_noise.pt' % args.dataset.replace('-', '_')) 156 | if args.dataset == 'mnist': 157 | test_dataset = MNIST(args.data_dir, train=False, download=True, transform=transforms.ToTensor()) 158 | noise_shape = (len(test_dataset.test_labels), 28 * 28) 159 | else: 160 | test_dataset = MNIST75sp(args.data_dir, split='test', use_mean_px=use_mean_px, use_coord=use_coord, 161 | gt_attn_threshold=gt_attn_threshold) 162 | noise_shape = (len(test_dataset.labels), 75) 163 | 164 | # Generate/load noise (save it to make reproducible) 165 | noises = load_save_noise(noise_file, noise_shape) 166 | color_noises = load_save_noise(color_noise_file, (noise_shape[0], noise_shape[1], 3)) 167 | 168 | if args.dataset == 'mnist': 169 | A, coord, mask = precompute_graph_images(train_dataset.train_data.shape[1]) 170 | collate_fn = lambda batch: collate_batch_images(batch, A, mask, use_mean_px=use_mean_px, 171 | coord=coord if use_coord else None, 172 | gt_attn_threshold=gt_attn_threshold, 173 | replicate_features=args.img_noise_levels is not None) 174 | else: 175 | train_dataset.precompute_graph_data(replicate_features=args.img_noise_levels is not None, threads=12) 176 | test_dataset.precompute_graph_data(replicate_features=args.img_noise_levels is not None, threads=12) 177 | collate_fn = collate_batch 178 | 179 | loss_fn = F.cross_entropy 180 | 181 | in_features = 0 if args.img_noise_levels is None else 2 182 | for features in args.img_features: 183 | if features == 'mean': 184 | in_features += 1 185 | elif features == 'coord': 186 | in_features += 2 187 | else: 188 | raise NotImplementedError(features) 189 | in_features = np.max((in_features, 1)) # in_features=1 if neither mean nor coord are used (dummy features will be used in this case) 190 | out_features = 10 191 | 192 | return train_dataset, test_dataset, loss_fn, collate_fn, in_features, out_features, noises, color_noises 193 | 194 | 195 | def load_TU(args, cv_folds=5): 196 | loss_fn = F.cross_entropy 197 | collate_fn = collate_batch 198 | scale, init = args.scale, args.init 199 | n_hidden_attn = float(args.pool_arch[2]) if (args.pool_arch is not None and len(args.pool_arch) > 2) else 0 200 | if args.pool is None: 201 | # Global pooling models 202 | datareader = DataReader(data_dir=args.data_dir, N_nodes=args.n_nodes, rnd_state=rnd_data, folds=0) 203 | train_dataset = GraphData(datareader, None, 'train_val') 204 | test_dataset = GraphData(datareader, None, 'test') 205 | in_features = train_dataset.num_features 206 | out_features = train_dataset.num_classes 207 | pool = args.pool 208 | kl_weight = args.kl_weight 209 | elif args.pool[1] == 'gt': 210 | raise ValueError('ground truth attention for TU datasets is not available') 211 | elif args.pool[1] in ['sup', 'unsup']: 212 | datareader = DataReader(data_dir=args.data_dir, N_nodes=args.n_nodes, rnd_state=rnd_data, folds=cv_folds) 213 | if args.ax: 214 | # Cross-validation using Ax (recommended way), Python3 must be used 215 | best_parameters = ax_optimize(datareader, args, collate_fn, loss_fn, None, folds=cv_folds, 216 | threads=args.cv_threads, n_trials=args.ax_trials) 217 | pool = args.pool 218 | kl_weight = best_parameters['kl_weight'] 219 | if args.tune_init: 220 | scale, init = best_parameters['scale'], best_parameters['init'] 221 | n_hidden_attn, layer = best_parameters['n_hidden_attn'], 1 222 | if layer == 0: 223 | pool = copy.deepcopy(args.pool) 224 | del pool[3] 225 | 226 | pool = set_pool(best_parameters['pool'], pool) 227 | 228 | else: 229 | if not args.cv: 230 | # Run with some fixed parameters without cross-validation 231 | pool_thresh_values = np.array([float(args.pool[-1])]) 232 | n_hiddens = [n_hidden_attn] 233 | layers = [1] 234 | elif args.debug: 235 | pool_thresh_values = np.array([1e-4, 1e-1]) 236 | n_hiddens = [n_hidden_attn] 237 | layers = [1] 238 | else: 239 | # Cross-validation using grid search (not recommended, since it's time consuming and not effective 240 | if args.data_dir.lower().find('proteins') >= 0: 241 | pool_thresh_values = np.array([2e-3, 5e-3, 1e-2, 3e-2, 5e-2]) 242 | elif args.data_dir.lower().find('dd') >= 0: 243 | pool_thresh_values = np.array([1e-4, 1e-3, 2e-3, 5e-3, 1e-2, 3e-2, 5e-2, 1e-1]) 244 | elif args.data_dir.lower().find('collab') >= 0: 245 | pool_thresh_values = np.array([1e-3, 2e-3, 5e-3, 1e-2, 3e-2, 5e-2, 1e-1]) 246 | else: 247 | raise NotImplementedError('this dataset is not supported currently') 248 | n_hiddens = np.array([0, 32]) # hidden units in the atention model 249 | layers = np.array([0, 1]) # layer where to attach the attention model 250 | 251 | if args.pool[1] == 'sup' and not args.debug and args.cv: 252 | kl_weight_values = np.array([0.25, 1, 2, 10]) 253 | else: 254 | kl_weight_values = np.array([args.kl_weight]) # any value (ignored for unsupervised training) 255 | 256 | 257 | if len(pool_thresh_values) > 1 or len(kl_weight_values) > 1 or len(n_hiddens) > 1 or len(layers) > 1: 258 | val_acc = np.zeros((len(layers), len(n_hiddens), len(pool_thresh_values), len(kl_weight_values))) 259 | for i_, layer in enumerate(layers): 260 | if layer == 0: 261 | pool = copy.deepcopy(args.pool) 262 | del pool[3] 263 | else: 264 | pool = args.pool 265 | for j_, n_hidden_attn in enumerate(n_hiddens): 266 | for k_, pool_thresh in enumerate(pool_thresh_values): 267 | for m_, kl_weight in enumerate(kl_weight_values): 268 | val_acc[i_, j_, k_, m_] = \ 269 | cross_validation(datareader, args, collate_fn, loss_fn, set_pool(pool_thresh, pool), 270 | kl_weight, None, n_hidden_attn=n_hidden_attn, folds=cv_folds, threads=args.cv_threads) 271 | ind1, ind2, ind3, ind4 = np.where(val_acc == np.max(val_acc)) # np.argmax returns only first occurrence 272 | print(val_acc) 273 | print(ind1, ind2, ind3, ind4, layers[ind1], n_hiddens[ind2], pool_thresh_values[ind3], kl_weight_values[ind4], 274 | val_acc[ind1[0], ind2[0], ind3[0], ind4[0]]) 275 | 276 | layer = layers[ind1[0]] 277 | if layer == 0: 278 | pool = copy.deepcopy(args.pool) 279 | del pool[3] 280 | else: 281 | pool = args.pool 282 | n_hidden_attn = n_hiddens[ind2[0]] 283 | pool = set_pool(pool_thresh_values[ind3[0]], pool) 284 | kl_weight = kl_weight_values[ind4[0]] 285 | else: 286 | pool = args.pool 287 | kl_weight = args.kl_weight 288 | 289 | train_dataset = GraphData(datareader, None, 'train_val') 290 | test_dataset = GraphData(datareader, None, 'test') 291 | in_features = train_dataset.num_features 292 | out_features = train_dataset.num_classes 293 | 294 | if args.pool[1] == 'sup': 295 | # Train a model with global pooling first 296 | train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.threads, 297 | collate_fn=collate_fn) 298 | train_loader_test = DataLoader(train_dataset, batch_size=args.test_batch_size, shuffle=False, 299 | num_workers=args.threads, collate_fn=collate_fn) 300 | # Train global pooling model 301 | start_epoch, model, optimizer, scheduler = create_model_optimizer(in_features, out_features, None, kl_weight, 302 | args, scale=scale, init=init, n_hidden_attn=n_hidden_attn) 303 | for epoch in range(start_epoch, args.epochs + 1): 304 | scheduler.step() 305 | train_loss, acc = train(model, train_loader, optimizer, epoch, args, loss_fn, None) 306 | train_loss, train_acc, attn_WS = test(model, train_loader_test, epoch, loss_fn, 'train', args, None, 307 | eval_attn=True)[:3] 308 | train_dataset = GraphData(datareader, None, 'train_val', attn_labels=attn_WS) 309 | else: 310 | raise NotImplementedError(args.pool) 311 | 312 | return train_dataset, test_dataset, loss_fn, collate_fn, in_features, out_features, pool, kl_weight, scale, init, n_hidden_attn 313 | 314 | 315 | if __name__ == '__main__': 316 | 317 | # mp.set_start_method('spawn') 318 | dt = datetime.datetime.now() 319 | print('start time:', dt) 320 | args = parse_args() 321 | args.experiment_ID = '%06d' % dt.microsecond 322 | print('experiment_ID: ', args.experiment_ID) 323 | 324 | if args.cv_threads > 1 and args.dataset == 'TU': 325 | # this requires python3 326 | torch.multiprocessing.set_start_method('spawn') 327 | 328 | print('gpus: ', torch.cuda.device_count()) 329 | 330 | if args.results not in [None, 'None'] and not os.path.isdir(args.results): 331 | os.mkdir(args.results) 332 | 333 | rnd, rnd_data = set_seed(args.seed, args.seed_data) 334 | 335 | pool = args.pool 336 | kl_weight = args.kl_weight 337 | scale = args.scale 338 | init = args.init 339 | n_hidden_attn = float(args.pool_arch[2]) if (args.pool_arch is not None and len(args.pool_arch) > 2) else 0 340 | if args.dataset.find('colors') >= 0 or args.dataset == 'triangles': 341 | train_dataset, test_dataset, loss_fn, collate_fn, in_features, out_features = load_synthetic(args) 342 | elif args.dataset in ['mnist', 'mnist-75sp']: 343 | train_dataset, test_dataset, loss_fn, collate_fn, in_features, out_features, noises, color_noises = load_mnist(args) 344 | 345 | else: 346 | train_dataset, test_dataset, loss_fn, collate_fn, in_features, out_features, pool, kl_weight, scale, init, n_hidden_attn = \ 347 | load_TU(args, cv_folds=args.cv_folds) 348 | 349 | 350 | train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.threads, 351 | collate_fn=collate_fn) 352 | # A loader to test and evaluate attn on the training set (shouldn't be shuffled and have larger batch size multiple of 50) 353 | train_loader_test = DataLoader(train_dataset, batch_size=args.test_batch_size, shuffle=False, num_workers=args.threads, collate_fn=collate_fn) 354 | print('test_dataset', test_dataset.split) 355 | test_loader = DataLoader(test_dataset, batch_size=args.test_batch_size, shuffle=False, 356 | num_workers=args.threads, collate_fn=collate_fn) 357 | 358 | start_epoch, model, optimizer, scheduler = create_model_optimizer(in_features, out_features, pool, kl_weight, args, 359 | scale=scale, init=init, n_hidden_attn=n_hidden_attn) 360 | 361 | feature_stats = None 362 | if args.dataset in ['mnist', 'mnist-75sp']: 363 | feature_stats = compute_feature_stats(model, train_loader, args.device, n_batches=1000) 364 | 365 | # Test function wrapper 366 | def test_fn(loader, epoch, split, eval_attn): 367 | test_loss, acc, _, _ = test(model, loader, epoch, loss_fn, split, args, feature_stats, 368 | noises=None, img_noise_level=None, eval_attn=eval_attn, alpha_WS_name='orig') 369 | if args.dataset in ['mnist', 'mnist-75sp'] and split == 'test' and args.img_noise_levels is not None: 370 | test(model, loader, epoch, loss_fn, split, args, feature_stats, 371 | noises=noises, img_noise_level=args.img_noise_levels[0], eval_attn=eval_attn, alpha_WS_name='noisy') 372 | test(model, loader, epoch, loss_fn, split, args, feature_stats, 373 | noises=color_noises, img_noise_level=args.img_noise_levels[1], eval_attn=eval_attn, alpha_WS_name='noisy-c') 374 | return test_loss, acc 375 | 376 | if start_epoch > args.epochs: 377 | print('evaluating the model') 378 | test_fn(test_loader, start_epoch - 1, 'val' if args.validation else 'test', args.eval_attn_test) 379 | else: 380 | for epoch in range(start_epoch, args.epochs + 1): 381 | eval_epoch = epoch <= 1 or epoch == args.epochs # check for epoch == 1 just to make sure that the test function works fine for this test set before training all the way until the last epoch 382 | scheduler.step() 383 | train_loss, acc = train(model, train_loader, optimizer, epoch, args, loss_fn, feature_stats) 384 | if eval_epoch: 385 | save_checkpoint(model, scheduler, optimizer, args, epoch) 386 | # Report Training accuracy and other metrics on the training set 387 | test_fn(train_loader_test, epoch, 'train', (epoch == args.epochs) and args.eval_attn_train) 388 | 389 | if args.validation: 390 | test_fn(test_loader, epoch, 'val', (epoch == args.epochs) and args.eval_attn_test) 391 | elif eval_epoch or args.debug: 392 | test_fn(test_loader, epoch, 'test', (epoch == args.epochs) and args.eval_attn_test) 393 | 394 | print('done in {}'.format(datetime.datetime.now() - dt)) 395 | -------------------------------------------------------------------------------- /notebooks/TRIANGLES_eval_models.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": { 7 | "collapsed": true 8 | }, 9 | "outputs": [], 10 | "source": [ 11 | "import matplotlib as mpl\n", 12 | "import matplotlib.pyplot as plt\n", 13 | "import numpy as np\n", 14 | "import pickle\n", 15 | "from skimage.segmentation import slic\n", 16 | "import scipy.ndimage\n", 17 | "import scipy.spatial\n", 18 | "import torch\n", 19 | "from torchvision import datasets\n", 20 | "from torchvision import datasets\n", 21 | "import sys\n", 22 | "sys.path.append(\"../\")\n", 23 | "from chebygin import ChebyGIN\n", 24 | "from extract_superpixels import process_image\n", 25 | "from graphdata import comput_adjacency_matrix_images\n", 26 | "from train_test import load_save_noise\n", 27 | "from utils import list_to_torch, data_to_device, normalize_zero_one" 28 | ] 29 | }, 30 | { 31 | "cell_type": "markdown", 32 | "metadata": {}, 33 | "source": [ 34 | "# TRIANGLES" 35 | ] 36 | }, 37 | { 38 | "cell_type": "code", 39 | "execution_count": 2, 40 | "metadata": {}, 41 | "outputs": [ 42 | { 43 | "name": "stdout", 44 | "output_type": "stream", 45 | "text": [ 46 | "dict_keys(['Adj_matrices', 'GT_attn', 'graph_labels', 'N_edges', 'Max_degree'])\n" 47 | ] 48 | } 49 | ], 50 | "source": [ 51 | "data_dir = '/scratch/ssd/data/graph_attention_pool/'\n", 52 | "checkpoints_dir = '../checkpoints'\n", 53 | "device = 'cuda'\n", 54 | "\n", 55 | "with open('%s/random_graphs_triangles_test.pkl' % data_dir, 'rb') as f:\n", 56 | " data = pickle.load(f)\n", 57 | " \n", 58 | "print(data.keys())\n", 59 | "targets = torch.from_numpy(data['graph_labels']).long()\n", 60 | "Node_degrees = [np.sum(A, 1).astype(np.int32) for A in data['Adj_matrices']]\n", 61 | "\n", 62 | "feature_dim = data['Max_degree'] + 1\n", 63 | "node_features = []\n", 64 | "for i in range(len(data['Adj_matrices'])):\n", 65 | " N = data['Adj_matrices'][i].shape[0]\n", 66 | " D_onehot = np.zeros((N, feature_dim ))\n", 67 | " D_onehot[np.arange(N), Node_degrees[i]] = 1\n", 68 | " node_features.append(D_onehot)" 69 | ] 70 | }, 71 | { 72 | "cell_type": "code", 73 | "execution_count": 3, 74 | "metadata": { 75 | "collapsed": true 76 | }, 77 | "outputs": [], 78 | "source": [ 79 | "def acc(pred):\n", 80 | " n = len(pred)\n", 81 | " return torch.mean((torch.stack(pred).view(n) == targets[:len(pred)].view(n)).float()).item() * 100\n", 82 | "\n", 83 | "def test(model, index, show_img=False): \n", 84 | " N_nodes = data['Adj_matrices'][index].shape[0]\n", 85 | " mask = torch.ones(1, N_nodes, dtype=torch.uint8)\n", 86 | " x = torch.from_numpy(node_features[index]).unsqueeze(0).float() \n", 87 | " A = torch.from_numpy(data['Adj_matrices'][index].astype(np.float32)).float().unsqueeze(0)\n", 88 | " y, other_outputs = model(data_to_device([x, A, mask, -1, {'N_nodes': torch.zeros(1, 1) + N_nodes}], \n", 89 | " device)) \n", 90 | " y = y.round().long().data.cpu()[0][0]\n", 91 | " alpha = other_outputs['alpha'][0].data.cpu() if 'alpha' in other_outputs else [] \n", 92 | " return y, alpha\n", 93 | "\n", 94 | "\n", 95 | "# This function returns predictions for the entire clean and noise test sets\n", 96 | "def get_predictions(model_path):\n", 97 | " state = torch.load(model_path)\n", 98 | " args = state['args']\n", 99 | " model = ChebyGIN(in_features=14,\n", 100 | " out_features=1,\n", 101 | " filters=args.filters,\n", 102 | " K=args.filter_scale,\n", 103 | " n_hidden=args.n_hidden,\n", 104 | " aggregation=args.aggregation,\n", 105 | " dropout=args.dropout,\n", 106 | " readout=args.readout,\n", 107 | " pool=args.pool,\n", 108 | " pool_arch=args.pool_arch)\n", 109 | " model.load_state_dict(state['state_dict'])\n", 110 | " model = model.eval().to(device)\n", 111 | "# print(model) \n", 112 | "\n", 113 | " # Get predictions\n", 114 | " pred, alpha = [], []\n", 115 | " for index in range(len(data['Adj_matrices'])):\n", 116 | " y = test(model, index, index == 0)\n", 117 | " pred.append(y[0])\n", 118 | " alpha.append(y[1])\n", 119 | " if len(pred) % 1000 == 0:\n", 120 | " print('{}/{}, acc on the combined test set={:.2f}%'.format(len(pred), len(data['Adj_matrices']), acc(pred)))\n", 121 | " return pred, alpha" 122 | ] 123 | }, 124 | { 125 | "cell_type": "markdown", 126 | "metadata": {}, 127 | "source": [ 128 | "## Weakly-supervised attention model" 129 | ] 130 | }, 131 | { 132 | "cell_type": "code", 133 | "execution_count": 4, 134 | "metadata": { 135 | "scrolled": true 136 | }, 137 | "outputs": [ 138 | { 139 | "name": "stdout", 140 | "output_type": "stream", 141 | "text": [ 142 | "ChebyGINLayer torch.Size([64, 98]) tensor([0.5568, 0.5545, 0.5580, 0.5656, 0.5318, 0.5698, 0.5655, 0.5937, 0.6087,\n", 143 | " 0.5437], grad_fn=)\n", 144 | "ChebyGINLayer torch.Size([32, 128]) tensor([0.5730, 0.5968, 0.5778, 0.5940, 0.5981, 0.5787, 0.5619, 0.5798, 0.5741,\n", 145 | " 0.5833], grad_fn=)\n", 146 | "ChebyGINLayer torch.Size([32, 64]) tensor([0.5703, 0.5380, 0.5825, 0.5836, 0.5649, 0.5537, 0.6568, 0.6129, 0.6161,\n", 147 | " 0.5258], grad_fn=)\n", 148 | "ChebyGINLayer torch.Size([1, 64]) tensor([0.5634], grad_fn=)\n", 149 | "ChebyGINLayer torch.Size([64, 448]) tensor([0.5923, 0.5840, 0.5608, 0.5615, 0.5799, 0.5668, 0.5924, 0.5840, 0.5709,\n", 150 | " 0.5637], grad_fn=)\n", 151 | "ChebyGINLayer torch.Size([32, 128]) tensor([0.5606, 0.5821, 0.5540, 0.5596, 0.6033, 0.6147, 0.5738, 0.5865, 0.5981,\n", 152 | " 0.5800], grad_fn=)\n", 153 | "ChebyGINLayer torch.Size([32, 64]) tensor([0.5938, 0.6073, 0.5995, 0.5230, 0.6091, 0.6070, 0.5901, 0.5752, 0.5594,\n", 154 | " 0.5499], grad_fn=)\n", 155 | "ChebyGINLayer torch.Size([1, 64]) tensor([0.6102], grad_fn=)\n", 156 | "ChebyGINLayer torch.Size([64, 448]) tensor([0.5877, 0.5797, 0.5591, 0.5688, 0.5758, 0.5645, 0.5483, 0.5846, 0.5883,\n", 157 | " 0.5961], grad_fn=)\n", 158 | "1000/10000, acc on the combined test set=83.00%\n", 159 | "2000/10000, acc on the combined test set=76.30%\n", 160 | "3000/10000, acc on the combined test set=72.23%\n", 161 | "4000/10000, acc on the combined test set=68.73%\n", 162 | "5000/10000, acc on the combined test set=66.82%\n", 163 | "6000/10000, acc on the combined test set=59.90%\n", 164 | "7000/10000, acc on the combined test set=55.04%\n", 165 | "8000/10000, acc on the combined test set=51.60%\n", 166 | "9000/10000, acc on the combined test set=49.02%\n", 167 | "10000/10000, acc on the combined test set=46.69%\n" 168 | ] 169 | } 170 | ], 171 | "source": [ 172 | "pred, alpha = get_predictions('%s/checkpoint_triangles_230187_epoch100_seed0000111.pth.tar' % checkpoints_dir)" 173 | ] 174 | } 175 | ], 176 | "metadata": { 177 | "kernelspec": { 178 | "display_name": "Python 3", 179 | "language": "python", 180 | "name": "python3" 181 | }, 182 | "language_info": { 183 | "codemirror_mode": { 184 | "name": "ipython", 185 | "version": 3 186 | }, 187 | "file_extension": ".py", 188 | "mimetype": "text/x-python", 189 | "name": "python", 190 | "nbconvert_exporter": "python", 191 | "pygments_lexer": "ipython3", 192 | "version": "3.6.7" 193 | } 194 | }, 195 | "nbformat": 4, 196 | "nbformat_minor": 2 197 | } 198 | -------------------------------------------------------------------------------- /notebooks/convert2TU.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import numpy as np\n", 10 | "import pickle\n", 11 | "import os" 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": 2, 17 | "metadata": {}, 18 | "outputs": [], 19 | "source": [ 20 | "data_dir = '../data/'" 21 | ] 22 | }, 23 | { 24 | "cell_type": "code", 25 | "execution_count": 3, 26 | "metadata": {}, 27 | "outputs": [], 28 | "source": [ 29 | "def write_data_TU(data, out_path, dim_test=None):\n", 30 | " c = 0\n", 31 | " nodes = 1\n", 32 | " with open('%s_A.txt' % out_path, 'w') as f:\n", 33 | " for A in data['Adj_matrices']:\n", 34 | " N = A.shape[0]\n", 35 | " for i in range(N):\n", 36 | " for j in range(N):\n", 37 | " if A[i, j] > 0:\n", 38 | " f.write('%d, %d\\n' % (i + nodes, j + nodes))\n", 39 | " nodes += N\n", 40 | " c += np.sum(A) / 2\n", 41 | " print(c)\n", 42 | "\n", 43 | " c = 1\n", 44 | " with open('%s_graph_indicator.txt' % out_path, 'w') as f:\n", 45 | " for A in data['Adj_matrices']:\n", 46 | " N = A.shape[0]\n", 47 | " for j in range(N):\n", 48 | " f.write('%d\\n' % c)\n", 49 | " c += 1\n", 50 | " print(c)\n", 51 | "\n", 52 | " with open('%s_graph_attributes.txt' % out_path, 'w') as f:\n", 53 | " for lbl in data['graph_labels']:\n", 54 | " f.write('%d\\n' % lbl)\n", 55 | "\n", 56 | " \n", 57 | " with open('%s_node_attributes.txt' % out_path, 'w') as f:\n", 58 | " for node_id in range(len(data['GT_attn'])):\n", 59 | " attn = data['GT_attn'][node_id]\n", 60 | " N = attn.shape[0]\n", 61 | " for i in range(N):\n", 62 | " f.write('%d' % attn[i])\n", 63 | " if dim_test is not None:\n", 64 | " f.write(', ')\n", 65 | " x = data['node_features'][node_id]\n", 66 | " for j in range(dim_test): \n", 67 | " if j < x.shape[1]: \n", 68 | " f.write('%d' % x[i, j])\n", 69 | " else:\n", 70 | " f.write('0')\n", 71 | " if j < dim_test - 1:\n", 72 | " f.write(', ') \n", 73 | " f.write('\\n') \n", 74 | " \n", 75 | "# with open('%s_node_attention.txt' % out_path, 'w') as f:\n", 76 | "# for x in data['GT_attn']:\n", 77 | "# N = x.shape[0]\n", 78 | "# for i in range(N):\n", 79 | "# f.write('%d\\n' % x[i])" 80 | ] 81 | }, 82 | { 83 | "cell_type": "code", 84 | "execution_count": 4, 85 | "metadata": {}, 86 | "outputs": [ 87 | { 88 | "name": "stdout", 89 | "output_type": "stream", 90 | "text": [ 91 | "train Adj_matrices 500\n", 92 | "train GT_attn 500\n", 93 | "train graph_labels 500\n", 94 | "train node_features 500\n", 95 | "train N_edges 500\n", 96 | "\n", 97 | "\n", 98 | "val Adj_matrices 3000\n", 99 | "val GT_attn 3000\n", 100 | "val graph_labels 3000\n", 101 | "val node_features 3000\n", 102 | "val N_edges 3000\n", 103 | "\n", 104 | "\n", 105 | "test Adj_matrices 10500\n", 106 | "test GT_attn 10500\n", 107 | "test graph_labels 10500\n", 108 | "test node_features 10500\n", 109 | "test N_edges 10500\n", 110 | "\n", 111 | "\n", 112 | "955788.0\n", 113 | "10501\n", 114 | "10500 61.31171428571429 60.52053469623652 4 200\n", 115 | "10500 91.02742857142857 93.66902222113747 4 397\n" 116 | ] 117 | } 118 | ], 119 | "source": [ 120 | "for dim in [3]:\n", 121 | " dim_test = dim + 1\n", 122 | " data = {}\n", 123 | " out_dir = '%s/COLORS-%d' % (data_dir, dim)\n", 124 | " try:\n", 125 | " os.mkdir(out_dir)\n", 126 | " except Exception as e:\n", 127 | " print(e)\n", 128 | "\n", 129 | " for split in ['train', 'val', 'test']:\n", 130 | " \n", 131 | " with open('%s/random_graphs_colors_dim%d_%s.pkl' % (data_dir, dim, split), 'rb') as f:\n", 132 | " data_tmp = pickle.load(f) \n", 133 | " for key in data_tmp: \n", 134 | " if split == 'train':\n", 135 | " data[key] = data_tmp[key]\n", 136 | " else:\n", 137 | " if isinstance(data[key], list):\n", 138 | " data[key].extend(data_tmp[key])\n", 139 | " else:\n", 140 | " data[key] = np.concatenate((data[key], data_tmp[key]))\n", 141 | " print(split, key, len(data[key]))\n", 142 | " print('\\n')\n", 143 | " write_data_TU(data, '%s/COLORS-%d' % (out_dir, dim), dim_test)\n", 144 | " \n", 145 | "nodes = [A.shape[0] for A in data['Adj_matrices']]\n", 146 | "edges = [np.sum(A) // 2 for A in data['Adj_matrices']]\n", 147 | "print(len(nodes), np.mean(nodes), np.std(nodes), np.min(nodes), np.max(nodes))\n", 148 | "print(len(edges), np.mean(edges), np.std(edges), np.min(edges), np.max(edges)) " 149 | ] 150 | }, 151 | { 152 | "cell_type": "code", 153 | "execution_count": 5, 154 | "metadata": {}, 155 | "outputs": [ 156 | { 157 | "name": "stdout", 158 | "output_type": "stream", 159 | "text": [ 160 | "\n", 161 | "\n", 162 | "736732.0\n", 163 | "30001\n", 164 | "[Errno 17] File exists: '../data//TRIANGLES'\n", 165 | "val Adj_matrices 35000\n", 166 | "val GT_attn 35000\n", 167 | "val graph_labels 35000\n", 168 | "val N_edges 35000\n", 169 | "val Max_degree 14\n", 170 | "\n", 171 | "\n", 172 | "859328.0\n", 173 | "35001\n", 174 | "[Errno 17] File exists: '../data//TRIANGLES'\n", 175 | "test Adj_matrices 45000\n", 176 | "test GT_attn 45000\n", 177 | "test graph_labels 45000\n", 178 | "test N_edges 45000\n", 179 | "test Max_degree 14\n", 180 | "\n", 181 | "\n", 182 | "1473512.0\n", 183 | "45001\n", 184 | "45000 20.854177777777778 17.460003254790188 4 100\n", 185 | "45000 32.74471111111111 28.14843482573703 4 198\n" 186 | ] 187 | } 188 | ], 189 | "source": [ 190 | "data = {}\n", 191 | "for split in ['train', 'val', 'test']:\n", 192 | " out_dir = '%s/TRIANGLES' % (data_dir)\n", 193 | " try:\n", 194 | " os.mkdir(out_dir)\n", 195 | " except Exception as e:\n", 196 | " print(e)\n", 197 | " with open('%s/random_graphs_triangles_%s.pkl' % (data_dir, split), 'rb') as f:\n", 198 | " data_tmp = pickle.load(f)\n", 199 | " for key in data_tmp:\n", 200 | " if split == 'train':\n", 201 | " data[key] = data_tmp[key]\n", 202 | " else:\n", 203 | " if key == 'Max_degree':\n", 204 | " print(split, key, data[key])\n", 205 | " data[key] = np.max((data[key], data_tmp[key]))\n", 206 | " else:\n", 207 | " if isinstance(data[key], list):\n", 208 | " data[key].extend(data_tmp[key])\n", 209 | " else:\n", 210 | " data[key] = np.concatenate((data[key], data_tmp[key]))\n", 211 | " print(split, key, len(data[key]))\n", 212 | " print('\\n')\n", 213 | " write_data_TU(data, '%s/TRIANGLES' % (out_dir))\n", 214 | " \n", 215 | "nodes = [A.shape[0] for A in data['Adj_matrices']]\n", 216 | "edges = [np.sum(A) // 2 for A in data['Adj_matrices']]\n", 217 | "print(len(nodes), np.mean(nodes), np.std(nodes), np.min(nodes), np.max(nodes))\n", 218 | "print(len(edges), np.mean(edges), np.std(edges), np.min(edges), np.max(edges)) " 219 | ] 220 | }, 221 | { 222 | "cell_type": "code", 223 | "execution_count": null, 224 | "metadata": { 225 | "collapsed": true 226 | }, 227 | "outputs": [], 228 | "source": [] 229 | } 230 | ], 231 | "metadata": { 232 | "kernelspec": { 233 | "display_name": "Python 3", 234 | "language": "python", 235 | "name": "python3" 236 | }, 237 | "language_info": { 238 | "codemirror_mode": { 239 | "name": "ipython", 240 | "version": 3 241 | }, 242 | "file_extension": ".py", 243 | "mimetype": "text/x-python", 244 | "name": "python", 245 | "nbconvert_exporter": "python", 246 | "pygments_lexer": "ipython3", 247 | "version": "3.7.3" 248 | } 249 | }, 250 | "nbformat": 4, 251 | "nbformat_minor": 2 252 | } 253 | -------------------------------------------------------------------------------- /scripts/colors.sh: -------------------------------------------------------------------------------- 1 | dataset="colors-3" 2 | seed=111 3 | checkpoints_dir=./checkpoints/ 4 | params="-D $dataset --test_batch_size 100 -K 2 -f 64,64 --aggregation mean --n_hidden 0 --readout sum --dropout 0 --threads 0 --pool_arch fc_prev --seed $seed --results $checkpoints_dir -d ./data" 5 | 6 | logs_dir=./logs/ 7 | 8 | # Global pooling 9 | python main.py $params --eval_attn_train --eval_attn_test --epochs 100 --lr_decay_step 90 | tee $logs_dir/"$dataset"_global_max_seed"$seed".log; 10 | 11 | # Unsupervised attention 12 | thresh=0.03 13 | pool=unsup 14 | python main.py $params --pool attn_"$pool"_threshold_skip_"$thresh" --epochs 300 --lr_decay_step 280 | tee $logs_dir/"$dataset"_"$pool"_seed"$seed".log; 15 | 16 | # Supervised attention 17 | thresh=0.05 18 | pool=sup 19 | python main.py $params --pool attn_"$pool"_threshold_skip_"$thresh" --epochs 300 --lr_decay_step 280 | tee $logs_dir/"$dataset"_"$pool"_seed"$seed".log; 20 | 21 | # Weakly-supervised attention 22 | thresh=0.05 23 | python main.py $params --pool attn_sup_threshold_skip_"$thresh" --epochs 300 --lr_decay_step 280 --alpha_ws $checkpoints_dir/"$dataset"_alpha_WS_train_seed"$seed"_orig.pkl | tee $logs_dir/"$dataset"_weaksup_seed"$seed".log; 24 | -------------------------------------------------------------------------------- /scripts/mnist_75sp.sh: -------------------------------------------------------------------------------- 1 | dataset="mnist-75sp" 2 | seed=3026 3 | checkpoints_dir=./checkpoints/ 4 | params="-D $dataset --epochs 30 --lr_decay_step 20,25 --test_batch_size 100 -K 4 -f 4,64,512 --aggregation mean --n_hidden 0 --readout max --dropout 0.5 --threads 0 --img_features mean,coord --img_noise_levels 0.4,0.6 --pool_arch fc_prev --kl_weight 100 --seed $seed --results $checkpoints_dir -d ./data" 5 | 6 | logs_dir=./logs/ 7 | thresh=0.01 8 | 9 | # Global pooling 10 | python main.py $params --eval_attn_train --eval_attn_test | tee $logs_dir/"$dataset"_global_max_seed"$seed".log; 11 | 12 | # Unsupervised and supervised attention 13 | for pool in unsup sup; 14 | do python main.py $params --pool attn_"$pool"_threshold_skip_skip_"$thresh" | tee $logs_dir/"$dataset"_"$pool"_seed"$seed".log; 15 | done 16 | 17 | # Weakly-supervised attention 18 | python main.py $params --pool attn_sup_threshold_skip_skip_"$thresh" --alpha_ws $checkpoints_dir/"$dataset"_alpha_WS_train_seed"$seed"_orig.pkl | tee $logs_dir/"$dataset"_weaksup_seed"$seed".log; 19 | -------------------------------------------------------------------------------- /scripts/prepare_data.sh: -------------------------------------------------------------------------------- 1 | date 2 | seed=111 3 | 4 | # Generate Colors data 5 | out_dir=./data 6 | for dim in 3 8 16 32; do python generate_data.py --dim $dim -o $out_dir --seed $seed; done 7 | 8 | # Generate Triangles data 9 | python generate_data.py -D triangles --N_train 30000 --N_val 5000 --N_test 5000 --label_min 1 --label_max 10 --N_max 100 -o $out_dir --seed $seed 10 | 11 | # Generate MNIST-75sp data 12 | for split in train test; do python extract_superpixels.py -s $split -o $out_dir --seed $seed; done 13 | 14 | # Generate CIFAR-10-150sp data 15 | #for split in train test; do python extract_superpixels.py -D cifar10 -c 10 -n 150 -s $split -t 0 -o $out_dir; done 16 | 17 | # Generate noise for MNIST-75sp 18 | python -c "import sys,torch; print('seed=%s\nout file noise=%s\nout file color noise=%s' % (sys.argv[1], sys.argv[2], sys.argv[3])); torch.manual_seed(int(sys.argv[1])); noise=torch.randn(10000,75, dtype=torch.float); torch.save(noise, sys.argv[2]); colornoise=torch.randn(10000,75,3, dtype=torch.float); torch.save(colornoise, sys.argv[3]);" $seed "$out_dir"/mnist_75sp_noise.pt "$out_dir"/mnist_75sp_color_noise.pt 19 | 20 | # Download and unzip COLLAB, PROTEINS and D&D 21 | for dataset in COLLAB PROTEINS DD; 22 | do wget https://ls11-www.cs.tu-dortmund.de/people/morris/graphkerneldatasets/"$dataset".zip -P $out_dir; unzip "$out_dir"/"$dataset".zip -d $out_dir; done 23 | 24 | date 25 | -------------------------------------------------------------------------------- /scripts/triangles.sh: -------------------------------------------------------------------------------- 1 | dataset="triangles" 2 | seed=111 3 | checkpoints_dir=./checkpoints/ 4 | params="-D $dataset --epochs 100 --lr_decay_step 85,95 --test_batch_size 100 -K 7 -f 64,64,64 --aggregation sum --n_hidden 64 --readout max --dropout 0 --threads 0 --pool_arch gnn_curr --seed $seed --results $checkpoints_dir -d ./data" 5 | 6 | logs_dir=./logs/ 7 | 8 | # Global pooling 9 | python main.py $params --eval_attn_train --eval_attn_test | tee $logs_dir/"$dataset"_global_max_seed"$seed".log; 10 | 11 | # Unsupervised attention 12 | thresh=0.0001 13 | pool=unsup 14 | python main.py $params --pool attn_"$pool"_threshold_skip_"$thresh"_"$thresh" | tee $logs_dir/"$dataset"_"$pool"_seed"$seed".log; 15 | 16 | # Supervised attention 17 | thresh=0.001 18 | pool=sup 19 | python main.py $params --pool attn_"$pool"_threshold_skip_"$thresh"_"$thresh" | tee $logs_dir/"$dataset"_"$pool"_seed"$seed".log; 20 | 21 | # Weakly-supervised attention 22 | thresh=0.01 23 | python main.py $params --pool attn_sup_threshold_skip_"$thresh"_"$thresh" --alpha_ws $checkpoints_dir/"$dataset"_alpha_WS_train_seed"$seed"_orig.pkl | tee $logs_dir/"$dataset"_weaksup_seed"$seed".log; 24 | -------------------------------------------------------------------------------- /train_test.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch 3 | from torch.utils.data import DataLoader 4 | import torch.optim as optim 5 | import torch.optim.lr_scheduler as lr_scheduler 6 | from chebygin import * 7 | from utils import * 8 | from graphdata import * 9 | import torch.multiprocessing as mp 10 | import multiprocessing 11 | try: 12 | import ax 13 | from ax.service.managed_loop import optimize 14 | except Exception as e: 15 | print('AX is not available: %s' % str(e)) 16 | 17 | 18 | def set_pool(pool_thresh, args_pool): 19 | pool = copy.deepcopy(args_pool) 20 | for i, s in enumerate(pool): 21 | try: 22 | thresh = float(s) 23 | pool[i] = str(pool_thresh) 24 | except: 25 | continue 26 | return pool 27 | 28 | 29 | def train_evaluate(datareader, args, collate_fn, loss_fn, feature_stats, parameterization, folds=10, threads=5): 30 | 31 | print('parameterization', parameterization) 32 | 33 | pool_thresh, kl_weight = parameterization['pool'], parameterization['kl_weight'] 34 | pool = args.pool 35 | 36 | if args.tune_init: 37 | scale, init = parameterization['scale'], parameterization['init'] 38 | else: 39 | scale, init = args.scale, args.init 40 | 41 | n_hidden_attn, layer = parameterization['n_hidden_attn'], 1 42 | if layer == 0: 43 | pool = copy.deepcopy(args.pool) 44 | del pool[3] 45 | 46 | pool = set_pool(pool_thresh, pool) 47 | 48 | manager = multiprocessing.Manager() 49 | val_acc = manager.dict() 50 | assert threads <= folds, (threads, folds) 51 | n_it = int(np.ceil(float(folds) / threads)) 52 | for i in range(n_it): 53 | processes = [] 54 | if threads <= 1: 55 | single_job(i * threads, datareader, args, collate_fn, loss_fn, pool, kl_weight, 56 | feature_stats, val_acc, scale=scale, init=init, n_hidden_attn=n_hidden_attn) 57 | else: 58 | for fold in range(threads): 59 | p = mp.Process(target=single_job, 60 | args=(i * threads + fold, datareader, args, collate_fn, loss_fn, pool, kl_weight, 61 | feature_stats, val_acc, scale, init, n_hidden_attn)) 62 | p.start() 63 | processes.append(p) 64 | 65 | for p in processes: 66 | p.join() 67 | 68 | print(val_acc) 69 | val_acc = list(val_acc.values()) 70 | print('average and std over {} folds: {} +- {}'.format(folds, np.mean(val_acc), np.std(val_acc))) 71 | metric = np.mean(val_acc) - np.std(val_acc) # large std is considered bad 72 | print('metric: avg acc - std: {}'.format(metric)) 73 | return metric 74 | 75 | 76 | def ax_optimize(datareader, args, collate_fn, loss_fn, feature_stats, folds=10, threads=5, n_trials=30): 77 | parameters = [ 78 | {"name": "pool", "type": "range", "bounds": [1e-4, 2e-2], "log_scale": False}, 79 | {"name": "kl_weight", "type": "range", "bounds": [0.1, 10.], "log_scale": False}, 80 | {"name": "n_hidden_attn", "type": "choice", "values": [0, 32]} # hidden units in the attention layer (0: no hidden layer) 81 | ] 82 | 83 | if args.tune_init: 84 | parameters.extend([{"name": "scale", "type": "range", "bounds": [0.1, 2.], "log_scale": False}, 85 | {"name": "init", "type": "choice", "values": ['normal', 'uniform']}]) 86 | 87 | best_parameters, values, experiment, model = optimize( 88 | parameters=parameters, 89 | evaluation_function=lambda parameterization: train_evaluate(datareader, 90 | args, collate_fn, loss_fn, 91 | feature_stats, parameterization, folds=folds, 92 | threads=threads), 93 | total_trials=n_trials, 94 | objective_name='accuracy', 95 | ) 96 | 97 | print('best_parameters', best_parameters) 98 | print('values', values) 99 | return best_parameters 100 | 101 | 102 | def train(model, train_loader, optimizer, epoch, args, loss_fn, feature_stats=None, log=True): 103 | model.train() 104 | optimizer.zero_grad() 105 | n_samples, correct, train_loss = 0, 0, 0 106 | alpha_pred, alpha_GT = {}, {} 107 | start = time.time() 108 | 109 | # with torch.autograd.set_detect_anomaly(True): 110 | for batch_idx, data in enumerate(train_loader): 111 | data = data_to_device(data, args.device) 112 | if feature_stats is not None: 113 | data[0] = (data[0] - feature_stats[0]) / feature_stats[1] 114 | if batch_idx == 0 and epoch <= 1: 115 | sanity_check(model.eval(), data) # to disable the effect of dropout or other regularizers that can change behavior from batch to batch 116 | model.train() 117 | optimizer.zero_grad() 118 | mask = [data[2].view(len(data[2]), -1)] 119 | output, other_outputs = model(data) 120 | other_losses = other_outputs['reg'] if 'reg' in other_outputs else [] 121 | alpha = other_outputs['alpha'] if 'alpha' in other_outputs else [] 122 | mask.extend(other_outputs['mask'] if 'mask' in other_outputs else []) 123 | targets = data[3] 124 | loss = loss_fn(output, targets) 125 | for l in other_losses: 126 | loss += l 127 | loss_item = loss.item() 128 | train_loss += loss_item 129 | n_samples += len(targets) 130 | loss.backward() # accumulates gradient 131 | optimizer.step() # update weights 132 | time_iter = time.time() - start 133 | correct += count_correct(output.detach(), targets.detach()) 134 | update_attn(data, alpha, alpha_pred, alpha_GT, mask) 135 | acc = 100. * correct / n_samples # average over all examples in the dataset 136 | train_loss_avg = train_loss / (batch_idx + 1) 137 | 138 | if log and ((batch_idx > 0 and batch_idx % args.log_interval == 0) or batch_idx == len(train_loader) - 1): 139 | print('Train set (epoch {}): [{}/{} ({:.0f}%)]\tLoss: {:.4f} (avg: {:.4f}), other losses: {}\tAcc metric: {}/{} ({:.2f}%)\t AttnAUC: {}\t avg sec/iter: {:.4f}'.format( 140 | epoch, n_samples, len(train_loader.dataset), 100. * n_samples / len(train_loader.dataset), 141 | loss_item, train_loss_avg, ['%.4f' % l.item() for l in other_losses], 142 | correct, n_samples, acc, ['%.2f' % a for a in attn_AUC(alpha_GT, alpha_pred)], 143 | time_iter / (batch_idx + 1))) 144 | 145 | assert n_samples == len(train_loader.dataset), (n_samples, len(train_loader.dataset)) 146 | 147 | return train_loss, acc 148 | 149 | 150 | def test(model, test_loader, epoch, loss_fn, split, args, feature_stats=None, noises=None, 151 | img_noise_level=None, eval_attn=False, alpha_WS_name=''): 152 | model.eval() 153 | n_samples, correct, test_loss = 0, 0, 0 154 | pred, targets, N_nodes = [], [], [] 155 | start = time.time() 156 | alpha_pred, alpha_GT = {}, {} 157 | if eval_attn: 158 | alpha_pred[0] = [] 159 | print('testing with evaluation of attention: takes longer time') 160 | if args.debug: 161 | debug_data = {} 162 | 163 | with torch.no_grad(): 164 | for batch_idx, data in enumerate(test_loader): 165 | data = data_to_device(data, args.device) 166 | if feature_stats is not None: 167 | assert feature_stats[0].shape[2] == feature_stats[1].shape[2] == data[0].shape[2], \ 168 | (feature_stats[0].shape, feature_stats[1].shape, data[0].shape) 169 | data[0] = (data[0] - feature_stats[0]) / feature_stats[1] 170 | if batch_idx == 0 and epoch <= 1: 171 | sanity_check(model, data) 172 | 173 | if noises is not None: 174 | noise = noises[n_samples:n_samples + len(data[0])].to(args.device) * img_noise_level 175 | if len(noise.shape) == 2: 176 | noise = noise.unsqueeze(2) 177 | data[0][:, :, :3] = data[0][:, :, :3] + noise 178 | 179 | mask = [data[2].view(len(data[2]), -1)] 180 | N_nodes.append(data[4]['N_nodes'].detach()) 181 | targets.append(data[3].detach()) 182 | output, other_outputs = model(data) 183 | other_losses = other_outputs['reg'] if 'reg' in other_outputs else [] 184 | alpha = other_outputs['alpha'] if 'alpha' in other_outputs else [] 185 | mask.extend(other_outputs['mask'] if 'mask' in other_outputs else []) 186 | if args.debug: 187 | for key in other_outputs: 188 | if key.find('debug') >= 0: 189 | if key not in debug_data: 190 | debug_data[key] = [] 191 | debug_data[key].append([d.data.cpu().numpy() for d in other_outputs[key]]) 192 | if args.torch.find('1.') == 0: 193 | loss = loss_fn(output, data[3], reduction='sum') 194 | else: 195 | loss = loss_fn(output, data[3], reduce=False).sum() 196 | for l in other_losses: 197 | loss += l 198 | test_loss += loss.item() 199 | pred.append(output.detach()) 200 | 201 | 202 | update_attn(data, alpha, alpha_pred, alpha_GT, mask) 203 | if eval_attn: 204 | assert len(alpha) == 0, ('invalid mode, eval_attn should be false for this type of pooling') 205 | alpha_pred[0].extend(attn_heatmaps(model, args.device, data, output.data, test_loader.batch_size, constant_mask=args.dataset=='mnist')) 206 | 207 | n_samples += len(data[0]) 208 | if eval_attn and (n_samples % 100 == 0 or n_samples == len(test_loader.dataset)): 209 | print('{}/{} samples processed'.format(n_samples, len(test_loader.dataset))) 210 | 211 | assert n_samples == len(test_loader.dataset), (n_samples, len(test_loader.dataset)) 212 | 213 | pred = torch.cat(pred) 214 | targets = torch.cat(targets) 215 | N_nodes = torch.cat(N_nodes) 216 | if args.dataset.find('colors') >= 0: 217 | correct = count_correct(pred, targets, N_nodes=N_nodes, N_nodes_min=0, N_nodes_max=25) 218 | if pred.shape[0] > 2500: 219 | correct += count_correct(pred[2500:5000], targets[2500:5000], N_nodes=N_nodes[2500:5000], N_nodes_min=26, N_nodes_max=200) 220 | correct += count_correct(pred[5000:], targets[5000:], N_nodes=N_nodes[5000:], N_nodes_min=26, N_nodes_max=200) 221 | elif args.dataset == 'triangles': 222 | correct = count_correct(pred, targets, N_nodes=N_nodes, N_nodes_min=0, N_nodes_max=25) 223 | if pred.shape[0] > 5000: 224 | correct += count_correct(pred, targets, N_nodes=N_nodes, N_nodes_min=26, N_nodes_max=100) 225 | else: 226 | correct = count_correct(pred, targets, N_nodes=N_nodes, N_nodes_min=0, N_nodes_max=1e5) 227 | 228 | time_iter = time.time() - start 229 | 230 | test_loss_avg = test_loss / n_samples 231 | acc = 100. * correct / n_samples # average over all examples in the dataset 232 | print('{} set (epoch {}): Avg loss: {:.4f}, Acc metric: {}/{} ({:.2f}%)\t AttnAUC: {}\t avg sec/iter: {:.4f}\n'.format( 233 | split.capitalize(), epoch, test_loss_avg, correct, n_samples, acc, 234 | ['%.2f' % a for a in attn_AUC(alpha_GT, alpha_pred)], time_iter / (batch_idx + 1))) 235 | 236 | if args.debug: 237 | for key in debug_data: 238 | for layer in range(len(debug_data[key][0])): 239 | print('{} (layer={}): {:.5f}'.format(key, layer, np.mean([d[layer] for d in debug_data[key]]))) 240 | 241 | if eval_attn: 242 | alpha_pred = alpha_pred[0] 243 | if args.results in [None, 'None', ''] or alpha_WS_name == '': 244 | print('skip saving alpha values, invalid results dir (%s) or alpha_WS_name (%s)' % (args.results, alpha_WS_name)) 245 | else: 246 | file_path = pjoin(args.results, '%s_alpha_WS_%s_seed%d_%s.pkl' % (args.dataset, split, args.seed, alpha_WS_name)) 247 | if os.path.isfile(file_path): 248 | print('WARNING: file %s exists and will be overwritten' % file_path) 249 | with open(file_path, 'wb') as f: 250 | pickle.dump(alpha_pred, f, protocol=2) 251 | 252 | return test_loss, acc, alpha_pred, pred 253 | 254 | 255 | def update_attn(data, alpha, alpha_pred, alpha_GT, mask): 256 | key = 'node_attn_eval' 257 | for layer in range(len(mask)): 258 | mask[layer] = mask[layer].data.cpu().numpy() > 0 259 | if key in data[4]: 260 | if not isinstance(data[4][key], list): 261 | data[4][key] = [data[4][key]] 262 | for layer in range(len(data[4][key])): 263 | if layer not in alpha_GT: 264 | alpha_GT[layer] = [] 265 | # print(key, layer, len(data[4][key]), len(mask)) 266 | alpha_GT[layer].extend(masked_alpha(data[4][key][layer].data.cpu().numpy(), mask[layer])) 267 | for layer in range(len(alpha)): 268 | if layer not in alpha_pred: 269 | alpha_pred[layer] = [] 270 | alpha_pred[layer].extend(masked_alpha(alpha[layer].data.cpu().numpy(), mask[layer])) 271 | 272 | 273 | def masked_alpha(alpha, mask): 274 | alpha_lst = [] 275 | for i in range(len(alpha)): 276 | # print('gt', len(alpha), alpha[i].shape, mask[i].shape, alpha[i][mask[i] > 0].shape, mask[i].sum(), mask[i].min(), mask[i].max(), mask[i].dtype) 277 | alpha_lst.append(alpha[i][mask[i]]) 278 | return alpha_lst 279 | 280 | 281 | def attn_heatmaps(model, device, data, output_org, batch_size=1, constant_mask=False): 282 | labels = torch.argmax(output_org, dim=1) 283 | B, N_nodes_max, C = data[0].shape # N_nodes should be the same in the batch 284 | alpha_WS = [] 285 | if N_nodes_max > 1000: 286 | print('WARNING: graph is too large (%d nodes) and not supported by this function (evaluation will be incorrect for graphs in this batch).' % N_nodes_max) 287 | for b in range(B): 288 | n = data[2][b].sum().item() 289 | alpha_WS.append(np.zeros((1, n)) + 1. / n) 290 | return alpha_WS 291 | 292 | if constant_mask: 293 | mask = torch.ones(N_nodes_max, N_nodes_max - 1).to(device) 294 | 295 | # Indices of nodes such that in each row one index (i.e. one node) is removed 296 | node_ids = torch.arange(start=0, end=N_nodes_max, device=device).view(1, -1).repeat(N_nodes_max, 1) 297 | node_ids[np.diag_indices(N_nodes_max, 2)] = -1 298 | node_ids = node_ids[node_ids >= 0].view(N_nodes_max, N_nodes_max - 1).long() 299 | 300 | with torch.no_grad(): 301 | for b in range(B): 302 | x = torch.gather(data[0][b].unsqueeze(0).expand(N_nodes_max, -1, -1), dim=1, index=node_ids.unsqueeze(2).expand(-1, -1, C)) 303 | if not constant_mask: 304 | mask = torch.gather(data[2][b].unsqueeze(0).expand(N_nodes_max, -1), dim=1, index=node_ids) 305 | A = torch.gather(data[1][b].unsqueeze(0).expand(N_nodes_max, -1, -1), dim=1, index=node_ids.unsqueeze(2).expand(-1, -1, N_nodes_max)) 306 | A = torch.gather(A, dim=2, index=node_ids.unsqueeze(1).expand(-1, N_nodes_max - 1, -1)) 307 | output = torch.zeros(N_nodes_max).to(device) 308 | n_chunks = int(np.ceil(N_nodes_max / float(batch_size))) 309 | for i in range(n_chunks): 310 | idx = np.arange(i * batch_size, (i + 1) * batch_size) if i < n_chunks - 1 else np.arange(i * batch_size, N_nodes_max) 311 | output[idx] = model([x[idx], A[idx], mask[idx], None, {}])[0][:, labels[b]].data 312 | 313 | alpha = torch.abs(output - output_org[b, labels[b]]).view(1, N_nodes_max) #* mask_org[b].view(1, N_nodes_max) 314 | if not constant_mask: 315 | alpha = alpha[data[2][b].view(1, N_nodes_max)] 316 | alpha_WS.append(normalize(alpha).data.cpu().numpy()) 317 | 318 | return alpha_WS 319 | 320 | 321 | def save_checkpoint(model, scheduler, optimizer, args, epoch): 322 | if args.results in [None, 'None']: 323 | print('skip saving checkpoint, invalid results dir: %s' % args.results) 324 | return 325 | file_path = '%s/checkpoint_%s_%s_epoch%d_seed%07d.pth.tar' % (args.results, args.dataset, args.experiment_ID, epoch, args.seed) 326 | try: 327 | print('saving the model to %s' % file_path) 328 | state = { 329 | 'epoch': epoch, 330 | 'args': args, 331 | 'state_dict': model.state_dict(), 332 | 'scheduler': scheduler.state_dict(), 333 | 'optimizer': optimizer.state_dict(), 334 | } 335 | if os.path.isfile(file_path): 336 | print('WARNING: file %s exists and will be overwritten' % file_path) 337 | torch.save(state, file_path) 338 | except Exception as e: 339 | print('error saving the model', e) 340 | 341 | 342 | def load_checkpoint(model, optimizer, scheduler, file_path): 343 | print('loading the model from %s' % file_path) 344 | state = torch.load(file_path) 345 | model.load_state_dict(state['state_dict']) 346 | optimizer.load_state_dict(state['optimizer']) 347 | scheduler.load_state_dict(state['scheduler']) 348 | print('loading from epoch %d done' % state['epoch']) 349 | return state['epoch'] + 1 # +1 because we already finished training for this epoch 350 | 351 | 352 | def create_model_optimizer(in_features, out_features, pool, kl_weight, args, scale=None, init=None, n_hidden_attn=None): 353 | set_seed(args.seed, seed_data=None) 354 | model = ChebyGIN(in_features=in_features, 355 | out_features=out_features, 356 | filters=args.filters, 357 | K=args.filter_scale, 358 | n_hidden=args.n_hidden, 359 | aggregation=args.aggregation, 360 | dropout=args.dropout, 361 | readout=args.readout, 362 | pool=pool, 363 | pool_arch=args.pool_arch if n_hidden_attn in [None, 0] else args.pool_arch[:2] + ['%d' % n_hidden_attn], 364 | large_graph=args.dataset.lower() == 'mnist', 365 | kl_weight=float(kl_weight), 366 | init=args.init if init is None else init, 367 | scale=args.scale if scale is None else scale, 368 | debug=args.debug) 369 | print(model) 370 | # Compute the total number of trainable parameters 371 | print('model capacity: %d' % 372 | np.sum([np.prod(p.size()) if p.requires_grad else 0 for p in model.parameters()])) 373 | 374 | optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.wdecay, betas=(0.5, 0.999)) 375 | scheduler = lr_scheduler.MultiStepLR(optimizer, args.lr_decay_step, gamma=0.1) 376 | epoch = 1 377 | if args.resume not in [None, 'None']: 378 | epoch = load_checkpoint(model, optimizer, scheduler, args.resume) 379 | if epoch < args.epochs + 1: 380 | print('resuming training for epoch %d' % epoch) 381 | 382 | model.to(args.device) 383 | 384 | return epoch, model, optimizer, scheduler 385 | 386 | 387 | def single_job(fold, datareader, args, collate_fn, loss_fn, pool, kl_weight, feature_stats, val_acc, 388 | scale=None, init=None, n_hidden_attn=None): 389 | 390 | set_seed(args.seed, seed_data=None) 391 | 392 | wsup = args.pool[1] == 'sup' 393 | train_loader = DataLoader(GraphData(datareader, fold, 'train'), batch_size=args.batch_size, shuffle=True, 394 | num_workers=args.threads, collate_fn=collate_fn) 395 | val_loader = DataLoader(GraphData(datareader, fold, 'val'), batch_size=args.test_batch_size, shuffle=False, 396 | num_workers=args.threads, collate_fn=collate_fn) 397 | start_epoch, model, optimizer, scheduler = create_model_optimizer(train_loader.dataset.num_features, 398 | train_loader.dataset.num_classes, 399 | None if wsup else pool, kl_weight, args, 400 | scale=scale, init=init, n_hidden_attn=n_hidden_attn) 401 | 402 | for epoch in range(start_epoch, args.epochs + 1): 403 | scheduler.step() 404 | train(model, train_loader, optimizer, epoch, args, loss_fn, feature_stats, log=False) 405 | 406 | if wsup: 407 | train_loader_test = DataLoader(GraphData(datareader, fold, 'train'), batch_size=args.test_batch_size, shuffle=False, 408 | num_workers=args.threads, collate_fn=collate_fn) 409 | train_loss, train_acc, attn_WS = test(model, train_loader_test, epoch, loss_fn, 'train', args, feature_stats, eval_attn=True)[:3] # test_loss, acc, alpha_pred, pred 410 | train_loader = DataLoader(GraphData(datareader, fold, 'train', attn_labels=attn_WS), 411 | batch_size=args.batch_size, shuffle=True, 412 | num_workers=args.threads, collate_fn=collate_fn) 413 | val_loader = DataLoader(GraphData(datareader, fold, 'val'), batch_size=args.test_batch_size, shuffle=False, 414 | num_workers=args.threads, collate_fn=collate_fn) 415 | start_epoch, model, optimizer, scheduler = create_model_optimizer(train_loader.dataset.num_features, 416 | train_loader.dataset.num_classes, 417 | pool, kl_weight, args, 418 | scale=scale, init=init, n_hidden_attn=n_hidden_attn) 419 | for epoch in range(start_epoch, args.epochs + 1): 420 | scheduler.step() 421 | train(model, train_loader, optimizer, epoch, args, loss_fn, feature_stats, log=False) 422 | 423 | acc = test(model, val_loader, epoch, loss_fn, 'val', args, feature_stats)[1] 424 | 425 | val_acc[fold] = acc 426 | 427 | 428 | def cross_validation(datareader, args, collate_fn, loss_fn, pool, kl_weight, feature_stats, n_hidden_attn=None, folds=10, threads=5): 429 | print('%d-fold cross-validation' % folds) 430 | manager = multiprocessing.Manager() 431 | val_acc = manager.dict() 432 | assert threads <= folds, (threads, folds) 433 | n_it = int(np.ceil(float(folds) / threads)) 434 | for i in range(n_it): 435 | processes = [] 436 | if threads <= 1: 437 | single_job(i * threads, datareader, args, collate_fn, loss_fn, pool, kl_weight, 438 | feature_stats, val_acc, scale=args.scale, init=args.init, n_hidden_attn=n_hidden_attn) 439 | else: 440 | for fold in range(threads): 441 | p = mp.Process(target=single_job, args=(i * threads + fold, datareader, args, collate_fn, loss_fn, pool, kl_weight, 442 | feature_stats, val_acc, args.scale, args.init, n_hidden_attn)) 443 | p.start() 444 | processes.append(p) 445 | 446 | for p in processes: 447 | p.join() 448 | 449 | print(val_acc) 450 | val_acc = list(val_acc.values()) 451 | print('average and std over {} folds: {} +- {}'.format(folds, np.mean(val_acc), np.std(val_acc))) 452 | metric = np.mean(val_acc) - np.std(val_acc) 453 | print('metric: avg acc - std: {}'.format(metric)) 454 | return metric 455 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import torch 4 | import copy 5 | from graphdata import * 6 | import torch.nn.functional as F 7 | from torchvision import datasets, transforms 8 | from sklearn.metrics import roc_auc_score 9 | import numbers 10 | import random 11 | 12 | 13 | def load_save_noise(f, noise_shape): 14 | if os.path.isfile(f): 15 | print('loading noise from %s' % f) 16 | noises = torch.load(f) 17 | else: 18 | noises = torch.randn(noise_shape, dtype=torch.float) 19 | # np.save(f, noises.numpy()) 20 | torch.save(noises, f) 21 | return noises 22 | 23 | 24 | def list_to_torch(data): 25 | for i in range(len(data)): 26 | if data[i] is None: 27 | continue 28 | elif isinstance(data[i], np.ndarray): 29 | if data[i].dtype == np.bool: 30 | data[i] = data[i].astype(np.float32) 31 | data[i] = torch.from_numpy(data[i]).float() 32 | elif isinstance(data[i], list): 33 | data[i] = list_to_torch(data[i]) 34 | return data 35 | 36 | 37 | def data_to_device(data, device): 38 | if isinstance(data, dict): 39 | keys = list(data.keys()) 40 | else: 41 | keys = range(len(data)) 42 | for i in keys: 43 | if isinstance(data[i], list) or isinstance(data[i], dict): 44 | data[i] = data_to_device(data[i], device) 45 | else: 46 | if isinstance(data[i], torch.Tensor): 47 | try: 48 | data[i] = data[i].to(device) 49 | except: 50 | print('error', i, data[i], type(data[i])) 51 | raise 52 | return data 53 | 54 | 55 | def count_correct(output, target, N_nodes=None, N_nodes_min=0, N_nodes_max=25): 56 | if output.shape[1] == 1: 57 | # Regression 58 | pred = output.round().long() 59 | else: 60 | # Classification 61 | pred = output.max(1, keepdim=True)[1] 62 | target = target.long().squeeze().cpu() # for older pytorch 63 | pred = pred.squeeze().cpu() # for older pytorch 64 | if N_nodes is not None: 65 | idx = (N_nodes >= N_nodes_min) & (N_nodes <= N_nodes_max) 66 | if idx.sum() > 0: 67 | correct = pred[idx].eq(target[idx]).sum().item() 68 | for lbl in torch.unique(target, sorted=True): 69 | idx_lbl = target[idx] == lbl 70 | eq = (pred[idx][idx_lbl] == target[idx][idx_lbl]).float() 71 | print('lbl: {}, avg acc: {:2.2f}% ({}/{})'.format(lbl, 100 * eq.mean(), int(eq.sum()), 72 | int(idx_lbl.float().sum()))) 73 | 74 | eq = (pred[idx] == target[idx]).float() 75 | print('{} <= N_nodes <= {} (min={}, max={}), avg acc: {:2.2f}% ({}/{})'.format(N_nodes_min, 76 | N_nodes_max, 77 | N_nodes[idx].min(), 78 | N_nodes[idx].max(), 79 | 100 * eq.mean(), 80 | int(eq.sum()), int(idx.sum()))) 81 | else: 82 | correct = 0 83 | print('no graphs with nodes >= {} and <= {}'.format(N_nodes_min, N_nodes_max)) 84 | else: 85 | correct = pred.eq(target).sum().item() 86 | 87 | return correct 88 | 89 | 90 | def attn_AUC(alpha_GT, alpha): 91 | auc = [] 92 | if len(alpha) > 0 and alpha_GT is not None and len(alpha_GT) > 0: 93 | for layer in alpha: 94 | alpha_gt = np.concatenate([a.flatten() for a in alpha_GT[layer]]) > 0 95 | if len(np.unique(alpha_gt)) <= 1: 96 | print('Only one class ({}) present in y_true. ROC AUC score is not defined in that case.'.format(np.unique(alpha_gt))) 97 | auc.append(np.nan) 98 | else: 99 | auc.append(100 * roc_auc_score(y_true=alpha_gt, 100 | y_score=np.concatenate([a.flatten() for a in alpha[layer]]))) 101 | return auc 102 | 103 | 104 | def stats(arr): 105 | return np.mean(arr), np.std(arr), np.min(arr), np.max(arr) 106 | 107 | 108 | def normalize(x, eps=1e-7): 109 | return x / (x.sum() + eps) 110 | 111 | 112 | def normalize_batch(x, dim=1, eps=1e-7): 113 | return x / (x.sum(dim=dim, keepdim=True) + eps) 114 | 115 | 116 | def normalize_zero_one(im, eps=1e-7): 117 | m1 = im.min() 118 | m2 = im.max() 119 | return (im - m1) / (m2 - m1 + eps) 120 | 121 | 122 | def mse_loss(target, output, reduction='mean', reduce=None): 123 | loss = (target.float().squeeze() - output.float().squeeze()) ** 2 124 | if reduce is None: 125 | if reduction == 'mean': 126 | return torch.mean(loss) 127 | elif reduction == 'sum': 128 | return torch.sum(loss) 129 | elif reduction == 'none': 130 | return loss 131 | else: 132 | NotImplementedError(reduction) 133 | elif not reduce: 134 | return loss 135 | else: 136 | NotImplementedError('use reduction if reduce=True') 137 | 138 | 139 | def shuffle_nodes(batch): 140 | x, A, mask, labels, params_dict = batch 141 | for b in range(x.shape[0]): 142 | idx = np.random.permutation(x.shape[1]) 143 | x[b] = x[b, idx] 144 | A[b] = A[b, :, idx][idx, :] 145 | mask[b] = mask[b, idx] 146 | if 'node_attn' in params_dict: 147 | params_dict['node_attn'][b] = params_dict['node_attn'][b, idx] 148 | return [x, A, mask, labels, params_dict] 149 | 150 | 151 | def copy_batch(data): 152 | data_cp = [] 153 | for i in range(len(data)): 154 | if isinstance(data[i], dict): 155 | data_cp.append({key: data[i][key].clone() for key in data[i]}) 156 | else: 157 | data_cp.append(data[i].clone()) 158 | return data_cp 159 | 160 | 161 | def sanity_check(model, data): 162 | with torch.no_grad(): 163 | output1 = model(copy_batch(data))[0] 164 | output2 = model(shuffle_nodes(copy_batch(data)))[0] 165 | if not torch.allclose(output1, output2, rtol=1e-02, atol=1e-03): 166 | print('WARNING: model outputs different depending on the nodes order', (torch.norm(output1 - output2), 167 | torch.max(output1 - output2), 168 | torch.max(output1), 169 | torch.max(output2))) 170 | print('model is checked for nodes shuffling') 171 | 172 | 173 | def set_seed(seed, seed_data=None): 174 | random.seed(seed) # for some libraries 175 | rnd = np.random.RandomState(seed) 176 | if seed_data is not None: 177 | rnd_data = np.random.RandomState(seed_data) 178 | else: 179 | rnd_data = rnd 180 | torch.backends.cudnn.deterministic = True 181 | torch.backends.cudnn.benchmark = True 182 | torch.manual_seed(seed) 183 | torch.cuda.manual_seed(seed) 184 | torch.cuda.manual_seed_all(seed) 185 | return rnd, rnd_data 186 | 187 | 188 | def compute_feature_stats(model, train_loader, device, n_batches=100): 189 | print('computing mean and std of input features') 190 | model.eval() 191 | x = [] 192 | with torch.no_grad(): 193 | for batch_idx, data in enumerate(train_loader): 194 | x.append(data[0].data.cpu().numpy()) # B,N,F 195 | if batch_idx > n_batches: 196 | break 197 | x = np.concatenate(x, axis=1).reshape(-1, x[0].shape[-1]) 198 | print('features shape loaded', x.shape) 199 | 200 | mn = x.mean(axis=0, keepdims=True) 201 | sd = x.std(axis=0, keepdims=True) 202 | print('mn', mn) 203 | print('std', sd) 204 | sd[sd < 1e-2] = 1 # to prevent dividing by a small number 205 | print('corrected (non zeros) std', sd)#.data.cpu().numpy()) 206 | 207 | mn = torch.from_numpy(mn).float().to(device).unsqueeze(0) 208 | sd = torch.from_numpy(sd).float().to(device).unsqueeze(0) 209 | return mn, sd 210 | 211 | 212 | def copy_data(data, idx): 213 | data_new = {} 214 | for key in data: 215 | if key == 'Max_degree': 216 | data_new[key] = data[key] 217 | print(key, data_new[key]) 218 | else: 219 | data_new[key] = copy.deepcopy([data[key][i] for i in idx]) 220 | if key in ['graph_labels', 'N_edges']: 221 | data_new[key] = np.array(data_new[key], np.int32) 222 | print(key, len(data_new[key])) 223 | 224 | return data_new 225 | 226 | 227 | def concat_data(data): 228 | data_new = {} 229 | for key in data[0]: 230 | if key == 'Max_degree': 231 | data_new[key] = np.max(np.array([ d[key] for d in data ])) 232 | print(key, data_new[key]) 233 | else: 234 | if key in ['graph_labels', 'N_edges']: 235 | data_new[key] = np.concatenate([ d[key] for d in data ]) 236 | else: 237 | lst = [] 238 | for d in data: 239 | lst.extend(d[key]) 240 | data_new[key] = lst 241 | print(key, len(data_new[key])) 242 | 243 | return data_new 244 | --------------------------------------------------------------------------------