├── .gitattributes
├── .gitignore
├── LICENSE
├── README.md
├── configs
├── DD
│ ├── TUs_graph_classification_DiffPool_DD_100k.json
│ └── TUs_graph_classification_MEWISPool_DD_100k.json
├── FRANKENSTEIN
│ ├── TUs_graph_classification_DiffPool_FRANKENSTEIN_100k.json
│ ├── TUs_graph_classification_GAT_FRANKENSTEIN_100k.json
│ └── TUs_graph_classification_MEWISPool_FRANKENSTEIN_100k.json
├── MUTAG
│ ├── TUs_graph_classification_DiffPool_MUTAG_100k.json
│ ├── TUs_graph_classification_GAT_MUTAG_100k.json
│ └── TUs_graph_classification_MEWISPool_MUTAG_100k.json
├── NCI1
│ ├── TUs_graph_classification_DiffPool_NCI1_100k.json
│ ├── TUs_graph_classification_GAT_NCI1_100k.json
│ └── TUs_graph_classification_MEWISPool_NCI1_100k.json
├── NCI109
│ ├── TUs_graph_classification_DiffPool_NCI109_100k.json
│ ├── TUs_graph_classification_GAT_NCI109_100k.json
│ └── TUs_graph_classification_MEWISPool_NCI109_100k.json
├── PROTEINS_full
│ ├── TUs_graph_classification_DiffPool_PROTEINS_full_100k.json
│ └── TUs_graph_classification_MEWISPool_PROTEINS_full_100k.json
├── TUs_graph_classification_3WLGNN_DD_100k.json
├── TUs_graph_classification_3WLGNN_ENZYMES_100k.json
├── TUs_graph_classification_3WLGNN_PROTEINS_full_100k.json
├── TUs_graph_classification_GAT_DD_100k.json
├── TUs_graph_classification_GAT_ENZYMES_100k.json
├── TUs_graph_classification_GAT_PROTEINS_full_100k.json
├── TUs_graph_classification_GCN_DD_100k.json
├── TUs_graph_classification_GCN_ENZYMES_100k.json
├── TUs_graph_classification_GCN_PROTEINS_full_100k.json
├── TUs_graph_classification_GIN_DD_100k.json
├── TUs_graph_classification_GIN_ENZYMES_100k.json
├── TUs_graph_classification_GIN_PROTEINS_full_100k.json
├── TUs_graph_classification_GatedGCN_DD_100k.json
├── TUs_graph_classification_GatedGCN_ENZYMES_100k.json
├── TUs_graph_classification_GatedGCN_PROTEINS_full_100k.json
├── TUs_graph_classification_GraphSage_DD_100k.json
├── TUs_graph_classification_GraphSage_ENZYMES_100k.json
├── TUs_graph_classification_GraphSage_PROTEINS_full_100k.json
├── TUs_graph_classification_MLP_DD_100k.json
├── TUs_graph_classification_MLP_ENZYMES_100k.json
├── TUs_graph_classification_MLP_PROTEINS_full_100k.json
├── TUs_graph_classification_MoNet_DD_100k.json
├── TUs_graph_classification_MoNet_ENZYMES_100k.json
├── TUs_graph_classification_MoNet_PROTEINS_full_100k.json
├── TUs_graph_classification_RingGNN_DD_100k.json
├── TUs_graph_classification_RingGNN_ENZYMES_100k.json
├── TUs_graph_classification_RingGNN_PROTEINS_full_100k.json
└── Tox21_AhR_training
│ ├── TUs_graph_classification_DiffPool_Tox21_AhR_training_100k.json
│ ├── TUs_graph_classification_GAT_Tox21_AhR_training_100k.json
│ └── TUs_graph_classification_MEWISPool_Tox21_AhR_training_100k.json
├── data
├── TUs.py
├── TUs
│ ├── COLLAB_test.index
│ ├── COLLAB_train.index
│ ├── COLLAB_val.index
│ ├── DD_test.index
│ ├── DD_train.index
│ ├── DD_val.index
│ ├── ENZYMES_test.index
│ ├── ENZYMES_train.index
│ ├── ENZYMES_val.index
│ ├── PROTEINS_full_test.index
│ ├── PROTEINS_full_train.index
│ ├── PROTEINS_full_val.index
│ └── README.md
└── data.py
├── layers
├── diffpool_layer.py
├── gat_layer.py
├── gated_gcn_layer.py
├── gcn_layer.py
├── gin_layer.py
├── gmm_layer.py
├── graphsage_layer.py
├── mlp_readout_layer.py
├── ring_gnn_equiv_layer.py
└── three_wl_gnn_layers.py
├── main.py
├── metrics.py
├── nets
├── gat_net.py
├── gated_gcn_net.py
├── gcn_net.py
├── gin_net.py
├── graphsage_net.py
├── load_net.py
├── mlp_net.py
├── mo_net.py
├── ring_gnn_net.py
└── three_wl_gnn_net.py
├── pooling
├── diffpool.py
├── edge_sparse_max.py
├── hgpslpool.py
├── mewispool.py
├── sagpool.py
└── topkpool.py
├── train_TUs_graph_classification.py
└── visualization
├── visualizations_dataset_stats.ipynb
└── visualizations_single_graph.ipynb
/.gitattributes:
--------------------------------------------------------------------------------
1 | # Auto detect text files and perform LF normalization
2 | * text=auto
3 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | *.egg-info/
24 | .installed.cfg
25 | *.egg
26 | MANIFEST
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .nox/
42 | .coverage
43 | .coverage.*
44 | .cache
45 | nosetests.xml
46 | coverage.xml
47 | *.cover
48 | .hypothesis/
49 | .pytest_cache/
50 |
51 | # Translations
52 | *.mo
53 | *.pot
54 |
55 | # Django stuff:
56 | *.log
57 | local_settings.py
58 | db.sqlite3
59 |
60 | # Flask stuff:
61 | instance/
62 | .webassets-cache
63 |
64 | # Scrapy stuff:
65 | .scrapy
66 |
67 | # Sphinx documentation
68 | docs/_build/
69 |
70 | # PyBuilder
71 | target/
72 |
73 | # Jupyter Notebook
74 | .ipynb_checkpoints
75 |
76 | # IPython
77 | profile_default/
78 | ipython_config.py
79 |
80 | # pyenv
81 | .python-version
82 |
83 | # celery beat schedule file
84 | celerybeat-schedule
85 |
86 | # SageMath parsed files
87 | *.sage.py
88 |
89 | # Environments
90 | .env
91 | .venv
92 | env/
93 | venv/
94 | ENV/
95 | env.bak/
96 | venv.bak/
97 |
98 | # Spyder project settings
99 | .spyderproject
100 | .spyproject
101 |
102 | # Rope project settings
103 | .ropeproject
104 |
105 | # mkdocs documentation
106 | /site
107 |
108 | # mypy
109 | .mypy_cache/
110 | .dmypy.json
111 | dmypy.json
112 | # mypy
113 | .DS_Store
114 | .idea
115 | *.bak
116 | *.pkl
117 | save
118 | log
119 | log.test
120 | log.txt
121 | outputs
122 | out
123 | tmp
124 | tmp1.sh
125 | tmp2.sh
126 | tmp3.sh
127 | tmp4.sh
128 | result/
129 |
130 | # Pyre type checker
131 | .pyre/
132 |
133 |
134 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Graph-Classification-Baseline
2 |
3 | A framework of graph classification baselines which including TUDataset Loader, GNN models and visualization.
4 |
5 | ## Feature
6 |
7 | - Datasets
8 |
9 | Implemented based on LegacyTUDataset of dgl.data. All [TUDataset](https://chrsmrrs.github.io/datasets/docs/datasets/) is avalible. It will download auotomatically when using.
10 |
11 | - GNN models
12 |
13 | Including GCN, GAT, Gated GCN, GIN, GraphSAGE, etc. Deafult configurations can be find in `configs/`.
14 |
15 | **Models with configs having 500k trainable parameters**
16 |
17 | |Rank|Model | #Params | Test Acc ± s.d. | Links |
18 | |----| ---------- |------------:| :--------:|:-------:|
19 | |1|GatedGCN-PE | 505421 | 86.363 ± 0.127| [Paper](https://arxiv.org/abs/2003.00982) |
20 | |2|RingGNN | 504766 | 86.244 ± 0.025 |[Paper](https://papers.nips.cc/paper/9718-on-the-equivalence-between-graph-isomorphism-testing-and-function-approximation-with-gnns) |
21 | |3|MoNet | 511487 | 85.582 ± 0.038 | [Paper](https://arxiv.org/abs/1611.08402) |
22 | |4|GatedGCN | 502223 | 85.568 ± 0.088 | [Paper](https://arxiv.org/abs/1711.07553) |
23 | |5|GIN | 508574 | 85.387 ± 0.136 | [Paper](https://arxiv.org/abs/1810.00826)|
24 | |6|3WLGNN | 502872 | 85.341 ± 0.207 | [Paper](https://arxiv.org/abs/1905.11136) |
25 | |7|GAT | 526990 | 78.271 ± 0.186 | [Paper](https://arxiv.org/abs/1710.10903) |
26 | |8|GCN | 500823 | 71.892 ± 0.334 | [Paper](https://arxiv.org/abs/1609.02907) |
27 | |9|GraphSage | 502842 | 50.492 ± 0.001 | [Paper](https://cs.stanford.edu/people/jure/pubs/graphsage-nips17.pdf) |
28 |
29 | - Visualization
30 |
31 | - [Visualize for dataset statistics](./visualization/visualizations_dataset_stats.ipynb)
32 | - [Visualize for single graph](./visualization/visualizations_single_graph.ipynb)
33 |
34 |
35 | ## Setup
36 |
37 | All test on server 201 with python3.9 and CUDA 11.1. An environment is available at `/home/jiaxing/anaconda3/envs/xjx`.
38 |
39 | ```
40 | pip install dgl-cu111 -f https://data.dgl.ai/wheels/repo.html
41 | pip install torch==1.9.0+cu111 torchvision==0.10.0+cu111 torchaudio==0.9.0 -f https://download.pytorch.org/whl/torch_stable.html
42 | pip install torch-scatter torch-sparse torch-cluster torch-spline-conv torch-geometric -f https://data.pyg.org/whl/torch-1.9.0+cu111.html
43 | pip install matplotlib
44 | pip install tenserboardX
45 | ```
46 |
47 | ## Usage
48 |
49 | ### Quick Start
50 |
51 | ```
52 | python main.py --config configs/TUs_graph_classification_GCN_DD_100k.json
53 | ```
54 |
55 | ### Output and Checkpoints
56 |
57 | Output results are located in the folder defined by the variable `out_dir` in the corresponding config file.
58 |
59 | If `out_dir = 'out/TUs_graph_classification/'`, then
60 |
61 | - Go to `out/TUs_graph_classification/results to view all result text files.`
62 | - Directory `out/TUs_graph_classification/checkpoints` contains model checkpoints.
63 | - To see the training logs in Tensorboard on local machine
64 |
65 | - Go to the logs directory, i.e. `out/TUs_graph_classification/logs/`.
66 |
67 | - Run `tensorboard --logdir='./' --port 6006`.
68 |
69 | - Open `http://localhost:6006 in your browser`. Note that the port information (here 6006 but it may change) appears on the terminal immediately after starting tensorboard.
70 |
71 | ## Design your own model
72 |
73 |
74 | ### New graph layer
75 |
76 | Add a class `MyGraphLayer()` in `my_graph_layer.py` file in the `layers/` directory. A standard code is
77 |
78 | ```
79 | import torch
80 | import torch.nn as nn
81 | import dgl
82 |
83 | class MyGraphLayer(nn.Module):
84 |
85 | def __init__(self, in_dim, out_dim, dropout):
86 | super().__init__()
87 |
88 | # write your code here
89 |
90 | def forward(self, g_in, h_in, e_in):
91 |
92 | # write your code here
93 | # write the dgl reduce and updates function call here
94 |
95 | return h_out, e_out
96 | ```
97 | Directory `layers/` contains all layer classes for all graph networks and standard layers like *MLP* for *readout* layers.
98 |
99 | As instance, the GCN class `GCNLayer()` is defined in the [layers/gcn_layer.py](./layers/gcn_layer.py) file.
100 |
101 |
102 |
103 |
104 |
105 |
106 | ## 2. New graph network
107 |
108 | Add a class `MyGraphNetwork()` in `my_gcn_net.py` file in the `net/` directory. The `loss()` function of the network is also defined in class `MyGraphNetwork()`.
109 |
110 | ```
111 | import torch
112 | import torch.nn as nn
113 | import dgl
114 |
115 | from layers.my_graph_layer import MyGraphLayer
116 |
117 | class MyGraphNetwork(nn.Module):
118 |
119 | def __init__(self, in_dim, out_dim, dropout):
120 | super().__init__()
121 |
122 | # write your code here
123 | self.layer = MyGraphLayer()
124 |
125 | def forward(self, g_in, h_in, e_in):
126 |
127 | # write your code here
128 | # write the dgl reduce and updates function call here
129 |
130 | return h_out
131 |
132 | def loss(self, pred, label):
133 |
134 | # write your loss function here
135 |
136 | return loss
137 | ```
138 |
139 | Add a name `MyGNN` for the proposed new graph network class in `load_gnn.py` file in the `net/` directory.
140 |
141 | ```
142 | from nets.my_gcn_net import MyGraphNetwork
143 |
144 | def MyGNN(net_params):
145 | return MyGraphNetwork(net_params)
146 |
147 | def gnn_model(MODEL_NAME, net_params):
148 | models = {
149 | 'MyGNN': MyGNN
150 | }
151 | return models[MODEL_NAME](net_params)
152 | ```
153 |
154 |
155 | For example, `GCNNet()` in [nets/gcn_net.py](./nets/gcn_net.py) is given the GNN name `GCN` in [nets/load_net.py](./nets/load_net.py).
156 |
157 |
158 | ## Reference
159 |
160 | Benchmarking Graph Neural Networks [[paper](https://arxiv.org/pdf/2003.00982.pdf%3C/p%3E)] [[Github](https://github.com/graphdeeplearning/benchmarking-gnns)]
161 | https://codechina.csdn.net/mirrors/dmlc/dgl/-/tree/master/examples/pytorch
--------------------------------------------------------------------------------
/configs/DD/TUs_graph_classification_DiffPool_DD_100k.json:
--------------------------------------------------------------------------------
1 | {
2 | "gpu": {
3 | "use": true,
4 | "id": 0
5 | },
6 |
7 | "model": "DiffPool",
8 | "dataset": "DD",
9 |
10 | "out_dir": "out/TUs_graph_classification/",
11 |
12 | "params": {
13 | "seed": 41,
14 | "epochs": 1000,
15 | "batch_size": 10,
16 | "init_lr": 7e-4,
17 | "lr_reduce_factor": 0.5,
18 | "lr_schedule_patience": 25,
19 | "min_lr": 1e-6,
20 | "weight_decay": 0.0,
21 | "print_epoch_interval": 5,
22 | "max_time": 30
23 | },
24 |
25 | "net_params": {
26 | "L": 4,
27 | "hidden_dim": 32,
28 | "out_dim": 32,
29 | "residual": true,
30 | "readout": "mean",
31 | "in_feat_dropout": 0.0,
32 | "dropout": 0.0,
33 | "batch_norm": true,
34 | "sage_aggregator": "maxpool",
35 | "self_loop": false,
36 | "edge_feat": false
37 | }
38 | }
--------------------------------------------------------------------------------
/configs/DD/TUs_graph_classification_MEWISPool_DD_100k.json:
--------------------------------------------------------------------------------
1 | {
2 | "gpu": {
3 | "use": true,
4 | "id": 0
5 | },
6 |
7 | "model": "MEWISPool",
8 | "dataset": "DD",
9 |
10 | "out_dir": "out/TUs_graph_classification/",
11 |
12 | "params": {
13 | "seed": 41,
14 | "epochs": 1000,
15 | "batch_size": 20,
16 | "init_lr": 7e-4,
17 | "lr_reduce_factor": 0.5,
18 | "lr_schedule_patience": 25,
19 | "min_lr": 1e-6,
20 | "weight_decay": 0.0,
21 | "print_epoch_interval": 5,
22 | "max_time": 30
23 | },
24 |
25 | "net_params": {
26 | "L": 4,
27 | "hidden_dim": 146,
28 | "out_dim": 146,
29 | "residual": true,
30 | "readout": "mean",
31 | "in_feat_dropout": 0.0,
32 | "dropout": 0.0,
33 | "batch_norm": true,
34 | "self_loop": false,
35 | "edge_feat": false
36 | }
37 | }
--------------------------------------------------------------------------------
/configs/FRANKENSTEIN/TUs_graph_classification_DiffPool_FRANKENSTEIN_100k.json:
--------------------------------------------------------------------------------
1 | {
2 | "gpu": {
3 | "use": true,
4 | "id": 0
5 | },
6 |
7 | "model": "DiffPool",
8 | "dataset": "FRANKENSTEIN",
9 |
10 | "out_dir": "out/TUs_graph_classification/",
11 |
12 | "params": {
13 | "seed": 41,
14 | "epochs": 1000,
15 | "batch_size": 20,
16 | "init_lr": 7e-4,
17 | "lr_reduce_factor": 0.5,
18 | "lr_schedule_patience": 25,
19 | "min_lr": 1e-6,
20 | "weight_decay": 0.0,
21 | "print_epoch_interval": 5,
22 | "max_time": 30
23 | },
24 |
25 | "net_params": {
26 | "L": 4,
27 | "hidden_dim": 86,
28 | "out_dim": 86,
29 | "residual": true,
30 | "readout": "mean",
31 | "in_feat_dropout": 0.0,
32 | "dropout": 0.0,
33 | "batch_norm": true,
34 | "sage_aggregator": "maxpool",
35 | "self_loop": false,
36 | "edge_feat": false
37 | }
38 | }
--------------------------------------------------------------------------------
/configs/FRANKENSTEIN/TUs_graph_classification_GAT_FRANKENSTEIN_100k.json:
--------------------------------------------------------------------------------
1 | {
2 | "gpu": {
3 | "use": true,
4 | "id": 0
5 | },
6 |
7 | "model": "GAT",
8 | "dataset": "FRANKENSTEIN",
9 |
10 | "out_dir": "out/TUs_graph_classification/",
11 |
12 | "params": {
13 | "seed": 41,
14 | "epochs": 1000,
15 | "batch_size": 20,
16 | "init_lr": 5e-5,
17 | "lr_reduce_factor": 0.5,
18 | "lr_schedule_patience": 25,
19 | "min_lr": 1e-6,
20 | "weight_decay": 0.0,
21 | "print_epoch_interval": 5,
22 | "max_time": 30
23 | },
24 |
25 | "net_params": {
26 | "L": 4,
27 | "hidden_dim": 17,
28 | "out_dim": 136,
29 | "residual": true,
30 | "readout": "mean",
31 | "n_heads": 8,
32 | "in_feat_dropout": 0.0,
33 | "dropout": 0.0,
34 | "batch_norm": true,
35 | "self_loop": false,
36 | "edge_feat": false
37 | }
38 | }
--------------------------------------------------------------------------------
/configs/FRANKENSTEIN/TUs_graph_classification_MEWISPool_FRANKENSTEIN_100k.json:
--------------------------------------------------------------------------------
1 | {
2 | "gpu": {
3 | "use": true,
4 | "id": 0
5 | },
6 |
7 | "model": "MEWISPool",
8 | "dataset": "FRANKENSTEIN",
9 |
10 | "out_dir": "out/TUs_graph_classification/",
11 |
12 | "params": {
13 | "seed": 41,
14 | "epochs": 1000,
15 | "batch_size": 20,
16 | "init_lr": 7e-4,
17 | "lr_reduce_factor": 0.5,
18 | "lr_schedule_patience": 25,
19 | "min_lr": 1e-6,
20 | "weight_decay": 0.0,
21 | "print_epoch_interval": 5,
22 | "max_time": 30
23 | },
24 |
25 | "net_params": {
26 | "L": 4,
27 | "hidden_dim": 146,
28 | "out_dim": 146,
29 | "residual": true,
30 | "readout": "mean",
31 | "in_feat_dropout": 0.0,
32 | "dropout": 0.0,
33 | "batch_norm": true,
34 | "self_loop": false,
35 | "edge_feat": false
36 | }
37 | }
--------------------------------------------------------------------------------
/configs/MUTAG/TUs_graph_classification_DiffPool_MUTAG_100k.json:
--------------------------------------------------------------------------------
1 | {
2 | "gpu": {
3 | "use": true,
4 | "id": 0
5 | },
6 |
7 | "model": "DiffPool",
8 | "dataset": "MUTAG",
9 |
10 | "out_dir": "out/TUs_graph_classification/",
11 |
12 | "params": {
13 | "seed": 41,
14 | "epochs": 1000,
15 | "batch_size": 20,
16 | "init_lr": 7e-4,
17 | "lr_reduce_factor": 0.5,
18 | "lr_schedule_patience": 25,
19 | "min_lr": 1e-6,
20 | "weight_decay": 0.0,
21 | "print_epoch_interval": 5,
22 | "max_time": 30
23 | },
24 |
25 | "net_params": {
26 | "L": 4,
27 | "hidden_dim": 86,
28 | "out_dim": 86,
29 | "residual": true,
30 | "readout": "mean",
31 | "in_feat_dropout": 0.0,
32 | "dropout": 0.0,
33 | "batch_norm": true,
34 | "sage_aggregator": "maxpool",
35 | "self_loop": false,
36 | "edge_feat": false
37 | }
38 | }
--------------------------------------------------------------------------------
/configs/MUTAG/TUs_graph_classification_GAT_MUTAG_100k.json:
--------------------------------------------------------------------------------
1 | {
2 | "gpu": {
3 | "use": true,
4 | "id": 0
5 | },
6 |
7 | "model": "GAT",
8 | "dataset": "MUTAG",
9 |
10 | "out_dir": "out/TUs_graph_classification/",
11 |
12 | "params": {
13 | "seed": 41,
14 | "epochs": 1000,
15 | "batch_size": 20,
16 | "init_lr": 5e-5,
17 | "lr_reduce_factor": 0.5,
18 | "lr_schedule_patience": 25,
19 | "min_lr": 1e-6,
20 | "weight_decay": 0.0,
21 | "print_epoch_interval": 5,
22 | "max_time": 30
23 | },
24 |
25 | "net_params": {
26 | "L": 4,
27 | "hidden_dim": 17,
28 | "out_dim": 136,
29 | "residual": true,
30 | "readout": "mean",
31 | "n_heads": 8,
32 | "in_feat_dropout": 0.0,
33 | "dropout": 0.0,
34 | "batch_norm": true,
35 | "self_loop": false,
36 | "edge_feat": false
37 | }
38 | }
--------------------------------------------------------------------------------
/configs/MUTAG/TUs_graph_classification_MEWISPool_MUTAG_100k.json:
--------------------------------------------------------------------------------
1 | {
2 | "gpu": {
3 | "use": true,
4 | "id": 0
5 | },
6 |
7 | "model": "MEWISPool",
8 | "dataset": "MUTAG",
9 |
10 | "out_dir": "out/TUs_graph_classification/",
11 |
12 | "params": {
13 | "seed": 41,
14 | "epochs": 1000,
15 | "batch_size": 20,
16 | "init_lr": 7e-4,
17 | "lr_reduce_factor": 0.5,
18 | "lr_schedule_patience": 25,
19 | "min_lr": 1e-6,
20 | "weight_decay": 0.0,
21 | "print_epoch_interval": 5,
22 | "max_time": 30
23 | },
24 |
25 | "net_params": {
26 | "L": 4,
27 | "hidden_dim": 146,
28 | "out_dim": 146,
29 | "residual": true,
30 | "readout": "mean",
31 | "in_feat_dropout": 0.0,
32 | "dropout": 0.0,
33 | "batch_norm": true,
34 | "self_loop": false,
35 | "edge_feat": false
36 | }
37 | }
--------------------------------------------------------------------------------
/configs/NCI1/TUs_graph_classification_DiffPool_NCI1_100k.json:
--------------------------------------------------------------------------------
1 | {
2 | "gpu": {
3 | "use": true,
4 | "id": 0
5 | },
6 |
7 | "model": "DiffPool",
8 | "dataset": "NCI1",
9 |
10 | "out_dir": "out/TUs_graph_classification/",
11 |
12 | "params": {
13 | "seed": 41,
14 | "epochs": 1000,
15 | "batch_size": 20,
16 | "init_lr": 7e-4,
17 | "lr_reduce_factor": 0.5,
18 | "lr_schedule_patience": 25,
19 | "min_lr": 1e-6,
20 | "weight_decay": 0.0,
21 | "print_epoch_interval": 5,
22 | "max_time": 30
23 | },
24 |
25 | "net_params": {
26 | "L": 4,
27 | "hidden_dim": 86,
28 | "out_dim": 86,
29 | "residual": true,
30 | "readout": "mean",
31 | "in_feat_dropout": 0.0,
32 | "dropout": 0.0,
33 | "batch_norm": true,
34 | "sage_aggregator": "maxpool",
35 | "self_loop": false,
36 | "edge_feat": false
37 | }
38 | }
--------------------------------------------------------------------------------
/configs/NCI1/TUs_graph_classification_GAT_NCI1_100k.json:
--------------------------------------------------------------------------------
1 | {
2 | "gpu": {
3 | "use": true,
4 | "id": 0
5 | },
6 |
7 | "model": "GAT",
8 | "dataset": "NCI1",
9 |
10 | "out_dir": "out/TUs_graph_classification/",
11 |
12 | "params": {
13 | "seed": 41,
14 | "epochs": 1000,
15 | "batch_size": 20,
16 | "init_lr": 5e-5,
17 | "lr_reduce_factor": 0.5,
18 | "lr_schedule_patience": 25,
19 | "min_lr": 1e-6,
20 | "weight_decay": 0.0,
21 | "print_epoch_interval": 5,
22 | "max_time": 30
23 | },
24 |
25 | "net_params": {
26 | "L": 4,
27 | "hidden_dim": 17,
28 | "out_dim": 136,
29 | "residual": true,
30 | "readout": "mean",
31 | "n_heads": 8,
32 | "in_feat_dropout": 0.0,
33 | "dropout": 0.0,
34 | "batch_norm": true,
35 | "self_loop": false,
36 | "edge_feat": false
37 | }
38 | }
--------------------------------------------------------------------------------
/configs/NCI1/TUs_graph_classification_MEWISPool_NCI1_100k.json:
--------------------------------------------------------------------------------
1 | {
2 | "gpu": {
3 | "use": true,
4 | "id": 0
5 | },
6 |
7 | "model": "MEWISPool",
8 | "dataset": "NCI1",
9 |
10 | "out_dir": "out/TUs_graph_classification/",
11 |
12 | "params": {
13 | "seed": 41,
14 | "epochs": 1000,
15 | "batch_size": 20,
16 | "init_lr": 7e-4,
17 | "lr_reduce_factor": 0.5,
18 | "lr_schedule_patience": 25,
19 | "min_lr": 1e-6,
20 | "weight_decay": 0.0,
21 | "print_epoch_interval": 5,
22 | "max_time": 30
23 | },
24 |
25 | "net_params": {
26 | "L": 4,
27 | "hidden_dim": 146,
28 | "out_dim": 146,
29 | "residual": true,
30 | "readout": "mean",
31 | "in_feat_dropout": 0.0,
32 | "dropout": 0.0,
33 | "batch_norm": true,
34 | "self_loop": false,
35 | "edge_feat": false
36 | }
37 | }
--------------------------------------------------------------------------------
/configs/NCI109/TUs_graph_classification_DiffPool_NCI109_100k.json:
--------------------------------------------------------------------------------
1 | {
2 | "gpu": {
3 | "use": true,
4 | "id": 0
5 | },
6 |
7 | "model": "DiffPool",
8 | "dataset": "NCI109",
9 |
10 | "out_dir": "out/TUs_graph_classification/",
11 |
12 | "params": {
13 | "seed": 41,
14 | "epochs": 1000,
15 | "batch_size": 20,
16 | "init_lr": 7e-4,
17 | "lr_reduce_factor": 0.5,
18 | "lr_schedule_patience": 25,
19 | "min_lr": 1e-6,
20 | "weight_decay": 0.0,
21 | "print_epoch_interval": 5,
22 | "max_time": 30
23 | },
24 |
25 | "net_params": {
26 | "L": 4,
27 | "hidden_dim": 86,
28 | "out_dim": 86,
29 | "residual": true,
30 | "readout": "mean",
31 | "in_feat_dropout": 0.0,
32 | "dropout": 0.0,
33 | "batch_norm": true,
34 | "sage_aggregator": "maxpool",
35 | "self_loop": false,
36 | "edge_feat": false
37 | }
38 | }
--------------------------------------------------------------------------------
/configs/NCI109/TUs_graph_classification_GAT_NCI109_100k.json:
--------------------------------------------------------------------------------
1 | {
2 | "gpu": {
3 | "use": true,
4 | "id": 0
5 | },
6 |
7 | "model": "GAT",
8 | "dataset": "NCI109",
9 |
10 | "out_dir": "out/TUs_graph_classification/",
11 |
12 | "params": {
13 | "seed": 41,
14 | "epochs": 1000,
15 | "batch_size": 20,
16 | "init_lr": 5e-5,
17 | "lr_reduce_factor": 0.5,
18 | "lr_schedule_patience": 25,
19 | "min_lr": 1e-6,
20 | "weight_decay": 0.0,
21 | "print_epoch_interval": 5,
22 | "max_time": 30
23 | },
24 |
25 | "net_params": {
26 | "L": 4,
27 | "hidden_dim": 17,
28 | "out_dim": 136,
29 | "residual": true,
30 | "readout": "mean",
31 | "n_heads": 8,
32 | "in_feat_dropout": 0.0,
33 | "dropout": 0.0,
34 | "batch_norm": true,
35 | "self_loop": false,
36 | "edge_feat": false
37 | }
38 | }
--------------------------------------------------------------------------------
/configs/NCI109/TUs_graph_classification_MEWISPool_NCI109_100k.json:
--------------------------------------------------------------------------------
1 | {
2 | "gpu": {
3 | "use": true,
4 | "id": 0
5 | },
6 |
7 | "model": "MEWISPool",
8 | "dataset": "NCI109",
9 |
10 | "out_dir": "out/TUs_graph_classification/",
11 |
12 | "params": {
13 | "seed": 41,
14 | "epochs": 1000,
15 | "batch_size": 20,
16 | "init_lr": 7e-4,
17 | "lr_reduce_factor": 0.5,
18 | "lr_schedule_patience": 25,
19 | "min_lr": 1e-6,
20 | "weight_decay": 0.0,
21 | "print_epoch_interval": 5,
22 | "max_time": 30
23 | },
24 |
25 | "net_params": {
26 | "L": 4,
27 | "hidden_dim": 146,
28 | "out_dim": 146,
29 | "residual": true,
30 | "readout": "mean",
31 | "in_feat_dropout": 0.0,
32 | "dropout": 0.0,
33 | "batch_norm": true,
34 | "self_loop": false,
35 | "edge_feat": false
36 | }
37 | }
--------------------------------------------------------------------------------
/configs/PROTEINS_full/TUs_graph_classification_DiffPool_PROTEINS_full_100k.json:
--------------------------------------------------------------------------------
1 | {
2 | "gpu": {
3 | "use": true,
4 | "id": 0
5 | },
6 |
7 | "model": "DiffPool",
8 | "dataset": "PROTEINS_full",
9 |
10 | "out_dir": "out/TUs_graph_classification/",
11 |
12 | "params": {
13 | "seed": 41,
14 | "epochs": 1000,
15 | "batch_size": 20,
16 | "init_lr": 7e-4,
17 | "lr_reduce_factor": 0.5,
18 | "lr_schedule_patience": 25,
19 | "min_lr": 1e-6,
20 | "weight_decay": 0.0,
21 | "print_epoch_interval": 5,
22 | "max_time": 30
23 | },
24 |
25 | "net_params": {
26 | "L": 4,
27 | "hidden_dim": 86,
28 | "out_dim": 86,
29 | "residual": true,
30 | "readout": "mean",
31 | "in_feat_dropout": 0.0,
32 | "dropout": 0.0,
33 | "batch_norm": true,
34 | "sage_aggregator": "maxpool",
35 | "self_loop": false,
36 | "edge_feat": false
37 | }
38 | }
--------------------------------------------------------------------------------
/configs/PROTEINS_full/TUs_graph_classification_MEWISPool_PROTEINS_full_100k.json:
--------------------------------------------------------------------------------
1 | {
2 | "gpu": {
3 | "use": true,
4 | "id": 0
5 | },
6 |
7 | "model": "MEWISPool",
8 | "dataset": "PROTEINS_full",
9 |
10 | "out_dir": "out/TUs_graph_classification/",
11 |
12 | "params": {
13 | "seed": 41,
14 | "epochs": 1000,
15 | "batch_size": 20,
16 | "init_lr": 7e-4,
17 | "lr_reduce_factor": 0.5,
18 | "lr_schedule_patience": 25,
19 | "min_lr": 1e-6,
20 | "weight_decay": 0.0,
21 | "print_epoch_interval": 5,
22 | "max_time": 30
23 | },
24 |
25 | "net_params": {
26 | "L": 4,
27 | "hidden_dim": 146,
28 | "out_dim": 146,
29 | "residual": true,
30 | "readout": "mean",
31 | "in_feat_dropout": 0.0,
32 | "dropout": 0.0,
33 | "batch_norm": true,
34 | "self_loop": false,
35 | "edge_feat": false
36 | }
37 | }
--------------------------------------------------------------------------------
/configs/TUs_graph_classification_3WLGNN_DD_100k.json:
--------------------------------------------------------------------------------
1 | {
2 | "gpu": {
3 | "use": true,
4 | "id": 0
5 | },
6 |
7 | "model": "3WLGNN",
8 | "dataset": "DD",
9 |
10 | "out_dir": "out/TUs_graph_classification/",
11 |
12 | "params": {
13 | "seed": 41,
14 | "epochs": 1000,
15 | "batch_size": 4,
16 | "init_lr": 1e-4,
17 | "lr_reduce_factor": 0.5,
18 | "lr_schedule_patience": 25,
19 | "min_lr": 1e-6,
20 | "weight_decay": 0.0,
21 | "print_epoch_interval": 5,
22 | "max_time": 12
23 | },
24 |
25 | "net_params": {
26 | "L": 3,
27 | "hidden_dim": 74,
28 | "depth_of_mlp": 2,
29 | "residual": false,
30 | "dropout": 0.0,
31 | "layer_norm": false
32 | }
33 | }
--------------------------------------------------------------------------------
/configs/TUs_graph_classification_3WLGNN_ENZYMES_100k.json:
--------------------------------------------------------------------------------
1 | {
2 | "gpu": {
3 | "use": true,
4 | "id": 0
5 | },
6 |
7 | "model": "3WLGNN",
8 | "dataset": "ENZYMES",
9 |
10 | "out_dir": "out/TUs_graph_classification/",
11 |
12 | "params": {
13 | "seed": 41,
14 | "epochs": 1000,
15 | "batch_size": 4,
16 | "init_lr": 1e-3,
17 | "lr_reduce_factor": 0.5,
18 | "lr_schedule_patience": 25,
19 | "min_lr": 1e-6,
20 | "weight_decay": 0.0,
21 | "print_epoch_interval": 5,
22 | "max_time": 12
23 | },
24 |
25 | "net_params": {
26 | "L": 3,
27 | "hidden_dim": 80,
28 | "depth_of_mlp": 2,
29 | "residual": false,
30 | "dropout": 0.0,
31 | "layer_norm": false
32 | }
33 | }
--------------------------------------------------------------------------------
/configs/TUs_graph_classification_3WLGNN_PROTEINS_full_100k.json:
--------------------------------------------------------------------------------
1 | {
2 | "gpu": {
3 | "use": true,
4 | "id": 0
5 | },
6 |
7 | "model": "3WLGNN",
8 | "dataset": "PROTEINS_full",
9 |
10 | "out_dir": "out/TUs_graph_classification/",
11 |
12 | "params": {
13 | "seed": 41,
14 | "epochs": 1000,
15 | "batch_size": 4,
16 | "init_lr": 1e-3,
17 | "lr_reduce_factor": 0.5,
18 | "lr_schedule_patience": 25,
19 | "min_lr": 1e-5,
20 | "weight_decay": 0.0,
21 | "print_epoch_interval": 5,
22 | "max_time": 12
23 | },
24 |
25 | "net_params": {
26 | "L": 3,
27 | "hidden_dim": 80,
28 | "depth_of_mlp": 2,
29 | "residual": false,
30 | "dropout": 0.0,
31 | "layer_norm": false
32 | }
33 | }
--------------------------------------------------------------------------------
/configs/TUs_graph_classification_GAT_DD_100k.json:
--------------------------------------------------------------------------------
1 | {
2 | "gpu": {
3 | "use": true,
4 | "id": 0
5 | },
6 |
7 | "model": "GAT",
8 | "dataset": "DD",
9 |
10 | "out_dir": "out/TUs_graph_classification/",
11 |
12 | "params": {
13 | "seed": 41,
14 | "epochs": 1000,
15 | "batch_size": 20,
16 | "init_lr": 5e-5,
17 | "lr_reduce_factor": 0.5,
18 | "lr_schedule_patience": 25,
19 | "min_lr": 1e-6,
20 | "weight_decay": 0.0,
21 | "print_epoch_interval": 5,
22 | "max_time": 12
23 | },
24 |
25 | "net_params": {
26 | "L": 4,
27 | "hidden_dim": 17,
28 | "out_dim": 136,
29 | "residual": true,
30 | "readout": "mean",
31 | "n_heads": 8,
32 | "in_feat_dropout": 0.0,
33 | "dropout": 0.0,
34 | "batch_norm": true,
35 | "self_loop": false
36 | }
37 | }
--------------------------------------------------------------------------------
/configs/TUs_graph_classification_GAT_ENZYMES_100k.json:
--------------------------------------------------------------------------------
1 | {
2 | "gpu": {
3 | "use": true,
4 | "id": 0
5 | },
6 |
7 | "model": "GAT",
8 | "dataset": "ENZYMES",
9 |
10 | "out_dir": "out/TUs_graph_classification/",
11 |
12 | "params": {
13 | "seed": 41,
14 | "epochs": 1000,
15 | "batch_size": 20,
16 | "init_lr": 1e-3,
17 | "lr_reduce_factor": 0.5,
18 | "lr_schedule_patience": 25,
19 | "min_lr": 1e-6,
20 | "weight_decay": 0.0,
21 | "print_epoch_interval": 5,
22 | "max_time": 12
23 | },
24 |
25 | "net_params": {
26 | "L": 4,
27 | "hidden_dim": 18,
28 | "out_dim": 144,
29 | "residual": true,
30 | "readout": "mean",
31 | "n_heads": 8,
32 | "in_feat_dropout": 0.0,
33 | "dropout": 0.0,
34 | "batch_norm": true,
35 | "self_loop": false
36 | }
37 | }
--------------------------------------------------------------------------------
/configs/TUs_graph_classification_GAT_PROTEINS_full_100k.json:
--------------------------------------------------------------------------------
1 | {
2 | "gpu": {
3 | "use": true,
4 | "id": 0
5 | },
6 |
7 | "model": "GAT",
8 | "dataset": "PROTEINS_full",
9 |
10 | "out_dir": "out/TUs_graph_classification/",
11 |
12 | "params": {
13 | "seed": 41,
14 | "epochs": 1000,
15 | "batch_size": 20,
16 | "init_lr": 1e-3,
17 | "lr_reduce_factor": 0.5,
18 | "lr_schedule_patience": 25,
19 | "min_lr": 1e-6,
20 | "weight_decay": 0.0,
21 | "print_epoch_interval": 5,
22 | "max_time": 12
23 | },
24 |
25 | "net_params": {
26 | "L": 4,
27 | "hidden_dim": 18,
28 | "out_dim": 144,
29 | "residual": true,
30 | "readout": "mean",
31 | "n_heads": 8,
32 | "in_feat_dropout": 0.0,
33 | "dropout": 0.0,
34 | "batch_norm": true,
35 | "self_loop": false
36 | }
37 | }
--------------------------------------------------------------------------------
/configs/TUs_graph_classification_GCN_DD_100k.json:
--------------------------------------------------------------------------------
1 | {
2 | "gpu": {
3 | "use": true,
4 | "id": 0
5 | },
6 |
7 | "model": "GCN",
8 | "dataset": "DD",
9 |
10 | "out_dir": "out/TUs_graph_classification/",
11 |
12 | "params": {
13 | "seed": 41,
14 | "epochs": 1000,
15 | "batch_size": 20,
16 | "init_lr": 7e-4,
17 | "lr_reduce_factor": 0.5,
18 | "lr_schedule_patience": 25,
19 | "min_lr": 1e-6,
20 | "weight_decay": 0.0,
21 | "print_epoch_interval": 5,
22 | "max_time": 12
23 | },
24 |
25 | "net_params": {
26 | "L": 4,
27 | "hidden_dim": 138,
28 | "out_dim": 138,
29 | "residual": true,
30 | "readout": "mean",
31 | "in_feat_dropout": 0.0,
32 | "dropout": 0.0,
33 | "batch_norm": true,
34 | "self_loop": false
35 | }
36 | }
--------------------------------------------------------------------------------
/configs/TUs_graph_classification_GCN_ENZYMES_100k.json:
--------------------------------------------------------------------------------
1 | {
2 | "gpu": {
3 | "use": true,
4 | "id": 0
5 | },
6 |
7 | "model": "GCN",
8 | "dataset": "ENZYMES",
9 |
10 | "out_dir": "out/TUs_graph_classification/",
11 |
12 | "params": {
13 | "seed": 41,
14 | "epochs": 1000,
15 | "batch_size": 20,
16 | "init_lr": 7e-4,
17 | "lr_reduce_factor": 0.5,
18 | "lr_schedule_patience": 25,
19 | "min_lr": 1e-6,
20 | "weight_decay": 0.0,
21 | "print_epoch_interval": 5,
22 | "max_time": 12
23 | },
24 |
25 | "net_params": {
26 | "L": 4,
27 | "hidden_dim": 146,
28 | "out_dim": 146,
29 | "residual": true,
30 | "readout": "mean",
31 | "in_feat_dropout": 0.0,
32 | "dropout": 0.0,
33 | "batch_norm": true,
34 | "self_loop": false
35 | }
36 | }
--------------------------------------------------------------------------------
/configs/TUs_graph_classification_GCN_PROTEINS_full_100k.json:
--------------------------------------------------------------------------------
1 | {
2 | "gpu": {
3 | "use": true,
4 | "id": 0
5 | },
6 |
7 | "model": "GCN",
8 | "dataset": "PROTEINS_full",
9 |
10 | "out_dir": "out/TUs_graph_classification/",
11 |
12 | "params": {
13 | "seed": 41,
14 | "epochs": 1000,
15 | "batch_size": 20,
16 | "init_lr": 7e-4,
17 | "lr_reduce_factor": 0.5,
18 | "lr_schedule_patience": 25,
19 | "min_lr": 1e-6,
20 | "weight_decay": 0.0,
21 | "print_epoch_interval": 5,
22 | "max_time": 12
23 | },
24 |
25 | "net_params": {
26 | "L": 4,
27 | "hidden_dim": 146,
28 | "out_dim": 146,
29 | "residual": true,
30 | "readout": "mean",
31 | "in_feat_dropout": 0.0,
32 | "dropout": 0.0,
33 | "batch_norm": true,
34 | "self_loop": false
35 | }
36 | }
--------------------------------------------------------------------------------
/configs/TUs_graph_classification_GIN_DD_100k.json:
--------------------------------------------------------------------------------
1 | {
2 | "gpu": {
3 | "use": true,
4 | "id": 0
5 | },
6 |
7 | "model": "GIN",
8 | "dataset": "DD",
9 |
10 | "out_dir": "out/TUs_graph_classification/",
11 |
12 | "params": {
13 | "seed": 41,
14 | "epochs": 1000,
15 | "batch_size": 20,
16 | "init_lr": 1e-3,
17 | "lr_reduce_factor": 0.5,
18 | "lr_schedule_patience": 25,
19 | "min_lr": 1e-6,
20 | "weight_decay": 0.0,
21 | "print_epoch_interval": 5,
22 | "max_time": 12
23 | },
24 |
25 | "net_params": {
26 | "L": 4,
27 | "hidden_dim": 106,
28 | "residual": true,
29 | "readout": "sum",
30 | "n_mlp_GIN": 2,
31 | "learn_eps_GIN": true,
32 | "neighbor_aggr_GIN": "sum",
33 | "in_feat_dropout": 0.0,
34 | "dropout": 0.0,
35 | "batch_norm": true
36 | }
37 | }
--------------------------------------------------------------------------------
/configs/TUs_graph_classification_GIN_ENZYMES_100k.json:
--------------------------------------------------------------------------------
1 | {
2 | "gpu": {
3 | "use": true,
4 | "id": 0
5 | },
6 |
7 | "model": "GIN",
8 | "dataset": "ENZYMES",
9 |
10 | "out_dir": "out/TUs_graph_classification/",
11 |
12 | "params": {
13 | "seed": 41,
14 | "epochs": 1000,
15 | "batch_size": 20,
16 | "init_lr": 7e-3,
17 | "lr_reduce_factor": 0.5,
18 | "lr_schedule_patience": 25,
19 | "min_lr": 1e-6,
20 | "weight_decay": 0.0,
21 | "print_epoch_interval": 5,
22 | "max_time": 12
23 | },
24 |
25 | "net_params": {
26 | "L": 4,
27 | "hidden_dim": 110,
28 | "residual": true,
29 | "readout": "sum",
30 | "n_mlp_GIN": 2,
31 | "learn_eps_GIN": true,
32 | "neighbor_aggr_GIN": "sum",
33 | "in_feat_dropout": 0.0,
34 | "dropout": 0.0,
35 | "batch_norm": true
36 | }
37 | }
--------------------------------------------------------------------------------
/configs/TUs_graph_classification_GIN_PROTEINS_full_100k.json:
--------------------------------------------------------------------------------
1 | {
2 | "gpu": {
3 | "use": true,
4 | "id": 0
5 | },
6 |
7 | "model": "GIN",
8 | "dataset": "PROTEINS_full",
9 |
10 | "out_dir": "out/TUs_graph_classification/",
11 |
12 | "params": {
13 | "seed": 41,
14 | "epochs": 1000,
15 | "batch_size": 20,
16 | "init_lr": 1e-4,
17 | "lr_reduce_factor": 0.5,
18 | "lr_schedule_patience": 25,
19 | "min_lr": 1e-6,
20 | "weight_decay": 0.0,
21 | "print_epoch_interval": 5,
22 | "max_time": 12
23 | },
24 |
25 | "net_params": {
26 | "L": 4,
27 | "hidden_dim": 110,
28 | "residual": true,
29 | "readout": "sum",
30 | "n_mlp_GIN": 2,
31 | "learn_eps_GIN": true,
32 | "neighbor_aggr_GIN": "sum",
33 | "in_feat_dropout": 0.0,
34 | "dropout": 0.0,
35 | "batch_norm": true
36 | }
37 | }
--------------------------------------------------------------------------------
/configs/TUs_graph_classification_GatedGCN_DD_100k.json:
--------------------------------------------------------------------------------
1 | {
2 | "gpu": {
3 | "use": true,
4 | "id": 0
5 | },
6 |
7 | "model": "GatedGCN",
8 | "dataset": "DD",
9 |
10 | "out_dir": "out/TUs_graph_classification/",
11 |
12 | "params": {
13 | "seed": 41,
14 | "epochs": 1000,
15 | "batch_size": 20,
16 | "init_lr": 1e-5,
17 | "lr_reduce_factor": 0.5,
18 | "lr_schedule_patience": 25,
19 | "min_lr": 1e-6,
20 | "weight_decay": 0.0,
21 | "print_epoch_interval": 5,
22 | "max_time": 12
23 | },
24 |
25 | "net_params": {
26 | "L": 4,
27 | "hidden_dim": 66,
28 | "out_dim": 66,
29 | "residual": true,
30 | "readout": "mean",
31 | "in_feat_dropout": 0.0,
32 | "dropout": 0.0,
33 | "batch_norm": true,
34 | "edge_feat": false
35 | }
36 | }
--------------------------------------------------------------------------------
/configs/TUs_graph_classification_GatedGCN_ENZYMES_100k.json:
--------------------------------------------------------------------------------
1 | {
2 | "gpu": {
3 | "use": true,
4 | "id": 0
5 | },
6 |
7 | "model": "GatedGCN",
8 | "dataset": "ENZYMES",
9 |
10 | "out_dir": "out/TUs_graph_classification/",
11 |
12 | "params": {
13 | "seed": 41,
14 | "epochs": 1000,
15 | "batch_size": 20,
16 | "init_lr": 7e-4,
17 | "lr_reduce_factor": 0.5,
18 | "lr_schedule_patience": 25,
19 | "min_lr": 1e-6,
20 | "weight_decay": 0.0,
21 | "print_epoch_interval": 5,
22 | "max_time": 12
23 | },
24 |
25 | "net_params": {
26 | "L": 4,
27 | "hidden_dim": 69,
28 | "out_dim": 69,
29 | "residual": true,
30 | "readout": "mean",
31 | "in_feat_dropout": 0.0,
32 | "dropout": 0.0,
33 | "batch_norm": true,
34 | "edge_feat": false
35 | }
36 | }
--------------------------------------------------------------------------------
/configs/TUs_graph_classification_GatedGCN_PROTEINS_full_100k.json:
--------------------------------------------------------------------------------
1 | {
2 | "gpu": {
3 | "use": true,
4 | "id": 0
5 | },
6 |
7 | "model": "GatedGCN",
8 | "dataset": "PROTEINS_full",
9 |
10 | "out_dir": "out/TUs_graph_classification/",
11 |
12 | "params": {
13 | "seed": 41,
14 | "epochs": 1000,
15 | "batch_size": 20,
16 | "init_lr": 1e-4,
17 | "lr_reduce_factor": 0.5,
18 | "lr_schedule_patience": 25,
19 | "min_lr": 1e-6,
20 | "weight_decay": 0.0,
21 | "print_epoch_interval": 5,
22 | "max_time": 12
23 | },
24 |
25 | "net_params": {
26 | "L": 4,
27 | "hidden_dim": 69,
28 | "out_dim": 69,
29 | "residual": true,
30 | "readout": "mean",
31 | "in_feat_dropout": 0.0,
32 | "dropout": 0.0,
33 | "batch_norm": true,
34 | "edge_feat": false
35 | }
36 | }
--------------------------------------------------------------------------------
/configs/TUs_graph_classification_GraphSage_DD_100k.json:
--------------------------------------------------------------------------------
1 | {
2 | "gpu": {
3 | "use": true,
4 | "id": 0
5 | },
6 |
7 | "model": "GraphSage",
8 | "dataset": "DD",
9 |
10 | "out_dir": "out/TUs_graph_classification/",
11 |
12 | "params": {
13 | "seed": 41,
14 | "epochs": 1000,
15 | "batch_size": 20,
16 | "init_lr": 1e-3,
17 | "lr_reduce_factor": 0.5,
18 | "lr_schedule_patience": 25,
19 | "min_lr": 1e-6,
20 | "weight_decay": 0.0,
21 | "print_epoch_interval": 5,
22 | "max_time": 12
23 | },
24 |
25 | "net_params": {
26 | "L": 4,
27 | "hidden_dim": 86,
28 | "out_dim": 86,
29 | "residual": true,
30 | "readout": "mean",
31 | "in_feat_dropout": 0.0,
32 | "dropout": 0.0,
33 | "batch_norm": true,
34 | "sage_aggregator": "maxpool"
35 | }
36 | }
--------------------------------------------------------------------------------
/configs/TUs_graph_classification_GraphSage_ENZYMES_100k.json:
--------------------------------------------------------------------------------
1 | {
2 | "gpu": {
3 | "use": true,
4 | "id": 0
5 | },
6 |
7 | "model": "GraphSage",
8 | "dataset": "ENZYMES",
9 |
10 | "out_dir": "out/TUs_graph_classification/",
11 |
12 | "params": {
13 | "seed": 41,
14 | "epochs": 1000,
15 | "batch_size": 20,
16 | "init_lr": 7e-4,
17 | "lr_reduce_factor": 0.5,
18 | "lr_schedule_patience": 25,
19 | "min_lr": 1e-6,
20 | "weight_decay": 0.0,
21 | "print_epoch_interval": 5,
22 | "max_time": 12
23 | },
24 |
25 | "net_params": {
26 | "L": 4,
27 | "hidden_dim": 90,
28 | "out_dim": 90,
29 | "residual": true,
30 | "readout": "mean",
31 | "in_feat_dropout": 0.0,
32 | "dropout": 0.0,
33 | "batch_norm": true,
34 | "sage_aggregator": "maxpool"
35 | }
36 | }
--------------------------------------------------------------------------------
/configs/TUs_graph_classification_GraphSage_PROTEINS_full_100k.json:
--------------------------------------------------------------------------------
1 | {
2 | "gpu": {
3 | "use": true,
4 | "id": 0
5 | },
6 |
7 | "model": "GraphSage",
8 | "dataset": "PROTEINS_full",
9 |
10 | "out_dir": "out/TUs_graph_classification/",
11 |
12 | "params": {
13 | "seed": 41,
14 | "epochs": 1000,
15 | "batch_size": 20,
16 | "init_lr": 7e-5,
17 | "lr_reduce_factor": 0.5,
18 | "lr_schedule_patience": 25,
19 | "min_lr": 1e-6,
20 | "weight_decay": 0.0,
21 | "print_epoch_interval": 5,
22 | "max_time": 12
23 | },
24 |
25 | "net_params": {
26 | "L": 4,
27 | "hidden_dim": 88,
28 | "out_dim": 88,
29 | "residual": true,
30 | "readout": "mean",
31 | "in_feat_dropout": 0.0,
32 | "dropout": 0.0,
33 | "batch_norm": true,
34 | "sage_aggregator": "maxpool"
35 | }
36 | }
--------------------------------------------------------------------------------
/configs/TUs_graph_classification_MLP_DD_100k.json:
--------------------------------------------------------------------------------
1 | {
2 | "gpu": {
3 | "use": true,
4 | "id": 0
5 | },
6 |
7 | "model": "MLP",
8 | "dataset": "DD",
9 |
10 | "out_dir": "out/TUs_graph_classification/",
11 |
12 | "params": {
13 | "seed": 41,
14 | "epochs": 1000,
15 | "batch_size": 20,
16 | "init_lr": 1e-4,
17 | "lr_reduce_factor": 0.5,
18 | "lr_schedule_patience": 25,
19 | "min_lr": 1e-6,
20 | "weight_decay": 0.0,
21 | "print_epoch_interval": 5,
22 | "max_time": 12
23 | },
24 |
25 | "net_params": {
26 | "L": 4,
27 | "hidden_dim": 154,
28 | "out_dim": 154,
29 | "readout": "mean",
30 | "gated": false,
31 | "in_feat_dropout": 0.0,
32 | "dropout": 0.0
33 | }
34 | }
--------------------------------------------------------------------------------
/configs/TUs_graph_classification_MLP_ENZYMES_100k.json:
--------------------------------------------------------------------------------
1 | {
2 | "gpu": {
3 | "use": true,
4 | "id": 0
5 | },
6 |
7 | "model": "MLP",
8 | "dataset": "ENZYMES",
9 |
10 | "out_dir": "out/TUs_graph_classification/",
11 |
12 | "params": {
13 | "seed": 41,
14 | "epochs": 1000,
15 | "batch_size": 20,
16 | "init_lr": 1e-3,
17 | "lr_reduce_factor": 0.5,
18 | "lr_schedule_patience": 25,
19 | "min_lr": 1e-6,
20 | "weight_decay": 0.0,
21 | "print_epoch_interval": 5,
22 | "max_time": 12
23 | },
24 |
25 | "net_params": {
26 | "L": 4,
27 | "hidden_dim": 164,
28 | "out_dim": 164,
29 | "readout": "mean",
30 | "gated": false,
31 | "in_feat_dropout": 0.0,
32 | "dropout": 0.0
33 | }
34 | }
--------------------------------------------------------------------------------
/configs/TUs_graph_classification_MLP_PROTEINS_full_100k.json:
--------------------------------------------------------------------------------
1 | {
2 | "gpu": {
3 | "use": true,
4 | "id": 0
5 | },
6 |
7 | "model": "MLP",
8 | "dataset": "PROTEINS_full",
9 |
10 | "out_dir": "out/TUs_graph_classification/",
11 |
12 | "params": {
13 | "seed": 41,
14 | "epochs": 1000,
15 | "batch_size": 20,
16 | "init_lr": 1e-4,
17 | "lr_reduce_factor": 0.5,
18 | "lr_schedule_patience": 25,
19 | "min_lr": 1e-6,
20 | "weight_decay": 0.0,
21 | "print_epoch_interval": 5,
22 | "max_time": 12
23 | },
24 |
25 | "net_params": {
26 | "L": 4,
27 | "hidden_dim": 162,
28 | "out_dim": 162,
29 | "readout": "mean",
30 | "gated": false,
31 | "in_feat_dropout": 0.0,
32 | "dropout": 0.0
33 | }
34 | }
--------------------------------------------------------------------------------
/configs/TUs_graph_classification_MoNet_DD_100k.json:
--------------------------------------------------------------------------------
1 | {
2 | "gpu": {
3 | "use": true,
4 | "id": 0
5 | },
6 |
7 | "model": "MoNet",
8 | "dataset": "DD",
9 |
10 | "out_dir": "out/TUs_graph_classification/",
11 |
12 | "params": {
13 | "seed": 41,
14 | "epochs": 1000,
15 | "batch_size": 20,
16 | "init_lr": 1e-5,
17 | "lr_reduce_factor": 0.5,
18 | "lr_schedule_patience": 25,
19 | "min_lr": 1e-6,
20 | "weight_decay": 0.0,
21 | "print_epoch_interval": 5,
22 | "max_time": 12
23 | },
24 |
25 | "net_params": {
26 | "L": 4,
27 | "hidden_dim": 86,
28 | "out_dim": 86,
29 | "residual": true,
30 | "readout": "mean",
31 | "kernel": 3,
32 | "pseudo_dim_MoNet": 2,
33 | "in_feat_dropout": 0.0,
34 | "dropout": 0.0,
35 | "batch_norm": true
36 | }
37 | }
--------------------------------------------------------------------------------
/configs/TUs_graph_classification_MoNet_ENZYMES_100k.json:
--------------------------------------------------------------------------------
1 | {
2 | "gpu": {
3 | "use": true,
4 | "id": 0
5 | },
6 |
7 | "model": "MoNet",
8 | "dataset": "ENZYMES",
9 |
10 | "out_dir": "out/TUs_graph_classification/",
11 |
12 | "params": {
13 | "seed": 41,
14 | "epochs": 1000,
15 | "batch_size": 20,
16 | "init_lr": 1e-3,
17 | "lr_reduce_factor": 0.5,
18 | "lr_schedule_patience": 25,
19 | "min_lr": 1e-6,
20 | "weight_decay": 0.0,
21 | "print_epoch_interval": 5,
22 | "max_time": 12
23 | },
24 |
25 | "net_params": {
26 | "L": 4,
27 | "hidden_dim": 90,
28 | "out_dim": 90,
29 | "residual": true,
30 | "readout": "mean",
31 | "kernel": 3,
32 | "pseudo_dim_MoNet": 2,
33 | "in_feat_dropout": 0.0,
34 | "dropout": 0.0,
35 | "batch_norm": true
36 | }
37 | }
--------------------------------------------------------------------------------
/configs/TUs_graph_classification_MoNet_PROTEINS_full_100k.json:
--------------------------------------------------------------------------------
1 | {
2 | "gpu": {
3 | "use": true,
4 | "id": 0
5 | },
6 |
7 | "model": "MoNet",
8 | "dataset": "PROTEINS_full",
9 |
10 | "out_dir": "out/TUs_graph_classification/",
11 |
12 | "params": {
13 | "seed": 41,
14 | "epochs": 1000,
15 | "batch_size": 20,
16 | "init_lr": 7e-5,
17 | "lr_reduce_factor": 0.5,
18 | "lr_schedule_patience": 25,
19 | "min_lr": 1e-6,
20 | "weight_decay": 0.0,
21 | "print_epoch_interval": 5,
22 | "max_time": 12
23 | },
24 |
25 | "net_params": {
26 | "L": 4,
27 | "hidden_dim": 89,
28 | "out_dim": 89,
29 | "residual": true,
30 | "readout": "mean",
31 | "kernel": 3,
32 | "pseudo_dim_MoNet": 2,
33 | "in_feat_dropout": 0.0,
34 | "dropout": 0.0,
35 | "batch_norm": true
36 | }
37 | }
--------------------------------------------------------------------------------
/configs/TUs_graph_classification_RingGNN_DD_100k.json:
--------------------------------------------------------------------------------
1 | {
2 | "gpu": {
3 | "use": true,
4 | "id": 0
5 | },
6 |
7 | "model": "RingGNN",
8 | "dataset": "DD",
9 |
10 | "out_dir": "out/TUs_graph_classification/",
11 |
12 | "params": {
13 | "seed": 41,
14 | "epochs": 1000,
15 | "batch_size": 4,
16 | "init_lr": 1e-4,
17 | "lr_reduce_factor": 0.5,
18 | "lr_schedule_patience": 25,
19 | "min_lr": 1e-6,
20 | "weight_decay": 0.0,
21 | "print_epoch_interval": 5,
22 | "max_time": 12
23 | },
24 |
25 | "net_params": {
26 | "L": 2,
27 | "hidden_dim": 20,
28 | "radius": 2,
29 | "residual": false,
30 | "dropout": 0.0,
31 | "layer_norm": false
32 | }
33 | }
--------------------------------------------------------------------------------
/configs/TUs_graph_classification_RingGNN_ENZYMES_100k.json:
--------------------------------------------------------------------------------
1 | {
2 | "gpu": {
3 | "use": true,
4 | "id": 0
5 | },
6 |
7 | "model": "RingGNN",
8 | "dataset": "ENZYMES",
9 |
10 | "out_dir": "out/TUs_graph_classification/",
11 |
12 | "params": {
13 | "seed": 41,
14 | "epochs": 1000,
15 | "batch_size": 4,
16 | "init_lr": 1e-4,
17 | "lr_reduce_factor": 0.5,
18 | "lr_schedule_patience": 25,
19 | "min_lr": 1e-6,
20 | "weight_decay": 0.0,
21 | "print_epoch_interval": 5,
22 | "max_time": 12
23 | },
24 |
25 | "net_params": {
26 | "L": 2,
27 | "hidden_dim": 38,
28 | "radius": 2,
29 | "residual": false,
30 | "dropout": 0.0,
31 | "layer_norm": false
32 | }
33 | }
--------------------------------------------------------------------------------
/configs/TUs_graph_classification_RingGNN_PROTEINS_full_100k.json:
--------------------------------------------------------------------------------
1 | {
2 | "gpu": {
3 | "use": true,
4 | "id": 0
5 | },
6 |
7 | "model": "RingGNN",
8 | "dataset": "PROTEINS_full",
9 |
10 | "out_dir": "out/TUs_graph_classification/",
11 |
12 | "params": {
13 | "seed": 41,
14 | "epochs": 1000,
15 | "batch_size": 4,
16 | "init_lr": 7e-5,
17 | "lr_reduce_factor": 0.5,
18 | "lr_schedule_patience": 25,
19 | "min_lr": 1e-6,
20 | "weight_decay": 0.0,
21 | "print_epoch_interval": 5,
22 | "max_time": 12
23 | },
24 |
25 | "net_params": {
26 | "L": 2,
27 | "hidden_dim": 35,
28 | "radius": 2,
29 | "residual": false,
30 | "dropout": 0.0,
31 | "layer_norm": false
32 | }
33 | }
--------------------------------------------------------------------------------
/configs/Tox21_AhR_training/TUs_graph_classification_DiffPool_Tox21_AhR_training_100k.json:
--------------------------------------------------------------------------------
1 | {
2 | "gpu": {
3 | "use": true,
4 | "id": 0
5 | },
6 |
7 | "model": "DiffPool",
8 | "dataset": "Tox21_AhR_training",
9 |
10 | "out_dir": "out/TUs_graph_classification/",
11 |
12 | "params": {
13 | "seed": 41,
14 | "epochs": 1000,
15 | "batch_size": 20,
16 | "init_lr": 7e-4,
17 | "lr_reduce_factor": 0.5,
18 | "lr_schedule_patience": 25,
19 | "min_lr": 1e-6,
20 | "weight_decay": 0.0,
21 | "print_epoch_interval": 5,
22 | "max_time": 30
23 | },
24 |
25 | "net_params": {
26 | "L": 4,
27 | "hidden_dim": 86,
28 | "out_dim": 86,
29 | "residual": true,
30 | "readout": "mean",
31 | "in_feat_dropout": 0.0,
32 | "dropout": 0.0,
33 | "batch_norm": true,
34 | "sage_aggregator": "maxpool",
35 | "self_loop": false,
36 | "edge_feat": false
37 | }
38 | }
--------------------------------------------------------------------------------
/configs/Tox21_AhR_training/TUs_graph_classification_GAT_Tox21_AhR_training_100k.json:
--------------------------------------------------------------------------------
1 | {
2 | "gpu": {
3 | "use": true,
4 | "id": 0
5 | },
6 |
7 | "model": "GAT",
8 | "dataset": "Tox21_AhR_training",
9 |
10 | "out_dir": "out/TUs_graph_classification/",
11 |
12 | "params": {
13 | "seed": 41,
14 | "epochs": 1000,
15 | "batch_size": 20,
16 | "init_lr": 5e-5,
17 | "lr_reduce_factor": 0.5,
18 | "lr_schedule_patience": 25,
19 | "min_lr": 1e-6,
20 | "weight_decay": 0.0,
21 | "print_epoch_interval": 5,
22 | "max_time": 30
23 | },
24 |
25 | "net_params": {
26 | "L": 4,
27 | "hidden_dim": 17,
28 | "out_dim": 136,
29 | "residual": true,
30 | "readout": "mean",
31 | "n_heads": 8,
32 | "in_feat_dropout": 0.0,
33 | "dropout": 0.0,
34 | "batch_norm": true,
35 | "self_loop": true,
36 | "edge_feat": false
37 | }
38 | }
--------------------------------------------------------------------------------
/configs/Tox21_AhR_training/TUs_graph_classification_MEWISPool_Tox21_AhR_training_100k.json:
--------------------------------------------------------------------------------
1 | {
2 | "gpu": {
3 | "use": true,
4 | "id": 0
5 | },
6 |
7 | "model": "MEWISPool",
8 | "dataset": "Tox21_AhR_training",
9 |
10 | "out_dir": "out/TUs_graph_classification/",
11 |
12 | "params": {
13 | "seed": 41,
14 | "epochs": 1000,
15 | "batch_size": 20,
16 | "init_lr": 7e-4,
17 | "lr_reduce_factor": 0.5,
18 | "lr_schedule_patience": 25,
19 | "min_lr": 1e-6,
20 | "weight_decay": 0.0,
21 | "print_epoch_interval": 5,
22 | "max_time": 30
23 | },
24 |
25 | "net_params": {
26 | "L": 4,
27 | "hidden_dim": 146,
28 | "out_dim": 146,
29 | "residual": true,
30 | "readout": "mean",
31 | "in_feat_dropout": 0.0,
32 | "dropout": 0.0,
33 | "batch_norm": true,
34 | "self_loop": false,
35 | "edge_feat": false
36 | }
37 | }
--------------------------------------------------------------------------------
/data/TUs.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import pickle
3 | import torch.utils.data
4 | import time
5 | import os
6 | import numpy as np
7 |
8 | import csv
9 |
10 | import dgl
11 | from dgl.data import TUDataset
12 | from dgl.data import LegacyTUDataset
13 |
14 | import random
15 | random.seed(42)
16 |
17 | from sklearn.model_selection import StratifiedKFold, train_test_split
18 |
19 |
20 | class DGLFormDataset(torch.utils.data.Dataset):
21 | """
22 | DGLFormDataset wrapping graph list and label list as per pytorch Dataset.
23 | *lists (list): lists of 'graphs' and 'labels' with same len().
24 | """
25 | def __init__(self, *lists):
26 | assert all(len(lists[0]) == len(li) for li in lists)
27 | self.lists = lists
28 | self.graph_lists = lists[0]
29 | self.graph_labels = lists[1]
30 |
31 | def __getitem__(self, index):
32 | return tuple(li[index] for li in self.lists)
33 |
34 | def __len__(self):
35 | return len(self.lists[0])
36 |
37 |
38 |
39 | def self_loop(g):
40 | """
41 | Utility function only, to be used only when necessary as per user self_loop flag
42 | : Overwriting the function dgl.transform.add_self_loop() to not miss ndata['feat'] and edata['feat']
43 |
44 |
45 | This function is called inside a function in TUsDataset class.
46 | """
47 | new_g = dgl.DGLGraph()
48 | new_g.add_nodes(g.number_of_nodes())
49 | new_g.ndata['feat'] = g.ndata['feat']
50 |
51 | src, dst = g.all_edges(order="eid")
52 | src = dgl.backend.zerocopy_to_numpy(src)
53 | dst = dgl.backend.zerocopy_to_numpy(dst)
54 | non_self_edges_idx = src != dst
55 | nodes = np.arange(g.number_of_nodes())
56 | new_g.add_edges(src[non_self_edges_idx], dst[non_self_edges_idx])
57 | new_g.add_edges(nodes, nodes)
58 |
59 | # This new edata is not used since this function gets called only for GCN, GAT
60 | # However, we need this for the generic requirement of ndata and edata
61 | new_g.edata['feat'] = torch.zeros(new_g.number_of_edges())
62 | return new_g
63 |
64 |
65 | class TUsDataset(torch.utils.data.Dataset):
66 | def __init__(self, name):
67 | t0 = time.time()
68 | self.name = name
69 |
70 | #dataset = TUDataset(self.name, hidden_size=1)
71 | dataset = LegacyTUDataset(self.name, hidden_size=1) # dgl 4.0
72 | self.input_dim, self.label_dim, self.max_num_node = dataset.statistics()
73 |
74 | # frankenstein has labels 0 and 2; so correcting them as 0 and 1
75 | if self.name in ["FRANKENSTEIN", "MUTAG"]:
76 | dataset.graph_labels = np.array([1 if x==2 else x for x in dataset.graph_labels])
77 |
78 | print("[!] Dataset: ", self.name)
79 |
80 | # this function splits data into train/val/test and returns the indices
81 | self.all_idx = self.get_all_split_idx(dataset)
82 |
83 | self.all = dataset
84 | self.train = [self.format_dataset([dataset[idx] for idx in self.all_idx['train'][split_num]]) for split_num in range(10)]
85 | self.val = [self.format_dataset([dataset[idx] for idx in self.all_idx['val'][split_num]]) for split_num in range(10)]
86 | self.test = [self.format_dataset([dataset[idx] for idx in self.all_idx['test'][split_num]]) for split_num in range(10)]
87 |
88 | print("Time taken: {:.4f}s".format(time.time()-t0))
89 |
90 | def get_all_split_idx(self, dataset):
91 | """
92 | - Split total number of graphs into 3 (train, val and test) in 80:10:10
93 | - Stratified split proportionate to original distribution of data with respect to classes
94 | - Using sklearn to perform the split and then save the indexes
95 | - Preparing 10 such combinations of indexes split to be used in Graph NNs
96 | - As with KFold, each of the 10 fold have unique test set.
97 | """
98 | root_idx_dir = './data/TUs/'
99 | if not os.path.exists(root_idx_dir):
100 | os.makedirs(root_idx_dir)
101 | all_idx = {}
102 |
103 | # If there are no idx files, do the split and store the files
104 | if not (os.path.exists(root_idx_dir + dataset.name + '_train.index')):
105 | print("[!] Splitting the data into train/val/test ...")
106 |
107 | # Using 10-fold cross val to compare with benchmark papers
108 | k_splits = 10
109 |
110 | cross_val_fold = StratifiedKFold(n_splits=k_splits, shuffle=True)
111 | k_data_splits = []
112 |
113 | # this is a temporary index assignment, to be used below for val splitting
114 | for i in range(len(dataset.graph_lists)):
115 | dataset[i][0].a = lambda: None
116 | setattr(dataset[i][0].a, 'index', i)
117 |
118 | for indexes in cross_val_fold.split(dataset.graph_lists, dataset.graph_labels):
119 | remain_index, test_index = indexes[0], indexes[1]
120 |
121 | remain_set = self.format_dataset([dataset[index] for index in remain_index])
122 |
123 | # Gets final 'train' and 'val'
124 | train, val, _, __ = train_test_split(remain_set,
125 | range(len(remain_set.graph_lists)),
126 | test_size=0.111,
127 | stratify=remain_set.graph_labels)
128 |
129 | train, val = self.format_dataset(train), self.format_dataset(val)
130 | test = self.format_dataset([dataset[index] for index in test_index])
131 |
132 | # Extracting only idx
133 | idx_train = [item[0].a.index for item in train]
134 | idx_val = [item[0].a.index for item in val]
135 | idx_test = [item[0].a.index for item in test]
136 |
137 | f_train_w = csv.writer(open(root_idx_dir + dataset.name + '_train.index', 'a+'))
138 | f_val_w = csv.writer(open(root_idx_dir + dataset.name + '_val.index', 'a+'))
139 | f_test_w = csv.writer(open(root_idx_dir + dataset.name + '_test.index', 'a+'))
140 |
141 | f_train_w.writerow(idx_train)
142 | f_val_w.writerow(idx_val)
143 | f_test_w.writerow(idx_test)
144 |
145 | print("[!] Splitting done!")
146 |
147 | # reading idx from the files
148 | for section in ['train', 'val', 'test']:
149 | with open(root_idx_dir + dataset.name + '_'+ section + '.index', 'r') as f:
150 | reader = csv.reader(f)
151 | all_idx[section] = [list(map(int, idx)) for idx in reader]
152 | return all_idx
153 |
154 | def format_dataset(self, dataset):
155 | """
156 | Utility function to recover data,
157 | INTO-> dgl/pytorch compatible format
158 | """
159 | graphs = [data[0] for data in dataset]
160 | labels = [data[1] for data in dataset]
161 |
162 | for graph in graphs:
163 | #graph.ndata['feat'] = torch.FloatTensor(graph.ndata['feat'])
164 | graph.ndata['feat'] = graph.ndata['feat'].float() # dgl 4.0
165 | # adding edge features for Residual Gated ConvNet, if not there
166 | if 'feat' not in graph.edata.keys():
167 | edge_feat_dim = graph.ndata['feat'].shape[1] # dim same as node feature dim
168 | graph.edata['feat'] = torch.ones(graph.number_of_edges(), edge_feat_dim)
169 |
170 | return DGLFormDataset(graphs, labels)
171 |
172 |
173 | # form a mini batch from a given list of samples = [(graph, label) pairs]
174 | def collate(self, samples):
175 | # The input samples is a list of pairs (graph, label).
176 | graphs, labels = map(list, zip(*samples))
177 | labels = torch.tensor(np.array(labels))
178 | # tab_sizes_n = [ graphs[i].number_of_nodes() for i in range(len(graphs))]
179 | # tab_snorm_n = [ torch.FloatTensor(size,1).fill_(1./float(size)) for size in tab_sizes_n ]
180 | # snorm_n = torch.cat(tab_snorm_n).sqrt()
181 | # tab_sizes_e = [ graphs[i].number_of_edges() for i in range(len(graphs))]
182 | # tab_snorm_e = [ torch.FloatTensor(size,1).fill_(1./float(size)) for size in tab_sizes_e ]
183 | # snorm_e = torch.cat(tab_snorm_e).sqrt()
184 | batched_graph = dgl.batch(graphs)
185 |
186 | return batched_graph, labels
187 |
188 |
189 | # prepare dense tensors for GNNs using them; such as RingGNN, 3WLGNN
190 | def collate_dense_gnn(self, samples):
191 | # The input samples is a list of pairs (graph, label).
192 | graphs, labels = map(list, zip(*samples))
193 | labels = torch.tensor(np.array(labels))
194 | #tab_sizes_n = [ graphs[i].number_of_nodes() for i in range(len(graphs))]
195 | #tab_snorm_n = [ torch.FloatTensor(size,1).fill_(1./float(size)) for size in tab_sizes_n ]
196 | #snorm_n = tab_snorm_n[0][0].sqrt()
197 |
198 | #batched_graph = dgl.batch(graphs)
199 |
200 | g = graphs[0]
201 | adj = self._sym_normalize_adj(g.adjacency_matrix().to_dense())
202 | """
203 | Adapted from https://github.com/leichen2018/Ring-GNN/
204 | Assigning node and edge feats::
205 | we have the adjacency matrix in R^{n x n}, the node features in R^{d_n} and edge features R^{d_e}.
206 | Then we build a zero-initialized tensor, say T, in R^{(1 + d_n + d_e) x n x n}. T[0, :, :] is the adjacency matrix.
207 | The diagonal T[1:1+d_n, i, i], i = 0 to n-1, store the node feature of node i.
208 | The off diagonal T[1+d_n:, i, j] store edge features of edge(i, j).
209 | """
210 |
211 | zero_adj = torch.zeros_like(adj)
212 |
213 | in_dim = g.ndata['feat'].shape[1]
214 |
215 | # use node feats to prepare adj
216 | adj_node_feat = torch.stack([zero_adj for j in range(in_dim)])
217 | adj_node_feat = torch.cat([adj.unsqueeze(0), adj_node_feat], dim=0)
218 |
219 | for node, node_feat in enumerate(g.ndata['feat']):
220 | adj_node_feat[1:, node, node] = node_feat
221 |
222 | x_node_feat = adj_node_feat.unsqueeze(0)
223 |
224 | return x_node_feat, labels
225 |
226 | def _sym_normalize_adj(self, adj):
227 | deg = torch.sum(adj, dim = 0)#.squeeze()
228 | deg_inv = torch.where(deg>0, 1./torch.sqrt(deg), torch.zeros(deg.size()))
229 | deg_inv = torch.diag(deg_inv)
230 | return torch.mm(deg_inv, torch.mm(adj, deg_inv))
231 |
232 |
233 | def _add_self_loops(self):
234 |
235 | # function for adding self loops
236 | # this function will be called only if self_loop flag is True
237 | for split_num in range(10):
238 | self.train[split_num].graph_lists = [self_loop(g) for g in self.train[split_num].graph_lists]
239 | self.val[split_num].graph_lists = [self_loop(g) for g in self.val[split_num].graph_lists]
240 | self.test[split_num].graph_lists = [self_loop(g) for g in self.test[split_num].graph_lists]
241 |
242 | for split_num in range(10):
243 | self.train[split_num] = DGLFormDataset(self.train[split_num].graph_lists, self.train[split_num].graph_labels)
244 | self.val[split_num] = DGLFormDataset(self.val[split_num].graph_lists, self.val[split_num].graph_labels)
245 | self.test[split_num] = DGLFormDataset(self.test[split_num].graph_lists, self.test[split_num].graph_labels)
246 |
--------------------------------------------------------------------------------
/data/TUs/DD_test.index:
--------------------------------------------------------------------------------
1 | 25,36,44,53,54,60,66,68,81,95,99,113,115,128,132,146,155,185,190,195,196,197,209,223,225,226,244,254,272,301,308,311,312,315,331,340,349,366,375,397,411,414,418,433,435,438,447,467,472,477,478,479,497,526,529,542,566,576,605,608,609,620,623,628,633,640,642,650,657,689,697,698,699,715,723,725,744,753,760,764,787,799,807,810,831,838,863,869,870,874,900,903,925,932,943,950,978,992,994,997,1008,1009,1018,1026,1037,1040,1074,1082,1083,1087,1113,1119,1125,1132,1135,1142,1143,1151,1175
2 | 8,42,47,48,55,86,111,114,116,119,129,140,144,148,150,160,161,176,178,221,243,255,257,268,273,276,289,295,297,302,321,343,353,354,360,367,376,378,379,389,401,428,431,434,436,449,453,474,476,480,493,499,500,505,523,527,540,543,557,581,585,597,629,630,649,660,668,675,685,693,694,707,720,727,728,737,756,768,773,774,780,783,784,786,792,801,805,809,814,818,820,822,824,826,829,835,837,846,851,857,909,910,913,919,923,936,953,963,975,981,1101,1107,1108,1128,1130,1147,1165,1172
3 | 3,19,21,31,61,97,104,118,123,124,134,138,142,152,159,171,179,194,217,231,256,269,296,298,299,310,327,341,348,350,362,364,380,381,383,393,398,399,405,408,409,417,430,432,442,456,459,460,485,489,492,530,533,548,553,556,558,568,572,583,584,588,589,614,627,632,654,672,686,700,703,735,769,772,775,777,782,788,832,848,856,860,884,890,891,892,899,906,917,935,944,952,955,959,970,972,973,980,982,1014,1015,1016,1019,1029,1036,1050,1069,1078,1081,1088,1129,1139,1148,1153,1156,1163,1164,1174
4 | 0,14,18,35,38,49,69,72,74,90,93,102,127,131,135,156,157,182,189,216,218,220,222,251,261,282,284,317,328,334,363,370,385,413,415,422,452,464,488,491,504,510,522,544,547,554,567,569,582,587,590,593,594,600,604,607,616,618,634,635,641,651,653,656,662,665,677,680,682,692,695,712,717,718,734,745,762,794,797,798,811,812,827,834,843,861,886,897,905,915,929,937,941,942,962,964,966,993,1003,1020,1027,1033,1034,1035,1043,1053,1058,1068,1075,1096,1100,1110,1112,1118,1122,1123,1134,1141
5 | 33,51,75,76,77,82,83,85,87,88,91,98,100,121,122,147,154,167,173,205,208,245,249,260,277,287,300,316,319,326,330,332,346,347,357,361,368,374,390,416,426,427,429,448,466,471,483,487,495,496,512,517,519,536,541,551,552,565,570,573,586,601,606,610,611,624,645,676,684,710,721,729,740,742,754,779,785,793,796,836,842,864,880,885,889,896,912,918,920,931,954,957,960,971,983,984,989,1002,1004,1024,1049,1054,1062,1067,1071,1089,1120,1126,1140,1145,1146,1149,1152,1154,1162,1168,1173,1176
6 | 4,27,28,34,39,62,80,92,125,126,143,153,166,168,187,192,201,212,215,219,230,233,237,241,246,258,259,267,271,275,278,281,290,305,313,318,339,342,344,345,352,356,386,387,394,400,412,424,444,470,503,508,520,528,535,537,549,561,562,564,577,621,625,643,648,652,661,678,688,711,724,730,736,739,747,752,755,770,791,806,821,828,840,844,862,867,868,876,881,883,902,904,908,916,933,934,940,948,949,965,979,986,990,995,996,1011,1045,1048,1059,1064,1070,1091,1111,1115,1136,1138,1150,1157
7 | 5,9,13,23,26,29,30,41,50,58,59,63,64,67,70,78,117,130,137,165,169,170,175,177,183,184,188,199,203,210,214,236,248,253,266,306,314,320,322,323,324,337,338,358,369,371,388,395,407,450,451,465,481,490,498,502,507,555,595,596,598,613,622,638,659,667,679,681,690,704,738,758,766,771,776,778,781,789,802,803,813,833,841,853,858,859,872,879,888,898,911,924,951,961,976,987,988,998,1012,1013,1017,1021,1025,1038,1046,1047,1051,1072,1086,1090,1095,1098,1103,1104,1105,1106,1124,1167
8 | 10,12,15,22,46,56,57,89,103,105,109,110,120,151,164,180,193,213,227,235,238,247,252,263,264,274,280,286,294,303,304,309,335,351,396,403,419,425,445,455,468,469,501,506,509,521,525,532,539,546,550,563,571,578,579,580,591,602,603,619,636,637,639,644,646,655,669,673,674,709,713,731,733,741,748,749,751,761,790,808,815,839,845,847,849,850,852,873,921,928,938,939,946,956,958,968,969,974,985,991,1000,1005,1006,1010,1030,1031,1044,1055,1057,1066,1073,1077,1092,1109,1114,1131,1169
9 | 7,16,20,32,40,43,52,79,84,94,101,108,136,145,149,172,186,191,198,200,204,206,228,232,242,262,270,283,291,292,307,325,355,359,373,382,391,402,404,410,421,423,440,441,454,457,458,463,484,486,511,513,514,515,516,524,531,538,560,575,592,617,631,647,658,664,666,670,683,696,702,716,722,726,743,750,757,795,804,816,817,819,855,865,866,875,877,882,887,893,895,901,907,922,927,945,977,999,1007,1041,1042,1052,1056,1060,1063,1065,1076,1079,1093,1094,1102,1117,1121,1158,1160,1161,1166
10 | 1,2,6,11,17,24,37,45,65,71,73,96,106,107,112,133,139,141,158,162,163,174,181,202,207,211,224,229,234,239,240,250,265,279,285,288,293,329,333,336,365,372,377,384,392,406,420,437,439,443,446,461,462,473,475,482,494,518,534,545,559,574,599,612,615,626,663,671,687,691,701,705,706,708,714,719,732,746,759,763,765,767,800,823,825,830,854,871,878,894,914,926,930,947,967,1001,1022,1023,1028,1032,1039,1061,1080,1084,1085,1097,1099,1116,1127,1133,1137,1144,1155,1159,1170,1171,1177
11 |
--------------------------------------------------------------------------------
/data/TUs/DD_val.index:
--------------------------------------------------------------------------------
1 | 443,729,804,373,710,741,1139,212,394,584,899,791,233,663,432,1012,813,16,112,538,639,268,763,738,127,579,749,783,803,728,423,917,731,702,238,611,974,158,485,942,1000,1141,150,701,455,889,569,554,984,462,85,415,774,850,43,422,1003,1098,221,11,680,550,263,204,1089,599,162,245,987,525,348,631,552,622,953,280,97,471,409,841,880,802,718,876,3,506,1120,1169,1069,716,775,638,32,30,1090,491,110,570,581,475,539,1045,344,896,319,752,1173,299,811,6,927,482,214,21,403,362,269,386
2 | 663,752,384,196,495,526,17,369,518,971,231,1100,716,575,811,443,152,46,703,1035,833,755,100,870,1070,115,859,533,993,1078,978,471,323,99,506,892,347,502,1047,821,475,1145,1002,957,341,854,402,199,264,4,723,486,122,118,606,977,77,20,529,680,750,188,881,1084,1056,1015,628,80,92,1025,239,831,955,206,1095,766,237,326,984,609,154,695,288,542,985,602,277,856,79,218,896,1124,612,782,1104,659,283,437,1148,232,558,1022,507,717,303,908,91,433,1053,274,1176,21,555,653,370,803,300,593
3 | 818,701,316,606,314,984,476,502,288,840,778,404,767,127,214,549,338,667,1008,926,1144,339,649,1056,347,29,1112,520,472,739,765,23,604,644,37,493,655,156,293,873,263,389,161,253,695,513,30,666,1011,939,908,354,773,181,685,731,1167,172,646,812,940,680,128,813,25,108,657,1169,861,587,1033,315,762,391,185,720,41,495,901,481,707,598,932,620,630,1068,69,192,59,1060,724,852,692,776,467,264,1123,796,1018,1001,511,186,824,507,877,708,895,33,143,1162,91,490,225,738,577,938,671,593
4 | 337,125,280,1025,551,949,810,1037,82,540,1078,7,633,721,80,921,733,359,945,609,1056,116,953,719,833,28,924,47,500,1069,664,1169,918,1152,809,162,518,631,200,351,689,589,96,148,754,1054,555,987,983,1109,394,299,755,585,666,606,454,1177,581,295,511,1015,188,648,643,973,477,523,467,297,329,155,1061,598,1175,357,137,503,698,787,644,913,647,638,674,455,744,969,628,432,1004,1080,158,957,1024,946,457,1038,424,1089,469,629,1008,259,1121,434,20,358,980,526,214,686,830,642,820,940,873,97
5 | 452,1086,1042,516,415,4,915,299,877,533,373,678,1085,1131,19,701,501,238,1,56,278,321,369,317,66,766,1123,1000,394,310,1156,612,1017,198,1081,746,470,250,698,575,0,196,455,994,538,1112,445,289,1097,638,351,161,375,292,505,53,20,847,112,718,970,814,244,8,648,359,945,86,884,1047,241,1005,759,479,901,65,203,274,1011,647,525,262,212,981,917,894,25,715,873,1023,675,739,851,273,549,1104,965,724,1010,192,674,1074,157,852,862,84,67,821,843,253,234,947,209,727,738,201,634,232
6 | 328,558,950,338,36,1096,248,1041,713,185,1055,984,839,268,639,607,882,532,358,239,1165,68,977,690,1088,1006,1176,705,11,378,1144,701,798,410,885,645,65,891,833,321,238,63,640,284,375,95,72,510,440,608,133,698,1086,283,362,357,150,106,666,486,953,9,751,861,519,732,330,206,808,728,848,670,842,880,294,935,493,824,513,277,501,800,814,710,637,487,605,205,436,282,114,60,837,1065,780,422,56,426,382,1090,35,754,363,3,871,700,340,204,574,70,1071,874,823,392,807,859,1018,368
7 | 877,724,1066,636,527,653,621,21,469,602,313,572,890,31,152,367,433,1074,1154,945,953,17,897,966,298,978,363,1069,352,723,16,763,1091,687,590,1019,1083,129,774,670,1144,228,470,305,856,100,647,548,1031,715,285,574,846,609,788,735,57,714,1119,739,860,1176,820,699,504,1003,247,271,168,1121,434,384,643,847,20,307,947,48,661,222,139,145,237,215,1110,747,1131,272,965,484,1071,1166,956,478,116,674,402,842,796,505,40,173,134,179,510,80,806,212,204,18,99,1045,594,665,335,878,258,932
8 | 932,1076,793,308,745,1042,1126,158,128,783,905,156,872,786,772,1024,402,581,421,311,399,95,1091,449,127,322,435,154,479,1120,957,737,1059,534,1054,99,705,552,560,337,1105,920,410,738,708,391,37,347,488,628,973,211,665,1033,408,629,721,291,217,462,1117,152,328,342,117,699,1003,385,312,915,478,348,47,828,5,814,119,446,595,1021,514,1072,148,517,190,1138,4,961,577,168,1107,801,323,989,53,710,1040,429,732,1047,504,822,706,395,897,31,1130,165,1087,1084,338,254,555,201,614,298,255,334
9 | 297,885,177,1104,75,797,383,464,1066,706,994,905,157,782,503,152,988,765,661,455,1070,481,747,1006,479,651,934,605,490,432,980,171,87,1087,334,1154,1164,818,1120,756,655,586,1109,606,540,129,494,156,212,42,71,1045,837,926,777,409,103,970,1080,97,869,168,997,533,85,1000,734,387,14,1097,238,253,123,1040,1037,11,584,261,243,221,148,962,565,316,1124,1168,786,167,815,330,309,453,990,53,708,159,134,187,589,1096,367,919,667,1122,197,215,142,955,491,891,258,438,1103,257,898,338,572,548
10 | 620,206,284,185,84,186,1148,587,826,36,690,63,1043,694,150,604,502,1079,1029,327,64,1112,648,655,693,851,1083,1164,1117,331,1113,943,465,778,405,201,886,286,789,454,312,1027,83,353,579,902,480,1110,531,490,93,302,1111,991,1045,1086,844,833,893,850,20,196,489,535,1106,776,123,488,38,960,584,233,530,156,25,1122,698,990,666,729,47,403,659,1105,471,1042,1114,23,457,289,117,161,832,841,622,1124,792,219,616,627,330,251,49,624,455,1152,790,379,1038,309,846,956,401,1173,595,866,505,674
11 |
--------------------------------------------------------------------------------
/data/TUs/ENZYMES_test.index:
--------------------------------------------------------------------------------
1 | 0,13,14,20,25,33,39,63,68,85,102,113,117,129,144,161,167,192,194,195,221,223,241,244,267,273,277,281,282,287,316,332,333,341,344,361,368,379,391,393,405,411,414,416,420,456,467,472,487,491,500,504,505,508,514,526,533,536,580,589
2 | 32,47,54,57,60,65,67,86,91,99,106,120,124,126,159,160,166,170,173,191,246,258,259,260,264,268,269,278,284,297,321,326,331,342,348,357,364,370,378,386,404,406,424,437,438,447,448,455,478,496,518,528,544,548,558,569,575,588,596,599
3 | 5,12,21,26,55,71,72,92,93,94,101,110,123,127,130,155,175,180,181,199,201,205,207,213,232,233,256,295,296,299,300,305,319,327,337,345,352,365,388,399,430,434,435,452,474,475,482,484,485,489,522,523,529,530,549,552,560,567,572,590
4 | 2,3,6,19,27,28,46,59,89,90,116,121,136,165,168,176,179,184,186,193,202,204,215,225,245,254,257,289,290,294,313,314,318,335,349,373,375,380,389,392,402,421,433,442,462,471,476,492,497,499,501,510,534,554,556,557,565,570,581,584
5 | 8,15,31,44,45,48,81,82,84,95,103,114,125,139,142,147,163,177,190,198,203,214,217,219,229,234,247,250,252,262,301,302,310,346,351,354,366,369,395,398,401,408,415,425,429,439,444,446,470,488,515,517,519,531,540,550,559,563,571,592
6 | 18,23,24,29,35,36,38,73,77,79,122,134,138,141,143,148,169,182,187,189,212,224,227,230,239,242,249,261,271,293,304,311,315,324,328,329,339,340,371,385,400,403,412,417,432,445,464,469,479,493,509,520,521,524,543,562,578,579,587,595
7 | 30,34,43,49,56,64,66,69,88,98,132,145,151,152,153,158,164,172,174,185,206,209,210,226,251,253,263,274,276,285,303,306,312,323,353,355,359,367,374,381,407,409,426,453,454,457,459,463,477,498,525,527,538,545,551,561,574,576,583,586
8 | 1,4,41,50,62,74,75,76,83,96,109,112,115,119,131,154,156,157,162,196,211,220,222,228,240,243,248,270,280,298,320,322,334,336,343,358,377,384,394,397,410,428,436,440,443,460,461,486,490,494,506,512,513,553,555,564,566,593,594,597
9 | 7,9,16,40,42,51,52,53,70,87,107,111,118,128,135,137,140,178,183,188,208,218,235,237,255,272,275,283,288,291,307,325,330,347,350,356,360,363,372,396,413,422,431,441,449,458,466,473,480,483,502,507,532,535,537,539,541,577,585,598
10 | 10,11,17,22,37,58,61,78,80,97,100,104,105,108,133,146,149,150,171,197,200,216,231,236,238,265,266,279,286,292,308,309,317,338,362,376,382,383,387,390,418,419,423,427,450,451,465,468,481,495,503,511,516,542,546,547,568,573,582,591
11 |
--------------------------------------------------------------------------------
/data/TUs/ENZYMES_val.index:
--------------------------------------------------------------------------------
1 | 328,431,464,451,9,78,153,236,4,80,182,441,225,257,83,199,245,88,426,327,513,390,126,463,157,301,170,304,594,554,540,204,401,380,324,330,497,578,178,310,134,567,384,296,572,547,97,17,565,545,50,400,65,276,492,235,168,252,218,171
2 | 319,531,435,393,385,581,12,372,339,110,460,49,276,343,107,142,174,74,271,51,461,398,175,266,564,58,467,506,306,204,436,469,7,30,407,92,257,273,101,526,240,598,38,138,380,237,507,495,395,33,123,210,504,256,464,111,592,454,562,155
3 | 165,77,22,331,69,119,202,240,393,45,497,573,315,193,222,28,56,397,162,8,509,317,373,408,163,192,442,234,84,525,355,188,289,496,504,219,301,441,133,269,142,4,288,445,339,292,440,505,524,279,18,469,536,562,423,343,554,416,117,557
4 | 228,69,237,170,244,488,112,406,97,164,326,291,191,64,56,466,238,541,574,520,397,33,100,493,309,347,354,563,22,503,310,583,549,430,319,31,98,321,317,53,276,78,118,438,585,230,542,207,457,423,135,139,288,419,129,261,390,490,101,517
5 | 73,525,36,496,378,206,59,566,154,367,221,431,174,162,150,167,159,402,197,209,141,228,134,35,386,280,11,584,474,452,565,467,244,430,309,594,233,362,288,72,164,537,92,389,508,486,521,316,208,66,23,359,490,242,65,538,341,328,494,523
6 | 210,309,377,104,590,565,388,240,491,576,500,117,391,76,247,527,406,273,552,467,132,176,7,248,37,449,156,450,363,147,473,14,82,250,56,599,418,214,383,43,1,322,443,302,170,396,85,181,151,364,559,413,287,135,253,280,526,546,6,472
7 | 247,334,188,262,455,341,548,31,96,248,521,203,11,268,54,195,423,495,298,180,40,92,336,537,7,491,170,432,479,313,166,591,131,549,587,510,292,223,392,340,384,534,228,39,469,1,378,442,149,530,544,147,106,431,62,319,317,227,130,496
8 | 210,535,332,402,524,173,596,335,598,585,216,423,400,374,453,559,16,372,309,275,174,376,338,287,99,544,256,179,236,323,30,439,3,499,205,182,357,548,446,108,5,496,509,429,237,107,47,147,92,134,71,144,12,325,403,186,207,234,504,84
9 | 432,105,572,203,38,271,273,524,499,582,589,301,172,163,368,193,560,13,200,412,130,421,108,72,401,554,547,386,450,277,84,513,553,407,404,355,119,494,220,467,233,364,323,590,20,294,239,319,329,150,246,104,59,376,181,49,73,25,357,18
10 | 119,490,226,48,291,272,274,564,303,557,474,124,82,85,320,242,377,203,483,459,380,112,111,98,86,441,401,79,62,597,99,453,354,388,430,121,198,38,561,164,153,344,509,329,539,438,350,540,49,165,235,569,553,447,148,224,375,500,211,297
11 |
--------------------------------------------------------------------------------
/data/TUs/PROTEINS_full_test.index:
--------------------------------------------------------------------------------
1 | 3,6,19,36,42,45,53,55,64,69,82,84,89,97,122,150,167,174,183,192,193,202,203,211,229,233,243,249,251,264,297,303,308,309,313,353,356,363,374,382,389,391,393,398,408,417,441,446,448,455,460,461,499,512,518,532,533,547,566,569,579,598,602,605,611,652,661,665,668,673,676,679,684,692,698,715,723,744,753,756,765,772,782,783,801,828,830,851,864,877,892,893,896,899,904,921,930,939,948,949,953,959,962,975,983,994,1009,1043,1044,1069,1075,1080
2 | 17,18,25,44,66,68,71,86,98,100,114,121,124,127,145,152,154,169,170,176,180,184,194,215,223,236,244,255,276,279,293,307,310,312,349,367,375,394,395,403,406,421,430,439,464,472,473,501,510,528,530,550,558,562,567,572,577,580,586,594,599,609,615,618,631,636,645,686,694,697,699,705,706,713,720,734,743,745,755,762,769,786,788,793,809,811,821,850,856,866,876,884,903,920,925,929,931,977,992,1007,1020,1028,1042,1053,1056,1057,1063,1073,1077,1082,1089,1096
3 | 0,1,2,4,20,22,24,35,40,48,51,63,74,83,107,112,120,139,143,144,153,165,190,197,204,210,217,219,222,240,248,253,260,270,278,286,320,321,343,345,347,358,370,379,383,404,411,414,423,442,470,474,477,479,482,486,490,504,560,568,600,603,620,623,649,656,657,671,672,682,687,714,716,722,741,778,790,792,812,815,816,817,820,825,831,849,861,863,865,881,895,901,916,918,935,938,941,942,957,973,1003,1004,1005,1008,1017,1024,1025,1029,1046,1048,1078,1093
4 | 28,29,31,32,61,76,77,106,115,116,128,129,133,136,140,159,160,173,212,214,220,234,235,265,274,282,283,287,294,295,315,318,324,327,333,335,338,351,361,366,376,405,416,419,428,431,435,440,458,475,485,487,491,502,520,522,534,539,542,555,556,576,593,608,614,617,664,670,689,696,707,708,711,721,733,742,746,751,757,766,779,785,787,827,844,848,857,859,872,873,878,898,902,919,922,933,963,965,978,988,991,1000,1002,1027,1037,1040,1041,1054,1079,1094,1106
5 | 5,11,16,37,46,49,50,52,59,78,119,132,146,147,163,172,178,179,201,232,259,267,275,289,311,322,325,352,360,362,380,388,399,412,413,420,432,457,459,465,478,481,492,494,505,511,513,521,527,531,537,548,549,571,575,582,585,616,619,627,633,639,650,651,655,658,666,677,683,691,717,718,725,731,735,739,760,775,795,802,806,810,813,818,824,860,888,900,905,907,908,927,945,954,958,966,968,972,979,982,995,1006,1012,1051,1055,1074,1090,1102,1103,1104,1110
6 | 9,15,41,54,56,57,60,72,92,96,104,123,125,126,131,135,138,151,161,162,171,175,187,189,195,205,207,250,257,258,263,268,284,301,314,316,317,331,348,359,390,397,401,427,437,444,449,450,454,456,484,489,497,508,516,517,523,544,557,563,570,573,588,621,646,648,667,703,704,727,728,730,747,749,771,776,780,781,803,814,822,829,853,868,874,880,897,912,913,926,944,947,960,964,974,976,980,990,999,1032,1045,1050,1060,1061,1068,1070,1076,1084,1091,1097,1108
7 | 8,23,34,39,47,65,73,85,91,95,101,105,141,148,155,156,164,177,181,185,186,196,216,218,225,237,241,246,247,256,261,269,273,277,280,285,291,334,336,342,350,364,371,377,384,407,410,415,424,443,445,493,503,509,525,564,589,596,601,606,612,624,632,635,638,659,669,680,685,688,693,695,701,709,712,738,740,750,759,784,797,798,845,855,858,862,890,923,932,936,937,951,952,956,961,970,985,1010,1014,1021,1022,1033,1047,1052,1058,1062,1081,1085,1105,1107,1112
8 | 10,13,21,26,30,75,88,94,102,108,109,113,142,168,191,198,208,213,227,228,231,239,252,302,305,306,328,355,357,369,373,381,392,400,402,409,425,426,429,452,453,467,469,471,476,483,495,498,500,536,551,552,553,559,578,587,595,610,622,628,641,643,644,653,654,660,674,678,710,724,736,737,752,761,768,794,804,839,843,846,847,852,854,870,871,875,910,915,924,928,943,967,969,987,996,997,998,1031,1036,1039,1049,1059,1064,1065,1072,1083,1087,1092,1095,1101,1111
9 | 7,14,33,58,70,80,90,93,111,117,130,149,166,188,199,206,209,221,230,242,254,266,271,272,281,288,290,292,296,319,323,330,332,337,368,372,378,385,387,396,418,422,434,462,463,468,515,519,524,535,540,541,543,561,574,581,584,590,591,597,604,629,630,637,640,642,663,681,690,748,758,764,774,777,789,791,805,807,808,823,826,832,835,836,837,838,842,869,885,886,887,906,909,914,917,934,971,981,986,989,993,1011,1013,1015,1018,1019,1035,1038,1086,1098,1100
10 | 12,27,38,43,62,67,79,81,87,99,103,110,118,134,137,157,158,182,200,224,226,238,245,262,298,299,300,304,326,329,339,340,341,344,346,354,365,386,433,436,438,447,451,466,480,488,496,506,507,514,526,529,538,545,546,554,565,583,592,607,613,625,626,634,647,662,675,700,702,719,726,729,732,754,763,767,770,773,796,799,800,819,833,834,840,841,867,879,882,883,889,891,894,911,940,946,950,955,984,1001,1016,1023,1026,1030,1034,1066,1067,1071,1088,1099,1109
11 |
--------------------------------------------------------------------------------
/data/TUs/PROTEINS_full_val.index:
--------------------------------------------------------------------------------
1 | 179,29,726,285,501,1005,780,790,738,1104,259,775,789,956,540,403,232,854,815,694,468,147,134,95,1112,331,616,539,184,778,1092,205,258,226,469,436,933,66,695,867,634,239,479,1036,526,641,796,745,48,168,549,112,434,632,873,320,794,626,462,1033,27,599,261,856,596,741,131,482,360,488,31,875,387,804,367,72,437,580,527,1026,761,1089,766,642,895,846,16,123,728,30,344,369,336,770,879,885,971,1,83,629,720,476,998,729,100,615,806,260,333,888,10,901
2 | 1001,1006,716,38,117,444,535,313,292,930,346,843,1031,527,372,965,829,663,613,398,801,476,167,291,344,318,569,806,650,216,452,72,1090,687,629,242,1081,9,649,985,675,1032,161,563,963,228,571,348,973,831,199,896,1014,241,841,1092,587,628,1040,245,747,287,898,369,708,728,750,867,736,434,685,1095,122,368,319,80,387,524,270,871,139,534,153,288,988,1038,30,150,168,511,94,357,410,297,551,928,1049,975,471,208,196,777,574,90,460,1002,937,909,386,197,826,97
3 | 680,106,1094,754,1033,400,629,993,85,569,272,265,739,49,1034,262,335,87,732,312,638,685,424,1075,91,706,230,191,99,186,1069,681,1106,384,1053,401,991,237,513,824,10,231,937,931,786,854,914,127,608,182,611,780,105,522,52,1009,77,480,333,870,689,348,562,227,1090,315,161,719,247,903,902,50,519,667,1002,758,1111,517,1021,956,823,542,832,920,156,663,555,521,983,999,468,434,924,516,923,619,141,494,90,229,192,175,808,661,501,371,573,394,70,124,126,635
4 | 903,166,826,1111,663,1004,227,1056,19,1069,937,797,585,1101,208,256,72,65,413,436,1003,382,132,247,653,211,814,762,602,4,311,97,1038,1029,444,589,811,528,786,433,778,319,334,641,486,293,1044,155,403,237,296,192,998,149,38,425,18,99,983,363,940,798,392,258,1108,277,720,592,1070,3,348,1021,563,355,188,1030,418,175,337,889,784,645,544,526,1051,996,148,1010,1092,157,196,40,701,974,402,1007,1033,1016,286,1055,134,739,788,759,559,774,183,239,853,64,412,540
5 | 458,895,623,706,987,934,479,234,628,914,263,105,522,112,716,766,823,1005,345,819,308,975,142,1086,872,1075,386,17,1046,547,967,361,1099,327,744,771,93,1087,796,539,550,19,1111,40,227,430,703,485,253,64,745,174,84,45,74,428,1031,123,826,742,250,87,789,215,898,638,1082,1048,409,682,858,596,787,1069,942,378,192,204,439,205,264,851,480,125,970,422,356,307,406,769,922,177,792,614,1011,570,440,365,330,302,350,221,1018,593,162,136,66,1065,394,334,711,246
6 | 283,320,581,40,107,796,363,922,148,286,1008,461,298,613,1067,373,85,17,153,408,357,904,1044,115,720,227,1081,38,719,247,59,36,838,946,672,700,480,1011,821,925,3,26,567,1043,524,655,394,696,302,1017,948,1090,334,622,729,238,1052,686,322,170,155,439,494,324,593,50,276,718,654,519,716,629,927,958,101,942,45,475,109,329,1004,1087,438,981,865,82,333,799,597,663,694,811,1023,1082,274,147,446,920,409,105,992,380,124,687,128,1096,929,398,1051,396,413,243
7 | 121,301,697,993,184,552,1039,425,226,1084,478,853,579,636,1024,434,90,1088,403,399,929,962,1064,592,1104,887,634,182,647,431,288,637,149,305,142,40,231,194,708,598,518,946,700,718,117,220,769,104,912,720,631,973,719,968,347,143,442,618,1051,353,992,307,471,864,1063,599,433,396,77,528,474,777,728,1025,488,1092,832,457,55,397,88,945,950,297,903,696,1108,234,1056,374,130,536,771,439,1002,860,611,114,27,268,665,420,479,668,679,646,607,990,608,192,376,917
8 | 1050,439,703,950,197,615,939,938,124,352,248,627,573,722,300,694,581,269,82,760,853,14,1103,877,840,130,490,122,819,71,896,535,849,488,280,1033,458,284,279,781,640,177,497,390,1081,925,676,1051,708,98,1085,226,972,1053,78,519,482,693,68,353,404,919,335,816,391,199,485,472,407,304,624,667,360,651,769,746,958,235,258,39,663,580,807,477,1088,388,188,491,1029,812,154,234,356,1082,246,569,325,116,19,792,450,650,1009,898,86,827,1017,1019,1002,303,35,657
9 | 486,458,653,238,1007,279,624,929,27,237,85,773,772,298,293,918,617,593,761,202,1025,677,487,962,722,18,945,363,957,725,351,1014,169,204,717,155,825,586,570,697,488,411,529,657,847,394,979,1088,527,1001,852,181,325,24,299,71,215,36,562,481,861,28,447,820,695,276,633,927,643,902,724,1042,201,1068,938,311,683,824,261,11,560,522,498,244,616,1055,804,526,995,546,343,104,457,648,65,510,456,850,1045,16,855,706,844,747,868,460,878,273,453,109,606,1096
10 | 757,999,756,250,926,644,824,464,434,281,26,476,1027,343,165,1076,629,568,922,132,254,385,906,236,486,356,289,1053,1044,440,826,59,16,169,860,713,593,500,1013,896,167,608,311,885,475,127,282,930,694,82,1068,112,901,393,408,652,996,521,1022,936,980,44,571,806,248,1105,186,1064,1075,780,705,187,28,363,146,697,271,168,572,898,588,704,51,904,163,893,1058,720,192,302,351,516,175,960,306,557,673,285,864,858,1039,69,561,56,998,272,654,508,442,251,74,944
11 |
--------------------------------------------------------------------------------
/data/TUs/README.md:
--------------------------------------------------------------------------------
1 | ### Details:
2 | - Split total number of graphs into 3 (train, val and test) in 80:10:10
3 | - Stratified split proportionate to original distribution of data with respect to classes
4 | - Using sklearn to perform the split and then save the indexes
5 | - Preparing 10 such combinations of indexes split to be used in Graph NNs
6 | - As with KFold, each of the 10 fold have unique test set.
--------------------------------------------------------------------------------
/data/data.py:
--------------------------------------------------------------------------------
1 | """
2 | File to load dataset based on user control from main file
3 | """
4 | from data.TUs import TUsDataset
5 |
6 |
7 | def LoadData(DATASET_NAME):
8 | """
9 | This function is called in the main.py file
10 | returns:
11 | ; dataset object
12 | """
13 |
14 | # handling for the TU Datasets
15 | TU_DATASETS = ['ENZYMES', 'DD', 'PROTEINS_full']
16 | if DATASET_NAME in TU_DATASETS:
17 | return TUsDataset(DATASET_NAME)
18 |
--------------------------------------------------------------------------------
/layers/diffpool_layer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.nn import functional as F
4 | import numpy as np
5 | from scipy.linalg import block_diag
6 |
7 | from torch.autograd import Function
8 |
9 | """
10 | DIFFPOOL:
11 | Z. Ying, J. You, C. Morris, X. Ren, W. Hamilton, and J. Leskovec,
12 | Hierarchical graph representation learning with differentiable pooling (NeurIPS 2018)
13 | https://arxiv.org/pdf/1806.08804.pdf
14 |
15 | ! code started from dgl diffpool examples dir
16 | """
17 |
18 | from layers.graphsage_layer import GraphSageLayer, DenseGraphSage
19 |
20 |
21 | def masked_softmax(matrix, mask, dim=-1, memory_efficient=True,
22 | mask_fill_value=-1e32):
23 | '''
24 | masked_softmax for dgl batch graph
25 | code snippet contributed by AllenNLP (https://github.com/allenai/allennlp)
26 | '''
27 | if mask is None:
28 | result = torch.nn.functional.softmax(matrix, dim=dim)
29 | else:
30 | mask = mask.float()
31 | while mask.dim() < matrix.dim():
32 | mask = mask.unsqueeze(1)
33 | if not memory_efficient:
34 | result = torch.nn.functional.softmax(matrix * mask, dim=dim)
35 | result = result * mask
36 | result = result / (result.sum(dim=dim, keepdim=True) + 1e-13)
37 | else:
38 | masked_matrix = matrix.masked_fill((1 - mask).byte(),
39 | mask_fill_value)
40 | result = torch.nn.functional.softmax(masked_matrix, dim=dim)
41 | return result
42 |
43 |
44 | class EntropyLoss(nn.Module):
45 | # Return Scalar
46 | # loss used in diffpool
47 | def forward(self, adj, anext, s_l):
48 | entropy = (torch.distributions.Categorical(
49 | probs=s_l).entropy()).sum(-1).mean(-1)
50 | assert not torch.isnan(entropy)
51 | return entropy
52 |
53 |
54 | class DiffPoolLayer(nn.Module):
55 |
56 | def __init__(self, input_dim, assign_dim, output_feat_dim,
57 | activation, dropout, aggregator_type, link_pred, batch_norm):
58 | super().__init__()
59 | self.embedding_dim = input_dim
60 | self.assign_dim = assign_dim
61 | self.hidden_dim = output_feat_dim
62 | self.link_pred = link_pred
63 | self.feat_gc = GraphSageLayer(
64 | input_dim,
65 | output_feat_dim,
66 | activation,
67 | dropout,
68 | aggregator_type,
69 | batch_norm)
70 | self.pool_gc = GraphSageLayer(
71 | input_dim,
72 | assign_dim,
73 | activation,
74 | dropout,
75 | aggregator_type,
76 | batch_norm)
77 | self.reg_loss = nn.ModuleList([])
78 | self.loss_log = {}
79 | self.reg_loss.append(EntropyLoss())
80 |
81 | def forward(self, g, h, e=None):
82 | feat, e = self.feat_gc(g, h, e)
83 | assign_tensor, e = self.pool_gc(g, h, e)
84 | # print(assign_tensor.shape)
85 | device = feat.device
86 | assign_tensor_masks = []
87 | batch_size = len(g.batch_num_nodes())
88 | for g_n_nodes in g.batch_num_nodes():
89 | mask = torch.ones((g_n_nodes,
90 | int(assign_tensor.size()[1] / batch_size)))
91 | assign_tensor_masks.append(mask)
92 | # print(assign_tensor_masks)
93 | """
94 | The first pooling layer is computed on batched graph.
95 | We first take the adjacency matrix of the batched graph, which is block-wise diagonal.
96 | We then compute the assignment matrix for the whole batch graph, which will also be block diagonal
97 | """
98 | mask = torch.FloatTensor(
99 | block_diag(
100 | *
101 | assign_tensor_masks)).to(
102 | device=device)
103 |
104 | assign_tensor = masked_softmax(assign_tensor, mask,
105 | memory_efficient=False)
106 | # print(assign_tensor.shape)
107 | h = torch.matmul(torch.t(assign_tensor), feat) # equation (3) of DIFFPOOL paper
108 | adj = g.adjacency_matrix(ctx=device)
109 |
110 | adj_new = torch.sparse.mm(adj, assign_tensor)
111 | adj_new = torch.mm(torch.t(assign_tensor), adj_new) # equation (4) of DIFFPOOL paper
112 |
113 | if self.link_pred:
114 | current_lp_loss = torch.norm(adj.to_dense() -
115 | torch.mm(assign_tensor, torch.t(assign_tensor))) / np.power(g.number_of_nodes(), 2)
116 | self.loss_log['LinkPredLoss'] = current_lp_loss
117 |
118 | for loss_layer in self.reg_loss:
119 | loss_name = str(type(loss_layer).__name__)
120 | self.loss_log[loss_name] = loss_layer(adj, adj_new, assign_tensor)
121 | return adj_new, h
122 |
123 |
124 | class LinkPredLoss(nn.Module):
125 | # loss used in diffpool
126 | def forward(self, adj, anext, s_l):
127 | link_pred_loss = (
128 | adj - s_l.matmul(s_l.transpose(-1, -2))).norm(dim=(1, 2))
129 | link_pred_loss = link_pred_loss / (adj.size(1) * adj.size(2))
130 | return link_pred_loss.mean()
131 |
132 |
133 | class DenseDiffPool(nn.Module):
134 | def __init__(self, nfeat, nnext, nhid, link_pred=False, entropy=True):
135 | super().__init__()
136 | self.link_pred = link_pred
137 | self.log = {}
138 | self.link_pred_layer = self.LinkPredLoss()
139 | self.embed = DenseGraphSage(nfeat, nhid, use_bn=True)
140 | self.assign = DiffPoolAssignment(nfeat, nnext)
141 | self.reg_loss = nn.ModuleList([])
142 | self.loss_log = {}
143 | if link_pred:
144 | self.reg_loss.append(LinkPredLoss())
145 | if entropy:
146 | self.reg_loss.append(EntropyLoss())
147 |
148 | def forward(self, x, adj, log=False):
149 | z_l = self.embed(x, adj)
150 | s_l = self.assign(x, adj)
151 | if log:
152 | self.log['s'] = s_l.cpu().numpy()
153 | xnext = torch.matmul(s_l.transpose(-1, -2), z_l)
154 | anext = (s_l.transpose(-1, -2)).matmul(adj).matmul(s_l)
155 |
156 | for loss_layer in self.reg_loss:
157 | loss_name = str(type(loss_layer).__name__)
158 | self.loss_log[loss_name] = loss_layer(adj, anext, s_l)
159 | if log:
160 | self.log['a'] = anext.cpu().numpy()
161 | return xnext, anext
162 |
163 | class DiffPoolAssignment(nn.Module):
164 | def __init__(self, nfeat, nnext):
165 | super().__init__()
166 | self.assign_mat = DenseGraphSage(nfeat, nnext, use_bn=True)
167 |
168 | def forward(self, x, adj, log=False):
169 | s_l_init = self.assign_mat(x, adj)
170 | s_l = F.softmax(s_l_init, dim=-1)
171 | return s_l
172 |
--------------------------------------------------------------------------------
/layers/gat_layer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from dgl.nn.pytorch import GATConv
6 |
7 | """
8 | GAT: Graph Attention Network
9 | Graph Attention Networks (Veličković et al., ICLR 2018)
10 | https://arxiv.org/abs/1710.10903
11 | """
12 |
13 | class GATLayer(nn.Module):
14 | """
15 | Parameters
16 | ----------
17 | in_dim :
18 | Number of input features.
19 | out_dim :
20 | Number of output features.
21 | num_heads : int
22 | Number of heads in Multi-Head Attention.
23 | dropout :
24 | Required for dropout of attn and feat in GATConv
25 | batch_norm :
26 | boolean flag for batch_norm layer.
27 | residual :
28 | If True, use residual connection inside this layer. Default: ``False``.
29 | activation : callable activation function/layer or None, optional.
30 | If not None, applies an activation function to the updated node features.
31 |
32 | Using dgl builtin GATConv by default:
33 | https://github.com/graphdeeplearning/benchmarking-gnns/commit/206e888ecc0f8d941c54e061d5dffcc7ae2142fc
34 | """
35 | def __init__(self, in_dim, out_dim, num_heads, dropout, batch_norm, residual=False, activation=F.elu):
36 | super().__init__()
37 | self.residual = residual
38 | self.activation = activation
39 | self.batch_norm = batch_norm
40 |
41 | if in_dim != (out_dim*num_heads):
42 | self.residual = False
43 |
44 | self.gatconv = GATConv(in_dim, out_dim, num_heads, dropout, dropout)
45 |
46 | if self.batch_norm:
47 | self.batchnorm_h = nn.BatchNorm1d(out_dim * num_heads)
48 |
49 | def forward(self, g, h):
50 | h_in = h # for residual connection
51 |
52 | h = self.gatconv(g, h).flatten(1)
53 |
54 | if self.batch_norm:
55 | h = self.batchnorm_h(h)
56 |
57 | if self.activation:
58 | h = self.activation(h)
59 |
60 | if self.residual:
61 | h = h_in + h # residual connection
62 |
63 | return h
64 |
65 |
66 | ##############################################################
67 | #
68 | # Additional layers for edge feature/representation analysis
69 | #
70 | ##############################################################
71 |
72 |
73 | class CustomGATHeadLayer(nn.Module):
74 | def __init__(self, in_dim, out_dim, dropout, batch_norm):
75 | super().__init__()
76 | self.dropout = dropout
77 | self.batch_norm = batch_norm
78 |
79 | self.fc = nn.Linear(in_dim, out_dim, bias=False)
80 | self.attn_fc = nn.Linear(2 * out_dim, 1, bias=False)
81 | self.batchnorm_h = nn.BatchNorm1d(out_dim)
82 |
83 | def edge_attention(self, edges):
84 | z2 = torch.cat([edges.src['z'], edges.dst['z']], dim=1)
85 | a = self.attn_fc(z2)
86 | return {'e': F.leaky_relu(a)}
87 |
88 | def message_func(self, edges):
89 | return {'z': edges.src['z'], 'e': edges.data['e']}
90 |
91 | def reduce_func(self, nodes):
92 | alpha = F.softmax(nodes.mailbox['e'], dim=1)
93 | alpha = F.dropout(alpha, self.dropout, training=self.training)
94 | h = torch.sum(alpha * nodes.mailbox['z'], dim=1)
95 | return {'h': h}
96 |
97 | def forward(self, g, h):
98 | z = self.fc(h)
99 | g.ndata['z'] = z
100 | g.apply_edges(self.edge_attention)
101 | g.update_all(self.message_func, self.reduce_func)
102 | h = g.ndata['h']
103 |
104 | if self.batch_norm:
105 | h = self.batchnorm_h(h)
106 |
107 | h = F.elu(h)
108 |
109 | h = F.dropout(h, self.dropout, training=self.training)
110 |
111 | return h
112 |
113 |
114 | class CustomGATLayer(nn.Module):
115 | """
116 | Param: [in_dim, out_dim, n_heads]
117 | """
118 | def __init__(self, in_dim, out_dim, num_heads, dropout, batch_norm, residual=True):
119 | super().__init__()
120 |
121 | self.in_channels = in_dim
122 | self.out_channels = out_dim
123 | self.num_heads = num_heads
124 | self.residual = residual
125 |
126 | if in_dim != (out_dim*num_heads):
127 | self.residual = False
128 |
129 | self.heads = nn.ModuleList()
130 | for i in range(num_heads):
131 | self.heads.append(CustomGATHeadLayer(in_dim, out_dim, dropout, batch_norm))
132 | self.merge = 'cat'
133 |
134 | def forward(self, g, h, e):
135 | h_in = h # for residual connection
136 |
137 | head_outs = [attn_head(g, h) for attn_head in self.heads]
138 |
139 | if self.merge == 'cat':
140 | h = torch.cat(head_outs, dim=1)
141 | else:
142 | h = torch.mean(torch.stack(head_outs))
143 |
144 | if self.residual:
145 | h = h_in + h # residual connection
146 |
147 | return h, e
148 |
149 | def __repr__(self):
150 | return '{}(in_channels={}, out_channels={}, heads={}, residual={})'.format(self.__class__.__name__,
151 | self.in_channels,
152 | self.out_channels, self.num_heads, self.residual)
153 |
154 |
155 | ##############################################################
156 |
157 |
158 | class CustomGATHeadLayerEdgeReprFeat(nn.Module):
159 | def __init__(self, in_dim, out_dim, dropout, batch_norm):
160 | super().__init__()
161 | self.dropout = dropout
162 | self.batch_norm = batch_norm
163 |
164 | self.fc_h = nn.Linear(in_dim, out_dim, bias=False)
165 | self.fc_e = nn.Linear(in_dim, out_dim, bias=False)
166 | self.fc_proj = nn.Linear(3* out_dim, out_dim)
167 | self.attn_fc = nn.Linear(3* out_dim, 1, bias=False)
168 | self.batchnorm_h = nn.BatchNorm1d(out_dim)
169 | self.batchnorm_e = nn.BatchNorm1d(out_dim)
170 |
171 | def edge_attention(self, edges):
172 | z = torch.cat([edges.data['z_e'], edges.src['z_h'], edges.dst['z_h']], dim=1)
173 | e_proj = self.fc_proj(z)
174 | attn = F.leaky_relu(self.attn_fc(z))
175 | return {'attn': attn, 'e_proj': e_proj}
176 |
177 | def message_func(self, edges):
178 | return {'z': edges.src['z_h'], 'attn': edges.data['attn']}
179 |
180 | def reduce_func(self, nodes):
181 | alpha = F.softmax(nodes.mailbox['attn'], dim=1)
182 | h = torch.sum(alpha * nodes.mailbox['z'], dim=1)
183 | return {'h': h}
184 |
185 | def forward(self, g, h, e):
186 | z_h = self.fc_h(h)
187 | z_e = self.fc_e(e)
188 | g.ndata['z_h'] = z_h
189 | g.edata['z_e'] = z_e
190 |
191 | g.apply_edges(self.edge_attention)
192 |
193 | g.update_all(self.message_func, self.reduce_func)
194 |
195 | h = g.ndata['h']
196 | e = g.edata['e_proj']
197 |
198 | if self.batch_norm:
199 | h = self.batchnorm_h(h)
200 | e = self.batchnorm_e(e)
201 |
202 | h = F.elu(h)
203 | e = F.elu(e)
204 |
205 | h = F.dropout(h, self.dropout, training=self.training)
206 | e = F.dropout(e, self.dropout, training=self.training)
207 |
208 | return h, e
209 |
210 |
211 | class CustomGATLayerEdgeReprFeat(nn.Module):
212 | """
213 | Param: [in_dim, out_dim, n_heads]
214 | """
215 | def __init__(self, in_dim, out_dim, num_heads, dropout, batch_norm, residual=True):
216 | super().__init__()
217 |
218 | self.in_channels = in_dim
219 | self.out_channels = out_dim
220 | self.num_heads = num_heads
221 | self.residual = residual
222 |
223 | if in_dim != (out_dim*num_heads):
224 | self.residual = False
225 |
226 | self.heads = nn.ModuleList()
227 | for i in range(num_heads):
228 | self.heads.append(CustomGATHeadLayerEdgeReprFeat(in_dim, out_dim, dropout, batch_norm))
229 | self.merge = 'cat'
230 |
231 | def forward(self, g, h, e):
232 | h_in = h # for residual connection
233 | e_in = e
234 |
235 | head_outs_h = []
236 | head_outs_e = []
237 | for attn_head in self.heads:
238 | h_temp, e_temp = attn_head(g, h, e)
239 | head_outs_h.append(h_temp)
240 | head_outs_e.append(e_temp)
241 |
242 | if self.merge == 'cat':
243 | h = torch.cat(head_outs_h, dim=1)
244 | e = torch.cat(head_outs_e, dim=1)
245 | else:
246 | raise NotImplementedError
247 |
248 | if self.residual:
249 | h = h_in + h # residual connection
250 | e = e_in + e
251 |
252 | return h, e
253 |
254 | def __repr__(self):
255 | return '{}(in_channels={}, out_channels={}, heads={}, residual={})'.format(self.__class__.__name__,
256 | self.in_channels,
257 | self.out_channels, self.num_heads, self.residual)
258 |
259 |
260 | ##############################################################
261 |
262 |
263 | class CustomGATHeadLayerIsotropic(nn.Module):
264 | def __init__(self, in_dim, out_dim, dropout, batch_norm):
265 | super().__init__()
266 | self.dropout = dropout
267 | self.batch_norm = batch_norm
268 |
269 | self.fc = nn.Linear(in_dim, out_dim, bias=False)
270 | self.batchnorm_h = nn.BatchNorm1d(out_dim)
271 |
272 | def message_func(self, edges):
273 | return {'z': edges.src['z']}
274 |
275 | def reduce_func(self, nodes):
276 | h = torch.sum(nodes.mailbox['z'], dim=1)
277 | return {'h': h}
278 |
279 | def forward(self, g, h):
280 | z = self.fc(h)
281 | g.ndata['z'] = z
282 | g.update_all(self.message_func, self.reduce_func)
283 | h = g.ndata['h']
284 |
285 | if self.batch_norm:
286 | h = self.batchnorm_h(h)
287 |
288 | h = F.elu(h)
289 |
290 | h = F.dropout(h, self.dropout, training=self.training)
291 |
292 | return h
293 |
294 |
295 | class CustomGATLayerIsotropic(nn.Module):
296 | """
297 | Param: [in_dim, out_dim, n_heads]
298 | """
299 | def __init__(self, in_dim, out_dim, num_heads, dropout, batch_norm, residual=True):
300 | super().__init__()
301 |
302 | self.in_channels = in_dim
303 | self.out_channels = out_dim
304 | self.num_heads = num_heads
305 | self.residual = residual
306 |
307 | if in_dim != (out_dim*num_heads):
308 | self.residual = False
309 |
310 | self.heads = nn.ModuleList()
311 | for i in range(num_heads):
312 | self.heads.append(CustomGATHeadLayerIsotropic(in_dim, out_dim, dropout, batch_norm))
313 | self.merge = 'cat'
314 |
315 | def forward(self, g, h, e):
316 | h_in = h # for residual connection
317 |
318 | head_outs = [attn_head(g, h) for attn_head in self.heads]
319 |
320 | if self.merge == 'cat':
321 | h = torch.cat(head_outs, dim=1)
322 | else:
323 | h = torch.mean(torch.stack(head_outs))
324 |
325 | if self.residual:
326 | h = h_in + h # residual connection
327 |
328 | return h, e
329 |
330 | def __repr__(self):
331 | return '{}(in_channels={}, out_channels={}, heads={}, residual={})'.format(self.__class__.__name__,
332 | self.in_channels,
333 | self.out_channels, self.num_heads, self.residual)
334 |
--------------------------------------------------------------------------------
/layers/gated_gcn_layer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import dgl.function as fn
5 |
6 | """
7 | ResGatedGCN: Residual Gated Graph ConvNets
8 | An Experimental Study of Neural Networks for Variable Graphs (Xavier Bresson and Thomas Laurent, ICLR 2018)
9 | https://arxiv.org/pdf/1711.07553v2.pdf
10 | """
11 |
12 | class GatedGCNLayer(nn.Module):
13 | """
14 | Param: []
15 | """
16 | def __init__(self, input_dim, output_dim, dropout, batch_norm, residual=False):
17 | super().__init__()
18 | self.in_channels = input_dim
19 | self.out_channels = output_dim
20 | self.dropout = dropout
21 | self.batch_norm = batch_norm
22 | self.residual = residual
23 |
24 | if input_dim != output_dim:
25 | self.residual = False
26 |
27 | self.A = nn.Linear(input_dim, output_dim, bias=True)
28 | self.B = nn.Linear(input_dim, output_dim, bias=True)
29 | self.C = nn.Linear(input_dim, output_dim, bias=True)
30 | self.D = nn.Linear(input_dim, output_dim, bias=True)
31 | self.E = nn.Linear(input_dim, output_dim, bias=True)
32 | self.bn_node_h = nn.BatchNorm1d(output_dim)
33 | self.bn_node_e = nn.BatchNorm1d(output_dim)
34 |
35 | def forward(self, g, h, e):
36 |
37 | h_in = h # for residual connection
38 | e_in = e # for residual connection
39 |
40 | g.ndata['h'] = h
41 | g.ndata['Ah'] = self.A(h)
42 | g.ndata['Bh'] = self.B(h)
43 | g.ndata['Dh'] = self.D(h)
44 | g.ndata['Eh'] = self.E(h)
45 | g.edata['e'] = e
46 | g.edata['Ce'] = self.C(e)
47 |
48 | g.apply_edges(fn.u_add_v('Dh', 'Eh', 'DEh'))
49 | g.edata['e'] = g.edata['DEh'] + g.edata['Ce']
50 | g.edata['sigma'] = torch.sigmoid(g.edata['e'])
51 | g.update_all(fn.u_mul_e('Bh', 'sigma', 'm'), fn.sum('m', 'sum_sigma_h'))
52 | g.update_all(fn.copy_e('sigma', 'm'), fn.sum('m', 'sum_sigma'))
53 | g.ndata['h'] = g.ndata['Ah'] + g.ndata['sum_sigma_h'] / (g.ndata['sum_sigma'] + 1e-6)
54 | #g.update_all(self.message_func,self.reduce_func)
55 | h = g.ndata['h'] # result of graph convolution
56 | e = g.edata['e'] # result of graph convolution
57 |
58 | if self.batch_norm:
59 | h = self.bn_node_h(h) # batch normalization
60 | e = self.bn_node_e(e) # batch normalization
61 |
62 | h = F.relu(h) # non-linear activation
63 | e = F.relu(e) # non-linear activation
64 |
65 | if self.residual:
66 | h = h_in + h # residual connection
67 | e = e_in + e # residual connection
68 |
69 | h = F.dropout(h, self.dropout, training=self.training)
70 | e = F.dropout(e, self.dropout, training=self.training)
71 |
72 | return h, e
73 |
74 | def __repr__(self):
75 | return '{}(in_channels={}, out_channels={})'.format(self.__class__.__name__,
76 | self.in_channels,
77 | self.out_channels)
78 |
79 |
80 | ##############################################################
81 | #
82 | # Additional layers for edge feature/representation analysis
83 | #
84 | ##############################################################
85 |
86 |
87 | class GatedGCNLayerEdgeFeatOnly(nn.Module):
88 | """
89 | Param: []
90 | """
91 | def __init__(self, input_dim, output_dim, dropout, batch_norm, residual=False):
92 | super().__init__()
93 | self.in_channels = input_dim
94 | self.out_channels = output_dim
95 | self.dropout = dropout
96 | self.batch_norm = batch_norm
97 | self.residual = residual
98 |
99 | if input_dim != output_dim:
100 | self.residual = False
101 |
102 | self.A = nn.Linear(input_dim, output_dim, bias=True)
103 | self.B = nn.Linear(input_dim, output_dim, bias=True)
104 | self.D = nn.Linear(input_dim, output_dim, bias=True)
105 | self.E = nn.Linear(input_dim, output_dim, bias=True)
106 | self.bn_node_h = nn.BatchNorm1d(output_dim)
107 |
108 |
109 | def forward(self, g, h, e):
110 |
111 | h_in = h # for residual connection
112 |
113 | g.ndata['h'] = h
114 | g.ndata['Ah'] = self.A(h)
115 | g.ndata['Bh'] = self.B(h)
116 | g.ndata['Dh'] = self.D(h)
117 | g.ndata['Eh'] = self.E(h)
118 | #g.update_all(self.message_func,self.reduce_func)
119 | g.apply_edges(fn.u_add_v('Dh', 'Eh', 'e'))
120 | g.edata['sigma'] = torch.sigmoid(g.edata['e'])
121 | g.update_all(fn.u_mul_e('Bh', 'sigma', 'm'), fn.sum('m', 'sum_sigma_h'))
122 | g.update_all(fn.copy_e('sigma', 'm'), fn.sum('m', 'sum_sigma'))
123 | g.ndata['h'] = g.ndata['Ah'] + g.ndata['sum_sigma_h'] / (g.ndata['sum_sigma'] + 1e-6)
124 | h = g.ndata['h'] # result of graph convolution
125 |
126 | if self.batch_norm:
127 | h = self.bn_node_h(h) # batch normalization
128 |
129 | h = F.relu(h) # non-linear activation
130 |
131 | if self.residual:
132 | h = h_in + h # residual connection
133 |
134 | h = F.dropout(h, self.dropout, training=self.training)
135 |
136 | return h, e
137 |
138 | def __repr__(self):
139 | return '{}(in_channels={}, out_channels={})'.format(self.__class__.__name__,
140 | self.in_channels,
141 | self.out_channels)
142 |
143 |
144 | ##############################################################
145 |
146 |
147 | class GatedGCNLayerIsotropic(nn.Module):
148 | """
149 | Param: []
150 | """
151 | def __init__(self, input_dim, output_dim, dropout, batch_norm, residual=False):
152 | super().__init__()
153 | self.in_channels = input_dim
154 | self.out_channels = output_dim
155 | self.dropout = dropout
156 | self.batch_norm = batch_norm
157 | self.residual = residual
158 |
159 | if input_dim != output_dim:
160 | self.residual = False
161 |
162 | self.A = nn.Linear(input_dim, output_dim, bias=True)
163 | self.B = nn.Linear(input_dim, output_dim, bias=True)
164 | self.bn_node_h = nn.BatchNorm1d(output_dim)
165 |
166 |
167 | def forward(self, g, h, e):
168 |
169 | h_in = h # for residual connection
170 |
171 | g.ndata['h'] = h
172 | g.ndata['Ah'] = self.A(h)
173 | g.ndata['Bh'] = self.B(h)
174 | #g.update_all(self.message_func,self.reduce_func)
175 | g.update_all(fn.copy_u('Bh', 'm'), fn.sum('m', 'sum_h'))
176 | g.ndata['h'] = g.ndata['Ah'] + g.ndata['sum_h']
177 | h = g.ndata['h'] # result of graph convolution
178 |
179 | if self.batch_norm:
180 | h = self.bn_node_h(h) # batch normalization
181 |
182 | h = F.relu(h) # non-linear activation
183 |
184 | if self.residual:
185 | h = h_in + h # residual connection
186 |
187 | h = F.dropout(h, self.dropout, training=self.training)
188 |
189 | return h, e
190 |
191 | def __repr__(self):
192 | return '{}(in_channels={}, out_channels={})'.format(self.__class__.__name__,
193 | self.in_channels,
194 | self.out_channels)
195 |
196 |
--------------------------------------------------------------------------------
/layers/gcn_layer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | import dgl
6 | import dgl.function as fn
7 | from dgl.nn.pytorch import GraphConv
8 |
9 | """
10 | GCN: Graph Convolutional Networks
11 | Thomas N. Kipf, Max Welling, Semi-Supervised Classification with Graph Convolutional Networks (ICLR 2017)
12 | http://arxiv.org/abs/1609.02907
13 | """
14 |
15 | # Sends a message of node feature h
16 | # Equivalent to => return {'m': edges.src['h']}
17 | msg = fn.copy_src(src='h', out='m')
18 | reduce = fn.mean('m', 'h')
19 |
20 | class NodeApplyModule(nn.Module):
21 | # Update node feature h_v with (Wh_v+b)
22 | def __init__(self, in_dim, out_dim):
23 | super().__init__()
24 | self.linear = nn.Linear(in_dim, out_dim)
25 |
26 | def forward(self, node):
27 | h = self.linear(node.data['h'])
28 | return {'h': h}
29 |
30 | class GCNLayer(nn.Module):
31 | """
32 | Param: [in_dim, out_dim]
33 | """
34 | def __init__(self, in_dim, out_dim, activation, dropout, batch_norm, residual=False, dgl_builtin=False):
35 | super().__init__()
36 | self.in_channels = in_dim
37 | self.out_channels = out_dim
38 | self.batch_norm = batch_norm
39 | self.residual = residual
40 | self.dgl_builtin = dgl_builtin
41 |
42 | if in_dim != out_dim:
43 | self.residual = False
44 |
45 | self.batchnorm_h = nn.BatchNorm1d(out_dim)
46 | self.activation = activation
47 | self.dropout = nn.Dropout(dropout)
48 | if self.dgl_builtin == False:
49 | self.apply_mod = NodeApplyModule(in_dim, out_dim)
50 | elif dgl.__version__ < "0.5":
51 | self.conv = GraphConv(in_dim, out_dim)
52 | else:
53 | self.conv = GraphConv(in_dim, out_dim, allow_zero_in_degree=True)
54 |
55 |
56 | def forward(self, g, feature):
57 | h_in = feature # to be used for residual connection
58 |
59 | if self.dgl_builtin == False:
60 | g.ndata['h'] = feature
61 | g.update_all(msg, reduce)
62 | g.apply_nodes(func=self.apply_mod)
63 | h = g.ndata['h'] # result of graph convolution
64 | else:
65 | h = self.conv(g, feature)
66 |
67 | if self.batch_norm:
68 | h = self.batchnorm_h(h) # batch normalization
69 |
70 | if self.activation:
71 | h = self.activation(h)
72 |
73 | if self.residual:
74 | h = h_in + h # residual connection
75 |
76 | h = self.dropout(h)
77 | return h
78 |
79 | def __repr__(self):
80 | return '{}(in_channels={}, out_channels={}, residual={})'.format(self.__class__.__name__,
81 | self.in_channels,
82 | self.out_channels, self.residual)
--------------------------------------------------------------------------------
/layers/gin_layer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import dgl.function as fn
5 |
6 | """
7 | GIN: Graph Isomorphism Networks
8 | HOW POWERFUL ARE GRAPH NEURAL NETWORKS? (Keyulu Xu, Weihua Hu, Jure Leskovec and Stefanie Jegelka, ICLR 2019)
9 | https://arxiv.org/pdf/1810.00826.pdf
10 | """
11 |
12 | class GINLayer(nn.Module):
13 | """
14 | [!] code adapted from dgl implementation of GINConv
15 |
16 | Parameters
17 | ----------
18 | apply_func : callable activation function/layer or None
19 | If not None, apply this function to the updated node feature,
20 | the :math:`f_\Theta` in the formula.
21 | aggr_type :
22 | Aggregator type to use (``sum``, ``max`` or ``mean``).
23 | out_dim :
24 | Rquired for batch norm layer; should match out_dim of apply_func if not None.
25 | dropout :
26 | Required for dropout of output features.
27 | batch_norm :
28 | boolean flag for batch_norm layer.
29 | residual :
30 | boolean flag for using residual connection.
31 | init_eps : optional
32 | Initial :math:`\epsilon` value, default: ``0``.
33 | learn_eps : bool, optional
34 | If True, :math:`\epsilon` will be a learnable parameter.
35 |
36 | """
37 | def __init__(self, apply_func, aggr_type, dropout, batch_norm, residual=False, init_eps=0, learn_eps=False):
38 | super().__init__()
39 | self.apply_func = apply_func
40 |
41 | if aggr_type == 'sum':
42 | self._reducer = fn.sum
43 | elif aggr_type == 'max':
44 | self._reducer = fn.max
45 | elif aggr_type == 'mean':
46 | self._reducer = fn.mean
47 | else:
48 | raise KeyError('Aggregator type {} not recognized.'.format(aggr_type))
49 |
50 | self.batch_norm = batch_norm
51 | self.residual = residual
52 | self.dropout = dropout
53 |
54 | in_dim = apply_func.mlp.input_dim
55 | out_dim = apply_func.mlp.output_dim
56 |
57 | if in_dim != out_dim:
58 | self.residual = False
59 |
60 | # to specify whether eps is trainable or not.
61 | if learn_eps:
62 | self.eps = torch.nn.Parameter(torch.FloatTensor([init_eps]))
63 | else:
64 | self.register_buffer('eps', torch.FloatTensor([init_eps]))
65 |
66 | self.bn_node_h = nn.BatchNorm1d(out_dim)
67 |
68 | def forward(self, g, h):
69 | h_in = h # for residual connection
70 |
71 | g = g.local_var()
72 | g.ndata['h'] = h
73 | g.update_all(fn.copy_u('h', 'm'), self._reducer('m', 'neigh'))
74 | h = (1 + self.eps) * h + g.ndata['neigh']
75 | if self.apply_func is not None:
76 | h = self.apply_func(h)
77 |
78 | if self.batch_norm:
79 | h = self.bn_node_h(h) # batch normalization
80 |
81 | h = F.relu(h) # non-linear activation
82 |
83 | if self.residual:
84 | h = h_in + h # residual connection
85 |
86 | h = F.dropout(h, self.dropout, training=self.training)
87 |
88 | return h
89 |
90 |
91 | class ApplyNodeFunc(nn.Module):
92 | """
93 | This class is used in class GINNet
94 | Update the node feature hv with MLP
95 | """
96 | def __init__(self, mlp):
97 | super().__init__()
98 | self.mlp = mlp
99 |
100 | def forward(self, h):
101 | h = self.mlp(h)
102 | return h
103 |
104 |
105 | class MLP(nn.Module):
106 | """MLP with linear output"""
107 | def __init__(self, num_layers, input_dim, hidden_dim, output_dim):
108 |
109 | super().__init__()
110 | self.linear_or_not = True # default is linear model
111 | self.num_layers = num_layers
112 | self.output_dim = output_dim
113 | self.input_dim = input_dim
114 |
115 | if num_layers < 1:
116 | raise ValueError("number of layers should be positive!")
117 | elif num_layers == 1:
118 | # Linear model
119 | self.linear = nn.Linear(input_dim, output_dim)
120 | else:
121 | # Multi-layer model
122 | self.linear_or_not = False
123 | self.linears = torch.nn.ModuleList()
124 | self.batch_norms = torch.nn.ModuleList()
125 |
126 | self.linears.append(nn.Linear(input_dim, hidden_dim))
127 | for layer in range(num_layers - 2):
128 | self.linears.append(nn.Linear(hidden_dim, hidden_dim))
129 | self.linears.append(nn.Linear(hidden_dim, output_dim))
130 |
131 | for layer in range(num_layers - 1):
132 | self.batch_norms.append(nn.BatchNorm1d((hidden_dim)))
133 |
134 | def forward(self, x):
135 | if self.linear_or_not:
136 | # If linear model
137 | return self.linear(x)
138 | else:
139 | # If MLP
140 | h = x
141 | for i in range(self.num_layers - 1):
142 | h = F.relu(self.batch_norms[i](self.linears[i](h)))
143 | return self.linears[-1](h)
144 |
--------------------------------------------------------------------------------
/layers/gmm_layer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torch.nn import init
5 | import dgl.function as fn
6 |
7 | """
8 | GMM: Gaussian Mixture Model Convolution layer
9 | Geometric Deep Learning on Graphs and Manifolds using Mixture Model CNNs (Federico Monti et al., CVPR 2017)
10 | https://arxiv.org/pdf/1611.08402.pdf
11 | """
12 |
13 | class GMMLayer(nn.Module):
14 | """
15 | [!] code adapted from dgl implementation of GMMConv
16 |
17 | Parameters
18 | ----------
19 | in_dim :
20 | Number of input features.
21 | out_dim :
22 | Number of output features.
23 | dim :
24 | Dimensionality of pseudo-coordinte.
25 | kernel :
26 | Number of kernels :math:`K`.
27 | aggr_type :
28 | Aggregator type (``sum``, ``mean``, ``max``).
29 | dropout :
30 | Required for dropout of output features.
31 | batch_norm :
32 | boolean flag for batch_norm layer.
33 | residual :
34 | If True, use residual connection inside this layer. Default: ``False``.
35 | bias :
36 | If True, adds a learnable bias to the output. Default: ``True``.
37 |
38 | """
39 | def __init__(self, in_dim, out_dim, dim, kernel, aggr_type, dropout,
40 | batch_norm, residual=False, bias=True):
41 | super().__init__()
42 |
43 | self.in_dim = in_dim
44 | self.out_dim = out_dim
45 | self.dim = dim
46 | self.kernel = kernel
47 | self.batch_norm = batch_norm
48 | self.residual = residual
49 | self.dropout = dropout
50 |
51 | if aggr_type == 'sum':
52 | self._reducer = fn.sum
53 | elif aggr_type == 'mean':
54 | self._reducer = fn.mean
55 | elif aggr_type == 'max':
56 | self._reducer = fn.max
57 | else:
58 | raise KeyError("Aggregator type {} not recognized.".format(aggr_type))
59 |
60 | self.mu = nn.Parameter(torch.Tensor(kernel, dim))
61 | self.inv_sigma = nn.Parameter(torch.Tensor(kernel, dim))
62 | self.fc = nn.Linear(in_dim, kernel * out_dim, bias=False)
63 |
64 | self.bn_node_h = nn.BatchNorm1d(out_dim)
65 |
66 | if in_dim != out_dim:
67 | self.residual = False
68 |
69 | if bias:
70 | self.bias = nn.Parameter(torch.Tensor(out_dim))
71 | else:
72 | self.register_buffer('bias', None)
73 | self.reset_parameters()
74 |
75 | def reset_parameters(self):
76 | """Reinitialize learnable parameters."""
77 | gain = init.calculate_gain('relu')
78 | init.xavier_normal_(self.fc.weight, gain=gain)
79 | init.normal_(self.mu.data, 0, 0.1)
80 | init.constant_(self.inv_sigma.data, 1)
81 | if self.bias is not None:
82 | init.zeros_(self.bias.data)
83 |
84 | def forward(self, g, h, pseudo):
85 | h_in = h # for residual connection
86 |
87 | g = g.local_var()
88 | g.ndata['h'] = self.fc(h).view(-1, self.kernel, self.out_dim)
89 | E = g.number_of_edges()
90 |
91 | # compute gaussian weight
92 | gaussian = -0.5 * ((pseudo.view(E, 1, self.dim) -
93 | self.mu.view(1, self.kernel, self.dim)) ** 2)
94 | gaussian = gaussian * (self.inv_sigma.view(1, self.kernel, self.dim) ** 2)
95 | gaussian = torch.exp(gaussian.sum(dim=-1, keepdim=True)) # (E, K, 1)
96 | g.edata['w'] = gaussian
97 | g.update_all(fn.u_mul_e('h', 'w', 'm'), self._reducer('m', 'h'))
98 | h = g.ndata['h'].sum(1)
99 |
100 | if self.batch_norm:
101 | h = self.bn_node_h(h) # batch normalization
102 |
103 | h = F.relu(h) # non-linear activation
104 |
105 | if self.residual:
106 | h = h_in + h # residual connection
107 |
108 | if self.bias is not None:
109 | h = h + self.bias
110 |
111 | h = F.dropout(h, self.dropout, training=self.training)
112 |
113 | return h
114 |
--------------------------------------------------------------------------------
/layers/mlp_readout_layer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | """
6 | MLP Layer used after graph vector representation
7 | """
8 |
9 | class MLPReadout(nn.Module):
10 |
11 | def __init__(self, input_dim, output_dim, L=2): #L=nb_hidden_layers
12 | super().__init__()
13 | list_FC_layers = [ nn.Linear( input_dim//2**l , input_dim//2**(l+1) , bias=True ) for l in range(L) ]
14 | list_FC_layers.append(nn.Linear( input_dim//2**L , output_dim , bias=True ))
15 | self.FC_layers = nn.ModuleList(list_FC_layers)
16 | self.L = L
17 |
18 | def forward(self, x):
19 | y = x
20 | for l in range(self.L):
21 | y = self.FC_layers[l](y)
22 | y = F.relu(y)
23 | y = self.FC_layers[self.L](y)
24 | return y
--------------------------------------------------------------------------------
/layers/ring_gnn_equiv_layer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | """
6 | Ring-GNN equi 2 to 2 layer file
7 | On the equivalence between graph isomorphism testing and function approximation with GNNs (Chen et al, 2019)
8 | https://arxiv.org/pdf/1905.12560v1.pdf
9 |
10 | CODE ADPATED FROM https://github.com/leichen2018/Ring-GNN/
11 | """
12 |
13 | class RingGNNEquivLayer(nn.Module):
14 | def __init__(self, device, input_dim, output_dim, layer_norm, residual, dropout,
15 | normalization='inf', normalization_val=1.0, radius=2, k2_init = 0.1):
16 | super().__init__()
17 | self.device = device
18 | basis_dimension = 15
19 | self.radius = radius
20 | self.layer_norm = layer_norm
21 | self.residual = residual
22 | self.dropout = dropout
23 |
24 | coeffs_values = lambda i, j, k: torch.randn([i, j, k]) * torch.sqrt(2. / (i + j).float())
25 | self.diag_bias_list = nn.ParameterList([])
26 |
27 | for i in range(radius):
28 | for j in range(i+1):
29 | self.diag_bias_list.append(nn.Parameter(torch.zeros(1, output_dim, 1, 1)))
30 |
31 | self.all_bias = nn.Parameter(torch.zeros(1, output_dim, 1, 1))
32 | self.coeffs_list = nn.ParameterList([])
33 |
34 | for i in range(radius):
35 | for j in range(i+1):
36 | self.coeffs_list.append(nn.Parameter(coeffs_values(input_dim, output_dim, basis_dimension)))
37 |
38 | self.switch = nn.ParameterList([nn.Parameter(torch.FloatTensor([1])), nn.Parameter(torch.FloatTensor([k2_init]))])
39 | self.output_dim = output_dim
40 |
41 | self.normalization = normalization
42 | self.normalization_val = normalization_val
43 |
44 | if self.layer_norm:
45 | self.ln_x = LayerNorm(output_dim.item())
46 |
47 | if self.residual:
48 | self.res_x = nn.Linear(input_dim.item(), output_dim.item())
49 |
50 | def forward(self, inputs):
51 | m = inputs.size()[3]
52 |
53 | ops_out = ops_2_to_2(inputs, m, normalization=self.normalization)
54 | ops_out = torch.stack(ops_out, dim = 2)
55 |
56 |
57 | output_list = []
58 |
59 | for i in range(self.radius):
60 | for j in range(i+1):
61 | output_i = torch.einsum('dsb,ndbij->nsij', self.coeffs_list[i*(i+1)//2 + j], ops_out)
62 |
63 | mat_diag_bias = torch.eye(inputs.size()[3]).unsqueeze(0).unsqueeze(0).to(self.device) * self.diag_bias_list[i*(i+1)//2 + j]
64 | # mat_diag_bias = torch.eye(inputs.size()[3]).to('cuda:0').unsqueeze(0).unsqueeze(0) * self.diag_bias_list[i*(i+1)//2 + j]
65 | if j == 0:
66 | output = output_i + mat_diag_bias
67 | else:
68 | output = torch.einsum('abcd,abde->abce', output_i, output)
69 | output_list.append(output)
70 |
71 | output = 0
72 | for i in range(self.radius):
73 | output += output_list[i] * self.switch[i]
74 |
75 | output = output + self.all_bias
76 |
77 | if self.layer_norm:
78 | # Now, changing shapes from [1xdxnxn] to [nxnxd] for BN
79 | output = output.permute(3,2,1,0).squeeze()
80 |
81 | # output = self.bn_x(output.reshape(m*m, self.output_dim.item())) # batch normalization
82 | output = self.ln_x(output) # layer normalization
83 |
84 | # Returning output back to original shape
85 | output = output.reshape(m, m, self.output_dim.item())
86 | output = output.permute(2,1,0).unsqueeze(0)
87 |
88 | output = F.relu(output) # non-linear activation
89 |
90 | if self.residual:
91 | # Now, changing shapes from [1xdxnxn] to [nxnxd] for Linear() layer
92 | inputs, output = inputs.permute(3,2,1,0).squeeze(), output.permute(3,2,1,0).squeeze()
93 |
94 | residual_ = self.res_x(inputs)
95 | output = residual_ + output # residual connection
96 |
97 | # Returning output back to original shape
98 | output = output.permute(2,1,0).unsqueeze(0)
99 |
100 | output = F.dropout(output, self.dropout, training=self.training)
101 |
102 | return output
103 |
104 |
105 | def ops_2_to_2(inputs, dim, normalization='inf', normalization_val=1.0): # N x D x m x m
106 | # input: N x D x m x m
107 | diag_part = torch.diagonal(inputs, dim1 = 2, dim2 = 3) # N x D x m
108 | sum_diag_part = torch.sum(diag_part, dim=2, keepdim = True) # N x D x 1
109 | sum_of_rows = torch.sum(inputs, dim=3) # N x D x m
110 | sum_of_cols = torch.sum(inputs, dim=2) # N x D x m
111 | sum_all = torch.sum(sum_of_rows, dim=2) # N x D
112 |
113 | # op1 - (1234) - extract diag
114 | op1 = torch.diag_embed(diag_part) # N x D x m x m
115 |
116 | # op2 - (1234) + (12)(34) - place sum of diag on diag
117 | op2 = torch.diag_embed(sum_diag_part.repeat(1, 1, dim))
118 |
119 | # op3 - (1234) + (123)(4) - place sum of row i on diag ii
120 | op3 = torch.diag_embed(sum_of_rows)
121 |
122 | # op4 - (1234) + (124)(3) - place sum of col i on diag ii
123 | op4 = torch.diag_embed(sum_of_cols)
124 |
125 | # op5 - (1234) + (124)(3) + (123)(4) + (12)(34) + (12)(3)(4) - place sum of all entries on diag
126 | op5 = torch.diag_embed(sum_all.unsqueeze(2).repeat(1, 1, dim))
127 |
128 | # op6 - (14)(23) + (13)(24) + (24)(1)(3) + (124)(3) + (1234) - place sum of col i on row i
129 | op6 = sum_of_cols.unsqueeze(3).repeat(1, 1, 1, dim)
130 |
131 | # op7 - (14)(23) + (23)(1)(4) + (234)(1) + (123)(4) + (1234) - place sum of row i on row i
132 | op7 = sum_of_rows.unsqueeze(3).repeat(1, 1, 1, dim)
133 |
134 | # op8 - (14)(2)(3) + (134)(2) + (14)(23) + (124)(3) + (1234) - place sum of col i on col i
135 | op8 = sum_of_cols.unsqueeze(2).repeat(1, 1, dim, 1)
136 |
137 | # op9 - (13)(24) + (13)(2)(4) + (134)(2) + (123)(4) + (1234) - place sum of row i on col i
138 | op9 = sum_of_rows.unsqueeze(2).repeat(1, 1, dim, 1)
139 |
140 | # op10 - (1234) + (14)(23) - identity
141 | op10 = inputs
142 |
143 | # op11 - (1234) + (13)(24) - transpose
144 | op11 = torch.transpose(inputs, -2, -1)
145 |
146 | # op12 - (1234) + (234)(1) - place ii element in row i
147 | op12 = diag_part.unsqueeze(3).repeat(1, 1, 1, dim)
148 |
149 | # op13 - (1234) + (134)(2) - place ii element in col i
150 | op13 = diag_part.unsqueeze(2).repeat(1, 1, dim, 1)
151 |
152 | # op14 - (34)(1)(2) + (234)(1) + (134)(2) + (1234) + (12)(34) - place sum of diag in all entries
153 | op14 = sum_diag_part.unsqueeze(3).repeat(1, 1, dim, dim)
154 |
155 | # op15 - sum of all ops - place sum of all entries in all entries
156 | op15 = sum_all.unsqueeze(2).unsqueeze(3).repeat(1, 1, dim, dim)
157 |
158 | #A_2 = torch.einsum('abcd,abde->abce', inputs, inputs)
159 | #A_4 = torch.einsum('abcd,abde->abce', A_2, A_2)
160 | #op16 = torch.where(A_4>1, torch.ones(A_4.size()), A_4)
161 |
162 | if normalization is not None:
163 | float_dim = float(dim)
164 | if normalization is 'inf':
165 | op2 = torch.div(op2, float_dim)
166 | op3 = torch.div(op3, float_dim)
167 | op4 = torch.div(op4, float_dim)
168 | op5 = torch.div(op5, float_dim**2)
169 | op6 = torch.div(op6, float_dim)
170 | op7 = torch.div(op7, float_dim)
171 | op8 = torch.div(op8, float_dim)
172 | op9 = torch.div(op9, float_dim)
173 | op14 = torch.div(op14, float_dim)
174 | op15 = torch.div(op15, float_dim**2)
175 |
176 | #return [op1, op2, op3, op4, op5, op6, op7, op8, op9, op10, op11, op12, op13, op14, op15, op16]
177 | '''
178 | l = [op1, op2, op3, op4, op5, op6, op7, op8, op9, op10, op11, op12, op13, op14, op15]
179 | for i, ls in enumerate(l):
180 | print(i+1)
181 | print(torch.sum(ls))
182 | print("$%^&*(*&^%$#$%^&*(*&^%$%^&*(*&^%$%^&*(")
183 | '''
184 | return [op1, op2, op3, op4, op5, op6, op7, op8, op9, op10, op11, op12, op13, op14, op15]
185 |
186 |
187 | class LayerNorm(nn.Module):
188 | def __init__(self, d):
189 | super().__init__()
190 | self.a = nn.Parameter(torch.ones(d).unsqueeze(0).unsqueeze(0)) # shape is 1 x 1 x d
191 | self.b = nn.Parameter(torch.zeros(d).unsqueeze(0).unsqueeze(0)) # shape is 1 x 1 x d
192 |
193 | def forward(self, x):
194 | # x tensor of the shape n x n x d
195 | mean = x.mean(dim=(0,1), keepdim=True)
196 | var = x.var(dim=(0,1), keepdim=True, unbiased=False)
197 | x = self.a * (x - mean) / torch.sqrt(var + 1e-6) + self.b # shape is n x n x d
198 | return x
199 |
200 |
--------------------------------------------------------------------------------
/layers/three_wl_gnn_layers.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | """
6 | Layers used for
7 | 3WLGNN
8 | Provably Powerful Graph Networks (Maron et al., 2019)
9 | https://papers.nips.cc/paper/8488-provably-powerful-graph-networks.pdf
10 |
11 | CODE adapted from https://github.com/hadarser/ProvablyPowerfulGraphNetworks_torch/
12 | """
13 |
14 | class RegularBlock(nn.Module):
15 | """
16 | Imputs: N x input_depth x m x m
17 | Take the input through 2 parallel MLP routes, multiply the result, and add a skip-connection at the end.
18 | At the skip-connection, reduce the dimension back to output_depth
19 | """
20 | def __init__(self, depth_of_mlp, in_features, out_features, residual=False):
21 | super().__init__()
22 |
23 | self.residual = residual
24 |
25 | self.mlp1 = MlpBlock(in_features, out_features, depth_of_mlp)
26 | self.mlp2 = MlpBlock(in_features, out_features, depth_of_mlp)
27 |
28 | self.skip = SkipConnection(in_features+out_features, out_features)
29 |
30 | if self.residual:
31 | self.res_x = nn.Linear(in_features, out_features)
32 |
33 | def forward(self, inputs):
34 | mlp1 = self.mlp1(inputs)
35 | mlp2 = self.mlp2(inputs)
36 |
37 | mult = torch.matmul(mlp1, mlp2)
38 |
39 | out = self.skip(in1=inputs, in2=mult)
40 |
41 | if self.residual:
42 | # Now, changing shapes from [1xdxnxn] to [nxnxd] for Linear() layer
43 | inputs, out = inputs.permute(3,2,1,0).squeeze(), out.permute(3,2,1,0).squeeze()
44 |
45 | residual_ = self.res_x(inputs)
46 | out = residual_ + out # residual connection
47 |
48 | # Returning output back to original shape
49 | out = out.permute(2,1,0).unsqueeze(0)
50 |
51 | return out
52 |
53 |
54 | class MlpBlock(nn.Module):
55 | """
56 | Block of MLP layers with activation function after each (1x1 conv layers).
57 | """
58 | def __init__(self, in_features, out_features, depth_of_mlp, activation_fn=nn.functional.relu):
59 | super().__init__()
60 | self.activation = activation_fn
61 | self.convs = nn.ModuleList()
62 | for i in range(depth_of_mlp):
63 | self.convs.append(nn.Conv2d(in_features, out_features, kernel_size=1, padding=0, bias=True))
64 | _init_weights(self.convs[-1])
65 | in_features = out_features
66 |
67 | def forward(self, inputs):
68 | out = inputs
69 | for conv_layer in self.convs:
70 | out = self.activation(conv_layer(out))
71 |
72 | return out
73 |
74 |
75 | class SkipConnection(nn.Module):
76 | """
77 | Connects the two given inputs with concatenation
78 | :param in1: earlier input tensor of shape N x d1 x m x m
79 | :param in2: later input tensor of shape N x d2 x m x m
80 | :param in_features: d1+d2
81 | :param out_features: output num of features
82 | :return: Tensor of shape N x output_depth x m x m
83 | """
84 | def __init__(self, in_features, out_features):
85 | super().__init__()
86 | self.conv = nn.Conv2d(in_features, out_features, kernel_size=1, padding=0, bias=True)
87 | _init_weights(self.conv)
88 |
89 | def forward(self, in1, in2):
90 | # in1: N x d1 x m x m
91 | # in2: N x d2 x m x m
92 | out = torch.cat((in1, in2), dim=1)
93 | out = self.conv(out)
94 | return out
95 |
96 |
97 | class FullyConnected(nn.Module):
98 | def __init__(self, in_features, out_features, activation_fn=nn.functional.relu):
99 | super().__init__()
100 |
101 | self.fc = nn.Linear(in_features, out_features)
102 | _init_weights(self.fc)
103 |
104 | self.activation = activation_fn
105 |
106 | def forward(self, input):
107 | out = self.fc(input)
108 | if self.activation is not None:
109 | out = self.activation(out)
110 |
111 | return out
112 |
113 |
114 | def diag_offdiag_maxpool(input):
115 | N = input.shape[-1]
116 |
117 | max_diag = torch.max(torch.diagonal(input, dim1=-2, dim2=-1), dim=2)[0] # BxS
118 |
119 | # with torch.no_grad():
120 | max_val = torch.max(max_diag)
121 | min_val = torch.max(-1 * input)
122 | val = torch.abs(torch.add(max_val, min_val))
123 |
124 | min_mat = torch.mul(val, torch.eye(N, device=input.device)).view(1, 1, N, N)
125 |
126 | max_offdiag = torch.max(torch.max(input - min_mat, dim=3)[0], dim=2)[0] # BxS
127 |
128 | return torch.cat((max_diag, max_offdiag), dim=1) # output Bx2S
129 |
130 | def _init_weights(layer):
131 | """
132 | Init weights of the layer
133 | :param layer:
134 | :return:
135 | """
136 | nn.init.xavier_uniform_(layer.weight)
137 | # nn.init.xavier_normal_(layer.weight)
138 | if layer.bias is not None:
139 | nn.init.zeros_(layer.bias)
140 |
141 |
142 | class LayerNorm(nn.Module):
143 | def __init__(self, d):
144 | super().__init__()
145 | self.a = nn.Parameter(torch.ones(d).unsqueeze(0).unsqueeze(0)) # shape is 1 x 1 x d
146 | self.b = nn.Parameter(torch.zeros(d).unsqueeze(0).unsqueeze(0)) # shape is 1 x 1 x d
147 |
148 | def forward(self, x):
149 | # x tensor of the shape n x n x d
150 | mean = x.mean(dim=(0,1), keepdim=True)
151 | var = x.var(dim=(0,1), keepdim=True, unbiased=False)
152 | x = self.a * (x - mean) / torch.sqrt(var + 1e-6) + self.b # shape is n x n x d
153 | return x
154 |
155 |
--------------------------------------------------------------------------------
/metrics.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from sklearn.metrics import confusion_matrix
6 | from sklearn.metrics import f1_score
7 | import numpy as np
8 |
9 |
10 | def MAE(scores, targets):
11 | MAE = F.l1_loss(scores, targets)
12 | MAE = MAE.detach().item()
13 | return MAE
14 |
15 |
16 | def accuracy_TU(scores, targets):
17 | scores = scores.detach().argmax(dim=1)
18 | acc = (scores==targets).float().sum().item()
19 | return acc
20 |
21 |
22 | def accuracy_MNIST_CIFAR(scores, targets):
23 | scores = scores.detach().argmax(dim=1)
24 | acc = (scores==targets).float().sum().item()
25 | return acc
26 |
27 | def accuracy_CITATION_GRAPH(scores, targets):
28 | scores = scores.detach().argmax(dim=1)
29 | acc = (scores==targets).float().sum().item()
30 | acc = acc / len(targets)
31 | return acc
32 |
33 |
34 | def accuracy_SBM(scores, targets):
35 | S = targets.cpu().numpy()
36 | C = np.argmax( torch.nn.Softmax(dim=1)(scores).cpu().detach().numpy() , axis=1 )
37 | CM = confusion_matrix(S,C).astype(np.float32)
38 | nb_classes = CM.shape[0]
39 | targets = targets.cpu().detach().numpy()
40 | nb_non_empty_classes = 0
41 | pr_classes = np.zeros(nb_classes)
42 | for r in range(nb_classes):
43 | cluster = np.where(targets==r)[0]
44 | if cluster.shape[0] != 0:
45 | pr_classes[r] = CM[r,r]/ float(cluster.shape[0])
46 | if CM[r,r]>0:
47 | nb_non_empty_classes += 1
48 | else:
49 | pr_classes[r] = 0.0
50 | acc = 100.* np.sum(pr_classes)/ float(nb_classes)
51 | return acc
52 |
53 |
54 | def binary_f1_score(scores, targets):
55 | """Computes the F1 score using scikit-learn for binary class labels.
56 |
57 | Returns the F1 score for the positive class, i.e. labelled '1'.
58 | """
59 | y_true = targets.cpu().numpy()
60 | y_pred = scores.argmax(dim=1).cpu().numpy()
61 | return f1_score(y_true, y_pred, average='binary')
62 |
63 |
64 | def accuracy_VOC(scores, targets):
65 | scores = scores.detach().argmax(dim=1).cpu()
66 | targets = targets.cpu().detach().numpy()
67 | acc = f1_score(scores, targets, average='weighted')
68 | return acc
69 |
--------------------------------------------------------------------------------
/nets/gat_net.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | import dgl
6 |
7 | """
8 | GAT: Graph Attention Network
9 | Graph Attention Networks (Veličković et al., ICLR 2018)
10 | https://arxiv.org/abs/1710.10903
11 | """
12 | from layers.gat_layer import GATLayer
13 | from layers.mlp_readout_layer import MLPReadout
14 |
15 | class GATNet(nn.Module):
16 | def __init__(self, net_params):
17 | super().__init__()
18 | in_dim = net_params['in_dim']
19 | hidden_dim = net_params['hidden_dim']
20 | out_dim = net_params['out_dim']
21 | n_classes = net_params['n_classes']
22 | num_heads = net_params['n_heads']
23 | in_feat_dropout = net_params['in_feat_dropout']
24 | dropout = net_params['dropout']
25 | n_layers = net_params['L']
26 | self.readout = net_params['readout']
27 | self.batch_norm = net_params['batch_norm']
28 | self.residual = net_params['residual']
29 |
30 | self.dropout = dropout
31 |
32 | self.embedding_h = nn.Linear(in_dim, hidden_dim * num_heads)
33 | self.in_feat_dropout = nn.Dropout(in_feat_dropout)
34 |
35 | self.layers = nn.ModuleList([GATLayer(hidden_dim * num_heads, hidden_dim, num_heads,
36 | dropout, self.batch_norm, self.residual) for _ in range(n_layers-1)])
37 | self.layers.append(GATLayer(hidden_dim * num_heads, out_dim, 1, dropout, self.batch_norm, self.residual))
38 | self.MLP_layer = MLPReadout(out_dim, n_classes)
39 |
40 | def forward(self, g, h, e):
41 | h = self.embedding_h(h)
42 | h = self.in_feat_dropout(h)
43 | for conv in self.layers:
44 | h = conv(g, h)
45 | g.ndata['h'] = h
46 |
47 | if self.readout == "sum":
48 | hg = dgl.sum_nodes(g, 'h')
49 | elif self.readout == "max":
50 | hg = dgl.max_nodes(g, 'h')
51 | elif self.readout == "mean":
52 | hg = dgl.mean_nodes(g, 'h')
53 | else:
54 | hg = dgl.mean_nodes(g, 'h') # default readout is mean nodes
55 |
56 | return self.MLP_layer(hg)
57 |
58 | def loss(self, pred, label):
59 | criterion = nn.CrossEntropyLoss()
60 | loss = criterion(pred, label)
61 | return loss
62 |
63 |
64 |
--------------------------------------------------------------------------------
/nets/gated_gcn_net.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | import dgl
6 |
7 | """
8 | ResGatedGCN: Residual Gated Graph ConvNets
9 | An Experimental Study of Neural Networks for Variable Graphs (Xavier Bresson and Thomas Laurent, ICLR 2018)
10 | https://arxiv.org/pdf/1711.07553v2.pdf
11 | """
12 | from layers.gated_gcn_layer import GatedGCNLayer
13 | from layers.mlp_readout_layer import MLPReadout
14 |
15 | class GatedGCNNet(nn.Module):
16 |
17 | def __init__(self, net_params):
18 | super().__init__()
19 | in_dim = net_params['in_dim']
20 | hidden_dim = net_params['hidden_dim']
21 | out_dim = net_params['out_dim']
22 | n_classes = net_params['n_classes']
23 | dropout = net_params['dropout']
24 | n_layers = net_params['L']
25 | self.readout = net_params['readout']
26 | self.batch_norm = net_params['batch_norm']
27 | self.residual = net_params['residual']
28 |
29 | self.embedding_h = nn.Linear(in_dim, hidden_dim)
30 | self.embedding_e = nn.Linear(in_dim, hidden_dim)
31 | self.layers = nn.ModuleList([ GatedGCNLayer(hidden_dim, hidden_dim, dropout,
32 | self.batch_norm, self.residual) for _ in range(n_layers-1) ])
33 | self.layers.append(GatedGCNLayer(hidden_dim, out_dim, dropout, self.batch_norm, self.residual))
34 | self.MLP_layer = MLPReadout(out_dim, n_classes)
35 |
36 | def forward(self, g, h, e):
37 | h = self.embedding_h(h)
38 | e = self.embedding_e(e)
39 |
40 | # convnets
41 | for conv in self.layers:
42 | h, e = conv(g, h, e)
43 | g.ndata['h'] = h
44 |
45 | if self.readout == "sum":
46 | hg = dgl.sum_nodes(g, 'h')
47 | elif self.readout == "max":
48 | hg = dgl.max_nodes(g, 'h')
49 | elif self.readout == "mean":
50 | hg = dgl.mean_nodes(g, 'h')
51 | else:
52 | hg = dgl.mean_nodes(g, 'h') # default readout is mean nodes
53 |
54 | return self.MLP_layer(hg)
55 |
56 | def loss(self, pred, label):
57 | criterion = nn.CrossEntropyLoss()
58 | loss = criterion(pred, label)
59 | return loss
60 |
61 |
62 |
63 |
--------------------------------------------------------------------------------
/nets/gcn_net.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | import dgl
6 |
7 | """
8 | GCN: Graph Convolutional Networks
9 | Thomas N. Kipf, Max Welling, Semi-Supervised Classification with Graph Convolutional Networks (ICLR 2017)
10 | http://arxiv.org/abs/1609.02907
11 | """
12 | from layers.gcn_layer import GCNLayer
13 | from layers.mlp_readout_layer import MLPReadout
14 |
15 | class GCNNet(nn.Module):
16 | def __init__(self, net_params):
17 | super().__init__()
18 | in_dim = net_params['in_dim']
19 | hidden_dim = net_params['hidden_dim']
20 | out_dim = net_params['out_dim']
21 | n_classes = net_params['n_classes']
22 | in_feat_dropout = net_params['in_feat_dropout']
23 | dropout = net_params['dropout']
24 | n_layers = net_params['L']
25 | self.readout = net_params['readout']
26 | self.batch_norm = net_params['batch_norm']
27 | self.residual = net_params['residual']
28 |
29 | self.embedding_h = nn.Linear(in_dim, hidden_dim)
30 | self.in_feat_dropout = nn.Dropout(in_feat_dropout)
31 |
32 | self.layers = nn.ModuleList([GCNLayer(hidden_dim, hidden_dim, F.relu, dropout,
33 | self.batch_norm, self.residual) for _ in range(n_layers-1)])
34 | self.layers.append(GCNLayer(hidden_dim, out_dim, F.relu, dropout, self.batch_norm, self.residual))
35 | self.MLP_layer = MLPReadout(out_dim, n_classes)
36 |
37 | def forward(self, g, h, e):
38 | h = self.embedding_h(h)
39 | h = self.in_feat_dropout(h)
40 | for conv in self.layers:
41 | h = conv(g, h)
42 | g.ndata['h'] = h
43 |
44 | if self.readout == "sum":
45 | hg = dgl.sum_nodes(g, 'h')
46 | elif self.readout == "max":
47 | hg = dgl.max_nodes(g, 'h')
48 | elif self.readout == "mean":
49 | hg = dgl.mean_nodes(g, 'h')
50 | else:
51 | hg = dgl.mean_nodes(g, 'h') # default readout is mean nodes
52 |
53 | return self.MLP_layer(hg)
54 |
55 | def loss(self, pred, label):
56 | criterion = nn.CrossEntropyLoss()
57 | loss = criterion(pred, label)
58 | return loss
59 |
60 |
--------------------------------------------------------------------------------
/nets/gin_net.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | import dgl
6 | from dgl.nn.pytorch.glob import SumPooling, AvgPooling, MaxPooling
7 |
8 | """
9 | GIN: Graph Isomorphism Networks
10 | HOW POWERFUL ARE GRAPH NEURAL NETWORKS? (Keyulu Xu, Weihua Hu, Jure Leskovec and Stefanie Jegelka, ICLR 2019)
11 | https://arxiv.org/pdf/1810.00826.pdf
12 | """
13 |
14 | from layers.gin_layer import GINLayer, ApplyNodeFunc, MLP
15 |
16 | class GINNet(nn.Module):
17 |
18 | def __init__(self, net_params):
19 | super().__init__()
20 | in_dim = net_params['in_dim']
21 | hidden_dim = net_params['hidden_dim']
22 | n_classes = net_params['n_classes']
23 | dropout = net_params['dropout']
24 | self.n_layers = net_params['L']
25 | n_mlp_layers = net_params['n_mlp_GIN'] # GIN
26 | learn_eps = net_params['learn_eps_GIN'] # GIN
27 | neighbor_aggr_type = net_params['neighbor_aggr_GIN'] # GIN
28 | readout = net_params['readout'] # this is graph_pooling_type
29 | batch_norm = net_params['batch_norm']
30 | residual = net_params['residual']
31 |
32 | # List of MLPs
33 | self.ginlayers = torch.nn.ModuleList()
34 |
35 | self.embedding_h = nn.Linear(in_dim, hidden_dim)
36 |
37 | for layer in range(self.n_layers):
38 | mlp = MLP(n_mlp_layers, hidden_dim, hidden_dim, hidden_dim)
39 |
40 | self.ginlayers.append(GINLayer(ApplyNodeFunc(mlp), neighbor_aggr_type,
41 | dropout, batch_norm, residual, 0, learn_eps))
42 |
43 | # Linear function for graph poolings (readout) of output of each layer
44 | # which maps the output of different layers into a prediction score
45 | self.linears_prediction = torch.nn.ModuleList()
46 |
47 | for layer in range(self.n_layers+1):
48 | self.linears_prediction.append(nn.Linear(hidden_dim, n_classes))
49 |
50 | if readout == 'sum':
51 | self.pool = SumPooling()
52 | elif readout == 'mean':
53 | self.pool = AvgPooling()
54 | elif readout == 'max':
55 | self.pool = MaxPooling()
56 | else:
57 | raise NotImplementedError
58 |
59 | def forward(self, g, h, e):
60 |
61 | h = self.embedding_h(h)
62 |
63 | # list of hidden representation at each layer (including input)
64 | hidden_rep = [h]
65 |
66 | for i in range(self.n_layers):
67 | h = self.ginlayers[i](g, h)
68 | hidden_rep.append(h)
69 |
70 | score_over_layer = 0
71 |
72 | # perform pooling over all nodes in each graph in every layer
73 | for i, h in enumerate(hidden_rep):
74 | pooled_h = self.pool(g, h)
75 | score_over_layer += self.linears_prediction[i](pooled_h)
76 |
77 | return score_over_layer
78 |
79 | def loss(self, pred, label):
80 | criterion = nn.CrossEntropyLoss()
81 | loss = criterion(pred, label)
82 | return loss
--------------------------------------------------------------------------------
/nets/graphsage_net.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | import dgl
6 |
7 | """
8 | GraphSAGE:
9 | William L. Hamilton, Rex Ying, Jure Leskovec, Inductive Representation Learning on Large Graphs (NeurIPS 2017)
10 | https://cs.stanford.edu/people/jure/pubs/graphsage-nips17.pdf
11 | """
12 |
13 | from layers.graphsage_layer import GraphSageLayer
14 | from layers.mlp_readout_layer import MLPReadout
15 |
16 | class GraphSageNet(nn.Module):
17 | """
18 | Grahpsage network with multiple GraphSageLayer layers
19 | """
20 | def __init__(self, net_params):
21 | super().__init__()
22 | in_dim = net_params['in_dim']
23 | hidden_dim = net_params['hidden_dim']
24 | out_dim = net_params['out_dim']
25 | n_classes = net_params['n_classes']
26 | in_feat_dropout = net_params['in_feat_dropout']
27 | dropout = net_params['dropout']
28 | aggregator_type = net_params['sage_aggregator']
29 | n_layers = net_params['L']
30 | batch_norm = net_params['batch_norm']
31 | residual = net_params['residual']
32 | self.readout = net_params['readout']
33 |
34 | self.embedding_h = nn.Linear(in_dim, hidden_dim)
35 | self.in_feat_dropout = nn.Dropout(in_feat_dropout)
36 |
37 | self.layers = nn.ModuleList([GraphSageLayer(hidden_dim, hidden_dim, F.relu,
38 | dropout, aggregator_type, batch_norm, residual) for _ in range(n_layers-1)])
39 | self.layers.append(GraphSageLayer(hidden_dim, out_dim, F.relu, dropout, aggregator_type, batch_norm, residual))
40 | self.MLP_layer = MLPReadout(out_dim, n_classes)
41 |
42 | def forward(self, g, h, e):
43 | h = self.embedding_h(h)
44 | h = self.in_feat_dropout(h)
45 | for conv in self.layers:
46 | h = conv(g, h)
47 | g.ndata['h'] = h
48 |
49 | if self.readout == "sum":
50 | hg = dgl.sum_nodes(g, 'h')
51 | elif self.readout == "max":
52 | hg = dgl.max_nodes(g, 'h')
53 | elif self.readout == "mean":
54 | hg = dgl.mean_nodes(g, 'h')
55 | else:
56 | hg = dgl.mean_nodes(g, 'h') # default readout is mean nodes
57 |
58 | return self.MLP_layer(hg)
59 |
60 |
61 | def loss(self, pred, label):
62 | criterion = nn.CrossEntropyLoss()
63 | loss = criterion(pred, label)
64 | return loss
65 |
--------------------------------------------------------------------------------
/nets/load_net.py:
--------------------------------------------------------------------------------
1 | """
2 | Utility file to select GraphNN model as
3 | selected by the user
4 | """
5 |
6 | from nets.gated_gcn_net import GatedGCNNet
7 | from nets.gcn_net import GCNNet
8 | from nets.gat_net import GATNet
9 | from nets.graphsage_net import GraphSageNet
10 | from nets.gin_net import GINNet
11 | from nets.mo_net import MoNet as MoNet_
12 | from nets.mlp_net import MLPNet
13 | from nets.ring_gnn_net import RingGNNNet
14 | from nets.three_wl_gnn_net import ThreeWLGNNNet
15 |
16 | def GatedGCN(net_params):
17 | return GatedGCNNet(net_params)
18 |
19 | def GCN(net_params):
20 | return GCNNet(net_params)
21 |
22 | def GAT(net_params):
23 | return GATNet(net_params)
24 |
25 | def GraphSage(net_params):
26 | return GraphSageNet(net_params)
27 |
28 | def GIN(net_params):
29 | return GINNet(net_params)
30 |
31 | def MoNet(net_params):
32 | return MoNet_(net_params)
33 |
34 | def MLP(net_params):
35 | return MLPNet(net_params)
36 |
37 | def RingGNN(net_params):
38 | return RingGNNNet(net_params)
39 |
40 | def ThreeWLGNN(net_params):
41 | return ThreeWLGNNNet(net_params)
42 |
43 |
44 | def gnn_model(MODEL_NAME, net_params):
45 | models = {
46 | 'GatedGCN': GatedGCN,
47 | 'GCN': GCN,
48 | 'GAT': GAT,
49 | 'GraphSage': GraphSage,
50 | 'GIN': GIN,
51 | 'MoNet': MoNet_,
52 | 'MLP': MLP,
53 | 'RingGNN': RingGNN,
54 | '3WLGNN': ThreeWLGNN
55 | }
56 |
57 | return models[MODEL_NAME](net_params)
--------------------------------------------------------------------------------
/nets/mlp_net.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | import dgl
6 |
7 | from layers.mlp_readout_layer import MLPReadout
8 |
9 | class MLPNet(nn.Module):
10 | def __init__(self, net_params):
11 | super().__init__()
12 | in_dim = net_params['in_dim']
13 | hidden_dim = net_params['hidden_dim']
14 | n_classes = net_params['n_classes']
15 | in_feat_dropout = net_params['in_feat_dropout']
16 | dropout = net_params['dropout']
17 | n_layers = net_params['L']
18 | self.gated = net_params['gated']
19 |
20 | self.in_feat_dropout = nn.Dropout(in_feat_dropout)
21 |
22 | feat_mlp_modules = [
23 | nn.Linear(in_dim, hidden_dim, bias=True),
24 | nn.ReLU(),
25 | nn.Dropout(dropout),
26 | ]
27 | for _ in range(n_layers-1):
28 | feat_mlp_modules.append(nn.Linear(hidden_dim, hidden_dim, bias=True))
29 | feat_mlp_modules.append(nn.ReLU())
30 | feat_mlp_modules.append(nn.Dropout(dropout))
31 | self.feat_mlp = nn.Sequential(*feat_mlp_modules)
32 |
33 | if self.gated:
34 | self.gates = nn.Linear(hidden_dim, hidden_dim, bias=True)
35 |
36 | self.readout_mlp = MLPReadout(hidden_dim, n_classes)
37 |
38 | def forward(self, g, h, e):
39 | h = self.in_feat_dropout(h)
40 | h = self.feat_mlp(h)
41 | if self.gated:
42 | h = torch.sigmoid(self.gates(h)) * h
43 | g.ndata['h'] = h
44 | hg = dgl.sum_nodes(g, 'h')
45 | # hg = torch.cat(
46 | # (
47 | # dgl.sum_nodes(g, 'h'),
48 | # dgl.max_nodes(g, 'h')
49 | # ),
50 | # dim=1
51 | # )
52 |
53 | else:
54 | g.ndata['h'] = h
55 | hg = dgl.mean_nodes(g, 'h')
56 |
57 | return self.readout_mlp(hg)
58 |
59 |
60 | def loss(self, pred, label):
61 | criterion = nn.CrossEntropyLoss()
62 | loss = criterion(pred, label)
63 | return loss
64 |
--------------------------------------------------------------------------------
/nets/mo_net.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | import dgl
6 |
7 | import numpy as np
8 |
9 | """
10 | GMM: Gaussian Mixture Model Convolution layer
11 | Geometric Deep Learning on Graphs and Manifolds using Mixture Model CNNs (Federico Monti et al., CVPR 2017)
12 | https://arxiv.org/pdf/1611.08402.pdf
13 | """
14 |
15 | from layers.gmm_layer import GMMLayer
16 | from layers.mlp_readout_layer import MLPReadout
17 |
18 | class MoNet(nn.Module):
19 | def __init__(self, net_params):
20 | super().__init__()
21 | self.name = 'MoNet'
22 | in_dim = net_params['in_dim']
23 | hidden_dim = net_params['hidden_dim']
24 | out_dim = net_params['out_dim']
25 | kernel = net_params['kernel'] # for MoNet
26 | dim = net_params['pseudo_dim_MoNet'] # for MoNet
27 | n_classes = net_params['n_classes']
28 | dropout = net_params['dropout']
29 | n_layers = net_params['L']
30 | self.readout = net_params['readout']
31 | batch_norm = net_params['batch_norm']
32 | residual = net_params['residual']
33 | self.device = net_params['device']
34 |
35 | aggr_type = "sum" # default for MoNet
36 |
37 | self.embedding_h = nn.Linear(in_dim, hidden_dim)
38 |
39 | self.layers = nn.ModuleList()
40 | self.pseudo_proj = nn.ModuleList()
41 |
42 | # Hidden layer
43 | for _ in range(n_layers-1):
44 | self.layers.append(GMMLayer(hidden_dim, hidden_dim, dim, kernel, aggr_type,
45 | dropout, batch_norm, residual))
46 | self.pseudo_proj.append(nn.Sequential(nn.Linear(2, dim), nn.Tanh()))
47 |
48 | # Output layer
49 | self.layers.append(GMMLayer(hidden_dim, out_dim, dim, kernel, aggr_type,
50 | dropout, batch_norm, residual))
51 | self.pseudo_proj.append(nn.Sequential(nn.Linear(2, dim), nn.Tanh()))
52 |
53 | self.MLP_layer = MLPReadout(out_dim, n_classes)
54 |
55 | def forward(self, g, h, e):
56 | h = self.embedding_h(h)
57 |
58 | # computing the 'pseudo' named tensor which depends on node degrees
59 | g.ndata['deg'] = g.in_degrees()
60 | g.apply_edges(self.compute_pseudo)
61 | pseudo = g.edata['pseudo'].to(self.device).float()
62 |
63 | for i in range(len(self.layers)):
64 | h = self.layers[i](g, h, self.pseudo_proj[i](pseudo))
65 | g.ndata['h'] = h
66 |
67 | if self.readout == "sum":
68 | hg = dgl.sum_nodes(g, 'h')
69 | elif self.readout == "max":
70 | hg = dgl.max_nodes(g, 'h')
71 | elif self.readout == "mean":
72 | hg = dgl.mean_nodes(g, 'h')
73 | else:
74 | hg = dgl.mean_nodes(g, 'h') # default readout is mean nodes
75 |
76 | return self.MLP_layer(hg)
77 |
78 | def compute_pseudo(self, edges):
79 | # compute pseudo edge features for MoNet
80 | # to avoid zero division in case in_degree is 0, we add constant '1' in all node degrees denoting self-loop
81 | srcs = 1/np.sqrt(edges.src['deg']+1)
82 | dsts = 1/np.sqrt(edges.dst['deg']+1)
83 | pseudo = torch.cat((srcs.unsqueeze(-1), dsts.unsqueeze(-1)), dim=1)
84 | return {'pseudo': pseudo}
85 |
86 | def loss(self, pred, label):
87 | criterion = nn.CrossEntropyLoss()
88 | loss = criterion(pred, label)
89 | return loss
--------------------------------------------------------------------------------
/nets/ring_gnn_net.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | import dgl
6 | import time
7 |
8 | """
9 | Ring-GNN
10 | On the equivalence between graph isomorphism testing and function approximation with GNNs (Chen et al, 2019)
11 | https://arxiv.org/pdf/1905.12560v1.pdf
12 | """
13 | from layers.ring_gnn_equiv_layer import RingGNNEquivLayer
14 | from layers.mlp_readout_layer import MLPReadout
15 |
16 | class RingGNNNet(nn.Module):
17 | def __init__(self, net_params):
18 | super().__init__()
19 | self.in_dim_node = net_params['in_dim']
20 | avg_node_num = net_params['avg_node_num']
21 | radius = net_params['radius']
22 | hidden_dim = net_params['hidden_dim']
23 | n_classes = net_params['n_classes']
24 | dropout = net_params['dropout']
25 | n_layers = net_params['L']
26 | self.layer_norm = net_params['layer_norm']
27 | self.residual = net_params['residual']
28 | self.device = net_params['device']
29 |
30 | self.depth = [torch.LongTensor([1+self.in_dim_node])] + [torch.LongTensor([hidden_dim])] * n_layers
31 |
32 | self.equi_modulelist = nn.ModuleList([RingGNNEquivLayer(self.device, m, n,
33 | layer_norm=self.layer_norm,
34 | residual=self.residual,
35 | dropout=dropout,
36 | radius=radius,
37 | k2_init=0.5/avg_node_num) for m, n in zip(self.depth[:-1], self.depth[1:])])
38 |
39 | self.prediction = MLPReadout(torch.sum(torch.stack(self.depth)).item(), n_classes)
40 |
41 | def forward(self, x):
42 | """
43 | CODE ADPATED FROM https://github.com/leichen2018/Ring-GNN/
44 | """
45 |
46 | # this x is the tensor with all info available => adj, node feat
47 |
48 | x_list = [x]
49 | for layer in self.equi_modulelist:
50 | x = layer(x)
51 | x_list.append(x)
52 |
53 | # # readout
54 | x_list = [torch.sum(torch.sum(x, dim=3), dim=2) for x in x_list]
55 | x_list = torch.cat(x_list, dim=1)
56 |
57 | x_out = self.prediction(x_list)
58 |
59 | return x_out
60 |
61 | def loss(self, pred, label):
62 | criterion = nn.CrossEntropyLoss()
63 | loss = criterion(pred, label)
64 | return loss
65 |
66 |
--------------------------------------------------------------------------------
/nets/three_wl_gnn_net.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | import dgl
6 | import time
7 |
8 | """
9 | 3WLGNN / ThreeWLGNN
10 | Provably Powerful Graph Networks (Maron et al., 2019)
11 | https://papers.nips.cc/paper/8488-provably-powerful-graph-networks.pdf
12 |
13 | CODE adapted from https://github.com/hadarser/ProvablyPowerfulGraphNetworks_torch/
14 | """
15 |
16 | from layers.three_wl_gnn_layers import RegularBlock, MlpBlock, SkipConnection, FullyConnected, diag_offdiag_maxpool
17 | from layers.mlp_readout_layer import MLPReadout
18 |
19 | class ThreeWLGNNNet(nn.Module):
20 | def __init__(self, net_params):
21 | super().__init__()
22 |
23 | self.in_dim_node = net_params['in_dim']
24 | depth_of_mlp = net_params['depth_of_mlp']
25 | hidden_dim = net_params['hidden_dim']
26 | n_classes = net_params['n_classes']
27 | dropout = net_params['dropout']
28 | n_layers = net_params['L']
29 | self.layer_norm = net_params['layer_norm']
30 | self.residual = net_params['residual']
31 | self.device = net_params['device']
32 | self.diag_pool_readout = True # if True, uses the new_suffix readout from original code
33 |
34 | block_features = [hidden_dim] * n_layers # L here is the block number
35 |
36 | original_features_num = self.in_dim_node + 1 # Number of features of the input
37 |
38 | # sequential mlp blocks
39 | last_layer_features = original_features_num
40 | self.reg_blocks = nn.ModuleList()
41 | for layer, next_layer_features in enumerate(block_features):
42 | mlp_block = RegularBlock(depth_of_mlp, last_layer_features, next_layer_features, self.residual)
43 | self.reg_blocks.append(mlp_block)
44 | last_layer_features = next_layer_features
45 |
46 | if self.diag_pool_readout:
47 | self.fc_layers = nn.ModuleList()
48 | for output_features in block_features:
49 | # each block's output will be pooled (thus have 2*output_features), and pass through a fully connected
50 | fc = FullyConnected(2*output_features, n_classes, activation_fn=None)
51 | self.fc_layers.append(fc)
52 | else:
53 | self.mlp_prediction = MLPReadout(sum(block_features)+original_features_num, n_classes)
54 |
55 | def forward(self, x):
56 | if self.diag_pool_readout:
57 | scores = torch.tensor(0, device=self.device, dtype=x.dtype)
58 | else:
59 | x_list = [x]
60 |
61 | for i, block in enumerate(self.reg_blocks):
62 |
63 | x = block(x)
64 | if self.diag_pool_readout:
65 | scores = self.fc_layers[i](diag_offdiag_maxpool(x)) + scores
66 | else:
67 | x_list.append(x)
68 |
69 | if self.diag_pool_readout:
70 | return scores
71 | else:
72 | # readout like RingGNN
73 | x_list = [torch.sum(torch.sum(x, dim=3), dim=2) for x in x_list]
74 | x_list = torch.cat(x_list, dim=1)
75 |
76 | x_out = self.mlp_prediction(x_list)
77 | return x_out
78 |
79 |
80 | def loss(self, pred, label):
81 | criterion = nn.CrossEntropyLoss()
82 | loss = criterion(pred, label)
83 | return loss
84 |
--------------------------------------------------------------------------------
/pooling/diffpool.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.nn import init
4 | import torch.nn.functional as F
5 |
6 | import time
7 | import numpy as np
8 | from scipy.linalg import block_diag
9 |
10 | import dgl
11 |
12 | """
13 |
14 |
15 | DIFFPOOL:
16 | Z. Ying, J. You, C. Morris, X. Ren, W. Hamilton, and J. Leskovec,
17 | Hierarchical graph representation learning with differentiable pooling (NeurIPS 2018)
18 | https://arxiv.org/pdf/1806.08804.pdf
19 |
20 | ! code started from dgl diffpool examples dir
21 | """
22 |
23 | from layers.graphsage_layer import GraphSageLayer, DenseGraphSage # this is GraphSageLayer, DiffPoolBatchedGraphLayer
24 | from layers.diffpool_layer import DiffPoolLayer, DenseDiffPool # this is GraphSageLayer, DiffPoolBatchedGraphLayer
25 | # from .graphsage_net import GraphSageNet # this is GraphSage
26 | # replace BatchedDiffPool with DenseDiffPool and BatchedGraphSAGE with DenseGraphSage
27 |
28 |
29 | class DiffPoolNet(nn.Module):
30 | """
31 | DiffPool Fuse with GNN layers and pooling layers in sequence
32 | """
33 |
34 | def __init__(self, net_params, pool_ratio=0.5):
35 |
36 | super().__init__()
37 | input_dim = net_params['in_dim']
38 | hidden_dim = net_params['hidden_dim']
39 | embedding_dim = net_params['hidden_dim'] # net_params['embedding_dim']
40 | label_dim = net_params['n_classes']
41 | activation = F.relu
42 | n_layers = net_params['L'] # this is the gnn_per_block param
43 | dropout = net_params['dropout']
44 | self.batch_norm = net_params['batch_norm']
45 | self.residual = net_params['residual']
46 | aggregator_type = net_params['sage_aggregator']
47 | # pool_ratio = net_params['pool_ratio']
48 |
49 | self.device = net_params['device']
50 | self.link_pred = True # net_params['linkpred']
51 | self.concat = False # net_params['cat']
52 | self.n_pooling = 1 # net_params['num_pool']
53 | self.batch_size = net_params['batch_size']
54 | self.e_feat = net_params['edge_feat']
55 | self.link_pred_loss = []
56 | self.entropy_loss = []
57 |
58 | self.embedding_h = nn.Linear(input_dim, hidden_dim)
59 |
60 | # list of GNN modules before the first diffpool operation
61 | self.gc_before_pool = nn.ModuleList()
62 |
63 | self.assign_dim = int(net_params['max_num_node'] * pool_ratio) * self.batch_size # net_params['assign_dim']
64 | self.bn = True
65 | self.num_aggs = 1
66 |
67 | # constructing layers
68 | # layers before diffpool
69 | assert n_layers >= 3, "n_layers too few"
70 | self.gc_before_pool.append(GraphSageLayer(hidden_dim, hidden_dim, activation,
71 | dropout, aggregator_type, self.residual, self.bn))
72 |
73 | for _ in range(n_layers - 2):
74 | self.gc_before_pool.append(GraphSageLayer(hidden_dim, hidden_dim, activation,
75 | dropout, aggregator_type, self.residual, self.bn))
76 |
77 | self.gc_before_pool.append(GraphSageLayer(hidden_dim, embedding_dim, None, dropout, aggregator_type, self.residual))
78 |
79 |
80 | assign_dims = []
81 | assign_dims.append(self.assign_dim)
82 | if self.concat:
83 | # diffpool layer receive pool_emedding_dim node feature tensor
84 | # and return pool_embedding_dim node embedding
85 | pool_embedding_dim = hidden_dim * (n_layers - 1) + embedding_dim
86 | else:
87 |
88 | pool_embedding_dim = embedding_dim
89 |
90 | self.first_diffpool_layer = DiffPoolLayer(pool_embedding_dim, self.assign_dim, hidden_dim,
91 | activation, dropout, aggregator_type, self.link_pred, self.batch_norm)
92 | gc_after_per_pool = nn.ModuleList()
93 |
94 | # list of list of GNN modules, each list after one diffpool operation
95 | self.gc_after_pool = nn.ModuleList()
96 |
97 | for _ in range(n_layers - 1):
98 | gc_after_per_pool.append(DenseGraphSage(hidden_dim, hidden_dim, self.residual))
99 | gc_after_per_pool.append(DenseGraphSage(hidden_dim, embedding_dim, self.residual))
100 | self.gc_after_pool.append(gc_after_per_pool)
101 |
102 | self.assign_dim = int(self.assign_dim * pool_ratio)
103 |
104 | self.diffpool_layers = nn.ModuleList()
105 | # each pooling module
106 | for _ in range(self.n_pooling - 1):
107 | self.diffpool_layers.append(DenseDiffPool(pool_embedding_dim, self.assign_dim, hidden_dim, self.link_pred))
108 |
109 | gc_after_per_pool = nn.ModuleList()
110 |
111 | for _ in range(n_layers - 1):
112 | gc_after_per_pool.append(DenseGraphSage(hidden_dim, hidden_dim, self.residual))
113 | gc_after_per_pool.append(DenseGraphSage(hidden_dim, embedding_dim, self.residual))
114 | self.gc_after_pool.append(gc_after_per_pool)
115 |
116 | assign_dims.append(self.assign_dim)
117 | self.assign_dim = int(self.assign_dim * pool_ratio)
118 |
119 | # predicting layer
120 | if self.concat:
121 | self.pred_input_dim = pool_embedding_dim * \
122 | self.num_aggs * (self.n_pooling + 1)
123 | else:
124 | self.pred_input_dim = embedding_dim * self.num_aggs
125 | self.pred_layer = nn.Linear(self.pred_input_dim, label_dim)
126 |
127 | # weight initialization
128 | for m in self.modules():
129 | if isinstance(m, nn.Linear):
130 | m.weight.data = init.xavier_uniform_(m.weight.data,
131 | gain=nn.init.calculate_gain('relu'))
132 | if m.bias is not None:
133 | m.bias.data = init.constant_(m.bias.data, 0.0)
134 |
135 | def gcn_forward(self, g, h, e, gc_layers, cat=False):
136 | """
137 | Return gc_layer embedding cat.
138 | """
139 | block_readout = []
140 | for gc_layer in gc_layers[:-1]:
141 | h, e = gc_layer(g, h, e)
142 | block_readout.append(h)
143 | h, e = gc_layers[-1](g, h, e)
144 | block_readout.append(h)
145 | if cat:
146 | block = torch.cat(block_readout, dim=1) # N x F, F = F1 + F2 + ...
147 | else:
148 | block = h
149 | return block
150 |
151 | def gcn_forward_tensorized(self, h, adj, gc_layers, cat=False):
152 | block_readout = []
153 | for gc_layer in gc_layers:
154 | h = gc_layer(h, adj)
155 | block_readout.append(h)
156 | if cat:
157 | block = torch.cat(block_readout, dim=2) # N x F, F = F1 + F2 + ...
158 | else:
159 | block = h
160 | return block
161 |
162 | def forward(self, g, h, e):
163 | self.link_pred_loss = []
164 | self.entropy_loss = []
165 |
166 | # node feature for assignment matrix computation is the same as the
167 | # original node feature
168 | h = self.embedding_h(h)
169 | h_a = h
170 |
171 | out_all = []
172 |
173 | # we use GCN blocks to get an embedding first
174 | g_embedding = self.gcn_forward(g, h, e, self.gc_before_pool, self.concat)
175 |
176 | g.ndata['h'] = g_embedding
177 |
178 | readout = dgl.sum_nodes(g, 'h')
179 | out_all.append(readout)
180 | if self.num_aggs == 2:
181 | readout = dgl.max_nodes(g, 'h')
182 | out_all.append(readout)
183 |
184 | adj, h = self.first_diffpool_layer(g, g_embedding)
185 | node_per_pool_graph = int(adj.size()[0] / self.batch_size)
186 |
187 | h, adj = self.batch2tensor(adj, h, node_per_pool_graph)
188 | h = self.gcn_forward_tensorized(h, adj, self.gc_after_pool[0], self.concat)
189 |
190 | readout = torch.sum(h, dim=1)
191 | out_all.append(readout)
192 | if self.num_aggs == 2:
193 | readout, _ = torch.max(h, dim=1)
194 | out_all.append(readout)
195 |
196 | for i, diffpool_layer in enumerate(self.diffpool_layers):
197 | h, adj = diffpool_layer(h, adj)
198 | h = self.gcn_forward_tensorized(h, adj, self.gc_after_pool[i + 1], self.concat)
199 |
200 | readout = torch.sum(h, dim=1)
201 | out_all.append(readout)
202 |
203 | if self.num_aggs == 2:
204 | readout, _ = torch.max(h, dim=1)
205 | out_all.append(readout)
206 |
207 | if self.concat or self.num_aggs > 1:
208 | final_readout = torch.cat(out_all, dim=1)
209 | else:
210 | final_readout = readout
211 | ypred = self.pred_layer(final_readout)
212 | return ypred
213 |
214 | def batch2tensor(self, batch_adj, batch_feat, node_per_pool_graph):
215 | """
216 | transform a batched graph to batched adjacency tensor and node feature tensor
217 | """
218 | batch_size = int(batch_adj.size()[0] / node_per_pool_graph)
219 | adj_list = []
220 | feat_list = []
221 |
222 | for i in range(batch_size):
223 | start = i * node_per_pool_graph
224 | end = (i + 1) * node_per_pool_graph
225 |
226 | # 1/sqrt(V) normalization
227 | snorm_n = torch.FloatTensor(node_per_pool_graph, 1).fill_(1./float(node_per_pool_graph)).sqrt().to(self.device)
228 |
229 | adj_list.append(batch_adj[start:end, start:end])
230 | feat_list.append((batch_feat[start:end, :])*snorm_n)
231 | adj_list = list(map(lambda x: torch.unsqueeze(x, 0), adj_list))
232 | feat_list = list(map(lambda x: torch.unsqueeze(x, 0), feat_list))
233 | adj = torch.cat(adj_list, dim=0)
234 | feat = torch.cat(feat_list, dim=0)
235 |
236 | return feat, adj
237 |
238 | def loss(self, pred, label):
239 | '''
240 | loss function
241 | '''
242 | #softmax + CE
243 | criterion = nn.CrossEntropyLoss()
244 | loss = criterion(pred, label)
245 | for diffpool_layer in self.diffpool_layers:
246 | for key, value in diffpool_layer.loss_log.items():
247 | loss += value
248 | return loss
249 |
250 |
--------------------------------------------------------------------------------
/pooling/edge_sparse_max.py:
--------------------------------------------------------------------------------
1 | """
2 | An original implementation of sparsemax (Martins & Astudillo, 2016) is available at
3 | https://github.com/OpenNMT/OpenNMT-py/blob/master/onmt/modules/sparse_activations.py.
4 | See `From Softmax to Sparsemax: A Sparse Model of Attention and Multi-Label Classification, ICML 2016`
5 | for detailed description.
6 |
7 | Here we implement a graph-edge version of sparsemax where we perform sparsemax for all edges
8 | with the same node as end-node in graphs.
9 | """
10 | import dgl
11 | import torch
12 | from dgl.backend import astype
13 | from dgl.base import ALL, is_all
14 | from dgl.heterograph_index import HeteroGraphIndex
15 | from dgl.sparse import _gsddmm, _gspmm
16 | from torch import Tensor
17 | from torch.autograd import Function
18 |
19 |
20 | def _neighbor_sort(scores:Tensor, end_n_ids:Tensor, in_degrees:Tensor, cum_in_degrees:Tensor):
21 | """Sort edge scores for each node"""
22 | num_nodes, max_in_degree = in_degrees.size(0), int(in_degrees.max().item())
23 |
24 | # Compute the index for dense score matrix with size (N x D_{max})
25 | # Note that the end_n_ids here is the end_node tensor in dgl graph,
26 | # which is not grouped by its node id (i.e. in this form: 0,0,1,1,1,...,N,N).
27 | # Thus here we first sort the end_node tensor to make it easier to compute
28 | # indexs in dense edge score matrix. Since we will need the original order
29 | # for following gspmm and gsddmm operations, we also keep the reverse mapping
30 | # (the reverse_perm) here.
31 | end_n_ids, perm = torch.sort(end_n_ids)
32 | scores = scores[perm]
33 | _, reverse_perm = torch.sort(perm)
34 |
35 | index = torch.arange(end_n_ids.size(0), dtype=torch.long, device=scores.device)
36 | index = (index - cum_in_degrees[end_n_ids]) + (end_n_ids * max_in_degree)
37 | index = index.long()
38 |
39 | dense_scores = scores.new_full((num_nodes * max_in_degree, ), torch.finfo(scores.dtype).min)
40 | dense_scores[index] = scores
41 | dense_scores = dense_scores.view(num_nodes, max_in_degree)
42 |
43 | sorted_dense_scores, dense_reverse_perm = dense_scores.sort(dim=-1, descending=True)
44 | _, dense_reverse_perm = torch.sort(dense_reverse_perm, dim=-1)
45 | dense_reverse_perm = dense_reverse_perm + cum_in_degrees.view(-1, 1)
46 | dense_reverse_perm = dense_reverse_perm.view(-1)
47 | cumsum_sorted_dense_scores = sorted_dense_scores.cumsum(dim=-1).view(-1)
48 | sorted_dense_scores = sorted_dense_scores.view(-1)
49 | arange_vec = torch.arange(1, max_in_degree + 1, dtype=torch.long, device=end_n_ids.device)
50 | arange_vec = torch.repeat_interleave(arange_vec.view(1, -1), num_nodes, dim=0).view(-1)
51 |
52 | valid_mask = (sorted_dense_scores != torch.finfo(scores.dtype).min)
53 | sorted_scores = sorted_dense_scores[valid_mask]
54 | cumsum_sorted_scores = cumsum_sorted_dense_scores[valid_mask]
55 | arange_vec = arange_vec[valid_mask]
56 | dense_reverse_perm = dense_reverse_perm[valid_mask].long()
57 |
58 | return sorted_scores, cumsum_sorted_scores, arange_vec, reverse_perm, dense_reverse_perm
59 |
60 |
61 | def _threshold_and_support_graph(gidx:HeteroGraphIndex, scores:Tensor, end_n_ids:Tensor):
62 | """Find the threshold for each node and its edges"""
63 | in_degrees = _gspmm(gidx, "copy_rhs", "sum", None, torch.ones_like(scores))[0]
64 | cum_in_degrees = torch.cat([in_degrees.new_zeros(1), in_degrees.cumsum(dim=0)[:-1]], dim=0)
65 |
66 | # perform sort on edges for each node
67 | sorted_scores, cumsum_scores, rhos, reverse_perm, dense_reverse_perm = _neighbor_sort(scores, end_n_ids,
68 | in_degrees, cum_in_degrees)
69 | cumsum_scores = cumsum_scores - 1.
70 | support = rhos * sorted_scores > cumsum_scores
71 | support = support[dense_reverse_perm] # from sorted order to unsorted order
72 | support = support[reverse_perm] # from src-dst order to eid order
73 |
74 | support_size = _gspmm(gidx, "copy_rhs", "sum", None, support.float())[0]
75 | support_size = support_size.long()
76 | idx = support_size + cum_in_degrees - 1
77 |
78 | # mask invalid index, for example, if batch is not start from 0 or not continuous, it may result in negative index
79 | mask = idx < 0
80 | idx[mask] = 0
81 | tau = cumsum_scores.gather(0, idx.long())
82 | tau /= support_size.to(scores.dtype)
83 |
84 | return tau, support_size
85 |
86 |
87 | class EdgeSparsemaxFunction(Function):
88 | r"""
89 | Description
90 | -----------
91 | Pytorch Auto-Grad Function for edge sparsemax.
92 |
93 | We define this auto-grad function here since
94 | sparsemax involves sort and select, which are
95 | not derivative.
96 | """
97 | @staticmethod
98 | def forward(ctx, gidx:HeteroGraphIndex, scores:Tensor,
99 | eids:Tensor, end_n_ids:Tensor, norm_by:str):
100 | if not is_all(eids):
101 | gidx = gidx.edge_subgraph([eids], True).graph
102 | if norm_by == "src":
103 | gidx = gidx.reverse()
104 |
105 | # use feat - max(feat) for numerical stability.
106 | scores = scores.float()
107 | scores_max = _gspmm(gidx, "copy_rhs", "max", None, scores)[0]
108 | scores = _gsddmm(gidx, "sub", scores, scores_max, "e", "v")
109 |
110 | # find threshold for each node and perform ReLU(u-t(u)) operation.
111 | tau, supp_size = _threshold_and_support_graph(gidx, scores, end_n_ids)
112 | out = torch.clamp(_gsddmm(gidx, "sub", scores, tau, "e", "v"), min=0)
113 | ctx.backward_cache = gidx
114 | ctx.save_for_backward(supp_size, out)
115 | torch.cuda.empty_cache()
116 | return out
117 |
118 | @staticmethod
119 | def backward(ctx, grad_out):
120 | gidx = ctx.backward_cache
121 | supp_size, out = ctx.saved_tensors
122 | grad_in = grad_out.clone()
123 |
124 | # grad for ReLU
125 | grad_in[out == 0] = 0
126 |
127 | # dL/dv_i = dL/do_i - 1/k \sum_{j=1}^k dL/do_j
128 | v_hat = _gspmm(gidx, "copy_rhs", "sum", None, grad_in)[0] / supp_size.to(out.dtype)
129 | grad_in_modify = _gsddmm(gidx, "sub", grad_in, v_hat, "e", "v")
130 | grad_in = torch.where(out != 0, grad_in_modify, grad_in)
131 | del gidx
132 | torch.cuda.empty_cache()
133 |
134 | return None, grad_in, None, None, None
135 |
136 |
137 | def edge_sparsemax(graph:dgl.DGLGraph, logits, eids=ALL, norm_by="dst"):
138 | r"""
139 | Description
140 | -----------
141 | Compute edge sparsemax. For a node :math:`i`, edge sparsemax is an operation that computes
142 |
143 | .. math::
144 | a_{ij} = \text{ReLU}(z_{ij} - \tau(\z_{i,:}))
145 |
146 | where :math:`z_{ij}` is a signal of edge :math:`j\rightarrow i`, also
147 | called logits in the context of sparsemax. :math:`\tau` is a function
148 | that can be found at the `From Softmax to Sparsemax `
149 | paper.
150 |
151 | NOTE: currently only homogeneous graphs are supported.
152 |
153 | Parameters
154 | ----------
155 | graph : DGLGraph
156 | The graph to perform edge sparsemax on.
157 | logits : torch.Tensor
158 | The input edge feature.
159 | eids : torch.Tensor or ALL, optional
160 | A tensor of edge index on which to apply edge sparsemax. If ALL, apply edge
161 | sparsemax on all edges in the graph. Default: ALL.
162 | norm_by : str, could be 'src' or 'dst'
163 | Normalized by source nodes of destination nodes. Default: `dst`.
164 |
165 | Returns
166 | -------
167 | Tensor
168 | Sparsemax value.
169 | """
170 | # we get edge index tensors here since it is
171 | # hard to get edge index with HeteroGraphIndex
172 | # object without other information like edge_type.
173 | row, col = graph.all_edges(order="eid")
174 | assert norm_by in ["dst", "src"]
175 | end_n_ids = col if norm_by == "dst" else row
176 | if not is_all(eids):
177 | eids = astype(eids, graph.idtype)
178 | end_n_ids = end_n_ids[eids]
179 | return EdgeSparsemaxFunction.apply(graph._graph, logits,
180 | eids, end_n_ids, norm_by)
181 |
182 |
183 | class EdgeSparsemax(torch.nn.Module):
184 | r"""
185 | Description
186 | -----------
187 | Compute edge sparsemax. For a node :math:`i`, edge sparsemax is an operation that computes
188 |
189 | .. math::
190 | a_{ij} = \text{ReLU}(z_{ij} - \tau(\z_{i,:}))
191 |
192 | where :math:`z_{ij}` is a signal of edge :math:`j\rightarrow i`, also
193 | called logits in the context of sparsemax. :math:`\tau` is a function
194 | that can be found at the `From Softmax to Sparsemax `
195 | paper.
196 |
197 | Parameters
198 | ----------
199 | graph : DGLGraph
200 | The graph to perform edge sparsemax on.
201 | logits : torch.Tensor
202 | The input edge feature.
203 | eids : torch.Tensor or ALL, optional
204 | A tensor of edge index on which to apply edge sparsemax. If ALL, apply edge
205 | sparsemax on all edges in the graph. Default: ALL.
206 | norm_by : str, could be 'src' or 'dst'
207 | Normalized by source nodes of destination nodes. Default: `dst`.
208 |
209 | NOTE: currently only homogeneous graphs are supported.
210 |
211 | Returns
212 | -------
213 | Tensor
214 | Sparsemax value.
215 | """
216 | def __init__(self):
217 | super(EdgeSparsemax, self).__init__()
218 |
219 | def forward(self, graph, logits, eids=ALL, norm_by="dst"):
220 | return edge_sparsemax(graph, logits, eids, norm_by)
221 |
--------------------------------------------------------------------------------
/pooling/mewispool.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import numpy as np
4 | import networkx as nx
5 | import dgl
6 | from dgl.nn import GINConv
7 | from scipy import sparse
8 | from dgl.nn.pytorch.glob import SumPooling, AvgPooling, MaxPooling
9 |
10 |
11 | class MLP(nn.Module):
12 | def __init__(self, input_dim, hidden_dim, output_dim, enhance=False):
13 | super(MLP, self).__init__()
14 |
15 | self.enhance = enhance
16 |
17 | self.fc1 = nn.Linear(in_features=input_dim, out_features=hidden_dim)
18 | self.fc2 = nn.Linear(in_features=hidden_dim, out_features=hidden_dim)
19 | self.fc3 = nn.Linear(in_features=hidden_dim, out_features=output_dim)
20 |
21 | if enhance:
22 | self.bn1 = nn.BatchNorm1d(hidden_dim)
23 | self.bn2 = nn.BatchNorm1d(hidden_dim)
24 | self.dropout = nn.Dropout(0.2)
25 |
26 | def forward(self, x):
27 | x = self.fc1(x)
28 | if self.enhance:
29 | x = self.bn1(x)
30 | x = torch.relu(x)
31 | if self.enhance:
32 | x = self.dropout(x)
33 |
34 | x = self.fc2(x)
35 | if self.enhance:
36 | x = self.bn2(x)
37 | x = torch.relu(x)
38 | if self.enhance:
39 | x = self.dropout(x)
40 |
41 | x = self.fc3(x)
42 |
43 | return x
44 |
45 |
46 | class MEWISPool(nn.Module):
47 | def __init__(self, hidden_dim):
48 | super(MEWISPool, self).__init__()
49 |
50 | self.gc1 = GINConv(MLP(1, hidden_dim, hidden_dim), 'sum')
51 | self.gc2 = GINConv(MLP(hidden_dim, hidden_dim, hidden_dim), 'sum')
52 | self.gc3 = GINConv(MLP(hidden_dim, hidden_dim, 1), 'sum')
53 |
54 | def forward(self, g, x, batch):
55 | # computing the graph laplacian and adjacency matrix
56 | batch_nodes = x.size(0)
57 | if g.num_edges() != 0:
58 | # L_indices, L_values = get_laplacian(edge_index)
59 | # L = torch.sparse.FloatTensor(L_indices, L_values, torch.Size([batch_nodes, batch_nodes]))
60 | # A = torch.diag(torch.diag(L.to_dense())) - L.to_dense()
61 |
62 | num_nodes = g.number_of_nodes()
63 | A = g.adjacency_matrix_scipy(return_edge_ids=False).astype(float)
64 | L = sparse.eye(num_nodes) - A
65 |
66 | A = self.matrix2tensor(A, device=x.device)
67 | L = self.matrix2tensor(L, device=x.device)
68 |
69 | # entropy computation
70 | entropies = self.compute_entropy(x, L, A, batch) # Eq. (8)
71 | else:
72 | A = torch.zeros([batch_nodes, batch_nodes])
73 | norm = torch.norm(x, dim=1).unsqueeze(-1)
74 | entropies = norm / norm
75 |
76 | # graph convolution and probability scores
77 | probabilities = self.gc1(g, entropies)
78 | probabilities = self.gc2(g, probabilities)
79 | probabilities = self.gc3(g, probabilities)
80 | probabilities = torch.sigmoid(probabilities)
81 |
82 | # conditional expectation; Algorithm 1
83 | gamma = entropies.sum()
84 | loss = self.loss_fn(entropies, probabilities, A, gamma) # Eq. (9)
85 |
86 | mewis = self.conditional_expectation(entropies, probabilities, A, loss, gamma)
87 |
88 | # graph reconstruction; Eq. (10)
89 | g, h, adj_pooled = self.graph_reconstruction(mewis, x, A, device=x.device)
90 | edge_index_pooled, batch_pooled = self.to_edge_index(adj_pooled, mewis, batch)
91 |
92 | return g, h, loss, batch_pooled
93 |
94 | @staticmethod
95 | def matrix2tensor(csr_matrix, device):
96 | coo_matrix = csr_matrix.tocoo()
97 | values = coo_matrix.data
98 | indices = np.vstack((coo_matrix.row, coo_matrix.col))
99 |
100 | i = torch.tensor(indices, device=device).long()
101 | v = torch.tensor(values, device=device)
102 | shape = coo_matrix.shape
103 |
104 | return torch.sparse.FloatTensor(i, v, torch.Size(shape)).to_dense().to(torch.float32)
105 |
106 | @staticmethod
107 | def compute_entropy(x, L, A, batch):
108 | # computing local variations; Eq. (5)
109 | V = x * torch.matmul(L, x) - x * torch.matmul(A, x) + torch.matmul(A, x * x)
110 | V = torch.norm(V, dim=1)
111 |
112 | # computing the probability distributions based on the local variations; Eq. (7)
113 | P = torch.cat([torch.softmax(V[batch == i], dim=0) for i in torch.unique(batch)])
114 | P[P == 0.] += 1
115 | # computing the entropies; Eq. (8)
116 | H = -P * torch.log(P)
117 |
118 | return H.unsqueeze(-1)
119 |
120 | @staticmethod
121 | def loss_fn(entropies, probabilities, A, gamma):
122 | term1 = -torch.matmul(entropies.t(), probabilities)[0, 0]
123 |
124 | term2 = torch.matmul(torch.matmul(probabilities.t(), A), probabilities).sum()
125 |
126 | return gamma + term1 + term2
127 |
128 | def conditional_expectation(self, entropies, probabilities, A, threshold, gamma):
129 | sorted_probabilities = torch.sort(probabilities, descending=True, dim=0)
130 |
131 | dummy_probabilities = probabilities.detach().clone()
132 | selected = set()
133 | rejected = set()
134 |
135 | for i in range(sorted_probabilities.values.size(0)):
136 | node_index = sorted_probabilities.indices[i].item()
137 | neighbors = torch.where(A[node_index] == 1)[0]
138 | if len(neighbors) == 0:
139 | selected.add(node_index)
140 | continue
141 | if node_index not in rejected and node_index not in selected:
142 | s = dummy_probabilities.clone()
143 | s[node_index] = 1
144 | s[neighbors] = 0
145 |
146 | loss = self.loss_fn(entropies, s, A, gamma)
147 |
148 | if loss <= threshold:
149 | selected.add(node_index)
150 | for n in neighbors.tolist():
151 | rejected.add(n)
152 |
153 | dummy_probabilities[node_index] = 1
154 | dummy_probabilities[neighbors] = 0
155 |
156 | mewis = list(selected)
157 | mewis = sorted(mewis)
158 |
159 | return mewis
160 |
161 | def graph_reconstruction(self, mewis, x, A, device):
162 | x_pooled = x[mewis]
163 |
164 | A2 = torch.matmul(A, A)
165 | A3 = torch.matmul(A2, A)
166 |
167 | A2 = A2[mewis][:, mewis]
168 | A3 = A3[mewis][:, mewis]
169 |
170 | I = torch.eye(len(mewis), device=device)
171 | one = torch.ones([len(mewis), len(mewis)], device=device)
172 |
173 | adj_pooled = (one - I) * torch.clamp(A2 + A3, min=0, max=1)
174 |
175 | src, dst = np.nonzero(adj_pooled.cpu().numpy())
176 | g = dgl.graph((src, dst), num_nodes=adj_pooled.size(0)).to(device)
177 | g.ndata['h'] = x_pooled
178 |
179 | return g, x_pooled, adj_pooled
180 |
181 | @staticmethod
182 | def to_edge_index(adj_pooled, mewis, batch):
183 | row1, row2 = torch.where(adj_pooled > 0)
184 | edge_index_pooled = torch.cat([row1.unsqueeze(0), row2.unsqueeze(0)], dim=0)
185 | batch_pooled = batch[mewis]
186 |
187 | return edge_index_pooled, batch_pooled
188 |
189 |
190 | class MEWISPoolNet(nn.Module):
191 | def __init__(self, net_params, readout='sum'):
192 | super(MEWISPoolNet, self).__init__()
193 |
194 | input_dim = net_params['in_dim']
195 | hidden_dim = net_params['hidden_dim']
196 | n_classes = net_params['n_classes']
197 | self.e_feat = net_params['edge_feat']
198 |
199 | self.gc1 = GINConv(MLP(input_dim=input_dim, hidden_dim=hidden_dim, output_dim=hidden_dim, enhance=True), 'sum')
200 | self.pool1 = MEWISPool(hidden_dim=hidden_dim)
201 | self.gc2 = GINConv(MLP(input_dim=hidden_dim, hidden_dim=hidden_dim, output_dim=hidden_dim, enhance=True), 'sum')
202 | self.pool2 = MEWISPool(hidden_dim=hidden_dim)
203 | self.gc3 = GINConv(MLP(input_dim=hidden_dim, hidden_dim=hidden_dim, output_dim=hidden_dim, enhance=True), 'sum')
204 | self.fc1 = nn.Linear(in_features=hidden_dim, out_features=hidden_dim)
205 | self.fc2 = nn.Linear(in_features=hidden_dim, out_features=n_classes)
206 |
207 | if readout == 'sum':
208 | self.pool = SumPooling()
209 | elif readout == 'mean':
210 | self.pool = AvgPooling()
211 | elif readout == 'max':
212 | self.pool = MaxPooling()
213 | else:
214 | raise NotImplementedError
215 |
216 | def forward(self, g, h, e):
217 | batch = []
218 | for i, g_n_nodes in enumerate(g.batch_num_nodes()):
219 | batch += [i] * g_n_nodes
220 | batch = torch.tensor(batch, device=h.device)
221 |
222 | x = torch.relu(self.gc1(g, h))
223 | g_pooled1, x_pooled1, loss1, batch_pooled1 = self.pool1(g, x, batch)
224 | x_pooled1 = torch.relu(self.gc2(g_pooled1, x_pooled1))
225 | g_pooled2, x_pooled2, loss2, batch_pooled2 = self.pool2(g_pooled1, x_pooled1, batch_pooled1)
226 | x_pooled2 = self.gc3(g_pooled2, x_pooled2)
227 |
228 | # readout = torch.cat([self.pool(g, x), self.pool(g_pooled1, x_pooled1), self.pool(g_pooled2, x_pooled2)], dim=1)
229 | # readout = self.pool(g_pooled2, x_pooled2)
230 | readout = torch.cat([x_pooled2[batch_pooled2 == i].mean(0).unsqueeze(0) for i in torch.unique(batch_pooled2)], dim=0)
231 |
232 | out = torch.relu(self.fc1(readout))
233 | out = self.fc2(out)
234 |
235 | return torch.log_softmax(out, dim=-1), loss1 + loss2
236 |
237 | def loss(self, pred, label, loss_pool):
238 | criterion = nn.CrossEntropyLoss()
239 | loss_classification = criterion(pred, label.view(-1))
240 | loss = loss_classification + 0.01 * loss_pool
241 | return loss
242 |
243 |
244 | class Net2(nn.Module):
245 | def __init__(self, input_dim, hidden_dim, num_classes):
246 | super(Net2, self).__init__()
247 |
248 | self.gc1 = GINConv(MLP(input_dim=input_dim, hidden_dim=hidden_dim, output_dim=hidden_dim, enhance=True))
249 | self.gc2 = GINConv(MLP(input_dim=hidden_dim, hidden_dim=hidden_dim, output_dim=hidden_dim, enhance=True))
250 | self.gc3 = GINConv(MLP(input_dim=hidden_dim, hidden_dim=hidden_dim, output_dim=hidden_dim, enhance=True))
251 | self.pool1 = MEWISPool(hidden_dim=hidden_dim)
252 | self.gc4 = GINConv(MLP(input_dim=hidden_dim, hidden_dim=hidden_dim, output_dim=hidden_dim, enhance=True))
253 | self.fc1 = nn.Linear(in_features=hidden_dim, out_features=hidden_dim)
254 | self.fc2 = nn.Linear(in_features=hidden_dim, out_features=num_classes)
255 |
256 | def forward(self, x, edge_index, batch):
257 | x = self.gc1(x, edge_index)
258 | x = torch.relu(x)
259 |
260 | x = self.gc2(x, edge_index)
261 | x = torch.relu(x)
262 |
263 | x = self.gc3(x, edge_index)
264 | x = torch.relu(x)
265 | readout2 = torch.cat([x[batch == i].mean(0).unsqueeze(0) for i in torch.unique(batch)], dim=0)
266 |
267 | x_pooled1, edge_index_pooled1, batch_pooled1, loss1, mewis = self.pool1(x, edge_index, batch)
268 |
269 | x_pooled1 = self.gc4(x_pooled1, edge_index_pooled1)
270 |
271 | readout = torch.cat([x_pooled1[batch_pooled1 == i].mean(0).unsqueeze(0) for i in torch.unique(batch_pooled1)],
272 | dim=0)
273 |
274 | out = readout2 + readout
275 |
276 | out = self.fc1(out)
277 | out = torch.relu(out)
278 | out = self.fc2(out)
279 |
280 | return torch.log_softmax(out, dim=-1), loss1
281 |
282 |
283 | class Net3(nn.Module):
284 | def __init__(self, input_dim, hidden_dim, num_classes):
285 | super(Net3, self).__init__()
286 |
287 | self.gc1 = GINConv(MLP(input_dim=input_dim, hidden_dim=hidden_dim, output_dim=hidden_dim, enhance=True))
288 | self.gc2 = GINConv(MLP(input_dim=hidden_dim, hidden_dim=hidden_dim, output_dim=hidden_dim, enhance=True))
289 | self.pool1 = MEWISPool(hidden_dim=hidden_dim)
290 | self.gc3 = GINConv(MLP(input_dim=hidden_dim, hidden_dim=hidden_dim, output_dim=hidden_dim, enhance=True))
291 | self.gc4 = GINConv(MLP(input_dim=hidden_dim, hidden_dim=hidden_dim, output_dim=hidden_dim, enhance=True))
292 | self.pool2 = MEWISPool(hidden_dim=hidden_dim)
293 | self.gc5 = GINConv(MLP(input_dim=hidden_dim, hidden_dim=hidden_dim, output_dim=hidden_dim, enhance=True))
294 | self.fc1 = nn.Linear(in_features=hidden_dim, out_features=hidden_dim)
295 | self.fc2 = nn.Linear(in_features=hidden_dim, out_features=num_classes)
296 |
297 | def forward(self, x, edge_index, batch):
298 | x = self.gc1(x, edge_index)
299 | x = torch.relu(x)
300 |
301 | x = self.gc2(x, edge_index)
302 | x = torch.relu(x)
303 |
304 | x_pooled1, edge_index_pooled1, batch_pooled1, loss1, mewis1 = self.pool1(x, edge_index, batch)
305 |
306 | x_pooled1 = self.gc3(x_pooled1, edge_index_pooled1)
307 | x_pooled1 = torch.relu(x_pooled1)
308 |
309 | x_pooled1 = self.gc4(x_pooled1, edge_index_pooled1)
310 | x_pooled1 = torch.relu(x_pooled1)
311 |
312 | x_pooled2, edge_index_pooled2, batch_pooled2, loss2, mewis2 = self.pool2(x_pooled1, edge_index_pooled1,
313 | batch_pooled1)
314 |
315 | x_pooled2 = self.gc5(x_pooled2, edge_index_pooled2)
316 | x_pooled2 = torch.relu(x_pooled2)
317 |
318 | readout = torch.cat([x_pooled2[batch_pooled2 == i].mean(0).unsqueeze(0) for i in torch.unique(batch_pooled2)],
319 | dim=0)
320 |
321 | out = self.fc1(readout)
322 | out = torch.relu(out)
323 | out = self.fc2(out)
324 |
325 | return torch.log_softmax(out, dim=-1), loss1 + loss2
--------------------------------------------------------------------------------
/pooling/sagpool.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn
3 | import torch.nn.functional as F
4 | import dgl
5 | from dgl.nn import GraphConv, AvgPooling, MaxPooling
6 | from pooling.topkpool import TopKPooling, get_batch_id
7 | from layers.gcn_layer import GCNLayer
8 | from layers.mlp_readout_layer import MLPReadout
9 |
10 |
11 | class SAGPool(torch.nn.Module):
12 | """The Self-Attention Pooling layer in paper
13 | `Self Attention Graph Pooling `
14 |
15 | Args:
16 | in_dim (int): The dimension of node feature.
17 | ratio (float, optional): The pool ratio which determines the amount of nodes
18 | remain after pooling. (default: :obj:`0.5`)
19 | conv_op (torch.nn.Module, optional): The graph convolution layer in dgl used to
20 | compute scale for each node. (default: :obj:`dgl.nn.GraphConv`)
21 | non_linearity (Callable, optional): The non-linearity function, a pytorch function.
22 | (default: :obj:`torch.tanh`)
23 | """
24 | def __init__(self, in_dim:int, ratio=0.5, non_linearity=torch.tanh):
25 | super(SAGPool, self).__init__()
26 | self.in_dim = in_dim
27 | self.ratio = ratio
28 | if dgl.__version__ < "0.5":
29 | self.score_layer = GraphConv(in_dim, 1)
30 | else:
31 | self.score_layer = GraphConv(in_dim, 1, allow_zero_in_degree=True)
32 | self.non_linearity = non_linearity
33 | self.softmax = torch.nn.Softmax()
34 |
35 | def forward(self, graph:dgl.DGLGraph, feature:torch.Tensor, e_feat=None):
36 | score = self.score_layer(graph, feature).squeeze()
37 | perm, next_batch_num_nodes = TopKPooling(score, self.ratio, get_batch_id(graph.batch_num_nodes()), graph.batch_num_nodes())
38 | feature = feature[perm] * self.non_linearity(score[perm]).view(-1, 1)
39 | graph = dgl.node_subgraph(graph, perm)
40 |
41 | # node_subgraph currently does not support batch-graph,
42 | # the 'batch_num_nodes' of the result subgraph is None.
43 | # So we manually set the 'batch_num_nodes' here.
44 | # Since global pooling has nothing to do with 'batch_num_edges',
45 | # we can leave it to be None or unchanged.
46 | graph.set_batch_num_nodes(next_batch_num_nodes)
47 |
48 | score = self.softmax(score)
49 | if torch.nonzero(torch.isnan(score)).size(0) > 0:
50 | print(score[torch.nonzero(torch.isnan(score))])
51 | raise KeyError
52 |
53 | if e_feat is not None:
54 | e_feat = graph.edata['feat'].unsqueeze(-1)
55 |
56 | return graph, feature, perm, score, e_feat
57 |
58 |
59 | class SAGPoolBlock(torch.nn.Module):
60 | """A combination of GCN layer and SAGPool layer,
61 | followed by a concatenated (mean||sum) readout operation.
62 | return the graph embedding with the same dimension of node embedding.
63 | """
64 | def __init__(self, in_dim:int, pool_ratio=0.5):
65 | super(SAGPoolBlock, self).__init__()
66 | self.pool = SAGPool(in_dim, ratio=pool_ratio)
67 | self.avgpool = AvgPooling()
68 | self.maxpool = MaxPooling()
69 | self.mlp = torch.nn.Linear(in_dim * 2, in_dim)
70 |
71 | def forward(self, graph, feature):
72 | graph, out, _, _, _ = self.pool(graph, feature)
73 | g_out = torch.cat([self.avgpool(graph, out), self.maxpool(graph, out)], dim=-1)
74 | g_out = F.relu(self.mlp(g_out))
75 | return graph, out, g_out
76 |
77 |
78 | class SAGPoolReadout(torch.nn.Module):
79 | """A combination of GCN layer and SAGPool layer,
80 | followed by a concatenated (mean||sum) readout operation.
81 | """
82 | def __init__(self, net_params, pool_ratio=0.5, pool=True):
83 | super(SAGPoolReadout, self).__init__()
84 | in_dim = net_params['in_dim']
85 | out_dim = net_params['out_dim']
86 | dropout = net_params['dropout']
87 | n_classes = net_params['n_classes']
88 | self.batch_norm = net_params['batch_norm']
89 | self.residual = net_params['residual']
90 | self.e_feat = net_params['edge_feat']
91 | self.conv = GCNLayer(in_dim, out_dim, F.relu, dropout, self.batch_norm, self.residual, e_feat=self.e_feat)
92 | self.use_pool = pool
93 | self.pool = SAGPool(out_dim, ratio=pool_ratio)
94 | self.avgpool = AvgPooling()
95 | self.maxpool = MaxPooling()
96 | self.MLP_layer = MLPReadout(out_dim * 2, n_classes)
97 |
98 | def forward(self, graph, feature, e_feat=None):
99 | out, e_feat = self.conv(graph, feature, e_feat)
100 | if self.use_pool:
101 | graph, out, _, _, e_feat = self.pool(graph, out, e_feat)
102 | hg = torch.cat([self.avgpool(graph, out), self.maxpool(graph, out)], dim=-1)
103 | scores = self.MLP_layer(hg)
104 | return scores
105 |
106 | def loss(self, pred, label, cluster=False):
107 |
108 | criterion = torch.nn.CrossEntropyLoss()
109 | loss = criterion(pred, label)
110 |
111 | return loss
112 |
--------------------------------------------------------------------------------
/pooling/topkpool.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import logging
3 | from scipy.stats import t
4 | import math
5 |
6 |
7 | def get_batch_id(num_nodes:torch.Tensor):
8 | """Convert the num_nodes array obtained from batch graph to batch_id array
9 | for each node.
10 |
11 | Args:
12 | num_nodes (torch.Tensor): The tensor whose element is the number of nodes
13 | in each graph in the batch graph.
14 | """
15 | batch_size = num_nodes.size(0)
16 | batch_ids = []
17 | for i in range(batch_size):
18 | item = torch.full((num_nodes[i],), i, dtype=torch.long, device=num_nodes.device)
19 | batch_ids.append(item)
20 | return torch.cat(batch_ids)
21 |
22 |
23 | def TopKPooling(x:torch.Tensor, ratio:float, batch_id:torch.Tensor, num_nodes:torch.Tensor):
24 | """The top-k pooling method. Given a graph batch, this method will pool out some
25 | nodes from input node feature tensor for each graph according to the given ratio.
26 |
27 | Args:
28 | x (torch.Tensor): The input node feature batch-tensor to be pooled.
29 | ratio (float): the pool ratio. For example if :obj:`ratio=0.5` then half of the input
30 | tensor will be pooled out.
31 | batch_id (torch.Tensor): The batch_id of each element in the input tensor.
32 | num_nodes (torch.Tensor): The number of nodes of each graph in batch.
33 |
34 | Returns:
35 | perm (torch.Tensor): The index in batch to be kept.
36 | k (torch.Tensor): The remaining number of nodes for each graph.
37 | """
38 | batch_size, max_num_nodes = num_nodes.size(0), num_nodes.max().item()
39 |
40 | cum_num_nodes = torch.cat(
41 | [num_nodes.new_zeros(1),
42 | num_nodes.cumsum(dim=0)[:-1]], dim=0)
43 |
44 | index = torch.arange(batch_id.size(0), dtype=torch.long, device=x.device)
45 | index = (index - cum_num_nodes[batch_id]) + (batch_id * max_num_nodes)
46 |
47 | dense_x = x.new_full((batch_size * max_num_nodes, ), torch.finfo(x.dtype).min)
48 | dense_x[index] = x
49 | dense_x = dense_x.view(batch_size, max_num_nodes)
50 |
51 | _, perm = dense_x.sort(dim=-1, descending=True)
52 | perm = perm + cum_num_nodes.view(-1, 1)
53 | perm = perm.view(-1)
54 |
55 | k = (ratio * num_nodes.to(torch.float)).ceil().to(torch.long)
56 | mask = [
57 | torch.arange(k[i], dtype=torch.long, device=x.device) +
58 | i * max_num_nodes for i in range(batch_size)]
59 |
60 | mask = torch.cat(mask, dim=0)
61 | perm = perm[mask]
62 |
63 | return perm, k
64 |
--------------------------------------------------------------------------------
/train_TUs_graph_classification.py:
--------------------------------------------------------------------------------
1 | """
2 | Utility functions for training one epoch
3 | and evaluating one epoch
4 | """
5 | import torch
6 | import torch.nn as nn
7 | import math
8 |
9 | from metrics import accuracy_TU as accuracy
10 |
11 | """
12 | For GCNs
13 | """
14 | def train_epoch_sparse(model, optimizer, device, data_loader, epoch):
15 | model.train()
16 | epoch_loss = 0
17 | epoch_train_acc = 0
18 | nb_data = 0
19 | gpu_mem = 0
20 | for iter, (batch_graphs, batch_labels) in enumerate(data_loader):
21 | batch_graphs = batch_graphs.to(device)
22 | batch_x = batch_graphs.ndata['feat'].to(device) # num x feat
23 | batch_e = batch_graphs.edata['feat'].to(device)
24 | batch_labels = batch_labels.to(device)
25 | optimizer.zero_grad()
26 |
27 | batch_scores = model.forward(batch_graphs, batch_x, batch_e)
28 | loss = model.loss(batch_scores, batch_labels)
29 | loss.backward()
30 | optimizer.step()
31 | epoch_loss += loss.detach().item()
32 | epoch_train_acc += accuracy(batch_scores, batch_labels)
33 | nb_data += batch_labels.size(0)
34 | epoch_loss /= (iter + 1)
35 | epoch_train_acc /= nb_data
36 |
37 | return epoch_loss, epoch_train_acc, optimizer
38 |
39 | def evaluate_network_sparse(model, device, data_loader, epoch):
40 | model.eval()
41 | epoch_test_loss = 0
42 | epoch_test_acc = 0
43 | nb_data = 0
44 | with torch.no_grad():
45 | for iter, (batch_graphs, batch_labels) in enumerate(data_loader):
46 | batch_graphs = batch_graphs.to(device)
47 | batch_x = batch_graphs.ndata['feat'].to(device)
48 | batch_e = batch_graphs.edata['feat'].to(device)
49 | batch_labels = batch_labels.to(device)
50 |
51 | batch_scores = model.forward(batch_graphs, batch_x, batch_e)
52 | loss = model.loss(batch_scores, batch_labels)
53 | epoch_test_loss += loss.detach().item()
54 | epoch_test_acc += accuracy(batch_scores, batch_labels)
55 | nb_data += batch_labels.size(0)
56 | epoch_test_loss /= (iter + 1)
57 | epoch_test_acc /= nb_data
58 |
59 | return epoch_test_loss, epoch_test_acc
60 |
61 |
62 |
63 |
64 | """
65 | For WL-GNNs
66 | """
67 | def train_epoch_dense(model, optimizer, device, data_loader, epoch, batch_size):
68 | model.train()
69 | epoch_loss = 0
70 | epoch_train_acc = 0
71 | nb_data = 0
72 | gpu_mem = 0
73 | optimizer.zero_grad()
74 | for iter, (x_with_node_feat, labels) in enumerate(data_loader):
75 | x_with_node_feat = x_with_node_feat.to(device)
76 | labels = labels.to(device)
77 |
78 | scores = model.forward(x_with_node_feat)
79 | loss = model.loss(scores, labels)
80 | loss.backward()
81 |
82 | if not (iter%batch_size):
83 | optimizer.step()
84 | optimizer.zero_grad()
85 |
86 | epoch_loss += loss.detach().item()
87 | epoch_train_acc += accuracy(scores, labels)
88 | nb_data += labels.size(0)
89 | epoch_loss /= (iter + 1)
90 | epoch_train_acc /= nb_data
91 |
92 | return epoch_loss, epoch_train_acc, optimizer
93 |
94 | def evaluate_network_dense(model, device, data_loader, epoch):
95 | model.eval()
96 | epoch_test_loss = 0
97 | epoch_test_acc = 0
98 | nb_data = 0
99 | with torch.no_grad():
100 | for iter, (x_with_node_feat, labels) in enumerate(data_loader):
101 | x_with_node_feat = x_with_node_feat.to(device)
102 | labels = labels.to(device)
103 |
104 | scores = model.forward(x_with_node_feat)
105 | loss = model.loss(scores, labels)
106 | epoch_test_loss += loss.detach().item()
107 | epoch_test_acc += accuracy(scores, labels)
108 | nb_data += labels.size(0)
109 | epoch_test_loss /= (iter + 1)
110 | epoch_test_acc /= nb_data
111 |
112 | return epoch_test_loss, epoch_test_acc
113 |
114 |
115 |
116 |
117 | def check_patience(all_losses, best_loss, best_epoch, curr_loss, curr_epoch, counter):
118 | if curr_loss < best_loss:
119 | counter = 0
120 | best_loss = curr_loss
121 | best_epoch = curr_epoch
122 | else:
123 | counter += 1
124 | return best_loss, best_epoch, counter
--------------------------------------------------------------------------------