├── LICENSE.md
├── README.md
├── attention_pooling.py
├── chebygin.py
├── checkpoints
├── checkpoint_TU_657452_epoch50_seed0000111.pth.tar
├── checkpoint_colors-3_223919_epoch300_seed0000111.pth.tar
├── checkpoint_colors-3_312570_epoch300_seed0000111.pth.tar
├── checkpoint_colors-3_332172_epoch300_seed0000111.pth.tar
├── checkpoint_colors-3_828931_epoch100_seed0000111.pth.tar
├── checkpoint_mnist-75sp_065802_epoch30_seed0000111.pth.tar
├── checkpoint_mnist-75sp_139255_epoch30_seed0000111.pth.tar
├── checkpoint_mnist-75sp_330394_epoch30_seed0000111.pth.tar
├── checkpoint_mnist-75sp_820601_epoch30_seed0000111.pth.tar
├── checkpoint_triangles_051609_epoch100_seed0000111.pth.tar
├── checkpoint_triangles_230187_epoch100_seed0000111.pth.tar
├── checkpoint_triangles_586710_epoch100_seed0000111.pth.tar
├── checkpoint_triangles_658037_epoch100_seed0000111.pth.tar
├── colors-3_alpha_WS_test_seed111_orig.pkl
├── colors-3_alpha_WS_train_seed111_orig.pkl
├── mnist-75sp_alpha_WS_test_seed111_noisy-c.pkl
├── mnist-75sp_alpha_WS_test_seed111_noisy.pkl
├── mnist-75sp_alpha_WS_test_seed111_orig.pkl
├── mnist-75sp_alpha_WS_train_seed111_orig.pkl
├── triangles_alpha_WS_test_seed111_orig.pkl
└── triangles_alpha_WS_train_seed111_orig.pkl
├── data
├── COLORS-3.zip
├── TRIANGLES.zip
├── datasets.png
├── mnist_animation.gif
└── triangles_animation.gif
├── extract_superpixels.py
├── generate_data.py
├── graphdata.py
├── logs
├── colors-3_global_max_seed111.log
├── colors-3_sup_seed111.log
├── colors-3_unsup_seed111.log
├── colors-3_weaksup_seed111.log
├── mnist-75sp_global_max_seed111.log
├── mnist-75sp_sup_seed111.log
├── mnist-75sp_unsup_seed111.log
├── mnist-75sp_weaksup_seed111.log
├── prepare_data.log
├── proteins_wsup_seed111.log
├── triangles_global_max_seed111.log
├── triangles_sup_seed111.log
├── triangles_unsup_seed111.log
└── triangles_weaksup_seed111.log
├── main.py
├── notebooks
├── MNIST_eval_models.ipynb
├── TRIANGLES_eval_models.ipynb
├── convert2TU.ipynb
├── graphs_visualize.ipynb
├── molecules_social_visualize.ipynb
├── superpixels_visualize.ipynb
└── synthetic_graphs_visualize.ipynb
├── scripts
├── colors.sh
├── mnist_75sp.sh
├── prepare_data.sh
└── triangles.sh
├── train_test.py
└── utils.py
/LICENSE.md:
--------------------------------------------------------------------------------
1 | Educational Community License, Version 2.0 (ECL-2.0)
2 |
3 | Version 2.0, April 2007
4 |
5 | http://www.osedu.org/licenses/
6 |
7 | The Educational Community License version 2.0 ("ECL") consists of the Apache 2.0 license, modified to change the scope of the patent grant in section 3 to be specific to the needs of the education communities using this license. The original Apache 2.0 license can be found at: http://www.apache.org/licenses /LICENSE-2.0
8 |
9 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
10 |
11 | 1. Definitions.
12 |
13 | "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document.
14 |
15 | "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License.
16 |
17 | "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity.
18 |
19 | "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License.
20 |
21 | "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files.
22 |
23 | "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types.
24 |
25 | "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below).
26 |
27 | "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof.
28 |
29 | "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution."
30 |
31 | "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work.
32 |
33 | 2. Grant of Copyright License.
34 |
35 | Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form.
36 |
37 | 3. Grant of Patent License.
38 |
39 | Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. Any patent license granted hereby with respect to contributions by an individual employed by an institution or organization is limited to patent claims where the individual that is the author of the Work is also the inventor of the patent claims licensed, and where the organization or institution has the right to grant such license under applicable grant and research funding agreements. No other express or implied licenses are granted.
40 |
41 | 4. Redistribution.
42 |
43 | You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions:
44 |
45 | You must give any other recipients of the Work or Derivative Works a copy of this License; and You must cause any modified files to carry prominent notices stating that You changed the files; and You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License.
46 |
47 | 5. Submission of Contributions.
48 |
49 | Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions.
50 |
51 | 6. Trademarks.
52 |
53 | This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file.
54 |
55 | 7. Disclaimer of Warranty.
56 |
57 | Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License.
58 |
59 | 8. Limitation of Liability.
60 |
61 | In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages.
62 |
63 | 9. Accepting Warranty or Additional Liability.
64 |
65 | While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability.
66 |
67 | END OF TERMS AND CONDITIONS
68 |
69 | APPENDIX: How to apply the Educational Community License to your work
70 |
71 | To apply the Educational Community License to your work, attach
72 | the following boilerplate notice, with the fields enclosed by
73 | brackets "[]" replaced with your own identifying information.
74 | (Don't include the brackets!) The text should be enclosed in the
75 | appropriate comment syntax for the file format. We also recommend
76 | that a file or class name and description of purpose be included on
77 | the same "printed page" as the copyright notice for easier
78 | identification within third-party archives.
79 |
80 | Copyright [yyyy] [name of copyright owner] Licensed under the
81 | Educational Community License, Version 2.0 (the "License"); you may
82 | not use this file except in compliance with the License. You may
83 | obtain a copy of the License at
84 |
85 | http://www.osedu.org/licenses /ECL-2.0
86 |
87 | Unless required by applicable law or agreed to in writing,
88 | software distributed under the License is distributed on an "AS IS"
89 | BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
90 | or implied. See the License for the specific language governing
91 | permissions and limitations under the License.
92 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Intro
2 |
3 | This repository contains code to generate data and reproduce experiments from our NeurIPS 2019 paper:
4 |
5 | [Boris Knyazev, Graham W. Taylor, Mohamed R. Amer. Understanding Attention and Generalization in Graph Neural Networks](https://arxiv.org/abs/1905.02850).
6 |
7 | See slides [here](https://drive.google.com/open?id=1HcmhSEnf8ll6-BxXK1PiGzcXDa6BbKnC).
8 |
9 | [An earlier short version](https://rlgm.github.io/papers/54.pdf) of our paper was presented as a **contributed talk** at [ICLR Workshop on Representation Learning on Graphs and Manifolds, 2019](https://rlgm.github.io/cfp/).
10 |
11 | **Update:**
12 |
13 | In the code for MNIST, the `dist` variable should have been squared to make it a Gaussian. All figures and results were generated without squaring it. I don't think it's very important in terms of results, but if you square it, `sigma` should be adjusted accordingly.
14 |
15 |
16 | | MNIST | TRIANGLES
17 | |:-------------------------:|:-------------------------:|
18 | |
|
|
19 |
20 |
21 | For MNIST from top to bottom rows:
22 |
23 | - input test images with additive Gaussian noise with standard deviation in the range from 0 to 1.4 with step 0.2
24 | - attention coefficients (alpha) predicted by the **unsupervised** model
25 | - attention coefficients (alpha) predicted by the **supervised** model
26 | - attention coefficients (alpha) predicted by our **weakly-supervised** model
27 |
28 | For TRIANGLES from top to bottom rows:
29 |
30 | - **on the left**: input test graph (with 4-100 nodes) with ground truth attention coefficients, **on the right**: graph obtained by **ground truth** node pooling
31 | - **on the left**: input test graph (with 4-100 nodes) with unsupervised attention coefficients, **on the right**: graph obtained by **unsupervised** node pooling
32 | - **on the left**: input test graph (with 4-100 nodes) with supervised attention coefficients, **on the right**: graph obtained by **supervised** node pooling
33 | - **on the left**: input test graph (with 4-100 nodes) with weakly-supervised attention coefficients, **on the right**: graph obtained by **weakly-supervised** node pooling
34 |
35 |
36 | Note that during training, our MNIST models have not encountered noisy images and our TRIANGLES models have not encountered graphs larger than with N=25 nodes.
37 |
38 |
39 | ## Examples using [PyTorch Geometric](https://github.com/rusty1s/pytorch_geometric)
40 |
41 |
42 |
43 |
44 |
45 | COLORS and TRIANGLES datasets are now also available in the [TU](https://ls11-www.cs.tu-dortmund.de/staff/morris/graphkerneldatasets) format, so that you can use a general TU datareader. See PyTorch Geometric examples for
46 | [COLORS](https://github.com/rusty1s/pytorch_geometric/blob/master/examples/colors_topk_pool.py) and [TRIANGLES](https://github.com/rusty1s/pytorch_geometric/blob/master/examples/triangles_sag_pool.py).
47 |
48 |
49 | ## Example of evaluating a pretrained model on MNIST
50 |
51 | For more examples, see [MNIST_eval_models](notebooks/MNIST_eval_models.ipynb) and
52 | [TRIANGLES_eval_models](notebooks/TRIANGLES_eval_models.ipynb).
53 |
54 | ```python
55 | # Download model checkpoint or 'git clone' this repo
56 | import urllib.request
57 | # Let's use the model with supervised attention (other models can be found in the Table below)
58 | model_name = 'checkpoint_mnist-75sp_139255_epoch30_seed0000111.pth.tar'
59 | model_url = 'https://github.com/bknyaz/graph_attention_pool/raw/master/checkpoints/%s' % model_name
60 | model_path = 'checkpoints/%s' % model_name
61 | urllib.request.urlretrieve(model_url, model_path)
62 | ```
63 |
64 | ```python
65 | # Load the model
66 | import torch
67 | from chebygin import ChebyGIN
68 |
69 | state = torch.load(model_path)
70 | args = state['args']
71 | model = ChebyGIN(in_features=5, out_features=10, filters=args.filters, K=args.filter_scale,
72 | n_hidden=args.n_hidden, aggregation=args.aggregation, dropout=args.dropout,
73 | readout=args.readout, pool=args.pool, pool_arch=args.pool_arch)
74 | model.load_state_dict(state['state_dict'])
75 | model = model.eval()
76 | ```
77 |
78 | ```python
79 | # Load image using standard PyTorch Dataset
80 | from torchvision import datasets
81 | data = datasets.MNIST('./data', train=False, download=True)
82 | images = (data.test_data.numpy() / 255.)
83 | import numpy as np
84 | img = images[0].astype(np.float32) # 28x28 MNIST image
85 | ```
86 |
87 | ```python
88 | # Extract superpixels and create node features
89 | import scipy.ndimage
90 | from skimage.segmentation import slic
91 | from scipy.spatial.distance import cdist
92 |
93 | # The number (n_segments) of superpixels returned by SLIC is usually smaller than requested, so we request more
94 | superpixels = slic(img, n_segments=95, compactness=0.25, multichannel=False)
95 | sp_indices = np.unique(superpixels)
96 | n_sp = len(sp_indices) # should be 74 with these parameters of slic
97 |
98 | sp_intensity = np.zeros((n_sp, 1), np.float32)
99 | sp_coord = np.zeros((n_sp, 2), np.float32) # row, col
100 | for seg in sp_indices:
101 | mask = superpixels == seg
102 | sp_intensity[seg] = np.mean(img[mask])
103 | sp_coord[seg] = np.array(scipy.ndimage.measurements.center_of_mass(mask))
104 |
105 | # The model is invariant to the order of nodes in a graph
106 | # We can shuffle nodes and obtain exactly the same results
107 | ind = np.random.permutation(n_sp)
108 | sp_coord = sp_coord[ind]
109 | sp_intensity = sp_intensity[ind]
110 | ```
111 |
112 | ```python
113 | # Create edges between nodes in the form of adjacency matrix
114 | sp_coord = sp_coord / images.shape[1]
115 | dist = cdist(sp_coord, sp_coord) # distance between all pairs of nodes
116 | sigma = 0.1 * np.pi # width of a Guassian
117 | A = np.exp(- dist / sigma ** 2) # transform distance to spatial closeness
118 | A[np.diag_indices_from(A)] = 0 # remove self-loops
119 | A = torch.from_numpy(A).float().unsqueeze(0)
120 | ```
121 |
122 | ```python
123 | # Prepare an input to the model and process it
124 | N_nodes = sp_intensity.shape[0]
125 | mask = torch.ones(1, N_nodes, dtype=torch.uint8)
126 |
127 | # mean and std computed for superpixel features in the training set
128 | mn = torch.tensor([0.11225057, 0.11225057, 0.11225057, 0.44206527, 0.43950436]).view(1, 1, -1)
129 | sd = torch.tensor([0.2721889, 0.2721889, 0.2721889, 0.2987583, 0.30080357]).view(1, 1, -1)
130 |
131 | node_features = (torch.from_numpy(np.pad(np.concatenate((sp_intensity, sp_coord), axis=1),
132 | ((0, 0), (2, 0)), 'edge')).unsqueeze(0) - mn) / sd
133 |
134 | y, other_outputs = model([node_features, A, mask, None, {'N_nodes': torch.zeros(1, 1) + N_nodes}])
135 | alpha = other_outputs['alpha'][0].data
136 | ```
137 |
138 | - `y` is a vector with 10 unnormalized class scores. To get a predicted label, we can use ```torch.argmax(y)```.
139 |
140 | - `alpha` is a vector of attention coefficients alpha for each node.
141 |
142 | ## Tasks & Datasets
143 |
144 | 1. We design two synthetic graph tasks, COLORS and TRIANGLES, in which we predict the number of green nodes and the number of triangles respectively.
145 |
146 | 2. We also experiment with the [MNIST](http://yann.lecun.com/exdb/mnist/) image classification dataset, which we preprocess by extracting superpixels - a more natural way to feed images to a graph. We denote this dataset as MNIST-75sp.
147 |
148 | 3. We validate our weakly-supervised approach on three common graph classification benchmarks: [COLLAB, PROTEINS and D&D](https://ls11-www.cs.tu-dortmund.de/staff/morris/graphkerneldatasets).
149 |
150 | For COLORS, TRIANGLES and MNIST we know ground truth attention for nodes, which allows us to study graph neural networks with attention in depth.
151 |
152 |
153 |
154 |
155 | ## Data generation
156 |
157 | To generate all data using a single command: ```./scripts/prepare_data.sh```.
158 |
159 | All generated/downloaded ata will be stored in the local ```./data``` directory.
160 | It can take about 1 hour to prepare all data (see my [log](logs/prepare_data.log)) and all data take about 2 GB.
161 |
162 | Alternatively, you can generate data for each task as described below.
163 |
164 | In case of any issues with running these scripts, data can be downloaded from [here](https://drive.google.com/drive/folders/1Prc-n9Nr8-5z-xphdRScftKKIxU4Olzh?usp=sharing).
165 |
166 | ### COLORS
167 | To generate training, validation and test data for our **Colors** dataset with different dimensionalities:
168 |
169 | ```for dim in 3 8 16 32; do python generate_data.py --dim $dim; done```
170 |
171 | ### MNIST-75sp
172 | To generate training and test data for our MNIST-75sp dataset using 4 CPU threads:
173 |
174 | ```for split in train test; do python extract_superpixels.py -s $split -t 4; done```
175 |
176 | ## Data visualization
177 | Once datasets are generated or downloaded, you can use the following IPython notebooks to load and visualize data:
178 |
179 | [COLORS and TRIANGLES](notebooks/synthetic_graphs_visualize.ipynb), [MNIST](notebooks/superpixels_visualize.ipynb) and
180 | [COLLAB, PROTEINS and D&D](notebooks/graphs_visualize.ipynb).
181 |
182 |
183 | # Pretrained ChebyGIN models
184 |
185 | Generalization results on the test sets for three tasks. Other results are available in the paper.
186 |
187 | Click on the result to download a trained model in the PyTorch format.
188 |
189 | | Model | COLORS-Test-LargeC | TRIANGLES-Test-Large | MNIST-75sp-Test-Noisy
190 | | --------------------- |:-------------:|:-------------:|:-------------:|
191 | | Script to train models | [colors.sh](scripts/colors.sh) | [triangles.sh](scripts/triangles.sh) | [mnist_75sp.sh](./scripts/mnist_75sp.sh) |
192 | | Global pooling | [15 ± 7](./checkpoints/checkpoint_colors-3_828931_epoch100_seed0000111.pth.tar) | [30 ± 1](./checkpoints/checkpoint_triangles_658037_epoch100_seed0000111.pth.tar) | [80 ± 12](./checkpoints/checkpoint_mnist-75sp_820601_epoch30_seed0000111.pth.tar) |
193 | | Unsupervised attention | [11 ± 6](./checkpoints/checkpoint_colors-3_223919_epoch300_seed0000111.pth.tar) | [26 ± 2](./checkpoints//checkpoint_triangles_051609_epoch100_seed0000111.pth.tar) | [80 ± 23](./checkpoints/checkpoint_mnist-75sp_330394_epoch30_seed0000111.pth.tar) |
194 | | Supervised attention | [75 ± 17](./checkpoints/checkpoint_colors-3_332172_epoch300_seed0000111.pth.tar) | [48 ± 1](./checkpoints/checkpoint_triangles_586710_epoch100_seed0000111.pth.tar) | [92.3 ± 0.4](./checkpoints/checkpoint_mnist-75sp_139255_epoch30_seed0000111.pth.tar) |
195 | | Weakly-supervised attention | [73 ± 14 ](./checkpoints//checkpoint_colors-3_312570_epoch300_seed0000111.pth.tar) | [30 ± 1](./checkpoints/checkpoint_triangles_230187_epoch100_seed0000111.pth.tar) | [88.8 ± 4](./checkpoints/checkpoint_mnist-75sp_065802_epoch30_seed0000111.pth.tar) | |
196 |
197 |
198 | The scripts to train the models must be run from the main directory, e.g.: ```./scripts/mnist_75sp.sh```
199 |
200 | Examples of evaluating our trained models can be found in notebooks: [MNIST_eval_models](notebooks/MNIST_eval_models.ipynb) and
201 | [TRIANGLES_eval_models](notebooks/TRIANGLES_eval_models.ipynb).
202 |
203 |
204 | ## Other examples of training models
205 |
206 | To tune hyperparameters on the validation set for COLORS, TRIANGLES and MNIST, use the ```--validation``` flag.
207 |
208 | For COLLAB, PROTEINS and D&D tuning of hyperparameters is included in the training script. Use the `--ax` flag.
209 |
210 | Example of running 10 weakly-supervised experiments on **PROTEINS** with cross-validation of hyperparameters *including initialization parameters (distribution and scale) of the attention model* (the `--tune_init` flag):
211 |
212 | ```
213 | for i in $(seq 1 1 10); do dataseed=$(( ( RANDOM % 10000 ) + 1 )); for j in $(seq 1 1 10); do seed=$(( ( RANDOM % 10000 ) + 1 )); python main.py --seed $seed -D TU --n_nodes 25 --epochs 50 --lr_decay_step 25,35,45 --test_batch_size 100 -f 64,64,64 -K 1 --readout max --dropout 0.1 --pool attn_sup_threshold_skip_skip_0 --pool_arch fc_prev --results None --data_dir ./data/PROTEINS --seed_data $dataseed --cv --cv_folds 5 --cv_threads 5 --ax --ax_trials 30 --scale None --tune_init | tee logs/proteins_wsup_"$dataseed"_"$seed".log; done; done
214 | ```
215 |
216 | No initialization tuning on **COLLAB**:
217 |
218 | ```
219 | for i in $(seq 1 1 10); do dataseed=$(( ( RANDOM % 10000 ) + 1 )); for j in $(seq 1 1 10); do seed=$(( ( RANDOM % 10000 ) + 1 )); python main.py --seed $seed -D TU --n_nodes 35 --epochs 50 --lr_decay_step 25,35,45 --test_batch_size 32 -f 64,64,64 -K 3 --readout max --dropout 0.1 --pool attn_sup_threshold_skip_skip_skip_0 --pool_arch fc_prev --results None --data_dir ./data/COLLAB --seed_data $dataseed --cv --cv_folds 5 --cv_threads 5 --ax --ax_trials 30 --scale None | tee logs/collab_wsup_"$dataseed"_"$seed".log; done; done
220 | ```
221 |
222 | Note that results can be better if using `--pool_arch gnn_prev`, but we didn't focus on that.
223 |
224 | # Requirements
225 |
226 | Python packages required (can be installed via pip or conda):
227 |
228 | - python >= 3.6.1
229 | - PyTorch >= 0.4.1
230 | - [Ax](https://github.com/facebook/Ax) for hyper-parameter tuning on COLLAB, PROTEINS and D\&D
231 | - networkx
232 | - OpenCV
233 | - SciPy
234 | - scikit-image
235 | - scikit-learn
236 |
237 | # Reference
238 |
239 | Please cite our paper if you use our data or code:
240 |
241 | ```
242 | @inproceedings{knyazev2019understanding,
243 | title={Understanding attention and generalization in graph neural networks},
244 | author={Knyazev, Boris and Taylor, Graham W and Amer, Mohamed},
245 | booktitle={Advances in Neural Information Processing Systems},
246 | pages={4202--4212},
247 | year={2019},
248 | pdf={http://arxiv.org/abs/1905.02850}
249 | }
250 | ```
251 |
--------------------------------------------------------------------------------
/attention_pooling.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.sparse
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | from utils import *
7 |
8 |
9 | class AttentionPooling(nn.Module):
10 | '''
11 | Graph pooling layer implementing top-k and threshold-based pooling.
12 | '''
13 | def __init__(self,
14 | in_features, # feature dimensionality in the current graph layer
15 | in_features_prev, # feature dimensionality in the previous graph layer
16 | pool_type,
17 | pool_arch,
18 | large_graph,
19 | attn_gnn=None,
20 | kl_weight=None,
21 | drop_nodes=True,
22 | init='normal',
23 | scale=None,
24 | debug=False):
25 | super(AttentionPooling, self).__init__()
26 | self.pool_type = pool_type
27 | self.pool_arch = pool_arch
28 | self.large_graph = large_graph
29 | self.kl_weight = kl_weight
30 | self.proj = None
31 | self.drop_nodes = drop_nodes
32 | self.is_topk = self.pool_type[2].lower() == 'topk'
33 | self.scale =scale
34 | self.init = init
35 | self.debug = debug
36 | self.clamp_value = 60
37 | self.torch = torch.__version__
38 | if self.is_topk:
39 | self.topk_ratio = float(self.pool_type[3]) # r
40 | assert self.topk_ratio > 0 and self.topk_ratio <= 1, ('invalid top-k ratio', self.topk_ratio, self.pool_type)
41 | else:
42 | self.threshold = float(self.pool_type[3]) # \tilde{alpha}
43 | assert self.threshold >= 0 and self.threshold <= 1, ('invalid pooling threshold', self.threshold, self.pool_type)
44 |
45 | if self.pool_type[1] in ['unsup', 'sup']:
46 | assert self.pool_arch not in [None, 'None'], self.pool_arch
47 |
48 | n_in = in_features_prev if self.pool_arch[1] == 'prev' else in_features
49 | if self.pool_arch[0] == 'fc':
50 | p_optimal = torch.from_numpy(np.pad(np.array([0, 1]), (0, n_in - 2), 'constant')).float().view(1, n_in)
51 | if len(self.pool_arch) == 2:
52 | # single layer projection
53 | self.proj = nn.Linear(n_in, 1, bias=False)
54 | p = self.proj.weight.data
55 | if scale is not None:
56 | if init == 'normal':
57 | p = torch.randn(n_in) # std=1, seed 9753 for optimal initialization
58 | elif init == 'uniform':
59 | p = torch.rand(n_in) * 2 - 1 # [-1,1]
60 | else:
61 | raise NotImplementedError(init)
62 | p *= scale # multiply std for normal or change range for uniform
63 | else:
64 | print('Default PyTorch init is used for layer %s, std=%.3f' % (str(p.shape), p.std()))
65 | self.proj.weight.data = p.view_as(self.proj.weight.data)
66 | p = self.proj.weight.data.view(1, n_in)
67 | else:
68 | # multi-layer projection
69 | filters = list(map(int, self.pool_arch[2:]))
70 | self.proj = []
71 | for layer in range(len(filters)):
72 | self.proj.append(nn.Linear(in_features=n_in if layer == 0 else filters[layer - 1],
73 | out_features=filters[layer]))
74 | if layer == 0:
75 | p = self.proj[0].weight.data
76 | if scale is not None:
77 | if init == 'normal':
78 | p = torch.randn(filters[layer], n_in)
79 | elif init == 'uniform':
80 | p = torch.rand(filters[layer], n_in) * 2 - 1 # [-1,1]
81 | else:
82 | raise NotImplementedError(init)
83 | p *= scale # multiply std for normal or change range for uniform
84 | else:
85 | print('Default PyTorch init is used for layer %s, std=%.3f' % (str(p.shape), p.std()))
86 | self.proj[0].weight.data = p.view_as(self.proj[0].weight.data)
87 | p = self.proj[0].weight.data.view(-1, n_in)
88 | self.proj.append(nn.ReLU(True))
89 |
90 | self.proj.append(nn.Linear(filters[-1], 1))
91 | self.proj = nn.Sequential(*self.proj)
92 |
93 | # Compute cosine similarity with the optimal vector and print values
94 | # ignore the last dimension, because it does not receive gradients during training
95 | # n_in=4 for colors-3 because some of our test subsets have 4 dimensional features
96 | cos_sim = self.cosine_sim(p[:, :-1], p_optimal[:, :-1])
97 | if p.shape[0] == 1:
98 | print('p values', p[0].data.cpu().numpy())
99 | print('cos_sim', cos_sim.item())
100 | else:
101 | for fn in [torch.max, torch.min, torch.mean, torch.std]:
102 | print('cos_sim', fn(cos_sim).item())
103 | elif self.pool_arch[0] == 'gnn':
104 | self.proj = attn_gnn(n_in)
105 | else:
106 | raise ValueError('invalid pooling layer architecture', self.pool_arch)
107 |
108 | elif self.pool_type[1] == 'gt':
109 | if not self.is_topk and self.threshold > 0:
110 | print('For ground truth attention threshold should be 0, but it is %f' % self.threshold)
111 | else:
112 | raise NotImplementedError(self.pool_type[1])
113 |
114 | def __repr__(self):
115 | return 'AttentionPooling(pool_type={}, pool_arch={}, topk={}, kl_weight={}, init={}, scale={}, proj={})'.format(
116 | self.pool_type,
117 | self.pool_arch,
118 | self.is_topk,
119 | self.kl_weight,
120 | self.init,
121 | self.scale,
122 | self.proj)
123 |
124 | def cosine_sim(self, a, b):
125 | return torch.mm(a, b.t()) / (torch.norm(a, dim=1, keepdim=True) * torch.norm(b, dim=1, keepdim=True))
126 |
127 | def mask_out(self, x, mask):
128 | return x.view_as(mask) * mask
129 |
130 | def drop_nodes_edges(self, x, A, mask):
131 | N_nodes = torch.sum(mask, dim=1).long() # B
132 | N_nodes_max = N_nodes.max()
133 | idx = None
134 | if N_nodes_max > 0:
135 | B, N, C = x.shape
136 | # Drop nodes
137 | mask, idx = torch.topk(mask.byte(), N_nodes_max, dim=1, largest=True, sorted=False)
138 | x = torch.gather(x, dim=1, index=idx.unsqueeze(2).expand(-1, -1, C))
139 | # Drop edges
140 | A = torch.gather(A, dim=1, index=idx.unsqueeze(2).expand(-1, -1, N))
141 | A = torch.gather(A, dim=2, index=idx.unsqueeze(1).expand(-1, N_nodes_max, -1))
142 |
143 | return x, A, mask, N_nodes, idx
144 |
145 | def forward(self, data):
146 |
147 | KL_loss = None
148 | x, A, mask, _, params_dict = data[:5]
149 |
150 | mask_float = mask.float()
151 | N_nodes_float = params_dict['N_nodes'].float()
152 | B, N, C = x.shape
153 | A = A.view(B, N, N)
154 | alpha_gt = None
155 | if 'node_attn' in params_dict:
156 | if not isinstance(params_dict['node_attn'], list):
157 | params_dict['node_attn'] = [params_dict['node_attn']]
158 | alpha_gt = params_dict['node_attn'][-1].view(B, N)
159 | if 'node_attn_eval' in params_dict:
160 | if not isinstance(params_dict['node_attn_eval'], list):
161 | params_dict['node_attn_eval'] = [params_dict['node_attn_eval']]
162 |
163 | if (self.pool_type[1] == 'gt' or (self.pool_type[1] == 'sup' and self.training)) and alpha_gt is None:
164 | raise ValueError('ground truth node attention values node_attn required for %s' % self.pool_type)
165 |
166 | if self.pool_type[1] in ['unsup', 'sup']:
167 | attn_input = data[-1] if self.pool_arch[1] == 'prev' else x.clone()
168 | if self.pool_arch[0] == 'fc':
169 | alpha_pre = self.proj(attn_input).view(B, N)
170 | else:
171 | # to support python2
172 | input = [attn_input]
173 | input.extend(data[1:])
174 | alpha_pre = self.proj(input)[0].view(B, N)
175 | # softmax with masking out dummy nodes
176 | alpha_pre = torch.clamp(alpha_pre, -self.clamp_value, self.clamp_value)
177 | alpha = normalize_batch(self.mask_out(torch.exp(alpha_pre), mask_float).view(B, N))
178 | if self.pool_type[1] == 'sup' and self.training:
179 | if self.torch.find('1.') == 0:
180 | KL_loss_per_node = self.mask_out(F.kl_div(torch.log(alpha + 1e-14), alpha_gt, reduction='none'),
181 | mask_float.view(B,N))
182 | else:
183 | KL_loss_per_node = self.mask_out(F.kl_div(torch.log(alpha + 1e-14), alpha_gt, reduce=False),
184 | mask_float.view(B, N))
185 | KL_loss = self.kl_weight * torch.mean(KL_loss_per_node.sum(dim=1) / (N_nodes_float + 1e-7)) # mean over nodes, then mean over batches
186 | else:
187 | alpha = alpha_gt
188 |
189 | x = x * alpha.view(B, N, 1)
190 | if self.large_graph:
191 | # For large graphs during training, all alpha values can be very small hindering training
192 | x = x * N_nodes_float.view(B, 1, 1)
193 | if self.is_topk:
194 | N_remove = torch.round(N_nodes_float * (1 - self.topk_ratio)).long() # number of nodes to be removed for each graph
195 | idx = torch.sort(alpha, dim=1, descending=False)[1] # indices of alpha in ascending order
196 | mask = mask.clone().view(B, N)
197 | for b in range(B):
198 | idx_b = idx[b, mask[b, idx[b]]] # take indices of non-dummy nodes for current data example
199 | mask[b, idx_b[:N_remove[b]]] = 0
200 | else:
201 | mask = (mask & (alpha.view_as(mask) > self.threshold)).view(B, N)
202 |
203 | if self.drop_nodes:
204 | x, A, mask, N_nodes_pooled, idx = self.drop_nodes_edges(x, A, mask)
205 | if idx is not None and 'node_attn' in params_dict:
206 | # update ground truth (or weakly labeled) attention for a reduced graph
207 | params_dict['node_attn'].append(normalize_batch(self.mask_out(torch.gather(alpha_gt, dim=1, index=idx), mask.float())))
208 | if idx is not None and 'node_attn_eval' in params_dict:
209 | # update ground truth (or weakly labeled) attention for a reduced graph
210 | params_dict['node_attn_eval'].append(normalize_batch(self.mask_out(torch.gather(params_dict['node_attn_eval'][-1], dim=1, index=idx), mask.float())))
211 | else:
212 | N_nodes_pooled = torch.sum(mask, dim=1).long() # B
213 | if 'node_attn' in params_dict:
214 | params_dict['node_attn'].append((self.mask_out(params_dict['node_attn'][-1], mask.float())))
215 | if 'node_attn_eval' in params_dict:
216 | params_dict['node_attn_eval'].append((self.mask_out(params_dict['node_attn_eval'][-1], mask.float())))
217 |
218 | params_dict['N_nodes'] = N_nodes_pooled
219 |
220 | mask_matrix = mask.unsqueeze(2) & mask.unsqueeze(1)
221 | A = A * mask_matrix.float() # or A[~mask_matrix] = 0
222 |
223 | # Add additional losses regularizing the model
224 | if KL_loss is not None:
225 | if 'reg' not in params_dict:
226 | params_dict['reg'] = []
227 | params_dict['reg'].append(KL_loss)
228 |
229 | # Keep attention coefficients for evaluation
230 | for key, value in zip(['alpha', 'mask'], [alpha, mask]):
231 | if key not in params_dict:
232 | params_dict[key] = []
233 | params_dict[key].append(value.detach())
234 |
235 | if self.debug and alpha_gt is not None:
236 | idx_correct_pool = (alpha_gt > 0)
237 | idx_correct_drop = (alpha_gt == 0)
238 | alpha_correct_pool = alpha[idx_correct_pool].sum() / N_nodes_float.sum()
239 | alpha_correct_drop = alpha[idx_correct_drop].sum() / N_nodes_float.sum()
240 | ratio_avg = (N_nodes_pooled.float() / N_nodes_float).mean()
241 |
242 | for key, values in zip(['alpha_correct_pool_debug', 'alpha_correct_drop_debug', 'ratio_avg_debug'],
243 | [alpha_correct_pool, alpha_correct_drop, ratio_avg]):
244 | if key not in params_dict:
245 | params_dict[key] = []
246 | params_dict[key].append(values.detach())
247 |
248 | output = [x, A, mask]
249 | output.extend(data[3:])
250 | return output
251 |
--------------------------------------------------------------------------------
/chebygin.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.sparse
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | from torch.nn.parameter import Parameter
7 | from attention_pooling import *
8 | from utils import *
9 |
10 |
11 | class ChebyGINLayer(nn.Module):
12 | '''
13 | General Graph Neural Network layer that depending on arguments can be:
14 | 1. Graph Convolution Layer (T. Kipf and M. Welling, ICLR 2017)
15 | 2. Chebyshev Graph Convolution Layer (M. Defferrard et al., NeurIPS 2017)
16 | 3. GIN Layer (K. Xu et al., ICLR 2019)
17 | 4. ChebyGIN Layer (B. Knyazev et al., ICLR 2019 Workshop on Representation Learning on Graphs and Manifolds)
18 | The first three types (1-3) of layers are particular cases of the fourth (4) case.
19 | '''
20 | def __init__(self,
21 | in_features,
22 | out_features,
23 | K,
24 | n_hidden=0,
25 | aggregation='mean',
26 | activation=nn.ReLU(True),
27 | n_relations=1):
28 | super(ChebyGINLayer, self).__init__()
29 | self.in_features = in_features
30 | self.out_features = out_features
31 | self.n_relations = n_relations
32 | assert K > 0, 'order is assumed to be > 0'
33 | self.K = K
34 | assert n_hidden >= 0, ('invalid n_hidden value', n_hidden)
35 | self.n_hidden = n_hidden
36 | assert aggregation in ['mean', 'sum'], ('invalid aggregation', aggregation)
37 | self.aggregation = aggregation
38 | self.activation = activation
39 | n_in = self.in_features * self.K * n_relations
40 | if self.n_hidden == 0:
41 | fc = [nn.Linear(n_in, self.out_features)]
42 | else:
43 | fc = [nn.Linear(n_in, n_hidden),
44 | nn.ReLU(True),
45 | nn.Linear(n_hidden, self.out_features)]
46 | if activation is not None:
47 | fc.append(activation)
48 | self.fc = nn.Sequential(*fc)
49 | print('ChebyGINLayer', list(self.fc.children())[0].weight.shape,
50 | torch.norm(list(self.fc.children())[0].weight, dim=1)[:10])
51 |
52 | def __repr__(self):
53 | return 'ChebyGINLayer(in_features={}, out_features={}, K={}, n_hidden={}, aggregation={})\nfc={}'.format(
54 | self.in_features,
55 | self.out_features,
56 | self.K,
57 | self.n_hidden,
58 | self.aggregation,
59 | str(self.fc))
60 |
61 | def chebyshev_basis(self, L, X, K):
62 | '''
63 | Return T_k X where T_k are the Chebyshev polynomials of order up to K.
64 | :param L: graph Laplacian, batch (B), nodes (N), nodes (N)
65 | :param X: input of size batch (B), nodes (N), features (F)
66 | :param K: Chebyshev polynomial order, i.e. filter size (number of hopes)
67 | :return: Tensor of size (B,N,K,F) as a result of multiplying T_k(L) by X for each order
68 | '''
69 | if K > 1:
70 | Xt = [X]
71 | Xt.append(torch.bmm(L, X)) # B,N,F
72 | for k in range(2, K):
73 | Xt.append(2 * torch.bmm(L, Xt[k - 1]) - Xt[k - 2]) # B,N,F
74 | Xt = torch.stack(Xt, 2) # B,N,K,F
75 | return Xt
76 | else:
77 | # GCN
78 | assert K == 1, K
79 | return torch.bmm(L, X).unsqueeze(2) # B,N,1,F
80 |
81 | def laplacian_batch(self, A, add_identity=False):
82 | '''
83 | Computes normalized Laplacian transformed so that its eigenvalues are in range [-1, 1].
84 | Note that sum of all eigenvalues = trace(L) = 0.
85 | :param A: Tensor of size (B,N,N) containing batch (B) of adjacency matrices of shape N,N
86 | :return: Normalized Laplacian of size (B,N,N)
87 | '''
88 | B, N = A.shape[:2]
89 | if add_identity:
90 | A = A + torch.eye(N, device=A.get_device() if A.is_cuda else 'cpu').unsqueeze(0)
91 | D = torch.sum(A, 1) # nodes degree (B,N)
92 | D_hat = (D + 1e-5) ** (-0.5)
93 | L = D_hat.view(B, N, 1) * A * D_hat.view(B, 1, N) # B,N,N
94 | if not add_identity:
95 | L = -L # for ChebyNet to make a valid Chebyshev basis
96 | return D, L
97 |
98 | def forward(self, data):
99 | x, A, mask = data[:3]
100 | B, N, F = x.shape
101 | assert N == A.shape[1] == A.shape[2], ('invalid shape', N, x.shape, A.shape)
102 |
103 | if len(A.shape) == 3:
104 | A = A.unsqueeze(3)
105 |
106 | y_out = []
107 | for rel in range(A.shape[3]):
108 | D, L = self.laplacian_batch(A[:, :, :, rel], add_identity=self.K == 1) # for the first layer this can be done at the preprocessing stage
109 | y = self.chebyshev_basis(L, x, self.K) # B,N,K,F
110 |
111 | if self.aggregation == 'sum':
112 | # Sum features of neighbors
113 | if self.K == 1:
114 | # GIN
115 | y = y * D.view(B, N, 1, 1)
116 | else:
117 | # ChebyGIN
118 | D_GIN = torch.ones(B, N, self.K, device=x.get_device() if x.is_cuda else 'cpu')
119 | D_GIN[:, :, 1:] = D.view(B, N, 1).expand(-1, -1, self.K - 1) # keep self-loop features the same
120 | y = y * D_GIN.view(B, N, self.K, 1) # apply summation for other scales
121 |
122 | y_out.append(y)
123 |
124 | y = torch.cat(y_out, dim=2)
125 | y = self.fc(y.view(B, N, -1)) # B,N,F
126 |
127 | if len(mask.shape) == 2:
128 | mask = mask.unsqueeze(2)
129 |
130 | y = y * mask.float()
131 | output = [y, A, mask]
132 | output.extend(data[3:] + [x]) # for python2
133 |
134 | return output
135 |
136 |
137 | class GraphReadout(nn.Module):
138 | '''
139 | Global pooling layer applied after the last graph layer.
140 | '''
141 | def __init__(self,
142 | pool_type):
143 | super(GraphReadout, self).__init__()
144 | self.pool_type = pool_type
145 | dim = 1 # pooling over nodes
146 | if pool_type == 'max':
147 | self.readout_layer = lambda x, mask: torch.max(x, dim=dim)[0]
148 | elif pool_type in ['avg', 'mean']:
149 | # sum over all nodes, then divide by the number of valid nodes in each sample of the batch
150 | self.readout_layer = lambda x, mask: torch.sum(x, dim=dim) / torch.sum(mask, dim=dim).float()
151 | elif pool_type in ['sum']:
152 | self.readout_layer = lambda x, mask: torch.sum(x, dim=dim)
153 | else:
154 | raise NotImplementedError(pool_type)
155 |
156 | def __repr__(self):
157 | return 'GraphReadout({})'.format(self.pool_type)
158 |
159 | def forward(self, data):
160 | x, A, mask = data[:3]
161 | B, N = x.shape[:2]
162 | x = self.readout_layer(x, mask.view(B, N, 1))
163 | output = [x]
164 | output.extend(data[1:]) # [x, *data[1:]] doesn't work in Python2
165 | return output
166 |
167 |
168 | class ChebyGIN(nn.Module):
169 | '''
170 | Graph Neural Network class.
171 | '''
172 | def __init__(self,
173 | in_features,
174 | out_features,
175 | filters,
176 | K=1,
177 | n_hidden=0,
178 | aggregation='mean',
179 | dropout=0,
180 | readout='max',
181 | pool=None, # Example: 'attn_gt_threshold_0_skip_skip'.split('_'),
182 | pool_arch='fc_prev'.split('_'),
183 | large_graph=False, # > ~500 graphs
184 | kl_weight=None,
185 | graph_layer_fn=None,
186 | init='normal',
187 | scale=None,
188 | debug=False):
189 | super(ChebyGIN, self).__init__()
190 | self.out_features = out_features
191 | assert len(filters) > 0, 'filters must be an iterable object with at least one element'
192 | assert K > 0, 'filter scale must be a positive integer'
193 | self.pool = pool
194 | self.pool_arch = pool_arch
195 | self.debug = debug
196 | n_prev = None
197 |
198 | attn_gnn = None
199 | if graph_layer_fn is None:
200 | graph_layer_fn = lambda n_in, n_out, K_, n_hidden_, activation: ChebyGINLayer(in_features=n_in,
201 | out_features=n_out,
202 | K=K_,
203 | n_hidden=n_hidden_,
204 | aggregation=aggregation,
205 | activation=activation)
206 | if self.pool_arch is not None and self.pool_arch[0] == 'gnn':
207 | attn_gnn = lambda n_in: ChebyGIN(in_features=n_in,
208 | out_features=0,
209 | filters=[32, 32, 1],
210 | K=np.min((K, 2)),
211 | n_hidden=0,
212 | graph_layer_fn=graph_layer_fn)
213 |
214 | graph_layers = []
215 |
216 | for layer, f in enumerate(filters + [None]):
217 |
218 | n_in = in_features if layer == 0 else filters[layer - 1]
219 | # Pooling layers
220 | # It's a non-standard way to put pooling before convolution, but it's important for our work
221 | if self.pool is not None and len(self.pool) > len(filters) + layer and self.pool[layer + 3] != 'skip':
222 | graph_layers.append(AttentionPooling(in_features=n_in, in_features_prev=n_prev,
223 | pool_type=self.pool[:3] + [self.pool[layer + 3]],
224 | pool_arch=self.pool_arch,
225 | large_graph=large_graph,
226 | kl_weight=kl_weight,
227 | attn_gnn=attn_gnn,
228 | init=init,
229 | scale=scale,
230 | debug=debug))
231 |
232 | if f is not None:
233 | # Graph "convolution" layers
234 | # no ReLU if the last layer and no fc layer after that
235 | graph_layers.append(graph_layer_fn(n_in, f, K, n_hidden,
236 | None if self.out_features == 0 and layer == len(filters) - 1 else nn.ReLU(True)))
237 | n_prev = n_in
238 |
239 | if self.out_features > 0:
240 | # Global pooling over nodes
241 | graph_layers.append(GraphReadout(readout))
242 | self.graph_layers = nn.Sequential(*graph_layers)
243 |
244 | if self.out_features > 0:
245 | # Fully connected (classification/regression) layers
246 | self.fc = nn.Sequential(*(([nn.Dropout(p=dropout)] if dropout > 0 else []) + [nn.Linear(filters[-1], out_features)]))
247 |
248 | def forward(self, data):
249 | data = self.graph_layers(data)
250 | if self.out_features > 0:
251 | y = self.fc(data[0]) # B,out_features
252 | else:
253 | y = data[0] # B,N,out_features
254 | return y, data[4]
255 |
--------------------------------------------------------------------------------
/checkpoints/checkpoint_TU_657452_epoch50_seed0000111.pth.tar:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bknyaz/graph_attention_pool/00a6771cf864048f94e8bec19b5c8225a3ed2b0e/checkpoints/checkpoint_TU_657452_epoch50_seed0000111.pth.tar
--------------------------------------------------------------------------------
/checkpoints/checkpoint_colors-3_223919_epoch300_seed0000111.pth.tar:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bknyaz/graph_attention_pool/00a6771cf864048f94e8bec19b5c8225a3ed2b0e/checkpoints/checkpoint_colors-3_223919_epoch300_seed0000111.pth.tar
--------------------------------------------------------------------------------
/checkpoints/checkpoint_colors-3_312570_epoch300_seed0000111.pth.tar:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bknyaz/graph_attention_pool/00a6771cf864048f94e8bec19b5c8225a3ed2b0e/checkpoints/checkpoint_colors-3_312570_epoch300_seed0000111.pth.tar
--------------------------------------------------------------------------------
/checkpoints/checkpoint_colors-3_332172_epoch300_seed0000111.pth.tar:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bknyaz/graph_attention_pool/00a6771cf864048f94e8bec19b5c8225a3ed2b0e/checkpoints/checkpoint_colors-3_332172_epoch300_seed0000111.pth.tar
--------------------------------------------------------------------------------
/checkpoints/checkpoint_colors-3_828931_epoch100_seed0000111.pth.tar:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bknyaz/graph_attention_pool/00a6771cf864048f94e8bec19b5c8225a3ed2b0e/checkpoints/checkpoint_colors-3_828931_epoch100_seed0000111.pth.tar
--------------------------------------------------------------------------------
/checkpoints/checkpoint_mnist-75sp_065802_epoch30_seed0000111.pth.tar:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bknyaz/graph_attention_pool/00a6771cf864048f94e8bec19b5c8225a3ed2b0e/checkpoints/checkpoint_mnist-75sp_065802_epoch30_seed0000111.pth.tar
--------------------------------------------------------------------------------
/checkpoints/checkpoint_mnist-75sp_139255_epoch30_seed0000111.pth.tar:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bknyaz/graph_attention_pool/00a6771cf864048f94e8bec19b5c8225a3ed2b0e/checkpoints/checkpoint_mnist-75sp_139255_epoch30_seed0000111.pth.tar
--------------------------------------------------------------------------------
/checkpoints/checkpoint_mnist-75sp_330394_epoch30_seed0000111.pth.tar:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bknyaz/graph_attention_pool/00a6771cf864048f94e8bec19b5c8225a3ed2b0e/checkpoints/checkpoint_mnist-75sp_330394_epoch30_seed0000111.pth.tar
--------------------------------------------------------------------------------
/checkpoints/checkpoint_mnist-75sp_820601_epoch30_seed0000111.pth.tar:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bknyaz/graph_attention_pool/00a6771cf864048f94e8bec19b5c8225a3ed2b0e/checkpoints/checkpoint_mnist-75sp_820601_epoch30_seed0000111.pth.tar
--------------------------------------------------------------------------------
/checkpoints/checkpoint_triangles_051609_epoch100_seed0000111.pth.tar:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bknyaz/graph_attention_pool/00a6771cf864048f94e8bec19b5c8225a3ed2b0e/checkpoints/checkpoint_triangles_051609_epoch100_seed0000111.pth.tar
--------------------------------------------------------------------------------
/checkpoints/checkpoint_triangles_230187_epoch100_seed0000111.pth.tar:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bknyaz/graph_attention_pool/00a6771cf864048f94e8bec19b5c8225a3ed2b0e/checkpoints/checkpoint_triangles_230187_epoch100_seed0000111.pth.tar
--------------------------------------------------------------------------------
/checkpoints/checkpoint_triangles_586710_epoch100_seed0000111.pth.tar:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bknyaz/graph_attention_pool/00a6771cf864048f94e8bec19b5c8225a3ed2b0e/checkpoints/checkpoint_triangles_586710_epoch100_seed0000111.pth.tar
--------------------------------------------------------------------------------
/checkpoints/checkpoint_triangles_658037_epoch100_seed0000111.pth.tar:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bknyaz/graph_attention_pool/00a6771cf864048f94e8bec19b5c8225a3ed2b0e/checkpoints/checkpoint_triangles_658037_epoch100_seed0000111.pth.tar
--------------------------------------------------------------------------------
/checkpoints/colors-3_alpha_WS_test_seed111_orig.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bknyaz/graph_attention_pool/00a6771cf864048f94e8bec19b5c8225a3ed2b0e/checkpoints/colors-3_alpha_WS_test_seed111_orig.pkl
--------------------------------------------------------------------------------
/checkpoints/colors-3_alpha_WS_train_seed111_orig.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bknyaz/graph_attention_pool/00a6771cf864048f94e8bec19b5c8225a3ed2b0e/checkpoints/colors-3_alpha_WS_train_seed111_orig.pkl
--------------------------------------------------------------------------------
/checkpoints/mnist-75sp_alpha_WS_test_seed111_noisy-c.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bknyaz/graph_attention_pool/00a6771cf864048f94e8bec19b5c8225a3ed2b0e/checkpoints/mnist-75sp_alpha_WS_test_seed111_noisy-c.pkl
--------------------------------------------------------------------------------
/checkpoints/mnist-75sp_alpha_WS_test_seed111_noisy.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bknyaz/graph_attention_pool/00a6771cf864048f94e8bec19b5c8225a3ed2b0e/checkpoints/mnist-75sp_alpha_WS_test_seed111_noisy.pkl
--------------------------------------------------------------------------------
/checkpoints/mnist-75sp_alpha_WS_test_seed111_orig.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bknyaz/graph_attention_pool/00a6771cf864048f94e8bec19b5c8225a3ed2b0e/checkpoints/mnist-75sp_alpha_WS_test_seed111_orig.pkl
--------------------------------------------------------------------------------
/checkpoints/mnist-75sp_alpha_WS_train_seed111_orig.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bknyaz/graph_attention_pool/00a6771cf864048f94e8bec19b5c8225a3ed2b0e/checkpoints/mnist-75sp_alpha_WS_train_seed111_orig.pkl
--------------------------------------------------------------------------------
/checkpoints/triangles_alpha_WS_test_seed111_orig.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bknyaz/graph_attention_pool/00a6771cf864048f94e8bec19b5c8225a3ed2b0e/checkpoints/triangles_alpha_WS_test_seed111_orig.pkl
--------------------------------------------------------------------------------
/checkpoints/triangles_alpha_WS_train_seed111_orig.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bknyaz/graph_attention_pool/00a6771cf864048f94e8bec19b5c8225a3ed2b0e/checkpoints/triangles_alpha_WS_train_seed111_orig.pkl
--------------------------------------------------------------------------------
/data/COLORS-3.zip:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bknyaz/graph_attention_pool/00a6771cf864048f94e8bec19b5c8225a3ed2b0e/data/COLORS-3.zip
--------------------------------------------------------------------------------
/data/TRIANGLES.zip:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bknyaz/graph_attention_pool/00a6771cf864048f94e8bec19b5c8225a3ed2b0e/data/TRIANGLES.zip
--------------------------------------------------------------------------------
/data/datasets.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bknyaz/graph_attention_pool/00a6771cf864048f94e8bec19b5c8225a3ed2b0e/data/datasets.png
--------------------------------------------------------------------------------
/data/mnist_animation.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bknyaz/graph_attention_pool/00a6771cf864048f94e8bec19b5c8225a3ed2b0e/data/mnist_animation.gif
--------------------------------------------------------------------------------
/data/triangles_animation.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bknyaz/graph_attention_pool/00a6771cf864048f94e8bec19b5c8225a3ed2b0e/data/triangles_animation.gif
--------------------------------------------------------------------------------
/extract_superpixels.py:
--------------------------------------------------------------------------------
1 | # Compute superpixels for MNIST/CIFAR-10 using SLIC algorithm
2 | # https://scikit-image.org/docs/dev/api/skimage.segmentation.html#skimage.segmentation.slic
3 |
4 | import numpy as np
5 | import random
6 | import os
7 | import scipy
8 | import pickle
9 | from skimage.segmentation import slic
10 | from torchvision import datasets
11 | import multiprocessing as mp
12 | import scipy.ndimage
13 | import scipy.spatial
14 | import argparse
15 | import datetime
16 |
17 |
18 | def parse_args():
19 | parser = argparse.ArgumentParser(description='Extract SLIC superpixels from images')
20 | parser.add_argument('-D', '--dataset', type=str, default='mnist', choices=['mnist', 'cifar10'])
21 | parser.add_argument('-d', '--data_dir', type=str, default='./data', help='path to the dataset')
22 | parser.add_argument('-o', '--out_dir', type=str, default='./data', help='path where to save superpixels')
23 | parser.add_argument('-s', '--split', type=str, default='train', choices=['train', 'val', 'test'])
24 | parser.add_argument('-t', '--threads', type=int, default=0, help='number of parallel threads')
25 | parser.add_argument('-n', '--n_sp', type=int, default=75, help='max number of superpixels per image')
26 | parser.add_argument('-c', '--compactness', type=int, default=0.25, help='compactness of the SLIC algorithm '
27 | '(Balances color proximity and space proximity): '
28 | '0.25 is a good value for MNIST '
29 | 'and 10 for color images like CIFAR-10')
30 | parser.add_argument('--seed', type=int, default=111, help='seed for shuffling nodes')
31 | args = parser.parse_args()
32 |
33 | for arg in vars(args):
34 | print(arg, getattr(args, arg))
35 |
36 | return args
37 |
38 |
39 | def process_image(params):
40 |
41 | img, index, n_images, args, to_print, shuffle = params
42 |
43 | assert img.dtype == np.uint8, img.dtype
44 | img = (img / 255.).astype(np.float32)
45 |
46 | n_sp_extracted = args.n_sp + 1 # number of actually extracted superpixels (can be different from requested in SLIC)
47 | n_sp_query = args.n_sp + (20 if args.dataset == 'mnist' else 50) # number of superpixels we ask to extract (larger to extract more superpixels - closer to the desired n_sp)
48 | while n_sp_extracted > args.n_sp:
49 | superpixels = slic(img, n_segments=n_sp_query, compactness=args.compactness, multichannel=len(img.shape) > 2)
50 | sp_indices = np.unique(superpixels)
51 | n_sp_extracted = len(sp_indices)
52 | n_sp_query -= 1 # reducing the number of superpixels until we get <= n superpixels
53 |
54 | assert n_sp_extracted <= args.n_sp and n_sp_extracted > 0, (args.split, index, n_sp_extracted, args.n_sp)
55 | assert n_sp_extracted == np.max(superpixels) + 1, ('superpixel indices', np.unique(superpixels)) # make sure superpixel indices are numbers from 0 to n-1
56 |
57 | if shuffle:
58 | ind = np.random.permutation(n_sp_extracted)
59 | else:
60 | ind = np.arange(n_sp_extracted)
61 |
62 | sp_order = sp_indices[ind].astype(np.int32)
63 | if len(img.shape) == 2:
64 | img = img[:, :, None]
65 |
66 | n_ch = 1 if img.shape[2] == 1 else 3
67 |
68 | sp_intensity, sp_coord = [], []
69 | for seg in sp_order:
70 | mask = (superpixels == seg).squeeze()
71 | avg_value = np.zeros(n_ch)
72 | for c in range(n_ch):
73 | avg_value[c] = np.mean(img[:, :, c][mask])
74 | cntr = np.array(scipy.ndimage.measurements.center_of_mass(mask)) # row, col
75 | sp_intensity.append(avg_value)
76 | sp_coord.append(cntr)
77 | sp_intensity = np.array(sp_intensity, np.float32)
78 | sp_coord = np.array(sp_coord, np.float32)
79 | if to_print:
80 | print('image={}/{}, shape={}, min={:.2f}, max={:.2f}, n_sp={}'.format(index + 1, n_images, img.shape,
81 | img.min(), img.max(), sp_intensity.shape[0]))
82 |
83 | return sp_intensity, sp_coord, sp_order, superpixels
84 |
85 |
86 | if __name__ == '__main__':
87 |
88 | dt = datetime.datetime.now()
89 | print('start time:', dt)
90 |
91 | args = parse_args()
92 |
93 | if not os.path.isdir(args.out_dir):
94 | os.mkdir(args.out_dir)
95 |
96 | random.seed(args.seed)
97 | np.random.seed(args.seed) # to make node random permutation reproducible (not tested)
98 |
99 | # Read image data using torchvision
100 | is_train = args.split.lower() == 'train'
101 | if args.dataset == 'mnist':
102 | data = datasets.MNIST(args.data_dir, train=is_train, download=True)
103 | assert args.compactness < 10, ('high compactness can result in bad superpixels on MNIST')
104 | assert args.n_sp > 1 and args.n_sp < 28*28, (
105 | 'the number of superpixels cannot exceed the total number of pixels or be too small')
106 | elif args.dataset == 'cifar10':
107 | data = datasets.CIFAR10(args.data_dir, train=is_train, download=True)
108 | assert args.compactness > 1, ('low compactness can result in bad superpixels on CIFAR-10')
109 | assert args.n_sp > 1 and args.n_sp < 32*32, (
110 | 'the number of superpixels cannot exceed the total number of pixels or be too small')
111 | else:
112 | raise NotImplementedError('unsupported dataset: ' + args.dataset)
113 |
114 | images = data.train_data if is_train else data.test_data
115 | labels = data.train_labels if is_train else data.test_labels
116 | if not isinstance(images, np.ndarray):
117 | images = images.numpy()
118 | if isinstance(labels, list):
119 | labels = np.array(labels)
120 | if not isinstance(labels, np.ndarray):
121 | labels = labels.numpy()
122 |
123 | n_images = len(labels)
124 |
125 | if args.threads <= 0:
126 | sp_data = []
127 | for i in range(n_images):
128 | sp_data.append(process_image((images[i], i, n_images, args, True, True)))
129 | else:
130 | with mp.Pool(processes=args.threads) as pool:
131 | sp_data = pool.map(process_image, [(images[i], i, n_images, args, True, True) for i in range(n_images)])
132 |
133 | superpixels = [sp_data[i][3] for i in range(n_images)]
134 | sp_data = [sp_data[i][:3] for i in range(n_images)]
135 | with open('%s/%s_%dsp_%s.pkl' % (args.out_dir, args.dataset, args.n_sp, args.split), 'wb') as f:
136 | pickle.dump((labels.astype(np.int32), sp_data), f, protocol=2)
137 | with open('%s/%s_%dsp_%s_superpixels.pkl' % (args.out_dir, args.dataset, args.n_sp, args.split), 'wb') as f:
138 | pickle.dump(superpixels, f, protocol=2)
139 |
140 | print('done in {}'.format(datetime.datetime.now() - dt))
141 |
--------------------------------------------------------------------------------
/generate_data.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import pickle
4 | import argparse
5 | import networkx as nx
6 | import datetime
7 | import random
8 | import multiprocessing as mp
9 | from utils import *
10 |
11 |
12 | def parse_args():
13 | parser = argparse.ArgumentParser(description='Generate synthetic graph datasets')
14 | parser.add_argument('-D', '--dataset', type=str, default='colors', choices=['colors', 'triangles'])
15 | parser.add_argument('-o', '--out_dir', type=str, default='./data', help='path where to save superpixels')
16 | parser.add_argument('--N_train', type=int, default=500, help='number of training graphs (500 for colors and 30000 for triangles)')
17 | parser.add_argument('--N_val', type=int, default=2500, help='number of graphs in the validation set (2500 for colors and 5000 for triangles)')
18 | parser.add_argument('--N_test', type=int, default=2500, help='number of graphs in each test subset (2500 for colors and 5000 for triangles)')
19 | parser.add_argument('--label_min', type=int, default=0,
20 | help='smallest label value for a graph (i.e. smallest number of green nodes); 1 for triangles')
21 | parser.add_argument('--label_max', type=int, default=10,
22 | help='largest label value for a graph (i.e. largest number of green nodes)')
23 | parser.add_argument('--N_min', type=int, default=4, help='minimum number of nodes')
24 | parser.add_argument('--N_max', type=int, default=200, help='maximum number of nodes (default: 200 for colors and 100 for triangles')
25 | parser.add_argument('--N_max_train', type=int, default=25, help='maximum number of nodes in the training set')
26 | parser.add_argument('--dim', type=int, default=3, help='node feature dimensionality')
27 | parser.add_argument('--green_ch_index', type=int, default=1,
28 | help='index of non-zero value in a one-hot node feature vector, '
29 | 'i.e. [0, 1, 0] in case green_channel_index=1 and dim=3')
30 | parser.add_argument('--seed', type=int, default=111, help='seed for shuffling nodes')
31 | parser.add_argument('--threads', type=int, default=0, help='only for triangles')
32 | args = parser.parse_args()
33 |
34 | for arg in vars(args):
35 | print(arg, getattr(args, arg))
36 |
37 | return args
38 |
39 |
40 | def check_graph_duplicates(Adj_matrices, node_features=None):
41 | n_graphs = len(Adj_matrices)
42 | print('check for duplicates for %d graphs' % n_graphs)
43 | n_duplicates = 0
44 | for i in range(n_graphs):
45 | if node_features is not None:
46 | assert Adj_matrices[i].shape[0] == node_features[i].shape[0], (
47 | 'invalid data', i, Adj_matrices[i].shape[0], node_features[i].shape[0])
48 | for j in range(i + 1, n_graphs):
49 | if Adj_matrices[i].shape[0] == Adj_matrices[j].shape[0]:
50 | if np.allclose(Adj_matrices[i], Adj_matrices[j]): # adjacency matrices are the same
51 | # for Colors graphs are not considered duplicates if they have the same adjacency matrix,
52 | # but different node features
53 | if node_features is None or np.allclose(node_features[i], node_features[j]):
54 | n_duplicates += 1
55 | print('duplicates %d/%d' % (n_duplicates, n_graphs * (n_graphs - 1) / 2))
56 | if n_duplicates > 0:
57 | raise ValueError('%d duplicates found in the dataset' % n_duplicates)
58 |
59 | print('no duplicated graphs')
60 |
61 |
62 | # COLORS
63 | def get_node_features_Colors(N_nodes, N_green, dim, green_ch_index=1, new_colors=False):
64 | node_features = np.zeros((N_nodes, dim))
65 |
66 | # Generate indices for non-zero values,
67 | # so that the total number of nodes with features having value 1 in the green_ch_index equals N_green
68 | idx_not_green = rnd.randint(0, dim - 1, size=N_nodes - N_green) # for dim=3 generate values 0,1 for non-green nodes
69 | idx_non_zero = np.concatenate((idx_not_green, np.zeros(N_green, np.int) + dim - 1)) # make green_ch_index=2 temporary
70 | idx_non_zero_cp = idx_non_zero.copy()
71 | idx_non_zero[idx_non_zero_cp == dim - 1] = green_ch_index # set idx_non_zero=1 for green nodes
72 | idx_non_zero[idx_non_zero_cp == green_ch_index] = dim - 1 # set idx_non_zero=2 for those nodes that were green temporary
73 | rnd.shuffle(idx_non_zero) # shuffle nodes
74 | node_features[np.arange(N_nodes), idx_non_zero] = 1
75 |
76 | if new_colors:
77 | for ind in np.where(idx_non_zero != green_ch_index)[0]: # for non-green nodes
78 | node_features[ind] = rnd.randint(0, 2, size=dim)
79 | node_features[ind, green_ch_index] = 0 # set value at green_ch_index to 0 to avoid confusion with green nodes
80 |
81 | label = np.sum((np.sum(node_features, 1) == node_features[:, green_ch_index]) & (node_features[:, green_ch_index] == 1))
82 |
83 | gt_attn = (idx_non_zero == green_ch_index).reshape(-1, 1)
84 | label2 = np.sum(gt_attn)
85 | assert N_green == label == label2, ('invalid node features', N_green, label, label2)
86 | return node_features, idx_non_zero, gt_attn
87 |
88 |
89 | def generate_graphs_Colors(N_graphs, N_min, N_max, dim, args, rnd, new_colors=False):
90 | Adj_matrices, node_features, GT_attn, graph_labels, N_edges = [], [], [], [], []
91 | n_labels = args.label_max - args.label_min + 1
92 | n_graphs_per_shape = int(np.ceil(N_graphs / (N_max - N_min + 1) / n_labels) * n_labels)
93 | for n_nodes in np.array(range(N_min, N_max + 1)):
94 | c = 0
95 | while True:
96 | labels = np.arange(args.label_min, n_labels)
97 | labels = labels[labels <= n_nodes]
98 | rnd.shuffle(labels)
99 | for lbl in labels:
100 | features, idx_non_zero, gt_attn = get_node_features_Colors(N_nodes=n_nodes,
101 | N_green=lbl,
102 | dim=dim,
103 | green_ch_index=args.green_ch_index,
104 | new_colors=new_colors)
105 | n_edges = int((rnd.rand() + 1) * n_nodes)
106 | A = nx.to_numpy_array(nx.gnm_random_graph(n_nodes, n_edges))
107 | add = True
108 | for k in range(len(Adj_matrices)):
109 | if A.shape[0] == Adj_matrices[k].shape[0] and np.allclose(A, Adj_matrices[k]):
110 | if np.allclose(node_features[k], features):
111 | add = False
112 | break
113 | if add:
114 | Adj_matrices.append(A.astype(np.bool)) # binary adjacency matrix
115 | graph_labels.append(lbl)
116 | node_features.append(features.astype(np.bool)) # binary features
117 | GT_attn.append(gt_attn) # binary GT attention
118 | N_edges.append(n_edges)
119 | c += 1
120 | if c >= n_graphs_per_shape:
121 | break
122 | if c >= n_graphs_per_shape:
123 | break
124 | graph_labels = np.array(graph_labels, np.int32)
125 | N_edges = np.array(N_edges, np.int32)
126 | print(N_graphs, len(graph_labels))
127 |
128 | return {'Adj_matrices': Adj_matrices,
129 | 'GT_attn': GT_attn, # not normalized to sum=1
130 | 'graph_labels': graph_labels,
131 | 'node_features': node_features,
132 | 'N_edges': N_edges}
133 |
134 |
135 | # TRIANGLES
136 | def get_gt_atnn_triangles(args):
137 | G, N = args
138 | node_ids = []
139 | if G is not None:
140 | for clq in nx.enumerate_all_cliques(G):
141 | if len(clq) == 3:
142 | node_ids.extend(clq)
143 | node_ids = np.array(node_ids)
144 | gt_attn = np.zeros((N, 1), np.int32)
145 | for i in np.unique(node_ids):
146 | gt_attn[i] = int(np.sum(node_ids == i))
147 | return gt_attn # unnormalized (do not sum to 1, i.e. use int32 for storage efficiency)
148 |
149 |
150 | def get_graph_triangles(args):
151 | N_nodes, rnd = args
152 | N_edges = int((rnd.rand() + 1) * N_nodes)
153 | G = nx.dense_gnm_random_graph(N_nodes, N_edges, seed=None)
154 | A = nx.to_numpy_array(G)
155 | A_cube = A.dot(A).dot(A)
156 | label = int(np.trace(A_cube) / 6.) # number of triangles
157 | return A.astype(np.bool), label, N_edges, G
158 |
159 |
160 | def generate_graphs_Triangles(N_graphs, N_min, N_max, args, rnd):
161 | N_nodes = rnd.randint(N_min, N_max + 1, size=int(N_graphs * 10))
162 | print('generating %d graphs with %d-%d nodes' % (N_graphs * 10, N_min, N_max))
163 |
164 | if args.threads > 0:
165 | with mp.Pool(processes=args.threads) as pool:
166 | data = pool.map(get_graph_triangles, [(N_nodes[i], rnd) for i in range(len(N_nodes))])
167 | else:
168 | data = [get_graph_triangles((N_nodes[i], rnd)) for i in range(len(N_nodes))]
169 | labels = np.array([data[i][1] for i in range(len(data))], np.int32)
170 | Adj_matrices, node_features, G, graph_labels, N_edges, node_degrees = [], [], [], [], [], []
171 | for lbl in range(args.label_min, args.label_max + 1):
172 | idx = np.where(labels == lbl)[0]
173 | c = 0
174 | for i in idx:
175 | add = True
176 | for k in range(len(Adj_matrices)):
177 | if data[i][0].shape[0] == Adj_matrices[k].shape[0] and labels[i] == graph_labels[k] and np.allclose(data[i][0], Adj_matrices[k]):
178 | add = False
179 | break
180 | if add:
181 | Adj_matrices.append(data[i][0])
182 | graph_labels.append(labels[i])
183 | G.append(data[i][3])
184 | N_edges.append(data[i][2])
185 | node_degrees.append(data[i][0].astype(np.int32).sum(1).max())
186 | c += 1
187 | if c >= int(N_graphs / (args.label_max - args.label_min + 1)):
188 | break
189 | print('label={}, number of graphs={}/{}, total number of generated graphs={}'.format(lbl, c, len(idx), len(Adj_matrices)))
190 |
191 | assert c == int(N_graphs / (args.label_max - args.label_min + 1)), (
192 | 'invalid data', c, int(N_graphs / (args.label_max - args.label_min + 1)))
193 |
194 | print('computing GT attention for %d graphs' % len(Adj_matrices))
195 | if args.threads > 0:
196 | with mp.Pool(processes=args.threads) as pool:
197 | GT_attn = pool.map(get_gt_atnn_triangles, [(G[i], Adj_matrices[i].shape[0]) for i in range(len(Adj_matrices))])
198 | else:
199 | GT_attn = [get_gt_atnn_triangles((G[i], Adj_matrices[i].shape[0])) for i in range(len(Adj_matrices))]
200 |
201 | graph_labels = np.array(graph_labels, np.int32)
202 | N_edges = np.array(N_edges, np.int32)
203 |
204 | return {'Adj_matrices': Adj_matrices,
205 | 'GT_attn': GT_attn, # not normalized to sum=1
206 | 'graph_labels': graph_labels,
207 | 'N_edges': N_edges,
208 | 'Max_degree': np.max(node_degrees)}
209 |
210 |
211 | if __name__ == '__main__':
212 |
213 | dt = datetime.datetime.now()
214 | print('start time:', dt)
215 |
216 | args = parse_args()
217 |
218 | if not os.path.isdir(args.out_dir):
219 | os.mkdir(args.out_dir)
220 |
221 | random.seed(args.seed) # for networkx
222 | np.random.seed(args.seed)
223 | rnd = np.random.RandomState(args.seed)
224 |
225 | def print_stats(data, split_name):
226 | print('%s: %d graphs' % (split_name, len(data['graph_labels'])))
227 | for lbl in np.unique(data['graph_labels']):
228 | print('%s: label=%d, %d graphs' % (split_name, lbl, np.sum(data['graph_labels'] == lbl)))
229 |
230 | if args.dataset.lower() == 'colors':
231 |
232 | # Generate train and test sets
233 | data_test_combined, Adj_matrices, node_features = [], [], []
234 | for N_graphs, N_nodes_min, N_nodes_max, dim, name in zip([args.N_train + args.N_val + args.N_test, args.N_test, args.N_test],
235 | [args.N_min, args.N_max_train + 1, args.N_max_train + 1],
236 | [args.N_max_train, args.N_max, args.N_max],
237 | [args.dim, args.dim, args.dim + 1],
238 | ['test orig', 'test large', 'test large-c']):
239 | data = generate_graphs_Colors(N_graphs, N_nodes_min, N_nodes_max, dim, args, rnd, new_colors=dim==args.dim + 1)
240 |
241 | if name.find('orig') >= 0:
242 | idx = rnd.permutation(len(data['graph_labels']))
243 | data_train = copy_data(data, idx[:args.N_train])
244 | print_stats(data_train, name.replace('test', 'train'))
245 | node_features += data_train['node_features']
246 | Adj_matrices += data_train['Adj_matrices']
247 | data_val = copy_data(data, idx[args.N_train: args.N_train + args.N_val])
248 | print_stats(data_val, name.replace('test', 'val'))
249 | node_features += data_val['node_features']
250 | Adj_matrices += data_val['Adj_matrices']
251 |
252 | data_test = copy_data(data, idx[args.N_train + args.N_val: args.N_train + args.N_val + args.N_test])
253 | else:
254 | data_test = copy_data(data, rnd.permutation(len(data['graph_labels']))[:args.N_test])
255 |
256 | Adj_matrices += data_test['Adj_matrices']
257 | node_features += data_test['node_features']
258 | data_test_combined.append(data_test)
259 | print_stats(data_test, name)
260 |
261 | # Check for duplicates in the combined train+val+test sets
262 | check_graph_duplicates(Adj_matrices, node_features)
263 |
264 | # Saving
265 | with open('%s/random_graphs_colors_dim%d_train.pkl' % (args.out_dir, args.dim), 'wb') as f:
266 | pickle.dump(data_train, f, protocol=2)
267 |
268 | with open('%s/random_graphs_colors_dim%d_val.pkl' % (args.out_dir, args.dim), 'wb') as f:
269 | pickle.dump(data_val, f, protocol=2)
270 |
271 | with open('%s/random_graphs_colors_dim%d_test.pkl' % (args.out_dir, args.dim), 'wb') as f:
272 | pickle.dump(concat_data(data_test_combined), f, protocol=2)
273 |
274 | elif args.dataset.lower() == 'triangles':
275 |
276 | data = generate_graphs_Triangles((args.N_train + args.N_val + args.N_test), args.N_min, args.N_max_train, args, rnd)
277 | # Create balanced splits
278 | idx_train, idx_val, idx_test = [], [], []
279 | classes = np.unique(data['graph_labels'])
280 | n_classes = len(classes)
281 | for lbl in classes:
282 | idx = np.where(data['graph_labels'] == lbl)[0]
283 | rnd.shuffle(idx)
284 | n_train = int(args.N_train / n_classes)
285 | n_val = int(args.N_val / n_classes)
286 | n_test = int(args.N_test / n_classes)
287 | idx_train.append(idx[:n_train])
288 | idx_val.append(idx[n_train: n_train + n_val])
289 | idx_test.append(idx[n_train + n_val: n_train + n_val + n_test])
290 | data_train = copy_data(data, np.concatenate(idx_train))
291 | print_stats(data_train, 'train orig')
292 | data_val = copy_data(data, np.concatenate(idx_val))
293 | print_stats(data_val, 'val orig')
294 | data_test = copy_data(data, np.concatenate(idx_test))
295 | print_stats(data_test, 'test orig')
296 |
297 | data = generate_graphs_Triangles(args.N_test, args.N_max_train + 1, args.N_max, args, rnd)
298 | data_test_large = copy_data(data, rnd.permutation(len(data['graph_labels']))[:args.N_test])
299 | print_stats(data_test_large, 'test large')
300 |
301 | check_graph_duplicates(data_train['Adj_matrices'] + data_val['Adj_matrices'] +
302 | data_test['Adj_matrices'] + data_test_large['Adj_matrices'])
303 |
304 | # Saving
305 | # Max degree is max over all graphs in the training and test sets
306 | max_degree = np.max(np.array([d['Max_degree'] for d in (data_train, data_val, data_test, data_test_large)]))
307 | data_train['Max_degree'] = max_degree
308 | with open('%s/random_graphs_triangles_train.pkl' % args.out_dir, 'wb') as f:
309 | pickle.dump(data_train, f, protocol=2)
310 |
311 | data_val['Max_degree'] = max_degree
312 | with open('%s/random_graphs_triangles_val.pkl' % args.out_dir, 'wb') as f:
313 | pickle.dump(data_val, f, protocol=2)
314 |
315 | data_test = concat_data((data_test, data_test_large))
316 | data_test['Max_degree'] = max_degree
317 | with open('%s/random_graphs_triangles_test.pkl' % args.out_dir, 'wb') as f:
318 | pickle.dump(data_test, f, protocol=2)
319 |
320 | else:
321 | raise NotImplementedError('unsupported dataset: ' + args.dataset)
322 |
323 | print('done in {}'.format(datetime.datetime.now() - dt))
324 |
--------------------------------------------------------------------------------
/graphdata.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | from os.path import join as pjoin
4 | import pickle
5 | import copy
6 | import torch
7 | import torch.utils
8 | import torch.utils.data
9 | import torch.nn.functional as F
10 | import torchvision
11 | from scipy.spatial.distance import cdist
12 | from utils import *
13 |
14 |
15 | def compute_adjacency_matrix_images(coord, sigma=0.1):
16 | coord = coord.reshape(-1, 2)
17 | dist = cdist(coord, coord)
18 | A = np.exp(- dist / (sigma * np.pi) ** 2)
19 | A[np.diag_indices_from(A)] = 0
20 | return A
21 |
22 |
23 | def precompute_graph_images(img_size):
24 | col, row = np.meshgrid(np.arange(img_size), np.arange(img_size))
25 | coord = np.stack((col, row), axis=2) / img_size # 28,28,2
26 | A = torch.from_numpy(compute_adjacency_matrix_images(coord)).float().unsqueeze(0)
27 | coord = torch.from_numpy(coord).float().unsqueeze(0).view(1, -1, 2)
28 | mask = torch.ones(1, img_size * img_size, dtype=torch.uint8)
29 | return A, coord, mask
30 |
31 |
32 | def collate_batch_images(batch, A, mask, use_mean_px=True, coord=None,
33 | gt_attn_threshold=0, replicate_features=True):
34 | B = len(batch)
35 | C, H, W = batch[0][0].shape
36 | N_nodes = H * W
37 | params_dict = {'N_nodes': torch.zeros(B, dtype=torch.long) + N_nodes, 'node_attn_eval': None}
38 | has_WS_attn = len(batch[0]) > 2
39 | if has_WS_attn:
40 | WS_attn = torch.from_numpy(np.stack([batch[b][2].reshape(N_nodes) for b in range(B)]).astype(np.float32)).view(B, N_nodes)
41 | WS_attn = normalize_batch(WS_attn)
42 | params_dict.update({'node_attn': WS_attn}) # use these scores for training
43 |
44 | if use_mean_px:
45 | x = torch.stack([batch[b][0].view(C, N_nodes).t() for b in range(B)]).float()
46 | if gt_attn_threshold == 0:
47 | GT_attn = (x > 0).view(B, N_nodes).float()
48 | else:
49 | GT_attn = x.view(B, N_nodes).float().clone()
50 | GT_attn[GT_attn < gt_attn_threshold] = 0
51 | GT_attn = normalize_batch(GT_attn)
52 |
53 | params_dict.update({'node_attn_eval': GT_attn}) # use this for evaluation of attention
54 | if not has_WS_attn:
55 | params_dict.update({'node_attn': GT_attn}) # use this to train attention
56 | else:
57 | raise NotImplementedError('this case is not well supported')
58 |
59 | if coord is not None:
60 | if use_mean_px:
61 | x = torch.cat((x, coord.expand(B, -1, -1)), dim=2)
62 | else:
63 | x = coord.expand(B, -1, -1)
64 | if x is None:
65 | x = torch.ones(B, N_nodes, 1) # dummy features
66 |
67 | if replicate_features:
68 | x = F.pad(x, (2, 0), 'replicate')
69 |
70 | try:
71 | labels = torch.Tensor([batch[b][1] for b in range(B)]).long()
72 | except:
73 | labels = torch.stack([batch[b][1] for b in range(B)]).long()
74 |
75 | return [x, A.expand(B, -1, -1), mask.expand(B, -1), labels, params_dict]
76 |
77 |
78 | def collate_batch(batch):
79 | '''
80 | Creates a batch of same size graphs by zero-padding node features and adjacency matrices up to
81 | the maximum number of nodes in the current batch rather than in the entire dataset.
82 | '''
83 |
84 | B = len(batch)
85 | N_nodes = [batch[b][2] for b in range(B)]
86 | C = batch[0][0].shape[1]
87 | N_nodes_max = int(np.max(N_nodes))
88 |
89 | mask = torch.zeros(B, N_nodes_max, dtype=torch.bool) # use byte for older PyTorch
90 | A = torch.zeros(B, N_nodes_max, N_nodes_max)
91 | x = torch.zeros(B, N_nodes_max, C)
92 | has_GT_attn = len(batch[0]) > 4 and batch[0][4] is not None
93 | if has_GT_attn:
94 | GT_attn = torch.zeros(B, N_nodes_max)
95 | has_WS_attn = len(batch[0]) > 5 and batch[0][5] is not None
96 | if has_WS_attn:
97 | WS_attn = torch.zeros(B, N_nodes_max)
98 |
99 | for b in range(B):
100 | x[b, :N_nodes[b]] = batch[b][0]
101 | A[b, :N_nodes[b], :N_nodes[b]] = batch[b][1]
102 | mask[b][:N_nodes[b]] = 1 # mask with values of 0 for dummy (zero padded) nodes, otherwise 1
103 | if has_GT_attn:
104 | GT_attn[b, :N_nodes[b]] = batch[b][4].squeeze()
105 | if has_WS_attn:
106 | WS_attn[b, :N_nodes[b]] = batch[b][5].squeeze()
107 |
108 | N_nodes = torch.from_numpy(np.array(N_nodes)).long()
109 |
110 | params_dict = {'N_nodes': N_nodes}
111 | if has_WS_attn:
112 | params_dict.update({'node_attn': WS_attn}) # use this to train attention
113 | if has_GT_attn:
114 | params_dict.update({'node_attn_eval': GT_attn}) # use this for evaluation of attention
115 | if not has_WS_attn:
116 | params_dict.update({'node_attn': GT_attn}) # use this to train attention
117 | elif has_WS_attn:
118 | params_dict.update({'node_attn_eval': WS_attn}) # use this for evaluation of attention
119 |
120 | labels = torch.from_numpy(np.array([batch[b][3] for b in range(B)])).long()
121 | return [x, A, mask, labels, params_dict]
122 |
123 |
124 | class MNIST(torchvision.datasets.MNIST):
125 | '''
126 | Wrapper around MNIST to use predefined attention coefficients
127 | '''
128 | def __init__(self, root, train=True, transform=None, target_transform=None, download=False, attn_coef=None):
129 | super(MNIST, self).__init__(root, train, transform, target_transform, download)
130 | self.alpha_WS = None
131 | if attn_coef is not None and train:
132 | print('loading weakly-supervised labels from %s' % attn_coef)
133 | with open(attn_coef, 'rb') as f:
134 | self.alpha_WS = pickle.load(f)
135 | print(train, len(self.alpha_WS))
136 |
137 | def __getitem__(self, index):
138 | img, target = super(MNIST, self).__getitem__(index)
139 | if self.alpha_WS is None:
140 | return img, target
141 | else:
142 | return img, target, self.alpha_WS[index]
143 |
144 |
145 | class MNIST75sp(torch.utils.data.Dataset):
146 | def __init__(self,
147 | data_dir,
148 | split,
149 | use_mean_px=True,
150 | use_coord=True,
151 | gt_attn_threshold=0,
152 | attn_coef=None):
153 |
154 | self.data_dir = data_dir
155 | self.split = split
156 | self.is_test = split.lower() in ['test', 'val']
157 | with open(pjoin(data_dir, 'mnist_75sp_%s.pkl' % split), 'rb') as f:
158 | self.labels, self.sp_data = pickle.load(f)
159 |
160 | self.use_mean_px = use_mean_px
161 | self.use_coord = use_coord
162 | self.n_samples = len(self.labels)
163 | self.img_size = 28
164 | self.gt_attn_threshold = gt_attn_threshold
165 |
166 | self.alpha_WS = None
167 | if attn_coef is not None and not self.is_test:
168 | with open(attn_coef, 'rb') as f:
169 | self.alpha_WS = pickle.load(f)
170 | print('using weakly-supervised labels from %s (%d samples)' % (attn_coef, len(self.alpha_WS)))
171 |
172 | def train_val_split(self, samples_idx):
173 | self.sp_data = [self.sp_data[i] for i in samples_idx]
174 | self.labels = self.labels[samples_idx]
175 | self.n_samples = len(self.labels)
176 |
177 | def precompute_graph_data(self, replicate_features, threads=0):
178 | print('precompute all data for the %s set...' % self.split.upper())
179 | self.Adj_matrices, self.node_features, self.GT_attn, self.WS_attn = [], [], [], []
180 | for index, sample in enumerate(self.sp_data):
181 | mean_px, coord = sample[:2]
182 | coord = coord / self.img_size
183 | A = compute_adjacency_matrix_images(coord)
184 | N_nodes = A.shape[0]
185 | x = None
186 | if self.use_mean_px:
187 | x = mean_px.reshape(N_nodes, -1)
188 | if self.use_coord:
189 | coord = coord.reshape(N_nodes, 2)
190 | if self.use_mean_px:
191 | x = np.concatenate((x, coord), axis=1)
192 | else:
193 | x = coord
194 | if x is None:
195 | x = np.ones(N_nodes, 1) # dummy features
196 | if replicate_features:
197 | x = np.pad(x, ((0, 0), (2, 0)), 'edge') # replicate features to make it possible to test on colored images
198 | if self.gt_attn_threshold == 0:
199 | gt_attn = (mean_px > 0).astype(np.float32)
200 | else:
201 | gt_attn = mean_px.copy()
202 | gt_attn[gt_attn < self.gt_attn_threshold] = 0
203 | self.GT_attn.append(normalize(gt_attn))
204 |
205 | if self.alpha_WS is not None:
206 | self.WS_attn.append(normalize(self.alpha_WS[index]))
207 |
208 | self.node_features.append(x)
209 | self.Adj_matrices.append(A)
210 |
211 |
212 | def __len__(self):
213 | return self.n_samples
214 |
215 | def __getitem__(self, index):
216 | data = [self.node_features[index],
217 | self.Adj_matrices[index],
218 | self.Adj_matrices[index].shape[0],
219 | self.labels[index],
220 | self.GT_attn[index]]
221 |
222 | if self.alpha_WS is not None:
223 | data.append(self.WS_attn[index])
224 |
225 | data = list_to_torch(data) # convert to torch
226 |
227 | return data
228 |
229 |
230 | class SyntheticGraphs(torch.utils.data.Dataset):
231 | def __init__(self,
232 | data_dir,
233 | dataset,
234 | split,
235 | degree_feature=True,
236 | attn_coef=None,
237 | threads=12):
238 |
239 | self.is_test = split.lower() in ['test', 'val']
240 | self.split = split
241 | self.degree_feature = degree_feature
242 |
243 | if dataset.find('colors') >= 0:
244 | dim = int(dataset.split('-')[1])
245 | data_file = 'random_graphs_colors_dim%d_%s.pkl' % (dim, split)
246 | is_triangles = False
247 | self.feature_dim = dim + 1
248 | if dataset.find('triangles') >= 0:
249 | data_file = 'random_graphs_triangles_%s.pkl' % split
250 | is_triangles = True
251 | else:
252 | NotImplementedError(dataset)
253 |
254 | with open(pjoin(data_dir, data_file), 'rb') as f:
255 | data = pickle.load(f)
256 |
257 | for key in data:
258 | if not isinstance(data[key], list) and not isinstance(data[key], np.ndarray):
259 | print(split, key, data[key])
260 | else:
261 | print(split, key, len(data[key]))
262 |
263 | self.Node_degrees = [np.sum(A, 1).astype(np.int32) for A in data['Adj_matrices']]
264 |
265 | if is_triangles:
266 | # use one-hot degree features as node features
267 | self.feature_dim = data['Max_degree'] + 1
268 | self.node_features = []
269 | for i in range(len(data['Adj_matrices'])):
270 | N = data['Adj_matrices'][i].shape[0]
271 | if degree_feature:
272 | D_onehot = np.zeros((N, self.feature_dim ))
273 | D_onehot[np.arange(N), self.Node_degrees[i]] = 1
274 | else:
275 | D_onehot = np.zeros((N, 1))
276 | self.node_features.append(D_onehot)
277 | if not degree_feature:
278 | self.feature_dim = 1
279 | else:
280 | # Add 1 feature to support new colors at test time
281 | self.node_features = []
282 | for i in range(len(data['node_features'])):
283 | features = data['node_features'][i]
284 | if features.shape[1] < self.feature_dim:
285 | features = np.pad(features, ((0, 0), (0, 1)), 'constant')
286 | self.node_features.append(features)
287 |
288 | self.alpha_WS = None
289 | if attn_coef is not None and not self.is_test:
290 | with open(attn_coef, 'rb') as f:
291 | self.alpha_WS = pickle.load(f)
292 | print('using weakly-supervised labels from %s (%d samples)' % (attn_coef, len(self.alpha_WS)))
293 | self.WS_attn = []
294 | for index in range(len(self.alpha_WS)):
295 | self.WS_attn.append(normalize(self.alpha_WS[index]))
296 |
297 | N_nodes = np.array([A.shape[0] for A in data['Adj_matrices']])
298 | self.Adj_matrices = data['Adj_matrices']
299 | self.GT_attn = data['GT_attn']
300 | # Normalizing ground truth attention so that it sums to 1
301 | for i in range(len(self.GT_attn)):
302 | self.GT_attn[i] = normalize(self.GT_attn[i])
303 | #assert np.sum(self.GT_attn[i]) == 1, (i, np.sum(self.GT_attn[i]), self.GT_attn[i])
304 | self.labels = data['graph_labels'].astype(np.int32)
305 | self.classes = np.unique(self.labels)
306 | self.n_classes = len(self.classes)
307 | R = np.corrcoef(self.labels, N_nodes)[0, 1]
308 |
309 | degrees = []
310 | for i in range(len(self.Node_degrees)):
311 | degrees.extend(list(self.Node_degrees[i]))
312 | degrees = np.array(degrees, np.int32)
313 |
314 | print('N nodes avg/std/min/max: \t{:.2f}/{:.2f}/{:d}/{:d}'.format(*stats(N_nodes)))
315 | print('N edges avg/std/min/max: \t{:.2f}/{:.2f}/{:d}/{:d}'.format(*stats(data['N_edges'])))
316 | print('Node degree avg/std/min/max: \t{:.2f}/{:.2f}/{:d}/{:d}'.format(*stats(degrees)))
317 | print('Node features dim: \t\t%d' % self.feature_dim)
318 | print('N classes: \t\t\t%d' % self.n_classes)
319 | print('Correlation of labels with graph size: \t%.2f' % R)
320 | print('Classes: \t\t\t%s' % str(self.classes))
321 | for lbl in self.classes:
322 | idx = self.labels == lbl
323 | print('Class {}: \t\t\t{} samples, N_nodes: avg/std/min/max: \t{:.2f}/{:.2f}/{:d}/{:d}'.format(lbl, np.sum(idx), *stats(N_nodes[idx])))
324 |
325 | def __len__(self):
326 | return len(self.Adj_matrices)
327 |
328 | def __getitem__(self, index):
329 | data = [self.node_features[index],
330 | self.Adj_matrices[index],
331 | self.Adj_matrices[index].shape[0],
332 | self.labels[index],
333 | self.GT_attn[index]]
334 |
335 | if self.alpha_WS is not None:
336 | data.append(self.WS_attn[index])
337 |
338 | data = list_to_torch(data) # convert to torch
339 |
340 | return data
341 |
342 |
343 | class GraphData(torch.utils.data.Dataset):
344 | def __init__(self,
345 | datareader,
346 | fold_id,
347 | split, # train, val, train_val, test
348 | degree_feature=True,
349 | attn_labels=None):
350 | self.fold_id = fold_id
351 | self.split = split
352 | self.w_sup_signal_attn = None
353 | print('''The degree_feature argument is ignored for this dataset.
354 | It will automatically be set to True if nodes do not have any features. Otherwise it will be set to False''')
355 | if attn_labels is not None:
356 | if isinstance(attn_labels, str) and os.path.isfile(attn_labels):
357 | with open(attn_labels, 'rb') as f:
358 | self.w_sup_signal_attn = pickle.load(f)
359 | else:
360 | self.w_sup_signal_attn = attn_labels
361 | for i in range(len(self.w_sup_signal_attn)):
362 | alpha = self.w_sup_signal_attn[i]
363 | alpha[alpha < 1e-3] = 0 # assuming that some nodes should have zero importance
364 | self.w_sup_signal_attn[i] = normalize(alpha)
365 | print(('!!!using weakly supervised labels (%d samples)!!!' % len(self.w_sup_signal_attn)).upper())
366 |
367 | self.set_fold(datareader.data, fold_id)
368 |
369 | def set_fold(self, data, fold_id):
370 |
371 | self.total = len(data['targets'])
372 | self.N_nodes_max = data['N_nodes_max']
373 | self.num_classes = data['num_classes']
374 | self.num_features = data['num_features']
375 | if self.split in ['train', 'val']:
376 | self.idx = data['splits'][self.split][fold_id]
377 | else:
378 | assert self.split in ['train_val', 'test'], ('unexpected split', self.split)
379 | self.idx = data['splits'][self.split]
380 |
381 | # use deepcopy to make sure we don't alter objects in folds
382 | self.labels = np.array(copy.deepcopy([data['targets'][i] for i in self.idx]))
383 | self.adj_list = copy.deepcopy([data['adj_list'][i] for i in self.idx])
384 | self.features_onehot = copy.deepcopy([data['features_onehot'][i] for i in self.idx])
385 | self.N_edges = np.array([A.sum() // 2 for A in self.adj_list]) # assume undirected graph with binary edges
386 | print('%s: %d/%d' % (self.split.upper(), len(self.labels), len(data['targets'])))
387 | classes = np.unique(self.labels)
388 | for lbl in classes:
389 | print('Class %d: \t\t\t%d samples' % (lbl, np.sum(self.labels == lbl)))
390 |
391 | def __len__(self):
392 | return len(self.labels)
393 |
394 | def __getitem__(self, index):
395 | if isinstance(index, str):
396 | # To make data format consistent with SyntheticGraphs
397 | if index == 'Adj_matrices':
398 | return self.adj_list
399 | elif index == 'GT_attn':
400 | print('Ground truth attention is unavailable for this dataset: weakly-supervised labels will be returned')
401 | return self.w_sup_signal_attn
402 | elif index == 'graph_labels':
403 | return self.labels
404 | elif index == 'node_features':
405 | return self.features_onehot
406 | elif index == 'N_edges':
407 | return self.N_edges
408 | else:
409 | raise KeyError(index)
410 | else:
411 | data = [self.features_onehot[index],
412 | self.adj_list[index],
413 | self.adj_list[index].shape[0],
414 | self.labels[index],
415 | None] # no GT attention
416 | if self.w_sup_signal_attn is not None:
417 | data.append(self.w_sup_signal_attn[index])
418 | data = list_to_torch(data) # convert to torch
419 |
420 | return data
421 |
422 |
423 | class DataReader():
424 | '''
425 | Class to read the txt files containing all data of the dataset
426 | Should work for any dataset from https://ls11-www.cs.tu-dortmund.de/staff/morris/graphkerneldatasets
427 | '''
428 |
429 | def __init__(self,
430 | data_dir, # folder with txt files
431 | N_nodes=None, # maximum number of nodes in the training set
432 | rnd_state=None,
433 | use_cont_node_attr=False, # use or not additional float valued node attributes available in some datasets
434 | folds=10,
435 | fold_id=None):
436 |
437 | self.data_dir = data_dir
438 | self.rnd_state = np.random.RandomState() if rnd_state is None else rnd_state
439 | self.use_cont_node_attr = use_cont_node_attr
440 | self.N_nodes = N_nodes
441 | if os.path.isfile('%s/data.pkl' % data_dir):
442 | print('loading data from %s/data.pkl' % data_dir)
443 | with open('%s/data.pkl' % data_dir, 'rb') as f:
444 | data = pickle.load(f)
445 | else:
446 | files = os.listdir(self.data_dir)
447 | data = {}
448 | nodes, graphs = self.read_graph_nodes_relations(
449 | list(filter(lambda f: f.find('graph_indicator') >= 0, files))[0])
450 | lst = list(filter(lambda f: f.find('node_labels') >= 0, files))
451 | if len(lst) > 0:
452 | data['features'] = self.read_node_features(lst[0], nodes, graphs, fn=lambda s: int(s.strip()))
453 | else:
454 | data['features'] = None
455 | data['adj_list'] = self.read_graph_adj(list(filter(lambda f: f.find('_A') >= 0, files))[0], nodes, graphs)
456 | data['targets'] = np.array(
457 | self.parse_txt_file(list(filter(lambda f: f.find('graph_labels') >= 0, files))[0],
458 | line_parse_fn=lambda s: int(float(s.strip()))))
459 |
460 | if self.use_cont_node_attr:
461 | data['attr'] = self.read_node_features(list(filter(lambda f: f.find('node_attributes') >= 0, files))[0],
462 | nodes, graphs,
463 | fn=lambda s: np.array(list(map(float, s.strip().split(',')))))
464 |
465 | features, n_edges, degrees = [], [], []
466 | for sample_id, adj in enumerate(data['adj_list']):
467 | N = len(adj) # number of nodes
468 | if data['features'] is not None:
469 | assert N == len(data['features'][sample_id]), (N, len(data['features'][sample_id]))
470 | n = np.sum(adj) # total sum of edges
471 | # assert n % 2 == 0, n
472 | n_edges.append(int(n / 2)) # undirected edges, so need to divide by 2
473 | if not np.allclose(adj, adj.T):
474 | print(sample_id, 'not symmetric')
475 | degrees.extend(list(np.sum(adj, 1)))
476 | if data['features'] is not None:
477 | features.append(np.array(data['features'][sample_id]))
478 |
479 | # Create features over graphs as one-hot vectors for each node
480 | if data['features'] is not None:
481 | features_all = np.concatenate(features)
482 | features_min = features_all.min()
483 | num_features = int(features_all.max() - features_min + 1) # number of possible values
484 |
485 | features_onehot = []
486 | for i, x in enumerate(features):
487 | feature_onehot = np.zeros((len(x), num_features))
488 | for node, value in enumerate(x):
489 | feature_onehot[node, value - features_min] = 1
490 | if self.use_cont_node_attr:
491 | feature_onehot = np.concatenate((feature_onehot, np.array(data['attr'][i])), axis=1)
492 | features_onehot.append(feature_onehot)
493 |
494 | if self.use_cont_node_attr:
495 | num_features = features_onehot[0].shape[1]
496 | else:
497 | degree_max = int(np.max([np.sum(A, 1).max() for A in data['adj_list']]))
498 | num_features = degree_max + 1
499 | features_onehot = []
500 | for A in data['adj_list']:
501 | n = A.shape[0]
502 | D = np.sum(A, 1).astype(np.int)
503 | D_onehot = np.zeros((n, num_features))
504 | D_onehot[np.arange(n), D] = 1
505 | features_onehot.append(D_onehot)
506 |
507 | shapes = [len(adj) for adj in data['adj_list']]
508 | labels = data['targets'] # graph class labels
509 | labels -= np.min(labels) # to start from 0
510 |
511 | classes = np.unique(labels)
512 | num_classes = len(classes)
513 |
514 | if not np.all(np.diff(classes) == 1):
515 | print('making labels sequential, otherwise pytorch might crash')
516 | labels_new = np.zeros(labels.shape, dtype=labels.dtype) - 1
517 | for lbl in range(num_classes):
518 | labels_new[labels == classes[lbl]] = lbl
519 | labels = labels_new
520 | classes = np.unique(labels)
521 | assert len(np.unique(labels)) == num_classes, np.unique(labels)
522 |
523 | def stats(x):
524 | return (np.mean(x), np.std(x), np.min(x), np.max(x))
525 |
526 | print('N nodes avg/std/min/max: \t%.2f/%.2f/%d/%d' % stats(shapes))
527 | print('N edges avg/std/min/max: \t%.2f/%.2f/%d/%d' % stats(n_edges))
528 | print('Node degree avg/std/min/max: \t%.2f/%.2f/%d/%d' % stats(degrees))
529 | print('Node features dim: \t\t%d' % num_features)
530 | print('N classes: \t\t\t%d' % num_classes)
531 | print('Classes: \t\t\t%s' % str(classes))
532 | for lbl in classes:
533 | print('Class %d: \t\t\t%d samples' % (lbl, np.sum(labels == lbl)))
534 |
535 | if data['features'] is not None:
536 | for u in np.unique(features_all):
537 | print('feature {}, count {}/{}'.format(u, np.count_nonzero(features_all == u), len(features_all)))
538 |
539 | N_graphs = len(labels) # number of samples (graphs) in data
540 | assert N_graphs == len(data['adj_list']) == len(features_onehot), 'invalid data'
541 |
542 | data['features_onehot'] = features_onehot
543 | data['targets'] = labels
544 |
545 | data['N_nodes_max'] = np.max(shapes) # max number of nodes
546 | data['num_features'] = num_features
547 | data['num_classes'] = num_classes
548 |
549 | # Save preprocessed data for faster loading
550 | with open('%s/data.pkl' % data_dir, 'wb') as f:
551 | pickle.dump(data, f, protocol=2)
552 |
553 | labels = data['targets']
554 | # Create test sets first
555 | N_graphs = len(labels)
556 | shapes = np.array([len(adj) for adj in data['adj_list']])
557 | train_ids, val_ids, train_val_ids, test_ids = self.split_ids_shape(np.arange(N_graphs), shapes, N_nodes, folds=folds)
558 |
559 | # Create train sets
560 | splits = {'train': [], 'val': [], 'train_val': train_val_ids, 'test': test_ids}
561 | for fold in range(folds):
562 | splits['train'].append(train_ids[fold])
563 | splits['val'].append(val_ids[fold])
564 |
565 | data['splits'] = splits
566 |
567 | self.data = data
568 |
569 |
570 | def split_ids_shape(self, ids_all, shapes, N_nodes, folds=1, fold_id=0):
571 | if N_nodes > 0:
572 | small_graphs_ind = np.where(shapes <= N_nodes)[0]
573 | print('{}/{} graphs with at least {} nodes'.format(len(small_graphs_ind), len(shapes), N_nodes))
574 | idx = self.rnd_state.permutation(len(small_graphs_ind))
575 | if len(idx) > 1000:
576 | n = 1000
577 | else:
578 | n = 500
579 | train_val_ids = small_graphs_ind[idx[:n]]
580 | test_ids = small_graphs_ind[idx[n:]]
581 | large_graphs_ind = np.where(shapes > N_nodes)[0]
582 | test_ids = np.concatenate((test_ids, large_graphs_ind))
583 | else:
584 | idx = self.rnd_state.permutation(len(ids_all))
585 | n = len(ids_all) // folds # number of test samples
586 | test_ids = ids_all[idx[fold_id * n: (fold_id + 1) * n if fold_id < folds - 1 else -1]]
587 | train_val_ids = []
588 | for i in ids_all:
589 | if i not in test_ids:
590 | train_val_ids.append(i)
591 | train_val_ids = np.array(train_val_ids)
592 |
593 | assert np.all(
594 | np.unique(np.concatenate((train_val_ids, test_ids))) == sorted(ids_all)), 'some graphs are missing in the test sets'
595 | if folds > 0:
596 | print('generating %d-fold cross-validation splits' % folds)
597 | train_ids, val_ids = self.split_ids(train_val_ids, folds=folds)
598 | # Sanity checks
599 | for fold in range(folds):
600 | ind = np.concatenate((train_ids[fold], val_ids[fold]))
601 | print(fold, len(train_ids[fold]), len(val_ids[fold]))
602 | assert len(train_ids[fold]) + len(val_ids[fold]) == len(np.unique(ind)) == len(ind) == len(train_val_ids), 'invalid splits'
603 | else:
604 | train_ids, val_ids = [], []
605 |
606 | return train_ids, val_ids, train_val_ids, test_ids
607 |
608 | def split_ids(self, ids, folds=10):
609 | n = len(ids)
610 | stride = int(np.ceil(n / float(folds)))
611 | test_ids = [ids[i: i + stride] for i in range(0, n, stride)]
612 | assert np.all(
613 | np.unique(np.concatenate(test_ids)) == sorted(ids)), 'some graphs are missing in the test sets'
614 | assert len(test_ids) == folds, 'invalid test sets'
615 | train_ids = []
616 | for fold in range(folds):
617 | train_ids.append(np.array([e for e in ids if e not in test_ids[fold]]))
618 | assert len(train_ids[fold]) + len(test_ids[fold]) == len(
619 | np.unique(list(train_ids[fold]) + list(test_ids[fold]))) == n, 'invalid splits'
620 |
621 | return train_ids, test_ids
622 |
623 | def parse_txt_file(self, fpath, line_parse_fn=None):
624 | with open(pjoin(self.data_dir, fpath), 'r') as f:
625 | lines = f.readlines()
626 | data = [line_parse_fn(s) if line_parse_fn is not None else s for s in lines]
627 | return data
628 |
629 | def read_graph_adj(self, fpath, nodes, graphs):
630 | edges = self.parse_txt_file(fpath, line_parse_fn=lambda s: s.split(','))
631 | adj_dict = {}
632 | for edge in edges:
633 | node1 = int(edge[0].strip()) - 1 # -1 because of zero-indexing in our code
634 | node2 = int(edge[1].strip()) - 1
635 | graph_id = nodes[node1]
636 | assert graph_id == nodes[node2], ('invalid data', graph_id, nodes[node2])
637 | if graph_id not in adj_dict:
638 | n = len(graphs[graph_id])
639 | adj_dict[graph_id] = np.zeros((n, n))
640 | ind1 = np.where(graphs[graph_id] == node1)[0]
641 | ind2 = np.where(graphs[graph_id] == node2)[0]
642 | assert len(ind1) == len(ind2) == 1, (ind1, ind2)
643 | adj_dict[graph_id][ind1, ind2] = 1
644 |
645 | adj_list = [adj_dict[graph_id] for graph_id in sorted(list(graphs.keys()))]
646 |
647 | return adj_list
648 |
649 | def read_graph_nodes_relations(self, fpath):
650 | graph_ids = self.parse_txt_file(fpath, line_parse_fn=lambda s: int(s.rstrip()))
651 | nodes, graphs = {}, {}
652 | for node_id, graph_id in enumerate(graph_ids):
653 | if graph_id not in graphs:
654 | graphs[graph_id] = []
655 | graphs[graph_id].append(node_id)
656 | nodes[node_id] = graph_id
657 | graph_ids = np.unique(list(graphs.keys()))
658 | for graph_id in graph_ids:
659 | graphs[graph_id] = np.array(graphs[graph_id])
660 | return nodes, graphs
661 |
662 | def read_node_features(self, fpath, nodes, graphs, fn):
663 | node_features_all = self.parse_txt_file(fpath, line_parse_fn=fn)
664 | node_features = {}
665 | for node_id, x in enumerate(node_features_all):
666 | graph_id = nodes[node_id]
667 | if graph_id not in node_features:
668 | node_features[graph_id] = [None] * len(graphs[graph_id])
669 | ind = np.where(graphs[graph_id] == node_id)[0]
670 | assert len(ind) == 1, ind
671 | assert node_features[graph_id][ind[0]] is None, node_features[graph_id][ind[0]]
672 | node_features[graph_id][ind[0]] = x
673 | node_features_lst = [node_features[graph_id] for graph_id in sorted(list(graphs.keys()))]
674 | return node_features_lst
675 |
--------------------------------------------------------------------------------
/logs/colors-3_global_max_seed111.log:
--------------------------------------------------------------------------------
1 | start time: 2019-06-21 11:23:37.828931
2 | gpus: 1
3 | dataset colors-3
4 | data_dir ./data
5 | epochs 100
6 | batch_size 32
7 | lr 0.001
8 | lr_decay_step [90]
9 | wdecay 0.0001
10 | dropout 0.0
11 | filters [64, 64]
12 | filter_scale 2
13 | n_hidden 0
14 | aggregation mean
15 | readout sum
16 | kl_weight 100
17 | pool None
18 | pool_arch ['fc', 'prev']
19 | n_nodes 25
20 | cv_folds 5
21 | img_features ['mean', 'coord']
22 | img_noise_levels [0.4, 0.6]
23 | validation False
24 | debug False
25 | eval_attn_train True
26 | eval_attn_test True
27 | test_batch_size 100
28 | alpha_ws None
29 | log_interval 400
30 | results ./checkpoints/
31 | resume None
32 | device cuda
33 | seed 111
34 | threads 0
35 | experiment_ID: 828931
36 | train Adj_matrices 500
37 | train GT_attn 500
38 | train graph_labels 500
39 | train node_features 500
40 | train N_edges 500
41 | N nodes avg/std/min/max: 14.26/6.22/4/25
42 | N edges avg/std/min/max: 20.87/10.49/4/48
43 | Node degree avg/std/min/max: 2.93/1.53/0/9
44 | Node features dim: 4
45 | N classes: 11
46 | Correlation of labels with graph size: 0.18
47 | Classes: [ 0 1 2 3 4 5 6 7 8 9 10]
48 | Class 0: 58 samples, N_nodes: avg/std/min/max: 12.90/6.79/4/25
49 | Class 1: 56 samples, N_nodes: avg/std/min/max: 14.00/6.55/4/25
50 | Class 2: 52 samples, N_nodes: avg/std/min/max: 14.46/7.10/4/25
51 | Class 3: 56 samples, N_nodes: avg/std/min/max: 11.86/5.95/4/25
52 | Class 4: 53 samples, N_nodes: avg/std/min/max: 13.55/6.16/4/25
53 | Class 5: 49 samples, N_nodes: avg/std/min/max: 14.73/6.38/5/25
54 | Class 6: 49 samples, N_nodes: avg/std/min/max: 13.57/4.98/6/23
55 | Class 7: 30 samples, N_nodes: avg/std/min/max: 15.37/5.64/7/25
56 | Class 8: 33 samples, N_nodes: avg/std/min/max: 17.24/4.91/9/24
57 | Class 9: 35 samples, N_nodes: avg/std/min/max: 16.11/4.90/9/25
58 | Class 10: 29 samples, N_nodes: avg/std/min/max: 16.66/4.95/10/25
59 | test Adj_matrices 7500
60 | test GT_attn 7500
61 | test graph_labels 7500
62 | test node_features 7500
63 | test N_edges 7500
64 | N nodes avg/std/min/max: 80.64/62.43/4/200
65 | N edges avg/std/min/max: 120.95/98.42/4/396
66 | Node degree avg/std/min/max: 3.00/1.78/0/14
67 | Node features dim: 4
68 | N classes: 11
69 | Correlation of labels with graph size: 0.06
70 | Classes: [ 0 1 2 3 4 5 6 7 8 9 10]
71 | Class 0: 730 samples, N_nodes: avg/std/min/max: 76.34/62.01/4/200
72 | Class 1: 699 samples, N_nodes: avg/std/min/max: 77.19/63.03/4/200
73 | Class 2: 706 samples, N_nodes: avg/std/min/max: 77.07/62.79/4/200
74 | Class 3: 699 samples, N_nodes: avg/std/min/max: 75.73/62.78/4/200
75 | Class 4: 717 samples, N_nodes: avg/std/min/max: 77.32/62.59/4/200
76 | Class 5: 697 samples, N_nodes: avg/std/min/max: 81.44/63.24/5/200
77 | Class 6: 687 samples, N_nodes: avg/std/min/max: 82.32/61.91/6/200
78 | Class 7: 690 samples, N_nodes: avg/std/min/max: 82.41/61.77/7/200
79 | Class 8: 638 samples, N_nodes: avg/std/min/max: 84.81/62.83/8/200
80 | Class 9: 641 samples, N_nodes: avg/std/min/max: 85.45/61.25/9/200
81 | Class 10: 596 samples, N_nodes: avg/std/min/max: 89.37/60.59/10/200
82 | ChebyGINLayer torch.Size([64, 8]) tensor([0.5559, 0.4973, 0.4914, 0.5325, 0.4516, 0.5541, 0.7513, 0.6097, 0.5125,
83 | 0.5658], grad_fn=)
84 | ChebyGINLayer torch.Size([64, 128]) tensor([0.5931, 0.5994, 0.5544, 0.5491, 0.5944, 0.5986, 0.6081, 0.5848, 0.5648,
85 | 0.5838], grad_fn=)
86 | ChebyGIN(
87 | (graph_layers): Sequential(
88 | (0): ChebyGINLayer(in_features=4, out_features=64, K=2, n_hidden=0, aggregation=mean)
89 | fc=Sequential(
90 | (0): Linear(in_features=8, out_features=64, bias=True)
91 | (1): ReLU(inplace)
92 | )
93 | (1): ChebyGINLayer(in_features=64, out_features=64, K=2, n_hidden=0, aggregation=mean)
94 | fc=Sequential(
95 | (0): Linear(in_features=128, out_features=64, bias=True)
96 | (1): ReLU(inplace)
97 | )
98 | (2): GraphReadout(sum)
99 | )
100 | (fc): Sequential(
101 | (0): Linear(in_features=64, out_features=1, bias=True)
102 | )
103 | )
104 | model capacity: 8897
105 | model is checked for nodes shuffling
106 | Train set (epoch 1): [500/500 (100%)] Loss: 4.0388 (avg: 13.2787), other losses: [] Acc metric: 70/500 (14.00%) AttnAUC: [] avg sec/iter: 0.0076
107 |
108 |
109 | saving the model to ./checkpoints//checkpoint_colors-3_828931_epoch1_seed0000111.pth.tar
110 | model is checked for nodes shuffling
111 | lbl: 0, avg acc: 0.00% (0/58)
112 | lbl: 1, avg acc: 21.43% (12/56)
113 | lbl: 2, avg acc: 21.15% (11/52)
114 | lbl: 3, avg acc: 23.21% (13/56)
115 | lbl: 4, avg acc: 32.08% (17/53)
116 | lbl: 5, avg acc: 20.41% (10/49)
117 | lbl: 6, avg acc: 14.29% (7/49)
118 | lbl: 7, avg acc: 10.00% (3/30)
119 | lbl: 8, avg acc: 0.00% (0/33)
120 | lbl: 9, avg acc: 0.00% (0/35)
121 | lbl: 10, avg acc: 0.00% (0/29)
122 | 0 <= N_nodes <= 25 (min=4, max=25), avg acc: 14.60% (73/500)
123 | Train set (epoch 1): Avg loss: 4.7120, Acc metric: 73/500 (14.60%) AttnAUC: [] avg sec/iter: 0.0268
124 |
125 | model is checked for nodes shuffling
126 | lbl: 0, avg acc: 0.00% (0/262)
127 | lbl: 1, avg acc: 31.73% (79/249)
128 | lbl: 2, avg acc: 28.29% (73/258)
129 | lbl: 3, avg acc: 24.32% (63/259)
130 | lbl: 4, avg acc: 29.00% (78/269)
131 | lbl: 5, avg acc: 25.00% (59/236)
132 | lbl: 6, avg acc: 18.40% (39/212)
133 | lbl: 7, avg acc: 4.59% (10/218)
134 | lbl: 8, avg acc: 0.00% (0/196)
135 | lbl: 9, avg acc: 0.00% (0/192)
136 | lbl: 10, avg acc: 0.00% (0/149)
137 | 0 <= N_nodes <= 25 (min=4, max=25), avg acc: 16.04% (401/2500)
138 | lbl: 0, avg acc: 0.00% (0/226)
139 | lbl: 1, avg acc: 0.00% (0/216)
140 | lbl: 2, avg acc: 0.00% (0/223)
141 | lbl: 3, avg acc: 0.00% (0/237)
142 | lbl: 4, avg acc: 0.00% (0/229)
143 | lbl: 5, avg acc: 0.00% (0/231)
144 | lbl: 6, avg acc: 1.25% (3/240)
145 | lbl: 7, avg acc: 2.07% (5/242)
146 | lbl: 8, avg acc: 2.82% (6/213)
147 | lbl: 9, avg acc: 2.71% (6/221)
148 | lbl: 10, avg acc: 1.80% (4/222)
149 | 26 <= N_nodes <= 200 (min=26, max=200), avg acc: 0.96% (24/2500)
150 | lbl: 0, avg acc: 0.00% (0/242)
151 | lbl: 1, avg acc: 0.00% (0/234)
152 | lbl: 2, avg acc: 0.00% (0/225)
153 | lbl: 3, avg acc: 0.00% (0/203)
154 | lbl: 4, avg acc: 0.00% (0/219)
155 | lbl: 5, avg acc: 0.00% (0/230)
156 | lbl: 6, avg acc: 0.43% (1/235)
157 | lbl: 7, avg acc: 3.04% (7/230)
158 | lbl: 8, avg acc: 1.31% (3/229)
159 | lbl: 9, avg acc: 3.07% (7/228)
160 | lbl: 10, avg acc: 2.67% (6/225)
161 | 26 <= N_nodes <= 200 (min=26, max=200), avg acc: 0.96% (24/2500)
162 | Test set (epoch 1): Avg loss: 218.5921, Acc metric: 449/7500 (5.99%) AttnAUC: [] avg sec/iter: 0.0118
163 |
164 | Train set (epoch 2): [500/500 (100%)] Loss: 0.5873 (avg: 2.1991), other losses: [] Acc metric: 133/500 (26.60%) AttnAUC: [] avg sec/iter: 0.0035
165 |
166 |
167 | Train set (epoch 3): [500/500 (100%)] Loss: 0.1954 (avg: 0.2053), other losses: [] Acc metric: 382/500 (76.40%) AttnAUC: [] avg sec/iter: 0.0030
168 |
169 |
170 | Train set (epoch 4): [500/500 (100%)] Loss: 0.1411 (avg: 0.1459), other losses: [] Acc metric: 417/500 (83.40%) AttnAUC: [] avg sec/iter: 0.0030
171 |
172 |
173 | Train set (epoch 5): [500/500 (100%)] Loss: 0.1857 (avg: 0.1294), other losses: [] Acc metric: 427/500 (85.40%) AttnAUC: [] avg sec/iter: 0.0030
174 |
175 |
176 | Train set (epoch 6): [500/500 (100%)] Loss: 0.1421 (avg: 0.1294), other losses: [] Acc metric: 432/500 (86.40%) AttnAUC: [] avg sec/iter: 0.0069
177 |
178 |
179 | Train set (epoch 7): [500/500 (100%)] Loss: 0.0833 (avg: 0.1182), other losses: [] Acc metric: 432/500 (86.40%) AttnAUC: [] avg sec/iter: 0.0036
180 |
181 |
182 | Train set (epoch 8): [500/500 (100%)] Loss: 0.1359 (avg: 0.1069), other losses: [] Acc metric: 437/500 (87.40%) AttnAUC: [] avg sec/iter: 0.0035
183 |
184 |
185 | Train set (epoch 9): [500/500 (100%)] Loss: 0.2092 (avg: 0.1067), other losses: [] Acc metric: 435/500 (87.00%) AttnAUC: [] avg sec/iter: 0.0031
186 |
187 |
188 | Train set (epoch 10): [500/500 (100%)] Loss: 0.0502 (avg: 0.0982), other losses: [] Acc metric: 450/500 (90.00%) AttnAUC: [] avg sec/iter: 0.0032
189 |
190 |
191 | Train set (epoch 11): [500/500 (100%)] Loss: 0.0328 (avg: 0.0914), other losses: [] Acc metric: 450/500 (90.00%) AttnAUC: [] avg sec/iter: 0.0032
192 |
193 |
194 | Train set (epoch 12): [500/500 (100%)] Loss: 0.1351 (avg: 0.0882), other losses: [] Acc metric: 450/500 (90.00%) AttnAUC: [] avg sec/iter: 0.0034
195 |
196 |
197 | Train set (epoch 13): [500/500 (100%)] Loss: 0.1683 (avg: 0.0780), other losses: [] Acc metric: 459/500 (91.80%) AttnAUC: [] avg sec/iter: 0.0034
198 |
199 |
200 | Train set (epoch 14): [500/500 (100%)] Loss: 0.0541 (avg: 0.0779), other losses: [] Acc metric: 457/500 (91.40%) AttnAUC: [] avg sec/iter: 0.0033
201 |
202 |
203 | Train set (epoch 15): [500/500 (100%)] Loss: 0.0494 (avg: 0.0691), other losses: [] Acc metric: 461/500 (92.20%) AttnAUC: [] avg sec/iter: 0.0034
204 |
205 |
206 | Train set (epoch 16): [500/500 (100%)] Loss: 0.0533 (avg: 0.0665), other losses: [] Acc metric: 460/500 (92.00%) AttnAUC: [] avg sec/iter: 0.0041
207 |
208 |
209 | Train set (epoch 17): [500/500 (100%)] Loss: 0.0395 (avg: 0.0572), other losses: [] Acc metric: 470/500 (94.00%) AttnAUC: [] avg sec/iter: 0.0040
210 |
211 |
212 | Train set (epoch 18): [500/500 (100%)] Loss: 0.0702 (avg: 0.0482), other losses: [] Acc metric: 479/500 (95.80%) AttnAUC: [] avg sec/iter: 0.0038
213 |
214 |
215 | Train set (epoch 19): [500/500 (100%)] Loss: 0.0504 (avg: 0.0463), other losses: [] Acc metric: 482/500 (96.40%) AttnAUC: [] avg sec/iter: 0.0035
216 |
217 |
218 | Train set (epoch 20): [500/500 (100%)] Loss: 0.0798 (avg: 0.0495), other losses: [] Acc metric: 474/500 (94.80%) AttnAUC: [] avg sec/iter: 0.0035
219 |
220 |
221 | Train set (epoch 21): [500/500 (100%)] Loss: 0.0540 (avg: 0.0386), other losses: [] Acc metric: 489/500 (97.80%) AttnAUC: [] avg sec/iter: 0.0038
222 |
223 |
224 | Train set (epoch 22): [500/500 (100%)] Loss: 0.0729 (avg: 0.0523), other losses: [] Acc metric: 471/500 (94.20%) AttnAUC: [] avg sec/iter: 0.0038
225 |
226 |
227 | Train set (epoch 23): [500/500 (100%)] Loss: 0.0210 (avg: 0.0301), other losses: [] Acc metric: 493/500 (98.60%) AttnAUC: [] avg sec/iter: 0.0040
228 |
229 |
230 | Train set (epoch 24): [500/500 (100%)] Loss: 0.0591 (avg: 0.0323), other losses: [] Acc metric: 489/500 (97.80%) AttnAUC: [] avg sec/iter: 0.0040
231 |
232 |
233 | Train set (epoch 25): [500/500 (100%)] Loss: 0.0328 (avg: 0.0287), other losses: [] Acc metric: 491/500 (98.20%) AttnAUC: [] avg sec/iter: 0.0039
234 |
235 |
236 | Train set (epoch 26): [500/500 (100%)] Loss: 0.0195 (avg: 0.0281), other losses: [] Acc metric: 495/500 (99.00%) AttnAUC: [] avg sec/iter: 0.0039
237 |
238 |
239 | Train set (epoch 27): [500/500 (100%)] Loss: 0.0114 (avg: 0.0240), other losses: [] Acc metric: 496/500 (99.20%) AttnAUC: [] avg sec/iter: 0.0039
240 |
241 |
242 | Train set (epoch 28): [500/500 (100%)] Loss: 0.0369 (avg: 0.0207), other losses: [] Acc metric: 496/500 (99.20%) AttnAUC: [] avg sec/iter: 0.0044
243 |
244 |
245 | Train set (epoch 29): [500/500 (100%)] Loss: 0.0334 (avg: 0.0266), other losses: [] Acc metric: 495/500 (99.00%) AttnAUC: [] avg sec/iter: 0.0056
246 |
247 |
248 | Train set (epoch 30): [500/500 (100%)] Loss: 0.0079 (avg: 0.0249), other losses: [] Acc metric: 493/500 (98.60%) AttnAUC: [] avg sec/iter: 0.0078
249 |
250 |
251 | Train set (epoch 31): [500/500 (100%)] Loss: 0.0076 (avg: 0.0161), other losses: [] Acc metric: 498/500 (99.60%) AttnAUC: [] avg sec/iter: 0.0060
252 |
253 |
254 | Train set (epoch 32): [500/500 (100%)] Loss: 0.0076 (avg: 0.0218), other losses: [] Acc metric: 497/500 (99.40%) AttnAUC: [] avg sec/iter: 0.0044
255 |
256 |
257 | Train set (epoch 33): [500/500 (100%)] Loss: 0.0301 (avg: 0.0235), other losses: [] Acc metric: 496/500 (99.20%) AttnAUC: [] avg sec/iter: 0.0037
258 |
259 |
260 | Train set (epoch 34): [500/500 (100%)] Loss: 0.0307 (avg: 0.0154), other losses: [] Acc metric: 497/500 (99.40%) AttnAUC: [] avg sec/iter: 0.0033
261 |
262 |
263 | Train set (epoch 35): [500/500 (100%)] Loss: 0.0156 (avg: 0.0122), other losses: [] Acc metric: 499/500 (99.80%) AttnAUC: [] avg sec/iter: 0.0033
264 |
265 |
266 | Train set (epoch 36): [500/500 (100%)] Loss: 0.0095 (avg: 0.0110), other losses: [] Acc metric: 498/500 (99.60%) AttnAUC: [] avg sec/iter: 0.0033
267 |
268 |
269 | Train set (epoch 37): [500/500 (100%)] Loss: 0.0799 (avg: 0.0174), other losses: [] Acc metric: 500/500 (100.00%) AttnAUC: [] avg sec/iter: 0.0033
270 |
271 |
272 | Train set (epoch 38): [500/500 (100%)] Loss: 0.0030 (avg: 0.0129), other losses: [] Acc metric: 499/500 (99.80%) AttnAUC: [] avg sec/iter: 0.0033
273 |
274 |
275 | Train set (epoch 39): [500/500 (100%)] Loss: 0.0798 (avg: 0.0168), other losses: [] Acc metric: 497/500 (99.40%) AttnAUC: [] avg sec/iter: 0.0039
276 |
277 |
278 | Train set (epoch 40): [500/500 (100%)] Loss: 0.0138 (avg: 0.0181), other losses: [] Acc metric: 499/500 (99.80%) AttnAUC: [] avg sec/iter: 0.0041
279 |
280 |
281 | Train set (epoch 41): [500/500 (100%)] Loss: 0.0121 (avg: 0.0092), other losses: [] Acc metric: 499/500 (99.80%) AttnAUC: [] avg sec/iter: 0.0040
282 |
283 |
284 | Train set (epoch 42): [500/500 (100%)] Loss: 0.0390 (avg: 0.0146), other losses: [] Acc metric: 499/500 (99.80%) AttnAUC: [] avg sec/iter: 0.0034
285 |
286 |
287 | Train set (epoch 43): [500/500 (100%)] Loss: 0.0235 (avg: 0.0140), other losses: [] Acc metric: 499/500 (99.80%) AttnAUC: [] avg sec/iter: 0.0038
288 |
289 |
290 | Train set (epoch 44): [500/500 (100%)] Loss: 0.0053 (avg: 0.0102), other losses: [] Acc metric: 500/500 (100.00%) AttnAUC: [] avg sec/iter: 0.0033
291 |
292 |
293 | Train set (epoch 45): [500/500 (100%)] Loss: 0.0100 (avg: 0.0125), other losses: [] Acc metric: 500/500 (100.00%) AttnAUC: [] avg sec/iter: 0.0034
294 |
295 |
296 | Train set (epoch 46): [500/500 (100%)] Loss: 0.0064 (avg: 0.0086), other losses: [] Acc metric: 500/500 (100.00%) AttnAUC: [] avg sec/iter: 0.0033
297 |
298 |
299 | Train set (epoch 47): [500/500 (100%)] Loss: 0.0144 (avg: 0.0266), other losses: [] Acc metric: 498/500 (99.60%) AttnAUC: [] avg sec/iter: 0.0034
300 |
301 |
302 | Train set (epoch 48): [500/500 (100%)] Loss: 0.0031 (avg: 0.0076), other losses: [] Acc metric: 500/500 (100.00%) AttnAUC: [] avg sec/iter: 0.0041
303 |
304 |
305 | Train set (epoch 49): [500/500 (100%)] Loss: 0.0038 (avg: 0.0063), other losses: [] Acc metric: 500/500 (100.00%) AttnAUC: [] avg sec/iter: 0.0037
306 |
307 |
308 | Train set (epoch 50): [500/500 (100%)] Loss: 0.0127 (avg: 0.0132), other losses: [] Acc metric: 500/500 (100.00%) AttnAUC: [] avg sec/iter: 0.0040
309 |
310 |
311 | Train set (epoch 51): [500/500 (100%)] Loss: 0.0026 (avg: 0.0129), other losses: [] Acc metric: 500/500 (100.00%) AttnAUC: [] avg sec/iter: 0.0041
312 |
313 |
314 | Train set (epoch 52): [500/500 (100%)] Loss: 0.0416 (avg: 0.0121), other losses: [] Acc metric: 500/500 (100.00%) AttnAUC: [] avg sec/iter: 0.0059
315 |
316 |
317 | Train set (epoch 53): [500/500 (100%)] Loss: 0.0271 (avg: 0.0148), other losses: [] Acc metric: 499/500 (99.80%) AttnAUC: [] avg sec/iter: 0.0060
318 |
319 |
320 | Train set (epoch 54): [500/500 (100%)] Loss: 0.0023 (avg: 0.0093), other losses: [] Acc metric: 499/500 (99.80%) AttnAUC: [] avg sec/iter: 0.0054
321 |
322 |
323 | Train set (epoch 55): [500/500 (100%)] Loss: 0.0026 (avg: 0.0056), other losses: [] Acc metric: 500/500 (100.00%) AttnAUC: [] avg sec/iter: 0.0051
324 |
325 |
326 | Train set (epoch 56): [500/500 (100%)] Loss: 0.0046 (avg: 0.0093), other losses: [] Acc metric: 500/500 (100.00%) AttnAUC: [] avg sec/iter: 0.0047
327 |
328 |
329 | Train set (epoch 57): [500/500 (100%)] Loss: 0.0070 (avg: 0.0046), other losses: [] Acc metric: 500/500 (100.00%) AttnAUC: [] avg sec/iter: 0.0071
330 |
331 |
332 | Train set (epoch 58): [500/500 (100%)] Loss: 0.0124 (avg: 0.0259), other losses: [] Acc metric: 499/500 (99.80%) AttnAUC: [] avg sec/iter: 0.0071
333 |
334 |
335 | Train set (epoch 59): [500/500 (100%)] Loss: 0.0123 (avg: 0.0149), other losses: [] Acc metric: 500/500 (100.00%) AttnAUC: [] avg sec/iter: 0.0077
336 |
337 |
338 | Train set (epoch 60): [500/500 (100%)] Loss: 0.0111 (avg: 0.0205), other losses: [] Acc metric: 498/500 (99.60%) AttnAUC: [] avg sec/iter: 0.0068
339 |
340 |
341 | Train set (epoch 61): [500/500 (100%)] Loss: 0.0123 (avg: 0.0052), other losses: [] Acc metric: 500/500 (100.00%) AttnAUC: [] avg sec/iter: 0.0059
342 |
343 |
344 | Train set (epoch 62): [500/500 (100%)] Loss: 0.0062 (avg: 0.0051), other losses: [] Acc metric: 500/500 (100.00%) AttnAUC: [] avg sec/iter: 0.0055
345 |
346 |
347 | Train set (epoch 63): [500/500 (100%)] Loss: 0.0290 (avg: 0.0056), other losses: [] Acc metric: 500/500 (100.00%) AttnAUC: [] avg sec/iter: 0.0053
348 |
349 |
350 | Train set (epoch 64): [500/500 (100%)] Loss: 0.0457 (avg: 0.0792), other losses: [] Acc metric: 466/500 (93.20%) AttnAUC: [] avg sec/iter: 0.0057
351 |
352 |
353 | Train set (epoch 65): [500/500 (100%)] Loss: 0.0024 (avg: 0.0074), other losses: [] Acc metric: 500/500 (100.00%) AttnAUC: [] avg sec/iter: 0.0073
354 |
355 |
356 | Train set (epoch 66): [500/500 (100%)] Loss: 0.0019 (avg: 0.0048), other losses: [] Acc metric: 500/500 (100.00%) AttnAUC: [] avg sec/iter: 0.0061
357 |
358 |
359 | Train set (epoch 67): [500/500 (100%)] Loss: 0.0031 (avg: 0.0033), other losses: [] Acc metric: 500/500 (100.00%) AttnAUC: [] avg sec/iter: 0.0056
360 |
361 |
362 | Train set (epoch 68): [500/500 (100%)] Loss: 0.0108 (avg: 0.0169), other losses: [] Acc metric: 500/500 (100.00%) AttnAUC: [] avg sec/iter: 0.0052
363 |
364 |
365 | Train set (epoch 69): [500/500 (100%)] Loss: 0.0166 (avg: 0.0050), other losses: [] Acc metric: 500/500 (100.00%) AttnAUC: [] avg sec/iter: 0.0067
366 |
367 |
368 | Train set (epoch 70): [500/500 (100%)] Loss: 0.0037 (avg: 0.0032), other losses: [] Acc metric: 500/500 (100.00%) AttnAUC: [] avg sec/iter: 0.0067
369 |
370 |
371 | Train set (epoch 71): [500/500 (100%)] Loss: 0.0059 (avg: 0.0060), other losses: [] Acc metric: 500/500 (100.00%) AttnAUC: [] avg sec/iter: 0.0078
372 |
373 |
374 | Train set (epoch 72): [500/500 (100%)] Loss: 0.0019 (avg: 0.0091), other losses: [] Acc metric: 500/500 (100.00%) AttnAUC: [] avg sec/iter: 0.0074
375 |
376 |
377 | Train set (epoch 73): [500/500 (100%)] Loss: 0.0046 (avg: 0.0028), other losses: [] Acc metric: 500/500 (100.00%) AttnAUC: [] avg sec/iter: 0.0075
378 |
379 |
380 | Train set (epoch 74): [500/500 (100%)] Loss: 0.0035 (avg: 0.0040), other losses: [] Acc metric: 500/500 (100.00%) AttnAUC: [] avg sec/iter: 0.0076
381 |
382 |
383 | Train set (epoch 75): [500/500 (100%)] Loss: 0.0064 (avg: 0.0055), other losses: [] Acc metric: 500/500 (100.00%) AttnAUC: [] avg sec/iter: 0.0075
384 |
385 |
386 | Train set (epoch 76): [500/500 (100%)] Loss: 0.0036 (avg: 0.0117), other losses: [] Acc metric: 500/500 (100.00%) AttnAUC: [] avg sec/iter: 0.0070
387 |
388 |
389 | Train set (epoch 77): [500/500 (100%)] Loss: 0.0032 (avg: 0.0028), other losses: [] Acc metric: 500/500 (100.00%) AttnAUC: [] avg sec/iter: 0.0053
390 |
391 |
392 | Train set (epoch 78): [500/500 (100%)] Loss: 0.0055 (avg: 0.0082), other losses: [] Acc metric: 500/500 (100.00%) AttnAUC: [] avg sec/iter: 0.0059
393 |
394 |
395 | Train set (epoch 79): [500/500 (100%)] Loss: 0.0239 (avg: 0.0055), other losses: [] Acc metric: 500/500 (100.00%) AttnAUC: [] avg sec/iter: 0.0068
396 |
397 |
398 | Train set (epoch 80): [500/500 (100%)] Loss: 0.0029 (avg: 0.0249), other losses: [] Acc metric: 499/500 (99.80%) AttnAUC: [] avg sec/iter: 0.0064
399 |
400 |
401 | Train set (epoch 81): [500/500 (100%)] Loss: 0.0073 (avg: 0.0033), other losses: [] Acc metric: 500/500 (100.00%) AttnAUC: [] avg sec/iter: 0.0058
402 |
403 |
404 | Train set (epoch 82): [500/500 (100%)] Loss: 0.0060 (avg: 0.0080), other losses: [] Acc metric: 500/500 (100.00%) AttnAUC: [] avg sec/iter: 0.0055
405 |
406 |
407 | Train set (epoch 83): [500/500 (100%)] Loss: 0.0521 (avg: 0.0741), other losses: [] Acc metric: 455/500 (91.00%) AttnAUC: [] avg sec/iter: 0.0059
408 |
409 |
410 | Train set (epoch 84): [500/500 (100%)] Loss: 0.0029 (avg: 0.0073), other losses: [] Acc metric: 500/500 (100.00%) AttnAUC: [] avg sec/iter: 0.0066
411 |
412 |
413 | Train set (epoch 85): [500/500 (100%)] Loss: 0.0017 (avg: 0.0022), other losses: [] Acc metric: 500/500 (100.00%) AttnAUC: [] avg sec/iter: 0.0067
414 |
415 |
416 | Train set (epoch 86): [500/500 (100%)] Loss: 0.0013 (avg: 0.0020), other losses: [] Acc metric: 500/500 (100.00%) AttnAUC: [] avg sec/iter: 0.0062
417 |
418 |
419 | Train set (epoch 87): [500/500 (100%)] Loss: 0.0027 (avg: 0.0020), other losses: [] Acc metric: 500/500 (100.00%) AttnAUC: [] avg sec/iter: 0.0059
420 |
421 |
422 | Train set (epoch 88): [500/500 (100%)] Loss: 0.0047 (avg: 0.0055), other losses: [] Acc metric: 500/500 (100.00%) AttnAUC: [] avg sec/iter: 0.0060
423 |
424 |
425 | Train set (epoch 89): [500/500 (100%)] Loss: 0.0012 (avg: 0.0061), other losses: [] Acc metric: 500/500 (100.00%) AttnAUC: [] avg sec/iter: 0.0072
426 |
427 |
428 | Train set (epoch 90): [500/500 (100%)] Loss: 0.0054 (avg: 0.0030), other losses: [] Acc metric: 500/500 (100.00%) AttnAUC: [] avg sec/iter: 0.0070
429 |
430 |
431 | Train set (epoch 91): [500/500 (100%)] Loss: 0.0017 (avg: 0.0017), other losses: [] Acc metric: 500/500 (100.00%) AttnAUC: [] avg sec/iter: 0.0072
432 |
433 |
434 | Train set (epoch 92): [500/500 (100%)] Loss: 0.0013 (avg: 0.0012), other losses: [] Acc metric: 500/500 (100.00%) AttnAUC: [] avg sec/iter: 0.0072
435 |
436 |
437 | Train set (epoch 93): [500/500 (100%)] Loss: 0.0020 (avg: 0.0012), other losses: [] Acc metric: 500/500 (100.00%) AttnAUC: [] avg sec/iter: 0.0076
438 |
439 |
440 | Train set (epoch 94): [500/500 (100%)] Loss: 0.0012 (avg: 0.0012), other losses: [] Acc metric: 500/500 (100.00%) AttnAUC: [] avg sec/iter: 0.0074
441 |
442 |
443 | Train set (epoch 95): [500/500 (100%)] Loss: 0.0010 (avg: 0.0012), other losses: [] Acc metric: 500/500 (100.00%) AttnAUC: [] avg sec/iter: 0.0066
444 |
445 |
446 | Train set (epoch 96): [500/500 (100%)] Loss: 0.0010 (avg: 0.0012), other losses: [] Acc metric: 500/500 (100.00%) AttnAUC: [] avg sec/iter: 0.0077
447 |
448 |
449 | Train set (epoch 97): [500/500 (100%)] Loss: 0.0012 (avg: 0.0012), other losses: [] Acc metric: 500/500 (100.00%) AttnAUC: [] avg sec/iter: 0.0078
450 |
451 |
452 | Train set (epoch 98): [500/500 (100%)] Loss: 0.0011 (avg: 0.0012), other losses: [] Acc metric: 500/500 (100.00%) AttnAUC: [] avg sec/iter: 0.0062
453 |
454 |
455 | Train set (epoch 99): [500/500 (100%)] Loss: 0.0019 (avg: 0.0012), other losses: [] Acc metric: 500/500 (100.00%) AttnAUC: [] avg sec/iter: 0.0046
456 |
457 |
458 | Train set (epoch 100): [500/500 (100%)] Loss: 0.0010 (avg: 0.0012), other losses: [] Acc metric: 500/500 (100.00%) AttnAUC: [] avg sec/iter: 0.0037
459 |
460 |
461 | saving the model to ./checkpoints//checkpoint_colors-3_828931_epoch100_seed0000111.pth.tar
462 | testing with evaluation of attention: takes longer time
463 | 100/500 samples processed
464 | 200/500 samples processed
465 | 300/500 samples processed
466 | 400/500 samples processed
467 | 500/500 samples processed
468 | lbl: 0, avg acc: 100.00% (58/58)
469 | lbl: 1, avg acc: 100.00% (56/56)
470 | lbl: 2, avg acc: 100.00% (52/52)
471 | lbl: 3, avg acc: 100.00% (56/56)
472 | lbl: 4, avg acc: 100.00% (53/53)
473 | lbl: 5, avg acc: 100.00% (49/49)
474 | lbl: 6, avg acc: 100.00% (49/49)
475 | lbl: 7, avg acc: 100.00% (30/30)
476 | lbl: 8, avg acc: 100.00% (33/33)
477 | lbl: 9, avg acc: 100.00% (35/35)
478 | lbl: 10, avg acc: 100.00% (29/29)
479 | 0 <= N_nodes <= 25 (min=4, max=25), avg acc: 100.00% (500/500)
480 | Train set (epoch 100): Avg loss: 0.0012, Acc metric: 500/500 (100.00%) AttnAUC: ['97.99'] avg sec/iter: 0.1890
481 |
482 | testing with evaluation of attention: takes longer time
483 | 100/7500 samples processed
484 | 200/7500 samples processed
485 | 300/7500 samples processed
486 | 400/7500 samples processed
487 | 500/7500 samples processed
488 | 600/7500 samples processed
489 | 700/7500 samples processed
490 | 800/7500 samples processed
491 | 900/7500 samples processed
492 | 1000/7500 samples processed
493 | 1100/7500 samples processed
494 | 1200/7500 samples processed
495 | 1300/7500 samples processed
496 | 1400/7500 samples processed
497 | 1500/7500 samples processed
498 | 1600/7500 samples processed
499 | 1700/7500 samples processed
500 | 1800/7500 samples processed
501 | 1900/7500 samples processed
502 | 2000/7500 samples processed
503 | 2100/7500 samples processed
504 | 2200/7500 samples processed
505 | 2300/7500 samples processed
506 | 2400/7500 samples processed
507 | 2500/7500 samples processed
508 | 2600/7500 samples processed
509 | 2700/7500 samples processed
510 | 2800/7500 samples processed
511 | 2900/7500 samples processed
512 | 3000/7500 samples processed
513 | 3100/7500 samples processed
514 | 3200/7500 samples processed
515 | 3300/7500 samples processed
516 | 3400/7500 samples processed
517 | 3500/7500 samples processed
518 | 3600/7500 samples processed
519 | 3700/7500 samples processed
520 | 3800/7500 samples processed
521 | 3900/7500 samples processed
522 | 4000/7500 samples processed
523 | 4100/7500 samples processed
524 | 4200/7500 samples processed
525 | 4300/7500 samples processed
526 | 4400/7500 samples processed
527 | 4500/7500 samples processed
528 | 4600/7500 samples processed
529 | 4700/7500 samples processed
530 | 4800/7500 samples processed
531 | 4900/7500 samples processed
532 | 5000/7500 samples processed
533 | 5100/7500 samples processed
534 | 5200/7500 samples processed
535 | 5300/7500 samples processed
536 | 5400/7500 samples processed
537 | 5500/7500 samples processed
538 | 5600/7500 samples processed
539 | 5700/7500 samples processed
540 | 5800/7500 samples processed
541 | 5900/7500 samples processed
542 | 6000/7500 samples processed
543 | 6100/7500 samples processed
544 | 6200/7500 samples processed
545 | 6300/7500 samples processed
546 | 6400/7500 samples processed
547 | 6500/7500 samples processed
548 | 6600/7500 samples processed
549 | 6700/7500 samples processed
550 | 6800/7500 samples processed
551 | 6900/7500 samples processed
552 | 7000/7500 samples processed
553 | 7100/7500 samples processed
554 | 7200/7500 samples processed
555 | 7300/7500 samples processed
556 | 7400/7500 samples processed
557 | 7500/7500 samples processed
558 | lbl: 0, avg acc: 100.00% (262/262)
559 | lbl: 1, avg acc: 100.00% (249/249)
560 | lbl: 2, avg acc: 100.00% (258/258)
561 | lbl: 3, avg acc: 100.00% (259/259)
562 | lbl: 4, avg acc: 100.00% (269/269)
563 | lbl: 5, avg acc: 100.00% (236/236)
564 | lbl: 6, avg acc: 100.00% (212/212)
565 | lbl: 7, avg acc: 100.00% (218/218)
566 | lbl: 8, avg acc: 100.00% (196/196)
567 | lbl: 9, avg acc: 100.00% (192/192)
568 | lbl: 10, avg acc: 100.00% (149/149)
569 | 0 <= N_nodes <= 25 (min=4, max=25), avg acc: 100.00% (2500/2500)
570 | lbl: 0, avg acc: 54.87% (124/226)
571 | lbl: 1, avg acc: 58.80% (127/216)
572 | lbl: 2, avg acc: 58.30% (130/223)
573 | lbl: 3, avg acc: 60.76% (144/237)
574 | lbl: 4, avg acc: 59.83% (137/229)
575 | lbl: 5, avg acc: 59.31% (137/231)
576 | lbl: 6, avg acc: 65.42% (157/240)
577 | lbl: 7, avg acc: 63.64% (154/242)
578 | lbl: 8, avg acc: 65.26% (139/213)
579 | lbl: 9, avg acc: 69.23% (153/221)
580 | lbl: 10, avg acc: 67.12% (149/222)
581 | 26 <= N_nodes <= 200 (min=26, max=200), avg acc: 62.04% (1551/2500)
582 | lbl: 0, avg acc: 12.81% (31/242)
583 | lbl: 1, avg acc: 14.10% (33/234)
584 | lbl: 2, avg acc: 14.67% (33/225)
585 | lbl: 3, avg acc: 13.30% (27/203)
586 | lbl: 4, avg acc: 12.79% (28/219)
587 | lbl: 5, avg acc: 14.78% (34/230)
588 | lbl: 6, avg acc: 12.77% (30/235)
589 | lbl: 7, avg acc: 16.96% (39/230)
590 | lbl: 8, avg acc: 14.85% (34/229)
591 | lbl: 9, avg acc: 16.67% (38/228)
592 | lbl: 10, avg acc: 19.11% (43/225)
593 | 26 <= N_nodes <= 200 (min=26, max=200), avg acc: 14.80% (370/2500)
594 | Test set (epoch 100): Avg loss: 1.8102, Acc metric: 4421/7500 (58.95%) AttnAUC: ['99.83'] avg sec/iter: 0.5009
595 |
596 | done in 0:00:50.659462
597 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import random
3 | import datetime
4 | from torchvision import transforms
5 | from graphdata import *
6 | from train_test import *
7 | import warnings
8 |
9 | warnings.filterwarnings("once")
10 |
11 | def parse_args():
12 | parser = argparse.ArgumentParser(description='Run experiments with Graph Neural Networks')
13 | # Dataset
14 | parser.add_argument('-D', '--dataset', type=str, default='colors-3',
15 | choices=['colors-3', 'colors-4', 'colors-8', 'colors-16', 'colors-32',
16 | 'triangles', 'mnist', 'mnist-75sp', 'TU'],
17 | help='colors-n means the colors dataset with n-dimensional features; TU is any dataset from https://ls11-www.cs.tu-dortmund.de/staff/morris/graphkerneldatasets')
18 | parser.add_argument('-d', '--data_dir', type=str, default='./data', help='path to the dataset')
19 | # Hyperparameters
20 | parser.add_argument('--epochs', type=int, default=None, help='# of the epochs')
21 | parser.add_argument('--batch_size', type=int, default=32, help='batch size for training data')
22 | parser.add_argument('--lr', type=float, default=0.001, help='Learning Rate')
23 | parser.add_argument('--lr_decay_step', type=str, default=None, help='number of epochs after which to reduce learning rate')
24 | parser.add_argument('--wdecay', type=float, default=1e-4, help='weight decay')
25 | parser.add_argument('--dropout', type=float, default=0, help='dropout rate')
26 | parser.add_argument('-f', '--filters', type=str, default='64,64,64', help='number of filters in each graph layer')
27 | parser.add_argument('-K', '--filter_scale', type=int, default=1, help='filter scale (receptive field size), must be > 0; 1 for GCN or GIN')
28 | parser.add_argument('--n_hidden', type=int, default=0, help='number of hidden units inside the graph layer')
29 | parser.add_argument('--aggregation', type=str, default='mean', choices=['mean', 'sum'], help='neighbors aggregation inside the graph layer')
30 | parser.add_argument('--readout', type=str, default=None, choices=['mean', 'sum', 'max'], help='type of global pooling over all nodes')
31 | parser.add_argument('--kl_weight', type=float, default=100, help='weight of the KL term in the loss')
32 | parser.add_argument('--pool', type=str, default=None, help='type of pooling between layers, None for global pooling only')
33 | parser.add_argument('--pool_arch', type=str, default=None, help='pooling layers architecture defining whether to use fully-connected layers or GNN and to which layer to attach (e.g.: fc_prev, gnn_prev, fc_curr, gnn_curr, fc_prev_32)')
34 | parser.add_argument('--init', type=str, default='normal', choices=['normal', 'uniform'], help='distribution used for initialization for the attention model')
35 | parser.add_argument('--scale', type=str, default='1', help='initialized weights scale for the attention model, set to None to use PyTorch default init')
36 | parser.add_argument('--degree_feature', action='store_true', default=False, help='use degree features (only for the Triangles dataset)')
37 | # TU datasets arguments
38 | parser.add_argument('--n_nodes', type=int, default=25, help='maximum number of nodes in the training set for collab, proteins and dd (35 for collab, 25 for proteins, 200 or 300 for dd)')
39 | parser.add_argument('--cv_folds', type=int, default=5, help='number of folds for cross-validating hyperparameters for collab, proteins and dd (5 or 10 shows similar results, 5 is faster)')
40 | parser.add_argument('--cv_threads', type=int, default=5, help='number of parallel threads for cross-validation')
41 | parser.add_argument('--tune_init', action='store_true', default=False, help='do not tune initialization hyperparameters')
42 | parser.add_argument('--ax', action='store_true', default=False, help='use AX for hyperparameter optimization (recommended)')
43 | parser.add_argument('--ax_trials', type=int, default=30, help='number of AX trials (hyperparameters optimization steps)')
44 | parser.add_argument('--cv', action='store_true', default=False, help='run in the cross-validation mode')
45 | parser.add_argument('--seed_data', type=int, default=111, help='random seed for data splits')
46 | # Image datasets arguments
47 | parser.add_argument('--img_features', type=str, default='mean,coord', help='image features to use as node features')
48 | parser.add_argument('--img_noise_levels', type=str, default=None,
49 | help='Gaussian noise standard deviations for grayscale and color image features')
50 | # Auxiliary arguments
51 | parser.add_argument('--validation', action='store_true', default=False, help='run in the validation mode')
52 | parser.add_argument('--debug', action='store_true', default=False, help='evaluate on the test set after each epoch (only for visualization purposes)')
53 | parser.add_argument('--eval_attn_train', action='store_true', default=False, help='evaluate attention and save coefficients on the training set for models without learnable attention')
54 | parser.add_argument('--eval_attn_test', action='store_true', default=False, help='evaluate attention and save coefficients on the test set for models without learnable attention')
55 | parser.add_argument('--test_batch_size', type=int, default=100, help='batch size for test data')
56 | parser.add_argument('--alpha_ws', type=str, default=None, help='attention labels that will be used for (weak)supervision')
57 | parser.add_argument('--log_interval', type=int, default=400, help='print interval')
58 | parser.add_argument('--results', type=str, default='./results', help='directory to save model checkpoints and other results, set to None to prevent saving anything')
59 | parser.add_argument('--resume', type=str, default=None, help='checkpoint to load the model and optimzer states from and continue training')
60 | parser.add_argument('--device', type=str, default='cuda', choices=['cuda', 'cpu'], help='cuda/cpu')
61 | parser.add_argument('--seed', type=int, default=111, help='random seed for model parameters')
62 | parser.add_argument('--threads', type=int, default=0, help='number of threads for data loader')
63 | args = parser.parse_args()
64 |
65 | # Set default number of epochs and learning rate schedules and other hyperparameters
66 | if args.readout in [None, 'None']:
67 | args.readout = 'max' # global max pooling for all datasets except for COLORS
68 | set_default_lr_decay_step = args.lr_decay_step in [None, 'None']
69 | if args.epochs in [None, 'None']:
70 | if args.dataset.find('mnist') >= 0:
71 | args.epochs = 30
72 | if set_default_lr_decay_step:
73 | args.lr_decay_step = '20,25'
74 | elif args.dataset == 'triangles':
75 | args.epochs = 100
76 | if set_default_lr_decay_step:
77 | args.lr_decay_step = '85,95'
78 | elif args.dataset == 'TU':
79 | args.epochs = 50
80 | if set_default_lr_decay_step:
81 | args.lr_decay_step = '25,35,45'
82 | elif args.dataset.find('color') >= 0:
83 | if args.readout in [None, 'None']:
84 | args.readout = 'sum'
85 | if args.pool in [None, 'None']:
86 | args.epochs = 100
87 | if set_default_lr_decay_step:
88 | args.lr_decay_step = '90'
89 | else:
90 | args.epochs = 300
91 | if set_default_lr_decay_step:
92 | args.lr_decay_step = '280'
93 | else:
94 | raise NotImplementedError(args.dataset)
95 |
96 | args.lr_decay_step = list(map(int, args.lr_decay_step.split(',')))
97 | args.filters = list(map(int, args.filters.split(',')))
98 | args.img_features = args.img_features.split(',')
99 | args.img_noise_levels = None if args.img_noise_levels in [None, 'None'] else list(map(float, args.img_noise_levels.split(',')))
100 | args.pool = None if args.pool in [None, 'None'] else args.pool.split('_')
101 | args.pool_arch = None if args.pool_arch in [None, 'None'] else args.pool_arch.split('_')
102 | try:
103 | args.scale = float(args.scale)
104 | except:
105 | args.scale = None
106 |
107 | args.torch = torch.__version__
108 |
109 | for arg in vars(args):
110 | print(arg, getattr(args, arg))
111 |
112 | return args
113 |
114 |
115 | def load_synthetic(args):
116 | train_dataset = SyntheticGraphs(args.data_dir, args.dataset, 'train', degree_feature=args.degree_feature,
117 | attn_coef=args.alpha_ws)
118 | test_dataset = SyntheticGraphs(args.data_dir, args.dataset, 'val' if args.validation else 'test',
119 | degree_feature=args.degree_feature)
120 | loss_fn = mse_loss
121 | collate_fn = collate_batch
122 | in_features = train_dataset.feature_dim
123 | out_features = 1
124 | return train_dataset, test_dataset, loss_fn, collate_fn, in_features, out_features
125 |
126 |
127 | def load_mnist(args):
128 | use_mean_px = 'mean' in args.img_features
129 | use_coord = 'coord' in args.img_features
130 | assert use_mean_px, ('this mode is not well supported', use_mean_px)
131 | gt_attn_threshold = 0 if (args.pool is not None and args.pool[1] in ['gt'] and args.filter_scale > 1) else 0.5
132 | if args.dataset == 'mnist':
133 | train_dataset = MNIST(args.data_dir, train=True, download=True, transform=transforms.ToTensor(),
134 | attn_coef=args.alpha_ws)
135 | else:
136 | train_dataset = MNIST75sp(args.data_dir, split='train', use_mean_px=use_mean_px, use_coord=use_coord,
137 | gt_attn_threshold=gt_attn_threshold, attn_coef=args.alpha_ws)
138 |
139 | noises, color_noises = None, None
140 | if args.validation:
141 | n_val = 5000
142 | if args.dataset == 'mnist':
143 | train_dataset.train_data = train_dataset.train_data[:-n_val]
144 | train_dataset.train_labels = train_dataset.train_labels[:-n_val]
145 | test_dataset = MNIST(args.data_dir, train=True, download=True, transform=transforms.ToTensor())
146 | test_dataset.train_data = train_dataset.train_data[-n_val:]
147 | test_dataset.train_labels = train_dataset.train_labels[-n_val:]
148 | else:
149 | train_dataset.train_val_split(np.arange(0, train_dataset.n_samples - n_val))
150 | test_dataset = MNIST75sp(args.data_dir, split='train', use_mean_px=use_mean_px, use_coord=use_coord,
151 | gt_attn_threshold=gt_attn_threshold)
152 | test_dataset.train_val_split(np.arange(train_dataset.n_samples - n_val, train_dataset.n_samples))
153 | else:
154 | noise_file = pjoin(args.data_dir, '%s_noise.pt' % args.dataset.replace('-', '_'))
155 | color_noise_file = pjoin(args.data_dir, '%s_color_noise.pt' % args.dataset.replace('-', '_'))
156 | if args.dataset == 'mnist':
157 | test_dataset = MNIST(args.data_dir, train=False, download=True, transform=transforms.ToTensor())
158 | noise_shape = (len(test_dataset.test_labels), 28 * 28)
159 | else:
160 | test_dataset = MNIST75sp(args.data_dir, split='test', use_mean_px=use_mean_px, use_coord=use_coord,
161 | gt_attn_threshold=gt_attn_threshold)
162 | noise_shape = (len(test_dataset.labels), 75)
163 |
164 | # Generate/load noise (save it to make reproducible)
165 | noises = load_save_noise(noise_file, noise_shape)
166 | color_noises = load_save_noise(color_noise_file, (noise_shape[0], noise_shape[1], 3))
167 |
168 | if args.dataset == 'mnist':
169 | A, coord, mask = precompute_graph_images(train_dataset.train_data.shape[1])
170 | collate_fn = lambda batch: collate_batch_images(batch, A, mask, use_mean_px=use_mean_px,
171 | coord=coord if use_coord else None,
172 | gt_attn_threshold=gt_attn_threshold,
173 | replicate_features=args.img_noise_levels is not None)
174 | else:
175 | train_dataset.precompute_graph_data(replicate_features=args.img_noise_levels is not None, threads=12)
176 | test_dataset.precompute_graph_data(replicate_features=args.img_noise_levels is not None, threads=12)
177 | collate_fn = collate_batch
178 |
179 | loss_fn = F.cross_entropy
180 |
181 | in_features = 0 if args.img_noise_levels is None else 2
182 | for features in args.img_features:
183 | if features == 'mean':
184 | in_features += 1
185 | elif features == 'coord':
186 | in_features += 2
187 | else:
188 | raise NotImplementedError(features)
189 | in_features = np.max((in_features, 1)) # in_features=1 if neither mean nor coord are used (dummy features will be used in this case)
190 | out_features = 10
191 |
192 | return train_dataset, test_dataset, loss_fn, collate_fn, in_features, out_features, noises, color_noises
193 |
194 |
195 | def load_TU(args, cv_folds=5):
196 | loss_fn = F.cross_entropy
197 | collate_fn = collate_batch
198 | scale, init = args.scale, args.init
199 | n_hidden_attn = float(args.pool_arch[2]) if (args.pool_arch is not None and len(args.pool_arch) > 2) else 0
200 | if args.pool is None:
201 | # Global pooling models
202 | datareader = DataReader(data_dir=args.data_dir, N_nodes=args.n_nodes, rnd_state=rnd_data, folds=0)
203 | train_dataset = GraphData(datareader, None, 'train_val')
204 | test_dataset = GraphData(datareader, None, 'test')
205 | in_features = train_dataset.num_features
206 | out_features = train_dataset.num_classes
207 | pool = args.pool
208 | kl_weight = args.kl_weight
209 | elif args.pool[1] == 'gt':
210 | raise ValueError('ground truth attention for TU datasets is not available')
211 | elif args.pool[1] in ['sup', 'unsup']:
212 | datareader = DataReader(data_dir=args.data_dir, N_nodes=args.n_nodes, rnd_state=rnd_data, folds=cv_folds)
213 | if args.ax:
214 | # Cross-validation using Ax (recommended way), Python3 must be used
215 | best_parameters = ax_optimize(datareader, args, collate_fn, loss_fn, None, folds=cv_folds,
216 | threads=args.cv_threads, n_trials=args.ax_trials)
217 | pool = args.pool
218 | kl_weight = best_parameters['kl_weight']
219 | if args.tune_init:
220 | scale, init = best_parameters['scale'], best_parameters['init']
221 | n_hidden_attn, layer = best_parameters['n_hidden_attn'], 1
222 | if layer == 0:
223 | pool = copy.deepcopy(args.pool)
224 | del pool[3]
225 |
226 | pool = set_pool(best_parameters['pool'], pool)
227 |
228 | else:
229 | if not args.cv:
230 | # Run with some fixed parameters without cross-validation
231 | pool_thresh_values = np.array([float(args.pool[-1])])
232 | n_hiddens = [n_hidden_attn]
233 | layers = [1]
234 | elif args.debug:
235 | pool_thresh_values = np.array([1e-4, 1e-1])
236 | n_hiddens = [n_hidden_attn]
237 | layers = [1]
238 | else:
239 | # Cross-validation using grid search (not recommended, since it's time consuming and not effective
240 | if args.data_dir.lower().find('proteins') >= 0:
241 | pool_thresh_values = np.array([2e-3, 5e-3, 1e-2, 3e-2, 5e-2])
242 | elif args.data_dir.lower().find('dd') >= 0:
243 | pool_thresh_values = np.array([1e-4, 1e-3, 2e-3, 5e-3, 1e-2, 3e-2, 5e-2, 1e-1])
244 | elif args.data_dir.lower().find('collab') >= 0:
245 | pool_thresh_values = np.array([1e-3, 2e-3, 5e-3, 1e-2, 3e-2, 5e-2, 1e-1])
246 | else:
247 | raise NotImplementedError('this dataset is not supported currently')
248 | n_hiddens = np.array([0, 32]) # hidden units in the atention model
249 | layers = np.array([0, 1]) # layer where to attach the attention model
250 |
251 | if args.pool[1] == 'sup' and not args.debug and args.cv:
252 | kl_weight_values = np.array([0.25, 1, 2, 10])
253 | else:
254 | kl_weight_values = np.array([args.kl_weight]) # any value (ignored for unsupervised training)
255 |
256 |
257 | if len(pool_thresh_values) > 1 or len(kl_weight_values) > 1 or len(n_hiddens) > 1 or len(layers) > 1:
258 | val_acc = np.zeros((len(layers), len(n_hiddens), len(pool_thresh_values), len(kl_weight_values)))
259 | for i_, layer in enumerate(layers):
260 | if layer == 0:
261 | pool = copy.deepcopy(args.pool)
262 | del pool[3]
263 | else:
264 | pool = args.pool
265 | for j_, n_hidden_attn in enumerate(n_hiddens):
266 | for k_, pool_thresh in enumerate(pool_thresh_values):
267 | for m_, kl_weight in enumerate(kl_weight_values):
268 | val_acc[i_, j_, k_, m_] = \
269 | cross_validation(datareader, args, collate_fn, loss_fn, set_pool(pool_thresh, pool),
270 | kl_weight, None, n_hidden_attn=n_hidden_attn, folds=cv_folds, threads=args.cv_threads)
271 | ind1, ind2, ind3, ind4 = np.where(val_acc == np.max(val_acc)) # np.argmax returns only first occurrence
272 | print(val_acc)
273 | print(ind1, ind2, ind3, ind4, layers[ind1], n_hiddens[ind2], pool_thresh_values[ind3], kl_weight_values[ind4],
274 | val_acc[ind1[0], ind2[0], ind3[0], ind4[0]])
275 |
276 | layer = layers[ind1[0]]
277 | if layer == 0:
278 | pool = copy.deepcopy(args.pool)
279 | del pool[3]
280 | else:
281 | pool = args.pool
282 | n_hidden_attn = n_hiddens[ind2[0]]
283 | pool = set_pool(pool_thresh_values[ind3[0]], pool)
284 | kl_weight = kl_weight_values[ind4[0]]
285 | else:
286 | pool = args.pool
287 | kl_weight = args.kl_weight
288 |
289 | train_dataset = GraphData(datareader, None, 'train_val')
290 | test_dataset = GraphData(datareader, None, 'test')
291 | in_features = train_dataset.num_features
292 | out_features = train_dataset.num_classes
293 |
294 | if args.pool[1] == 'sup':
295 | # Train a model with global pooling first
296 | train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.threads,
297 | collate_fn=collate_fn)
298 | train_loader_test = DataLoader(train_dataset, batch_size=args.test_batch_size, shuffle=False,
299 | num_workers=args.threads, collate_fn=collate_fn)
300 | # Train global pooling model
301 | start_epoch, model, optimizer, scheduler = create_model_optimizer(in_features, out_features, None, kl_weight,
302 | args, scale=scale, init=init, n_hidden_attn=n_hidden_attn)
303 | for epoch in range(start_epoch, args.epochs + 1):
304 | scheduler.step()
305 | train_loss, acc = train(model, train_loader, optimizer, epoch, args, loss_fn, None)
306 | train_loss, train_acc, attn_WS = test(model, train_loader_test, epoch, loss_fn, 'train', args, None,
307 | eval_attn=True)[:3]
308 | train_dataset = GraphData(datareader, None, 'train_val', attn_labels=attn_WS)
309 | else:
310 | raise NotImplementedError(args.pool)
311 |
312 | return train_dataset, test_dataset, loss_fn, collate_fn, in_features, out_features, pool, kl_weight, scale, init, n_hidden_attn
313 |
314 |
315 | if __name__ == '__main__':
316 |
317 | # mp.set_start_method('spawn')
318 | dt = datetime.datetime.now()
319 | print('start time:', dt)
320 | args = parse_args()
321 | args.experiment_ID = '%06d' % dt.microsecond
322 | print('experiment_ID: ', args.experiment_ID)
323 |
324 | if args.cv_threads > 1 and args.dataset == 'TU':
325 | # this requires python3
326 | torch.multiprocessing.set_start_method('spawn')
327 |
328 | print('gpus: ', torch.cuda.device_count())
329 |
330 | if args.results not in [None, 'None'] and not os.path.isdir(args.results):
331 | os.mkdir(args.results)
332 |
333 | rnd, rnd_data = set_seed(args.seed, args.seed_data)
334 |
335 | pool = args.pool
336 | kl_weight = args.kl_weight
337 | scale = args.scale
338 | init = args.init
339 | n_hidden_attn = float(args.pool_arch[2]) if (args.pool_arch is not None and len(args.pool_arch) > 2) else 0
340 | if args.dataset.find('colors') >= 0 or args.dataset == 'triangles':
341 | train_dataset, test_dataset, loss_fn, collate_fn, in_features, out_features = load_synthetic(args)
342 | elif args.dataset in ['mnist', 'mnist-75sp']:
343 | train_dataset, test_dataset, loss_fn, collate_fn, in_features, out_features, noises, color_noises = load_mnist(args)
344 |
345 | else:
346 | train_dataset, test_dataset, loss_fn, collate_fn, in_features, out_features, pool, kl_weight, scale, init, n_hidden_attn = \
347 | load_TU(args, cv_folds=args.cv_folds)
348 |
349 |
350 | train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.threads,
351 | collate_fn=collate_fn)
352 | # A loader to test and evaluate attn on the training set (shouldn't be shuffled and have larger batch size multiple of 50)
353 | train_loader_test = DataLoader(train_dataset, batch_size=args.test_batch_size, shuffle=False, num_workers=args.threads, collate_fn=collate_fn)
354 | print('test_dataset', test_dataset.split)
355 | test_loader = DataLoader(test_dataset, batch_size=args.test_batch_size, shuffle=False,
356 | num_workers=args.threads, collate_fn=collate_fn)
357 |
358 | start_epoch, model, optimizer, scheduler = create_model_optimizer(in_features, out_features, pool, kl_weight, args,
359 | scale=scale, init=init, n_hidden_attn=n_hidden_attn)
360 |
361 | feature_stats = None
362 | if args.dataset in ['mnist', 'mnist-75sp']:
363 | feature_stats = compute_feature_stats(model, train_loader, args.device, n_batches=1000)
364 |
365 | # Test function wrapper
366 | def test_fn(loader, epoch, split, eval_attn):
367 | test_loss, acc, _, _ = test(model, loader, epoch, loss_fn, split, args, feature_stats,
368 | noises=None, img_noise_level=None, eval_attn=eval_attn, alpha_WS_name='orig')
369 | if args.dataset in ['mnist', 'mnist-75sp'] and split == 'test' and args.img_noise_levels is not None:
370 | test(model, loader, epoch, loss_fn, split, args, feature_stats,
371 | noises=noises, img_noise_level=args.img_noise_levels[0], eval_attn=eval_attn, alpha_WS_name='noisy')
372 | test(model, loader, epoch, loss_fn, split, args, feature_stats,
373 | noises=color_noises, img_noise_level=args.img_noise_levels[1], eval_attn=eval_attn, alpha_WS_name='noisy-c')
374 | return test_loss, acc
375 |
376 | if start_epoch > args.epochs:
377 | print('evaluating the model')
378 | test_fn(test_loader, start_epoch - 1, 'val' if args.validation else 'test', args.eval_attn_test)
379 | else:
380 | for epoch in range(start_epoch, args.epochs + 1):
381 | eval_epoch = epoch <= 1 or epoch == args.epochs # check for epoch == 1 just to make sure that the test function works fine for this test set before training all the way until the last epoch
382 | scheduler.step()
383 | train_loss, acc = train(model, train_loader, optimizer, epoch, args, loss_fn, feature_stats)
384 | if eval_epoch:
385 | save_checkpoint(model, scheduler, optimizer, args, epoch)
386 | # Report Training accuracy and other metrics on the training set
387 | test_fn(train_loader_test, epoch, 'train', (epoch == args.epochs) and args.eval_attn_train)
388 |
389 | if args.validation:
390 | test_fn(test_loader, epoch, 'val', (epoch == args.epochs) and args.eval_attn_test)
391 | elif eval_epoch or args.debug:
392 | test_fn(test_loader, epoch, 'test', (epoch == args.epochs) and args.eval_attn_test)
393 |
394 | print('done in {}'.format(datetime.datetime.now() - dt))
395 |
--------------------------------------------------------------------------------
/notebooks/TRIANGLES_eval_models.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {
7 | "collapsed": true
8 | },
9 | "outputs": [],
10 | "source": [
11 | "import matplotlib as mpl\n",
12 | "import matplotlib.pyplot as plt\n",
13 | "import numpy as np\n",
14 | "import pickle\n",
15 | "from skimage.segmentation import slic\n",
16 | "import scipy.ndimage\n",
17 | "import scipy.spatial\n",
18 | "import torch\n",
19 | "from torchvision import datasets\n",
20 | "from torchvision import datasets\n",
21 | "import sys\n",
22 | "sys.path.append(\"../\")\n",
23 | "from chebygin import ChebyGIN\n",
24 | "from extract_superpixels import process_image\n",
25 | "from graphdata import comput_adjacency_matrix_images\n",
26 | "from train_test import load_save_noise\n",
27 | "from utils import list_to_torch, data_to_device, normalize_zero_one"
28 | ]
29 | },
30 | {
31 | "cell_type": "markdown",
32 | "metadata": {},
33 | "source": [
34 | "# TRIANGLES"
35 | ]
36 | },
37 | {
38 | "cell_type": "code",
39 | "execution_count": 2,
40 | "metadata": {},
41 | "outputs": [
42 | {
43 | "name": "stdout",
44 | "output_type": "stream",
45 | "text": [
46 | "dict_keys(['Adj_matrices', 'GT_attn', 'graph_labels', 'N_edges', 'Max_degree'])\n"
47 | ]
48 | }
49 | ],
50 | "source": [
51 | "data_dir = '/scratch/ssd/data/graph_attention_pool/'\n",
52 | "checkpoints_dir = '../checkpoints'\n",
53 | "device = 'cuda'\n",
54 | "\n",
55 | "with open('%s/random_graphs_triangles_test.pkl' % data_dir, 'rb') as f:\n",
56 | " data = pickle.load(f)\n",
57 | " \n",
58 | "print(data.keys())\n",
59 | "targets = torch.from_numpy(data['graph_labels']).long()\n",
60 | "Node_degrees = [np.sum(A, 1).astype(np.int32) for A in data['Adj_matrices']]\n",
61 | "\n",
62 | "feature_dim = data['Max_degree'] + 1\n",
63 | "node_features = []\n",
64 | "for i in range(len(data['Adj_matrices'])):\n",
65 | " N = data['Adj_matrices'][i].shape[0]\n",
66 | " D_onehot = np.zeros((N, feature_dim ))\n",
67 | " D_onehot[np.arange(N), Node_degrees[i]] = 1\n",
68 | " node_features.append(D_onehot)"
69 | ]
70 | },
71 | {
72 | "cell_type": "code",
73 | "execution_count": 3,
74 | "metadata": {
75 | "collapsed": true
76 | },
77 | "outputs": [],
78 | "source": [
79 | "def acc(pred):\n",
80 | " n = len(pred)\n",
81 | " return torch.mean((torch.stack(pred).view(n) == targets[:len(pred)].view(n)).float()).item() * 100\n",
82 | "\n",
83 | "def test(model, index, show_img=False): \n",
84 | " N_nodes = data['Adj_matrices'][index].shape[0]\n",
85 | " mask = torch.ones(1, N_nodes, dtype=torch.uint8)\n",
86 | " x = torch.from_numpy(node_features[index]).unsqueeze(0).float() \n",
87 | " A = torch.from_numpy(data['Adj_matrices'][index].astype(np.float32)).float().unsqueeze(0)\n",
88 | " y, other_outputs = model(data_to_device([x, A, mask, -1, {'N_nodes': torch.zeros(1, 1) + N_nodes}], \n",
89 | " device)) \n",
90 | " y = y.round().long().data.cpu()[0][0]\n",
91 | " alpha = other_outputs['alpha'][0].data.cpu() if 'alpha' in other_outputs else [] \n",
92 | " return y, alpha\n",
93 | "\n",
94 | "\n",
95 | "# This function returns predictions for the entire clean and noise test sets\n",
96 | "def get_predictions(model_path):\n",
97 | " state = torch.load(model_path)\n",
98 | " args = state['args']\n",
99 | " model = ChebyGIN(in_features=14,\n",
100 | " out_features=1,\n",
101 | " filters=args.filters,\n",
102 | " K=args.filter_scale,\n",
103 | " n_hidden=args.n_hidden,\n",
104 | " aggregation=args.aggregation,\n",
105 | " dropout=args.dropout,\n",
106 | " readout=args.readout,\n",
107 | " pool=args.pool,\n",
108 | " pool_arch=args.pool_arch)\n",
109 | " model.load_state_dict(state['state_dict'])\n",
110 | " model = model.eval().to(device)\n",
111 | "# print(model) \n",
112 | "\n",
113 | " # Get predictions\n",
114 | " pred, alpha = [], []\n",
115 | " for index in range(len(data['Adj_matrices'])):\n",
116 | " y = test(model, index, index == 0)\n",
117 | " pred.append(y[0])\n",
118 | " alpha.append(y[1])\n",
119 | " if len(pred) % 1000 == 0:\n",
120 | " print('{}/{}, acc on the combined test set={:.2f}%'.format(len(pred), len(data['Adj_matrices']), acc(pred)))\n",
121 | " return pred, alpha"
122 | ]
123 | },
124 | {
125 | "cell_type": "markdown",
126 | "metadata": {},
127 | "source": [
128 | "## Weakly-supervised attention model"
129 | ]
130 | },
131 | {
132 | "cell_type": "code",
133 | "execution_count": 4,
134 | "metadata": {
135 | "scrolled": true
136 | },
137 | "outputs": [
138 | {
139 | "name": "stdout",
140 | "output_type": "stream",
141 | "text": [
142 | "ChebyGINLayer torch.Size([64, 98]) tensor([0.5568, 0.5545, 0.5580, 0.5656, 0.5318, 0.5698, 0.5655, 0.5937, 0.6087,\n",
143 | " 0.5437], grad_fn=)\n",
144 | "ChebyGINLayer torch.Size([32, 128]) tensor([0.5730, 0.5968, 0.5778, 0.5940, 0.5981, 0.5787, 0.5619, 0.5798, 0.5741,\n",
145 | " 0.5833], grad_fn=)\n",
146 | "ChebyGINLayer torch.Size([32, 64]) tensor([0.5703, 0.5380, 0.5825, 0.5836, 0.5649, 0.5537, 0.6568, 0.6129, 0.6161,\n",
147 | " 0.5258], grad_fn=)\n",
148 | "ChebyGINLayer torch.Size([1, 64]) tensor([0.5634], grad_fn=)\n",
149 | "ChebyGINLayer torch.Size([64, 448]) tensor([0.5923, 0.5840, 0.5608, 0.5615, 0.5799, 0.5668, 0.5924, 0.5840, 0.5709,\n",
150 | " 0.5637], grad_fn=)\n",
151 | "ChebyGINLayer torch.Size([32, 128]) tensor([0.5606, 0.5821, 0.5540, 0.5596, 0.6033, 0.6147, 0.5738, 0.5865, 0.5981,\n",
152 | " 0.5800], grad_fn=)\n",
153 | "ChebyGINLayer torch.Size([32, 64]) tensor([0.5938, 0.6073, 0.5995, 0.5230, 0.6091, 0.6070, 0.5901, 0.5752, 0.5594,\n",
154 | " 0.5499], grad_fn=)\n",
155 | "ChebyGINLayer torch.Size([1, 64]) tensor([0.6102], grad_fn=)\n",
156 | "ChebyGINLayer torch.Size([64, 448]) tensor([0.5877, 0.5797, 0.5591, 0.5688, 0.5758, 0.5645, 0.5483, 0.5846, 0.5883,\n",
157 | " 0.5961], grad_fn=)\n",
158 | "1000/10000, acc on the combined test set=83.00%\n",
159 | "2000/10000, acc on the combined test set=76.30%\n",
160 | "3000/10000, acc on the combined test set=72.23%\n",
161 | "4000/10000, acc on the combined test set=68.73%\n",
162 | "5000/10000, acc on the combined test set=66.82%\n",
163 | "6000/10000, acc on the combined test set=59.90%\n",
164 | "7000/10000, acc on the combined test set=55.04%\n",
165 | "8000/10000, acc on the combined test set=51.60%\n",
166 | "9000/10000, acc on the combined test set=49.02%\n",
167 | "10000/10000, acc on the combined test set=46.69%\n"
168 | ]
169 | }
170 | ],
171 | "source": [
172 | "pred, alpha = get_predictions('%s/checkpoint_triangles_230187_epoch100_seed0000111.pth.tar' % checkpoints_dir)"
173 | ]
174 | }
175 | ],
176 | "metadata": {
177 | "kernelspec": {
178 | "display_name": "Python 3",
179 | "language": "python",
180 | "name": "python3"
181 | },
182 | "language_info": {
183 | "codemirror_mode": {
184 | "name": "ipython",
185 | "version": 3
186 | },
187 | "file_extension": ".py",
188 | "mimetype": "text/x-python",
189 | "name": "python",
190 | "nbconvert_exporter": "python",
191 | "pygments_lexer": "ipython3",
192 | "version": "3.6.7"
193 | }
194 | },
195 | "nbformat": 4,
196 | "nbformat_minor": 2
197 | }
198 |
--------------------------------------------------------------------------------
/notebooks/convert2TU.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import numpy as np\n",
10 | "import pickle\n",
11 | "import os"
12 | ]
13 | },
14 | {
15 | "cell_type": "code",
16 | "execution_count": 2,
17 | "metadata": {},
18 | "outputs": [],
19 | "source": [
20 | "data_dir = '../data/'"
21 | ]
22 | },
23 | {
24 | "cell_type": "code",
25 | "execution_count": 3,
26 | "metadata": {},
27 | "outputs": [],
28 | "source": [
29 | "def write_data_TU(data, out_path, dim_test=None):\n",
30 | " c = 0\n",
31 | " nodes = 1\n",
32 | " with open('%s_A.txt' % out_path, 'w') as f:\n",
33 | " for A in data['Adj_matrices']:\n",
34 | " N = A.shape[0]\n",
35 | " for i in range(N):\n",
36 | " for j in range(N):\n",
37 | " if A[i, j] > 0:\n",
38 | " f.write('%d, %d\\n' % (i + nodes, j + nodes))\n",
39 | " nodes += N\n",
40 | " c += np.sum(A) / 2\n",
41 | " print(c)\n",
42 | "\n",
43 | " c = 1\n",
44 | " with open('%s_graph_indicator.txt' % out_path, 'w') as f:\n",
45 | " for A in data['Adj_matrices']:\n",
46 | " N = A.shape[0]\n",
47 | " for j in range(N):\n",
48 | " f.write('%d\\n' % c)\n",
49 | " c += 1\n",
50 | " print(c)\n",
51 | "\n",
52 | " with open('%s_graph_attributes.txt' % out_path, 'w') as f:\n",
53 | " for lbl in data['graph_labels']:\n",
54 | " f.write('%d\\n' % lbl)\n",
55 | "\n",
56 | " \n",
57 | " with open('%s_node_attributes.txt' % out_path, 'w') as f:\n",
58 | " for node_id in range(len(data['GT_attn'])):\n",
59 | " attn = data['GT_attn'][node_id]\n",
60 | " N = attn.shape[0]\n",
61 | " for i in range(N):\n",
62 | " f.write('%d' % attn[i])\n",
63 | " if dim_test is not None:\n",
64 | " f.write(', ')\n",
65 | " x = data['node_features'][node_id]\n",
66 | " for j in range(dim_test): \n",
67 | " if j < x.shape[1]: \n",
68 | " f.write('%d' % x[i, j])\n",
69 | " else:\n",
70 | " f.write('0')\n",
71 | " if j < dim_test - 1:\n",
72 | " f.write(', ') \n",
73 | " f.write('\\n') \n",
74 | " \n",
75 | "# with open('%s_node_attention.txt' % out_path, 'w') as f:\n",
76 | "# for x in data['GT_attn']:\n",
77 | "# N = x.shape[0]\n",
78 | "# for i in range(N):\n",
79 | "# f.write('%d\\n' % x[i])"
80 | ]
81 | },
82 | {
83 | "cell_type": "code",
84 | "execution_count": 4,
85 | "metadata": {},
86 | "outputs": [
87 | {
88 | "name": "stdout",
89 | "output_type": "stream",
90 | "text": [
91 | "train Adj_matrices 500\n",
92 | "train GT_attn 500\n",
93 | "train graph_labels 500\n",
94 | "train node_features 500\n",
95 | "train N_edges 500\n",
96 | "\n",
97 | "\n",
98 | "val Adj_matrices 3000\n",
99 | "val GT_attn 3000\n",
100 | "val graph_labels 3000\n",
101 | "val node_features 3000\n",
102 | "val N_edges 3000\n",
103 | "\n",
104 | "\n",
105 | "test Adj_matrices 10500\n",
106 | "test GT_attn 10500\n",
107 | "test graph_labels 10500\n",
108 | "test node_features 10500\n",
109 | "test N_edges 10500\n",
110 | "\n",
111 | "\n",
112 | "955788.0\n",
113 | "10501\n",
114 | "10500 61.31171428571429 60.52053469623652 4 200\n",
115 | "10500 91.02742857142857 93.66902222113747 4 397\n"
116 | ]
117 | }
118 | ],
119 | "source": [
120 | "for dim in [3]:\n",
121 | " dim_test = dim + 1\n",
122 | " data = {}\n",
123 | " out_dir = '%s/COLORS-%d' % (data_dir, dim)\n",
124 | " try:\n",
125 | " os.mkdir(out_dir)\n",
126 | " except Exception as e:\n",
127 | " print(e)\n",
128 | "\n",
129 | " for split in ['train', 'val', 'test']:\n",
130 | " \n",
131 | " with open('%s/random_graphs_colors_dim%d_%s.pkl' % (data_dir, dim, split), 'rb') as f:\n",
132 | " data_tmp = pickle.load(f) \n",
133 | " for key in data_tmp: \n",
134 | " if split == 'train':\n",
135 | " data[key] = data_tmp[key]\n",
136 | " else:\n",
137 | " if isinstance(data[key], list):\n",
138 | " data[key].extend(data_tmp[key])\n",
139 | " else:\n",
140 | " data[key] = np.concatenate((data[key], data_tmp[key]))\n",
141 | " print(split, key, len(data[key]))\n",
142 | " print('\\n')\n",
143 | " write_data_TU(data, '%s/COLORS-%d' % (out_dir, dim), dim_test)\n",
144 | " \n",
145 | "nodes = [A.shape[0] for A in data['Adj_matrices']]\n",
146 | "edges = [np.sum(A) // 2 for A in data['Adj_matrices']]\n",
147 | "print(len(nodes), np.mean(nodes), np.std(nodes), np.min(nodes), np.max(nodes))\n",
148 | "print(len(edges), np.mean(edges), np.std(edges), np.min(edges), np.max(edges)) "
149 | ]
150 | },
151 | {
152 | "cell_type": "code",
153 | "execution_count": 5,
154 | "metadata": {},
155 | "outputs": [
156 | {
157 | "name": "stdout",
158 | "output_type": "stream",
159 | "text": [
160 | "\n",
161 | "\n",
162 | "736732.0\n",
163 | "30001\n",
164 | "[Errno 17] File exists: '../data//TRIANGLES'\n",
165 | "val Adj_matrices 35000\n",
166 | "val GT_attn 35000\n",
167 | "val graph_labels 35000\n",
168 | "val N_edges 35000\n",
169 | "val Max_degree 14\n",
170 | "\n",
171 | "\n",
172 | "859328.0\n",
173 | "35001\n",
174 | "[Errno 17] File exists: '../data//TRIANGLES'\n",
175 | "test Adj_matrices 45000\n",
176 | "test GT_attn 45000\n",
177 | "test graph_labels 45000\n",
178 | "test N_edges 45000\n",
179 | "test Max_degree 14\n",
180 | "\n",
181 | "\n",
182 | "1473512.0\n",
183 | "45001\n",
184 | "45000 20.854177777777778 17.460003254790188 4 100\n",
185 | "45000 32.74471111111111 28.14843482573703 4 198\n"
186 | ]
187 | }
188 | ],
189 | "source": [
190 | "data = {}\n",
191 | "for split in ['train', 'val', 'test']:\n",
192 | " out_dir = '%s/TRIANGLES' % (data_dir)\n",
193 | " try:\n",
194 | " os.mkdir(out_dir)\n",
195 | " except Exception as e:\n",
196 | " print(e)\n",
197 | " with open('%s/random_graphs_triangles_%s.pkl' % (data_dir, split), 'rb') as f:\n",
198 | " data_tmp = pickle.load(f)\n",
199 | " for key in data_tmp:\n",
200 | " if split == 'train':\n",
201 | " data[key] = data_tmp[key]\n",
202 | " else:\n",
203 | " if key == 'Max_degree':\n",
204 | " print(split, key, data[key])\n",
205 | " data[key] = np.max((data[key], data_tmp[key]))\n",
206 | " else:\n",
207 | " if isinstance(data[key], list):\n",
208 | " data[key].extend(data_tmp[key])\n",
209 | " else:\n",
210 | " data[key] = np.concatenate((data[key], data_tmp[key]))\n",
211 | " print(split, key, len(data[key]))\n",
212 | " print('\\n')\n",
213 | " write_data_TU(data, '%s/TRIANGLES' % (out_dir))\n",
214 | " \n",
215 | "nodes = [A.shape[0] for A in data['Adj_matrices']]\n",
216 | "edges = [np.sum(A) // 2 for A in data['Adj_matrices']]\n",
217 | "print(len(nodes), np.mean(nodes), np.std(nodes), np.min(nodes), np.max(nodes))\n",
218 | "print(len(edges), np.mean(edges), np.std(edges), np.min(edges), np.max(edges)) "
219 | ]
220 | },
221 | {
222 | "cell_type": "code",
223 | "execution_count": null,
224 | "metadata": {
225 | "collapsed": true
226 | },
227 | "outputs": [],
228 | "source": []
229 | }
230 | ],
231 | "metadata": {
232 | "kernelspec": {
233 | "display_name": "Python 3",
234 | "language": "python",
235 | "name": "python3"
236 | },
237 | "language_info": {
238 | "codemirror_mode": {
239 | "name": "ipython",
240 | "version": 3
241 | },
242 | "file_extension": ".py",
243 | "mimetype": "text/x-python",
244 | "name": "python",
245 | "nbconvert_exporter": "python",
246 | "pygments_lexer": "ipython3",
247 | "version": "3.7.3"
248 | }
249 | },
250 | "nbformat": 4,
251 | "nbformat_minor": 2
252 | }
253 |
--------------------------------------------------------------------------------
/scripts/colors.sh:
--------------------------------------------------------------------------------
1 | dataset="colors-3"
2 | seed=111
3 | checkpoints_dir=./checkpoints/
4 | params="-D $dataset --test_batch_size 100 -K 2 -f 64,64 --aggregation mean --n_hidden 0 --readout sum --dropout 0 --threads 0 --pool_arch fc_prev --seed $seed --results $checkpoints_dir -d ./data"
5 |
6 | logs_dir=./logs/
7 |
8 | # Global pooling
9 | python main.py $params --eval_attn_train --eval_attn_test --epochs 100 --lr_decay_step 90 | tee $logs_dir/"$dataset"_global_max_seed"$seed".log;
10 |
11 | # Unsupervised attention
12 | thresh=0.03
13 | pool=unsup
14 | python main.py $params --pool attn_"$pool"_threshold_skip_"$thresh" --epochs 300 --lr_decay_step 280 | tee $logs_dir/"$dataset"_"$pool"_seed"$seed".log;
15 |
16 | # Supervised attention
17 | thresh=0.05
18 | pool=sup
19 | python main.py $params --pool attn_"$pool"_threshold_skip_"$thresh" --epochs 300 --lr_decay_step 280 | tee $logs_dir/"$dataset"_"$pool"_seed"$seed".log;
20 |
21 | # Weakly-supervised attention
22 | thresh=0.05
23 | python main.py $params --pool attn_sup_threshold_skip_"$thresh" --epochs 300 --lr_decay_step 280 --alpha_ws $checkpoints_dir/"$dataset"_alpha_WS_train_seed"$seed"_orig.pkl | tee $logs_dir/"$dataset"_weaksup_seed"$seed".log;
24 |
--------------------------------------------------------------------------------
/scripts/mnist_75sp.sh:
--------------------------------------------------------------------------------
1 | dataset="mnist-75sp"
2 | seed=3026
3 | checkpoints_dir=./checkpoints/
4 | params="-D $dataset --epochs 30 --lr_decay_step 20,25 --test_batch_size 100 -K 4 -f 4,64,512 --aggregation mean --n_hidden 0 --readout max --dropout 0.5 --threads 0 --img_features mean,coord --img_noise_levels 0.4,0.6 --pool_arch fc_prev --kl_weight 100 --seed $seed --results $checkpoints_dir -d ./data"
5 |
6 | logs_dir=./logs/
7 | thresh=0.01
8 |
9 | # Global pooling
10 | python main.py $params --eval_attn_train --eval_attn_test | tee $logs_dir/"$dataset"_global_max_seed"$seed".log;
11 |
12 | # Unsupervised and supervised attention
13 | for pool in unsup sup;
14 | do python main.py $params --pool attn_"$pool"_threshold_skip_skip_"$thresh" | tee $logs_dir/"$dataset"_"$pool"_seed"$seed".log;
15 | done
16 |
17 | # Weakly-supervised attention
18 | python main.py $params --pool attn_sup_threshold_skip_skip_"$thresh" --alpha_ws $checkpoints_dir/"$dataset"_alpha_WS_train_seed"$seed"_orig.pkl | tee $logs_dir/"$dataset"_weaksup_seed"$seed".log;
19 |
--------------------------------------------------------------------------------
/scripts/prepare_data.sh:
--------------------------------------------------------------------------------
1 | date
2 | seed=111
3 |
4 | # Generate Colors data
5 | out_dir=./data
6 | for dim in 3 8 16 32; do python generate_data.py --dim $dim -o $out_dir --seed $seed; done
7 |
8 | # Generate Triangles data
9 | python generate_data.py -D triangles --N_train 30000 --N_val 5000 --N_test 5000 --label_min 1 --label_max 10 --N_max 100 -o $out_dir --seed $seed
10 |
11 | # Generate MNIST-75sp data
12 | for split in train test; do python extract_superpixels.py -s $split -o $out_dir --seed $seed; done
13 |
14 | # Generate CIFAR-10-150sp data
15 | #for split in train test; do python extract_superpixels.py -D cifar10 -c 10 -n 150 -s $split -t 0 -o $out_dir; done
16 |
17 | # Generate noise for MNIST-75sp
18 | python -c "import sys,torch; print('seed=%s\nout file noise=%s\nout file color noise=%s' % (sys.argv[1], sys.argv[2], sys.argv[3])); torch.manual_seed(int(sys.argv[1])); noise=torch.randn(10000,75, dtype=torch.float); torch.save(noise, sys.argv[2]); colornoise=torch.randn(10000,75,3, dtype=torch.float); torch.save(colornoise, sys.argv[3]);" $seed "$out_dir"/mnist_75sp_noise.pt "$out_dir"/mnist_75sp_color_noise.pt
19 |
20 | # Download and unzip COLLAB, PROTEINS and D&D
21 | for dataset in COLLAB PROTEINS DD;
22 | do wget https://ls11-www.cs.tu-dortmund.de/people/morris/graphkerneldatasets/"$dataset".zip -P $out_dir; unzip "$out_dir"/"$dataset".zip -d $out_dir; done
23 |
24 | date
25 |
--------------------------------------------------------------------------------
/scripts/triangles.sh:
--------------------------------------------------------------------------------
1 | dataset="triangles"
2 | seed=111
3 | checkpoints_dir=./checkpoints/
4 | params="-D $dataset --epochs 100 --lr_decay_step 85,95 --test_batch_size 100 -K 7 -f 64,64,64 --aggregation sum --n_hidden 64 --readout max --dropout 0 --threads 0 --pool_arch gnn_curr --seed $seed --results $checkpoints_dir -d ./data"
5 |
6 | logs_dir=./logs/
7 |
8 | # Global pooling
9 | python main.py $params --eval_attn_train --eval_attn_test | tee $logs_dir/"$dataset"_global_max_seed"$seed".log;
10 |
11 | # Unsupervised attention
12 | thresh=0.0001
13 | pool=unsup
14 | python main.py $params --pool attn_"$pool"_threshold_skip_"$thresh"_"$thresh" | tee $logs_dir/"$dataset"_"$pool"_seed"$seed".log;
15 |
16 | # Supervised attention
17 | thresh=0.001
18 | pool=sup
19 | python main.py $params --pool attn_"$pool"_threshold_skip_"$thresh"_"$thresh" | tee $logs_dir/"$dataset"_"$pool"_seed"$seed".log;
20 |
21 | # Weakly-supervised attention
22 | thresh=0.01
23 | python main.py $params --pool attn_sup_threshold_skip_"$thresh"_"$thresh" --alpha_ws $checkpoints_dir/"$dataset"_alpha_WS_train_seed"$seed"_orig.pkl | tee $logs_dir/"$dataset"_weaksup_seed"$seed".log;
24 |
--------------------------------------------------------------------------------
/train_test.py:
--------------------------------------------------------------------------------
1 | import time
2 | import torch
3 | from torch.utils.data import DataLoader
4 | import torch.optim as optim
5 | import torch.optim.lr_scheduler as lr_scheduler
6 | from chebygin import *
7 | from utils import *
8 | from graphdata import *
9 | import torch.multiprocessing as mp
10 | import multiprocessing
11 | try:
12 | import ax
13 | from ax.service.managed_loop import optimize
14 | except Exception as e:
15 | print('AX is not available: %s' % str(e))
16 |
17 |
18 | def set_pool(pool_thresh, args_pool):
19 | pool = copy.deepcopy(args_pool)
20 | for i, s in enumerate(pool):
21 | try:
22 | thresh = float(s)
23 | pool[i] = str(pool_thresh)
24 | except:
25 | continue
26 | return pool
27 |
28 |
29 | def train_evaluate(datareader, args, collate_fn, loss_fn, feature_stats, parameterization, folds=10, threads=5):
30 |
31 | print('parameterization', parameterization)
32 |
33 | pool_thresh, kl_weight = parameterization['pool'], parameterization['kl_weight']
34 | pool = args.pool
35 |
36 | if args.tune_init:
37 | scale, init = parameterization['scale'], parameterization['init']
38 | else:
39 | scale, init = args.scale, args.init
40 |
41 | n_hidden_attn, layer = parameterization['n_hidden_attn'], 1
42 | if layer == 0:
43 | pool = copy.deepcopy(args.pool)
44 | del pool[3]
45 |
46 | pool = set_pool(pool_thresh, pool)
47 |
48 | manager = multiprocessing.Manager()
49 | val_acc = manager.dict()
50 | assert threads <= folds, (threads, folds)
51 | n_it = int(np.ceil(float(folds) / threads))
52 | for i in range(n_it):
53 | processes = []
54 | if threads <= 1:
55 | single_job(i * threads, datareader, args, collate_fn, loss_fn, pool, kl_weight,
56 | feature_stats, val_acc, scale=scale, init=init, n_hidden_attn=n_hidden_attn)
57 | else:
58 | for fold in range(threads):
59 | p = mp.Process(target=single_job,
60 | args=(i * threads + fold, datareader, args, collate_fn, loss_fn, pool, kl_weight,
61 | feature_stats, val_acc, scale, init, n_hidden_attn))
62 | p.start()
63 | processes.append(p)
64 |
65 | for p in processes:
66 | p.join()
67 |
68 | print(val_acc)
69 | val_acc = list(val_acc.values())
70 | print('average and std over {} folds: {} +- {}'.format(folds, np.mean(val_acc), np.std(val_acc)))
71 | metric = np.mean(val_acc) - np.std(val_acc) # large std is considered bad
72 | print('metric: avg acc - std: {}'.format(metric))
73 | return metric
74 |
75 |
76 | def ax_optimize(datareader, args, collate_fn, loss_fn, feature_stats, folds=10, threads=5, n_trials=30):
77 | parameters = [
78 | {"name": "pool", "type": "range", "bounds": [1e-4, 2e-2], "log_scale": False},
79 | {"name": "kl_weight", "type": "range", "bounds": [0.1, 10.], "log_scale": False},
80 | {"name": "n_hidden_attn", "type": "choice", "values": [0, 32]} # hidden units in the attention layer (0: no hidden layer)
81 | ]
82 |
83 | if args.tune_init:
84 | parameters.extend([{"name": "scale", "type": "range", "bounds": [0.1, 2.], "log_scale": False},
85 | {"name": "init", "type": "choice", "values": ['normal', 'uniform']}])
86 |
87 | best_parameters, values, experiment, model = optimize(
88 | parameters=parameters,
89 | evaluation_function=lambda parameterization: train_evaluate(datareader,
90 | args, collate_fn, loss_fn,
91 | feature_stats, parameterization, folds=folds,
92 | threads=threads),
93 | total_trials=n_trials,
94 | objective_name='accuracy',
95 | )
96 |
97 | print('best_parameters', best_parameters)
98 | print('values', values)
99 | return best_parameters
100 |
101 |
102 | def train(model, train_loader, optimizer, epoch, args, loss_fn, feature_stats=None, log=True):
103 | model.train()
104 | optimizer.zero_grad()
105 | n_samples, correct, train_loss = 0, 0, 0
106 | alpha_pred, alpha_GT = {}, {}
107 | start = time.time()
108 |
109 | # with torch.autograd.set_detect_anomaly(True):
110 | for batch_idx, data in enumerate(train_loader):
111 | data = data_to_device(data, args.device)
112 | if feature_stats is not None:
113 | data[0] = (data[0] - feature_stats[0]) / feature_stats[1]
114 | if batch_idx == 0 and epoch <= 1:
115 | sanity_check(model.eval(), data) # to disable the effect of dropout or other regularizers that can change behavior from batch to batch
116 | model.train()
117 | optimizer.zero_grad()
118 | mask = [data[2].view(len(data[2]), -1)]
119 | output, other_outputs = model(data)
120 | other_losses = other_outputs['reg'] if 'reg' in other_outputs else []
121 | alpha = other_outputs['alpha'] if 'alpha' in other_outputs else []
122 | mask.extend(other_outputs['mask'] if 'mask' in other_outputs else [])
123 | targets = data[3]
124 | loss = loss_fn(output, targets)
125 | for l in other_losses:
126 | loss += l
127 | loss_item = loss.item()
128 | train_loss += loss_item
129 | n_samples += len(targets)
130 | loss.backward() # accumulates gradient
131 | optimizer.step() # update weights
132 | time_iter = time.time() - start
133 | correct += count_correct(output.detach(), targets.detach())
134 | update_attn(data, alpha, alpha_pred, alpha_GT, mask)
135 | acc = 100. * correct / n_samples # average over all examples in the dataset
136 | train_loss_avg = train_loss / (batch_idx + 1)
137 |
138 | if log and ((batch_idx > 0 and batch_idx % args.log_interval == 0) or batch_idx == len(train_loader) - 1):
139 | print('Train set (epoch {}): [{}/{} ({:.0f}%)]\tLoss: {:.4f} (avg: {:.4f}), other losses: {}\tAcc metric: {}/{} ({:.2f}%)\t AttnAUC: {}\t avg sec/iter: {:.4f}'.format(
140 | epoch, n_samples, len(train_loader.dataset), 100. * n_samples / len(train_loader.dataset),
141 | loss_item, train_loss_avg, ['%.4f' % l.item() for l in other_losses],
142 | correct, n_samples, acc, ['%.2f' % a for a in attn_AUC(alpha_GT, alpha_pred)],
143 | time_iter / (batch_idx + 1)))
144 |
145 | assert n_samples == len(train_loader.dataset), (n_samples, len(train_loader.dataset))
146 |
147 | return train_loss, acc
148 |
149 |
150 | def test(model, test_loader, epoch, loss_fn, split, args, feature_stats=None, noises=None,
151 | img_noise_level=None, eval_attn=False, alpha_WS_name=''):
152 | model.eval()
153 | n_samples, correct, test_loss = 0, 0, 0
154 | pred, targets, N_nodes = [], [], []
155 | start = time.time()
156 | alpha_pred, alpha_GT = {}, {}
157 | if eval_attn:
158 | alpha_pred[0] = []
159 | print('testing with evaluation of attention: takes longer time')
160 | if args.debug:
161 | debug_data = {}
162 |
163 | with torch.no_grad():
164 | for batch_idx, data in enumerate(test_loader):
165 | data = data_to_device(data, args.device)
166 | if feature_stats is not None:
167 | assert feature_stats[0].shape[2] == feature_stats[1].shape[2] == data[0].shape[2], \
168 | (feature_stats[0].shape, feature_stats[1].shape, data[0].shape)
169 | data[0] = (data[0] - feature_stats[0]) / feature_stats[1]
170 | if batch_idx == 0 and epoch <= 1:
171 | sanity_check(model, data)
172 |
173 | if noises is not None:
174 | noise = noises[n_samples:n_samples + len(data[0])].to(args.device) * img_noise_level
175 | if len(noise.shape) == 2:
176 | noise = noise.unsqueeze(2)
177 | data[0][:, :, :3] = data[0][:, :, :3] + noise
178 |
179 | mask = [data[2].view(len(data[2]), -1)]
180 | N_nodes.append(data[4]['N_nodes'].detach())
181 | targets.append(data[3].detach())
182 | output, other_outputs = model(data)
183 | other_losses = other_outputs['reg'] if 'reg' in other_outputs else []
184 | alpha = other_outputs['alpha'] if 'alpha' in other_outputs else []
185 | mask.extend(other_outputs['mask'] if 'mask' in other_outputs else [])
186 | if args.debug:
187 | for key in other_outputs:
188 | if key.find('debug') >= 0:
189 | if key not in debug_data:
190 | debug_data[key] = []
191 | debug_data[key].append([d.data.cpu().numpy() for d in other_outputs[key]])
192 | if args.torch.find('1.') == 0:
193 | loss = loss_fn(output, data[3], reduction='sum')
194 | else:
195 | loss = loss_fn(output, data[3], reduce=False).sum()
196 | for l in other_losses:
197 | loss += l
198 | test_loss += loss.item()
199 | pred.append(output.detach())
200 |
201 |
202 | update_attn(data, alpha, alpha_pred, alpha_GT, mask)
203 | if eval_attn:
204 | assert len(alpha) == 0, ('invalid mode, eval_attn should be false for this type of pooling')
205 | alpha_pred[0].extend(attn_heatmaps(model, args.device, data, output.data, test_loader.batch_size, constant_mask=args.dataset=='mnist'))
206 |
207 | n_samples += len(data[0])
208 | if eval_attn and (n_samples % 100 == 0 or n_samples == len(test_loader.dataset)):
209 | print('{}/{} samples processed'.format(n_samples, len(test_loader.dataset)))
210 |
211 | assert n_samples == len(test_loader.dataset), (n_samples, len(test_loader.dataset))
212 |
213 | pred = torch.cat(pred)
214 | targets = torch.cat(targets)
215 | N_nodes = torch.cat(N_nodes)
216 | if args.dataset.find('colors') >= 0:
217 | correct = count_correct(pred, targets, N_nodes=N_nodes, N_nodes_min=0, N_nodes_max=25)
218 | if pred.shape[0] > 2500:
219 | correct += count_correct(pred[2500:5000], targets[2500:5000], N_nodes=N_nodes[2500:5000], N_nodes_min=26, N_nodes_max=200)
220 | correct += count_correct(pred[5000:], targets[5000:], N_nodes=N_nodes[5000:], N_nodes_min=26, N_nodes_max=200)
221 | elif args.dataset == 'triangles':
222 | correct = count_correct(pred, targets, N_nodes=N_nodes, N_nodes_min=0, N_nodes_max=25)
223 | if pred.shape[0] > 5000:
224 | correct += count_correct(pred, targets, N_nodes=N_nodes, N_nodes_min=26, N_nodes_max=100)
225 | else:
226 | correct = count_correct(pred, targets, N_nodes=N_nodes, N_nodes_min=0, N_nodes_max=1e5)
227 |
228 | time_iter = time.time() - start
229 |
230 | test_loss_avg = test_loss / n_samples
231 | acc = 100. * correct / n_samples # average over all examples in the dataset
232 | print('{} set (epoch {}): Avg loss: {:.4f}, Acc metric: {}/{} ({:.2f}%)\t AttnAUC: {}\t avg sec/iter: {:.4f}\n'.format(
233 | split.capitalize(), epoch, test_loss_avg, correct, n_samples, acc,
234 | ['%.2f' % a for a in attn_AUC(alpha_GT, alpha_pred)], time_iter / (batch_idx + 1)))
235 |
236 | if args.debug:
237 | for key in debug_data:
238 | for layer in range(len(debug_data[key][0])):
239 | print('{} (layer={}): {:.5f}'.format(key, layer, np.mean([d[layer] for d in debug_data[key]])))
240 |
241 | if eval_attn:
242 | alpha_pred = alpha_pred[0]
243 | if args.results in [None, 'None', ''] or alpha_WS_name == '':
244 | print('skip saving alpha values, invalid results dir (%s) or alpha_WS_name (%s)' % (args.results, alpha_WS_name))
245 | else:
246 | file_path = pjoin(args.results, '%s_alpha_WS_%s_seed%d_%s.pkl' % (args.dataset, split, args.seed, alpha_WS_name))
247 | if os.path.isfile(file_path):
248 | print('WARNING: file %s exists and will be overwritten' % file_path)
249 | with open(file_path, 'wb') as f:
250 | pickle.dump(alpha_pred, f, protocol=2)
251 |
252 | return test_loss, acc, alpha_pred, pred
253 |
254 |
255 | def update_attn(data, alpha, alpha_pred, alpha_GT, mask):
256 | key = 'node_attn_eval'
257 | for layer in range(len(mask)):
258 | mask[layer] = mask[layer].data.cpu().numpy() > 0
259 | if key in data[4]:
260 | if not isinstance(data[4][key], list):
261 | data[4][key] = [data[4][key]]
262 | for layer in range(len(data[4][key])):
263 | if layer not in alpha_GT:
264 | alpha_GT[layer] = []
265 | # print(key, layer, len(data[4][key]), len(mask))
266 | alpha_GT[layer].extend(masked_alpha(data[4][key][layer].data.cpu().numpy(), mask[layer]))
267 | for layer in range(len(alpha)):
268 | if layer not in alpha_pred:
269 | alpha_pred[layer] = []
270 | alpha_pred[layer].extend(masked_alpha(alpha[layer].data.cpu().numpy(), mask[layer]))
271 |
272 |
273 | def masked_alpha(alpha, mask):
274 | alpha_lst = []
275 | for i in range(len(alpha)):
276 | # print('gt', len(alpha), alpha[i].shape, mask[i].shape, alpha[i][mask[i] > 0].shape, mask[i].sum(), mask[i].min(), mask[i].max(), mask[i].dtype)
277 | alpha_lst.append(alpha[i][mask[i]])
278 | return alpha_lst
279 |
280 |
281 | def attn_heatmaps(model, device, data, output_org, batch_size=1, constant_mask=False):
282 | labels = torch.argmax(output_org, dim=1)
283 | B, N_nodes_max, C = data[0].shape # N_nodes should be the same in the batch
284 | alpha_WS = []
285 | if N_nodes_max > 1000:
286 | print('WARNING: graph is too large (%d nodes) and not supported by this function (evaluation will be incorrect for graphs in this batch).' % N_nodes_max)
287 | for b in range(B):
288 | n = data[2][b].sum().item()
289 | alpha_WS.append(np.zeros((1, n)) + 1. / n)
290 | return alpha_WS
291 |
292 | if constant_mask:
293 | mask = torch.ones(N_nodes_max, N_nodes_max - 1).to(device)
294 |
295 | # Indices of nodes such that in each row one index (i.e. one node) is removed
296 | node_ids = torch.arange(start=0, end=N_nodes_max, device=device).view(1, -1).repeat(N_nodes_max, 1)
297 | node_ids[np.diag_indices(N_nodes_max, 2)] = -1
298 | node_ids = node_ids[node_ids >= 0].view(N_nodes_max, N_nodes_max - 1).long()
299 |
300 | with torch.no_grad():
301 | for b in range(B):
302 | x = torch.gather(data[0][b].unsqueeze(0).expand(N_nodes_max, -1, -1), dim=1, index=node_ids.unsqueeze(2).expand(-1, -1, C))
303 | if not constant_mask:
304 | mask = torch.gather(data[2][b].unsqueeze(0).expand(N_nodes_max, -1), dim=1, index=node_ids)
305 | A = torch.gather(data[1][b].unsqueeze(0).expand(N_nodes_max, -1, -1), dim=1, index=node_ids.unsqueeze(2).expand(-1, -1, N_nodes_max))
306 | A = torch.gather(A, dim=2, index=node_ids.unsqueeze(1).expand(-1, N_nodes_max - 1, -1))
307 | output = torch.zeros(N_nodes_max).to(device)
308 | n_chunks = int(np.ceil(N_nodes_max / float(batch_size)))
309 | for i in range(n_chunks):
310 | idx = np.arange(i * batch_size, (i + 1) * batch_size) if i < n_chunks - 1 else np.arange(i * batch_size, N_nodes_max)
311 | output[idx] = model([x[idx], A[idx], mask[idx], None, {}])[0][:, labels[b]].data
312 |
313 | alpha = torch.abs(output - output_org[b, labels[b]]).view(1, N_nodes_max) #* mask_org[b].view(1, N_nodes_max)
314 | if not constant_mask:
315 | alpha = alpha[data[2][b].view(1, N_nodes_max)]
316 | alpha_WS.append(normalize(alpha).data.cpu().numpy())
317 |
318 | return alpha_WS
319 |
320 |
321 | def save_checkpoint(model, scheduler, optimizer, args, epoch):
322 | if args.results in [None, 'None']:
323 | print('skip saving checkpoint, invalid results dir: %s' % args.results)
324 | return
325 | file_path = '%s/checkpoint_%s_%s_epoch%d_seed%07d.pth.tar' % (args.results, args.dataset, args.experiment_ID, epoch, args.seed)
326 | try:
327 | print('saving the model to %s' % file_path)
328 | state = {
329 | 'epoch': epoch,
330 | 'args': args,
331 | 'state_dict': model.state_dict(),
332 | 'scheduler': scheduler.state_dict(),
333 | 'optimizer': optimizer.state_dict(),
334 | }
335 | if os.path.isfile(file_path):
336 | print('WARNING: file %s exists and will be overwritten' % file_path)
337 | torch.save(state, file_path)
338 | except Exception as e:
339 | print('error saving the model', e)
340 |
341 |
342 | def load_checkpoint(model, optimizer, scheduler, file_path):
343 | print('loading the model from %s' % file_path)
344 | state = torch.load(file_path)
345 | model.load_state_dict(state['state_dict'])
346 | optimizer.load_state_dict(state['optimizer'])
347 | scheduler.load_state_dict(state['scheduler'])
348 | print('loading from epoch %d done' % state['epoch'])
349 | return state['epoch'] + 1 # +1 because we already finished training for this epoch
350 |
351 |
352 | def create_model_optimizer(in_features, out_features, pool, kl_weight, args, scale=None, init=None, n_hidden_attn=None):
353 | set_seed(args.seed, seed_data=None)
354 | model = ChebyGIN(in_features=in_features,
355 | out_features=out_features,
356 | filters=args.filters,
357 | K=args.filter_scale,
358 | n_hidden=args.n_hidden,
359 | aggregation=args.aggregation,
360 | dropout=args.dropout,
361 | readout=args.readout,
362 | pool=pool,
363 | pool_arch=args.pool_arch if n_hidden_attn in [None, 0] else args.pool_arch[:2] + ['%d' % n_hidden_attn],
364 | large_graph=args.dataset.lower() == 'mnist',
365 | kl_weight=float(kl_weight),
366 | init=args.init if init is None else init,
367 | scale=args.scale if scale is None else scale,
368 | debug=args.debug)
369 | print(model)
370 | # Compute the total number of trainable parameters
371 | print('model capacity: %d' %
372 | np.sum([np.prod(p.size()) if p.requires_grad else 0 for p in model.parameters()]))
373 |
374 | optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.wdecay, betas=(0.5, 0.999))
375 | scheduler = lr_scheduler.MultiStepLR(optimizer, args.lr_decay_step, gamma=0.1)
376 | epoch = 1
377 | if args.resume not in [None, 'None']:
378 | epoch = load_checkpoint(model, optimizer, scheduler, args.resume)
379 | if epoch < args.epochs + 1:
380 | print('resuming training for epoch %d' % epoch)
381 |
382 | model.to(args.device)
383 |
384 | return epoch, model, optimizer, scheduler
385 |
386 |
387 | def single_job(fold, datareader, args, collate_fn, loss_fn, pool, kl_weight, feature_stats, val_acc,
388 | scale=None, init=None, n_hidden_attn=None):
389 |
390 | set_seed(args.seed, seed_data=None)
391 |
392 | wsup = args.pool[1] == 'sup'
393 | train_loader = DataLoader(GraphData(datareader, fold, 'train'), batch_size=args.batch_size, shuffle=True,
394 | num_workers=args.threads, collate_fn=collate_fn)
395 | val_loader = DataLoader(GraphData(datareader, fold, 'val'), batch_size=args.test_batch_size, shuffle=False,
396 | num_workers=args.threads, collate_fn=collate_fn)
397 | start_epoch, model, optimizer, scheduler = create_model_optimizer(train_loader.dataset.num_features,
398 | train_loader.dataset.num_classes,
399 | None if wsup else pool, kl_weight, args,
400 | scale=scale, init=init, n_hidden_attn=n_hidden_attn)
401 |
402 | for epoch in range(start_epoch, args.epochs + 1):
403 | scheduler.step()
404 | train(model, train_loader, optimizer, epoch, args, loss_fn, feature_stats, log=False)
405 |
406 | if wsup:
407 | train_loader_test = DataLoader(GraphData(datareader, fold, 'train'), batch_size=args.test_batch_size, shuffle=False,
408 | num_workers=args.threads, collate_fn=collate_fn)
409 | train_loss, train_acc, attn_WS = test(model, train_loader_test, epoch, loss_fn, 'train', args, feature_stats, eval_attn=True)[:3] # test_loss, acc, alpha_pred, pred
410 | train_loader = DataLoader(GraphData(datareader, fold, 'train', attn_labels=attn_WS),
411 | batch_size=args.batch_size, shuffle=True,
412 | num_workers=args.threads, collate_fn=collate_fn)
413 | val_loader = DataLoader(GraphData(datareader, fold, 'val'), batch_size=args.test_batch_size, shuffle=False,
414 | num_workers=args.threads, collate_fn=collate_fn)
415 | start_epoch, model, optimizer, scheduler = create_model_optimizer(train_loader.dataset.num_features,
416 | train_loader.dataset.num_classes,
417 | pool, kl_weight, args,
418 | scale=scale, init=init, n_hidden_attn=n_hidden_attn)
419 | for epoch in range(start_epoch, args.epochs + 1):
420 | scheduler.step()
421 | train(model, train_loader, optimizer, epoch, args, loss_fn, feature_stats, log=False)
422 |
423 | acc = test(model, val_loader, epoch, loss_fn, 'val', args, feature_stats)[1]
424 |
425 | val_acc[fold] = acc
426 |
427 |
428 | def cross_validation(datareader, args, collate_fn, loss_fn, pool, kl_weight, feature_stats, n_hidden_attn=None, folds=10, threads=5):
429 | print('%d-fold cross-validation' % folds)
430 | manager = multiprocessing.Manager()
431 | val_acc = manager.dict()
432 | assert threads <= folds, (threads, folds)
433 | n_it = int(np.ceil(float(folds) / threads))
434 | for i in range(n_it):
435 | processes = []
436 | if threads <= 1:
437 | single_job(i * threads, datareader, args, collate_fn, loss_fn, pool, kl_weight,
438 | feature_stats, val_acc, scale=args.scale, init=args.init, n_hidden_attn=n_hidden_attn)
439 | else:
440 | for fold in range(threads):
441 | p = mp.Process(target=single_job, args=(i * threads + fold, datareader, args, collate_fn, loss_fn, pool, kl_weight,
442 | feature_stats, val_acc, args.scale, args.init, n_hidden_attn))
443 | p.start()
444 | processes.append(p)
445 |
446 | for p in processes:
447 | p.join()
448 |
449 | print(val_acc)
450 | val_acc = list(val_acc.values())
451 | print('average and std over {} folds: {} +- {}'.format(folds, np.mean(val_acc), np.std(val_acc)))
452 | metric = np.mean(val_acc) - np.std(val_acc)
453 | print('metric: avg acc - std: {}'.format(metric))
454 | return metric
455 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | import torch
4 | import copy
5 | from graphdata import *
6 | import torch.nn.functional as F
7 | from torchvision import datasets, transforms
8 | from sklearn.metrics import roc_auc_score
9 | import numbers
10 | import random
11 |
12 |
13 | def load_save_noise(f, noise_shape):
14 | if os.path.isfile(f):
15 | print('loading noise from %s' % f)
16 | noises = torch.load(f)
17 | else:
18 | noises = torch.randn(noise_shape, dtype=torch.float)
19 | # np.save(f, noises.numpy())
20 | torch.save(noises, f)
21 | return noises
22 |
23 |
24 | def list_to_torch(data):
25 | for i in range(len(data)):
26 | if data[i] is None:
27 | continue
28 | elif isinstance(data[i], np.ndarray):
29 | if data[i].dtype == np.bool:
30 | data[i] = data[i].astype(np.float32)
31 | data[i] = torch.from_numpy(data[i]).float()
32 | elif isinstance(data[i], list):
33 | data[i] = list_to_torch(data[i])
34 | return data
35 |
36 |
37 | def data_to_device(data, device):
38 | if isinstance(data, dict):
39 | keys = list(data.keys())
40 | else:
41 | keys = range(len(data))
42 | for i in keys:
43 | if isinstance(data[i], list) or isinstance(data[i], dict):
44 | data[i] = data_to_device(data[i], device)
45 | else:
46 | if isinstance(data[i], torch.Tensor):
47 | try:
48 | data[i] = data[i].to(device)
49 | except:
50 | print('error', i, data[i], type(data[i]))
51 | raise
52 | return data
53 |
54 |
55 | def count_correct(output, target, N_nodes=None, N_nodes_min=0, N_nodes_max=25):
56 | if output.shape[1] == 1:
57 | # Regression
58 | pred = output.round().long()
59 | else:
60 | # Classification
61 | pred = output.max(1, keepdim=True)[1]
62 | target = target.long().squeeze().cpu() # for older pytorch
63 | pred = pred.squeeze().cpu() # for older pytorch
64 | if N_nodes is not None:
65 | idx = (N_nodes >= N_nodes_min) & (N_nodes <= N_nodes_max)
66 | if idx.sum() > 0:
67 | correct = pred[idx].eq(target[idx]).sum().item()
68 | for lbl in torch.unique(target, sorted=True):
69 | idx_lbl = target[idx] == lbl
70 | eq = (pred[idx][idx_lbl] == target[idx][idx_lbl]).float()
71 | print('lbl: {}, avg acc: {:2.2f}% ({}/{})'.format(lbl, 100 * eq.mean(), int(eq.sum()),
72 | int(idx_lbl.float().sum())))
73 |
74 | eq = (pred[idx] == target[idx]).float()
75 | print('{} <= N_nodes <= {} (min={}, max={}), avg acc: {:2.2f}% ({}/{})'.format(N_nodes_min,
76 | N_nodes_max,
77 | N_nodes[idx].min(),
78 | N_nodes[idx].max(),
79 | 100 * eq.mean(),
80 | int(eq.sum()), int(idx.sum())))
81 | else:
82 | correct = 0
83 | print('no graphs with nodes >= {} and <= {}'.format(N_nodes_min, N_nodes_max))
84 | else:
85 | correct = pred.eq(target).sum().item()
86 |
87 | return correct
88 |
89 |
90 | def attn_AUC(alpha_GT, alpha):
91 | auc = []
92 | if len(alpha) > 0 and alpha_GT is not None and len(alpha_GT) > 0:
93 | for layer in alpha:
94 | alpha_gt = np.concatenate([a.flatten() for a in alpha_GT[layer]]) > 0
95 | if len(np.unique(alpha_gt)) <= 1:
96 | print('Only one class ({}) present in y_true. ROC AUC score is not defined in that case.'.format(np.unique(alpha_gt)))
97 | auc.append(np.nan)
98 | else:
99 | auc.append(100 * roc_auc_score(y_true=alpha_gt,
100 | y_score=np.concatenate([a.flatten() for a in alpha[layer]])))
101 | return auc
102 |
103 |
104 | def stats(arr):
105 | return np.mean(arr), np.std(arr), np.min(arr), np.max(arr)
106 |
107 |
108 | def normalize(x, eps=1e-7):
109 | return x / (x.sum() + eps)
110 |
111 |
112 | def normalize_batch(x, dim=1, eps=1e-7):
113 | return x / (x.sum(dim=dim, keepdim=True) + eps)
114 |
115 |
116 | def normalize_zero_one(im, eps=1e-7):
117 | m1 = im.min()
118 | m2 = im.max()
119 | return (im - m1) / (m2 - m1 + eps)
120 |
121 |
122 | def mse_loss(target, output, reduction='mean', reduce=None):
123 | loss = (target.float().squeeze() - output.float().squeeze()) ** 2
124 | if reduce is None:
125 | if reduction == 'mean':
126 | return torch.mean(loss)
127 | elif reduction == 'sum':
128 | return torch.sum(loss)
129 | elif reduction == 'none':
130 | return loss
131 | else:
132 | NotImplementedError(reduction)
133 | elif not reduce:
134 | return loss
135 | else:
136 | NotImplementedError('use reduction if reduce=True')
137 |
138 |
139 | def shuffle_nodes(batch):
140 | x, A, mask, labels, params_dict = batch
141 | for b in range(x.shape[0]):
142 | idx = np.random.permutation(x.shape[1])
143 | x[b] = x[b, idx]
144 | A[b] = A[b, :, idx][idx, :]
145 | mask[b] = mask[b, idx]
146 | if 'node_attn' in params_dict:
147 | params_dict['node_attn'][b] = params_dict['node_attn'][b, idx]
148 | return [x, A, mask, labels, params_dict]
149 |
150 |
151 | def copy_batch(data):
152 | data_cp = []
153 | for i in range(len(data)):
154 | if isinstance(data[i], dict):
155 | data_cp.append({key: data[i][key].clone() for key in data[i]})
156 | else:
157 | data_cp.append(data[i].clone())
158 | return data_cp
159 |
160 |
161 | def sanity_check(model, data):
162 | with torch.no_grad():
163 | output1 = model(copy_batch(data))[0]
164 | output2 = model(shuffle_nodes(copy_batch(data)))[0]
165 | if not torch.allclose(output1, output2, rtol=1e-02, atol=1e-03):
166 | print('WARNING: model outputs different depending on the nodes order', (torch.norm(output1 - output2),
167 | torch.max(output1 - output2),
168 | torch.max(output1),
169 | torch.max(output2)))
170 | print('model is checked for nodes shuffling')
171 |
172 |
173 | def set_seed(seed, seed_data=None):
174 | random.seed(seed) # for some libraries
175 | rnd = np.random.RandomState(seed)
176 | if seed_data is not None:
177 | rnd_data = np.random.RandomState(seed_data)
178 | else:
179 | rnd_data = rnd
180 | torch.backends.cudnn.deterministic = True
181 | torch.backends.cudnn.benchmark = True
182 | torch.manual_seed(seed)
183 | torch.cuda.manual_seed(seed)
184 | torch.cuda.manual_seed_all(seed)
185 | return rnd, rnd_data
186 |
187 |
188 | def compute_feature_stats(model, train_loader, device, n_batches=100):
189 | print('computing mean and std of input features')
190 | model.eval()
191 | x = []
192 | with torch.no_grad():
193 | for batch_idx, data in enumerate(train_loader):
194 | x.append(data[0].data.cpu().numpy()) # B,N,F
195 | if batch_idx > n_batches:
196 | break
197 | x = np.concatenate(x, axis=1).reshape(-1, x[0].shape[-1])
198 | print('features shape loaded', x.shape)
199 |
200 | mn = x.mean(axis=0, keepdims=True)
201 | sd = x.std(axis=0, keepdims=True)
202 | print('mn', mn)
203 | print('std', sd)
204 | sd[sd < 1e-2] = 1 # to prevent dividing by a small number
205 | print('corrected (non zeros) std', sd)#.data.cpu().numpy())
206 |
207 | mn = torch.from_numpy(mn).float().to(device).unsqueeze(0)
208 | sd = torch.from_numpy(sd).float().to(device).unsqueeze(0)
209 | return mn, sd
210 |
211 |
212 | def copy_data(data, idx):
213 | data_new = {}
214 | for key in data:
215 | if key == 'Max_degree':
216 | data_new[key] = data[key]
217 | print(key, data_new[key])
218 | else:
219 | data_new[key] = copy.deepcopy([data[key][i] for i in idx])
220 | if key in ['graph_labels', 'N_edges']:
221 | data_new[key] = np.array(data_new[key], np.int32)
222 | print(key, len(data_new[key]))
223 |
224 | return data_new
225 |
226 |
227 | def concat_data(data):
228 | data_new = {}
229 | for key in data[0]:
230 | if key == 'Max_degree':
231 | data_new[key] = np.max(np.array([ d[key] for d in data ]))
232 | print(key, data_new[key])
233 | else:
234 | if key in ['graph_labels', 'N_edges']:
235 | data_new[key] = np.concatenate([ d[key] for d in data ])
236 | else:
237 | lst = []
238 | for d in data:
239 | lst.extend(d[key])
240 | data_new[key] = lst
241 | print(key, len(data_new[key]))
242 |
243 | return data_new
244 |
--------------------------------------------------------------------------------