├── .gitattributes ├── .gitignore ├── LICENSE ├── README.md ├── configs ├── DD │ ├── TUs_graph_classification_DiffPool_DD_100k.json │ └── TUs_graph_classification_MEWISPool_DD_100k.json ├── FRANKENSTEIN │ ├── TUs_graph_classification_DiffPool_FRANKENSTEIN_100k.json │ ├── TUs_graph_classification_GAT_FRANKENSTEIN_100k.json │ └── TUs_graph_classification_MEWISPool_FRANKENSTEIN_100k.json ├── MUTAG │ ├── TUs_graph_classification_DiffPool_MUTAG_100k.json │ ├── TUs_graph_classification_GAT_MUTAG_100k.json │ └── TUs_graph_classification_MEWISPool_MUTAG_100k.json ├── NCI1 │ ├── TUs_graph_classification_DiffPool_NCI1_100k.json │ ├── TUs_graph_classification_GAT_NCI1_100k.json │ └── TUs_graph_classification_MEWISPool_NCI1_100k.json ├── NCI109 │ ├── TUs_graph_classification_DiffPool_NCI109_100k.json │ ├── TUs_graph_classification_GAT_NCI109_100k.json │ └── TUs_graph_classification_MEWISPool_NCI109_100k.json ├── PROTEINS_full │ ├── TUs_graph_classification_DiffPool_PROTEINS_full_100k.json │ └── TUs_graph_classification_MEWISPool_PROTEINS_full_100k.json ├── TUs_graph_classification_3WLGNN_DD_100k.json ├── TUs_graph_classification_3WLGNN_ENZYMES_100k.json ├── TUs_graph_classification_3WLGNN_PROTEINS_full_100k.json ├── TUs_graph_classification_GAT_DD_100k.json ├── TUs_graph_classification_GAT_ENZYMES_100k.json ├── TUs_graph_classification_GAT_PROTEINS_full_100k.json ├── TUs_graph_classification_GCN_DD_100k.json ├── TUs_graph_classification_GCN_ENZYMES_100k.json ├── TUs_graph_classification_GCN_PROTEINS_full_100k.json ├── TUs_graph_classification_GIN_DD_100k.json ├── TUs_graph_classification_GIN_ENZYMES_100k.json ├── TUs_graph_classification_GIN_PROTEINS_full_100k.json ├── TUs_graph_classification_GatedGCN_DD_100k.json ├── TUs_graph_classification_GatedGCN_ENZYMES_100k.json ├── TUs_graph_classification_GatedGCN_PROTEINS_full_100k.json ├── TUs_graph_classification_GraphSage_DD_100k.json ├── TUs_graph_classification_GraphSage_ENZYMES_100k.json ├── TUs_graph_classification_GraphSage_PROTEINS_full_100k.json ├── TUs_graph_classification_MLP_DD_100k.json ├── TUs_graph_classification_MLP_ENZYMES_100k.json ├── TUs_graph_classification_MLP_PROTEINS_full_100k.json ├── TUs_graph_classification_MoNet_DD_100k.json ├── TUs_graph_classification_MoNet_ENZYMES_100k.json ├── TUs_graph_classification_MoNet_PROTEINS_full_100k.json ├── TUs_graph_classification_RingGNN_DD_100k.json ├── TUs_graph_classification_RingGNN_ENZYMES_100k.json ├── TUs_graph_classification_RingGNN_PROTEINS_full_100k.json └── Tox21_AhR_training │ ├── TUs_graph_classification_DiffPool_Tox21_AhR_training_100k.json │ ├── TUs_graph_classification_GAT_Tox21_AhR_training_100k.json │ └── TUs_graph_classification_MEWISPool_Tox21_AhR_training_100k.json ├── data ├── TUs.py ├── TUs │ ├── COLLAB_test.index │ ├── COLLAB_train.index │ ├── COLLAB_val.index │ ├── DD_test.index │ ├── DD_train.index │ ├── DD_val.index │ ├── ENZYMES_test.index │ ├── ENZYMES_train.index │ ├── ENZYMES_val.index │ ├── PROTEINS_full_test.index │ ├── PROTEINS_full_train.index │ ├── PROTEINS_full_val.index │ └── README.md └── data.py ├── layers ├── diffpool_layer.py ├── gat_layer.py ├── gated_gcn_layer.py ├── gcn_layer.py ├── gin_layer.py ├── gmm_layer.py ├── graphsage_layer.py ├── mlp_readout_layer.py ├── ring_gnn_equiv_layer.py └── three_wl_gnn_layers.py ├── main.py ├── metrics.py ├── nets ├── gat_net.py ├── gated_gcn_net.py ├── gcn_net.py ├── gin_net.py ├── graphsage_net.py ├── load_net.py ├── mlp_net.py ├── mo_net.py ├── ring_gnn_net.py └── three_wl_gnn_net.py ├── pooling ├── diffpool.py ├── edge_sparse_max.py ├── hgpslpool.py ├── mewispool.py ├── sagpool.py └── topkpool.py ├── train_TUs_graph_classification.py └── visualization ├── visualizations_dataset_stats.ipynb └── visualizations_single_graph.ipynb /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .nox/ 42 | .coverage 43 | .coverage.* 44 | .cache 45 | nosetests.xml 46 | coverage.xml 47 | *.cover 48 | .hypothesis/ 49 | .pytest_cache/ 50 | 51 | # Translations 52 | *.mo 53 | *.pot 54 | 55 | # Django stuff: 56 | *.log 57 | local_settings.py 58 | db.sqlite3 59 | 60 | # Flask stuff: 61 | instance/ 62 | .webassets-cache 63 | 64 | # Scrapy stuff: 65 | .scrapy 66 | 67 | # Sphinx documentation 68 | docs/_build/ 69 | 70 | # PyBuilder 71 | target/ 72 | 73 | # Jupyter Notebook 74 | .ipynb_checkpoints 75 | 76 | # IPython 77 | profile_default/ 78 | ipython_config.py 79 | 80 | # pyenv 81 | .python-version 82 | 83 | # celery beat schedule file 84 | celerybeat-schedule 85 | 86 | # SageMath parsed files 87 | *.sage.py 88 | 89 | # Environments 90 | .env 91 | .venv 92 | env/ 93 | venv/ 94 | ENV/ 95 | env.bak/ 96 | venv.bak/ 97 | 98 | # Spyder project settings 99 | .spyderproject 100 | .spyproject 101 | 102 | # Rope project settings 103 | .ropeproject 104 | 105 | # mkdocs documentation 106 | /site 107 | 108 | # mypy 109 | .mypy_cache/ 110 | .dmypy.json 111 | dmypy.json 112 | # mypy 113 | .DS_Store 114 | .idea 115 | *.bak 116 | *.pkl 117 | save 118 | log 119 | log.test 120 | log.txt 121 | outputs 122 | out 123 | tmp 124 | tmp1.sh 125 | tmp2.sh 126 | tmp3.sh 127 | tmp4.sh 128 | result/ 129 | 130 | # Pyre type checker 131 | .pyre/ 132 | 133 | 134 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Graph-Classification-Baseline 2 | 3 | A framework of graph classification baselines which including TUDataset Loader, GNN models and visualization. 4 | 5 | ## Feature 6 | 7 | - Datasets 8 | 9 | Implemented based on LegacyTUDataset of dgl.data. All [TUDataset](https://chrsmrrs.github.io/datasets/docs/datasets/) is avalible. It will download auotomatically when using. 10 | 11 | - GNN models 12 | 13 | Including GCN, GAT, Gated GCN, GIN, GraphSAGE, etc. Deafult configurations can be find in `configs/`. 14 | 15 | **Models with configs having 500k trainable parameters** 16 | 17 | |Rank|Model | #Params | Test Acc ± s.d. | Links | 18 | |----| ---------- |------------:| :--------:|:-------:| 19 | |1|GatedGCN-PE | 505421 | 86.363 ± 0.127| [Paper](https://arxiv.org/abs/2003.00982) | 20 | |2|RingGNN | 504766 | 86.244 ± 0.025 |[Paper](https://papers.nips.cc/paper/9718-on-the-equivalence-between-graph-isomorphism-testing-and-function-approximation-with-gnns) | 21 | |3|MoNet | 511487 | 85.582 ± 0.038 | [Paper](https://arxiv.org/abs/1611.08402) | 22 | |4|GatedGCN | 502223 | 85.568 ± 0.088 | [Paper](https://arxiv.org/abs/1711.07553) | 23 | |5|GIN | 508574 | 85.387 ± 0.136 | [Paper](https://arxiv.org/abs/1810.00826)| 24 | |6|3WLGNN | 502872 | 85.341 ± 0.207 | [Paper](https://arxiv.org/abs/1905.11136) | 25 | |7|GAT | 526990 | 78.271 ± 0.186 | [Paper](https://arxiv.org/abs/1710.10903) | 26 | |8|GCN | 500823 | 71.892 ± 0.334 | [Paper](https://arxiv.org/abs/1609.02907) | 27 | |9|GraphSage | 502842 | 50.492 ± 0.001 | [Paper](https://cs.stanford.edu/people/jure/pubs/graphsage-nips17.pdf) | 28 | 29 | - Visualization 30 | 31 | - [Visualize for dataset statistics](./visualization/visualizations_dataset_stats.ipynb) 32 | - [Visualize for single graph](./visualization/visualizations_single_graph.ipynb) 33 | 34 | 35 | ## Setup 36 | 37 | All test on server 201 with python3.9 and CUDA 11.1. An environment is available at `/home/jiaxing/anaconda3/envs/xjx`. 38 | 39 | ``` 40 | pip install dgl-cu111 -f https://data.dgl.ai/wheels/repo.html 41 | pip install torch==1.9.0+cu111 torchvision==0.10.0+cu111 torchaudio==0.9.0 -f https://download.pytorch.org/whl/torch_stable.html 42 | pip install torch-scatter torch-sparse torch-cluster torch-spline-conv torch-geometric -f https://data.pyg.org/whl/torch-1.9.0+cu111.html 43 | pip install matplotlib 44 | pip install tenserboardX 45 | ``` 46 | 47 | ## Usage 48 | 49 | ### Quick Start 50 | 51 | ``` 52 | python main.py --config configs/TUs_graph_classification_GCN_DD_100k.json 53 | ``` 54 | 55 | ### Output and Checkpoints 56 | 57 | Output results are located in the folder defined by the variable `out_dir` in the corresponding config file. 58 | 59 | If `out_dir = 'out/TUs_graph_classification/'`, then 60 | 61 | - Go to `out/TUs_graph_classification/results to view all result text files.` 62 | - Directory `out/TUs_graph_classification/checkpoints` contains model checkpoints. 63 | - To see the training logs in Tensorboard on local machine 64 | 65 | - Go to the logs directory, i.e. `out/TUs_graph_classification/logs/`. 66 | 67 | - Run `tensorboard --logdir='./' --port 6006`. 68 | 69 | - Open `http://localhost:6006 in your browser`. Note that the port information (here 6006 but it may change) appears on the terminal immediately after starting tensorboard. 70 | 71 | ## Design your own model 72 | 73 | 74 | ### New graph layer 75 | 76 | Add a class `MyGraphLayer()` in `my_graph_layer.py` file in the `layers/` directory. A standard code is 77 | 78 | ``` 79 | import torch 80 | import torch.nn as nn 81 | import dgl 82 | 83 | class MyGraphLayer(nn.Module): 84 | 85 | def __init__(self, in_dim, out_dim, dropout): 86 | super().__init__() 87 | 88 | # write your code here 89 | 90 | def forward(self, g_in, h_in, e_in): 91 | 92 | # write your code here 93 | # write the dgl reduce and updates function call here 94 | 95 | return h_out, e_out 96 | ``` 97 | Directory `layers/` contains all layer classes for all graph networks and standard layers like *MLP* for *readout* layers. 98 | 99 | As instance, the GCN class `GCNLayer()` is defined in the [layers/gcn_layer.py](./layers/gcn_layer.py) file. 100 | 101 | 102 | 103 | 104 |
105 | 106 | ## 2. New graph network 107 | 108 | Add a class `MyGraphNetwork()` in `my_gcn_net.py` file in the `net/` directory. The `loss()` function of the network is also defined in class `MyGraphNetwork()`. 109 | 110 | ``` 111 | import torch 112 | import torch.nn as nn 113 | import dgl 114 | 115 | from layers.my_graph_layer import MyGraphLayer 116 | 117 | class MyGraphNetwork(nn.Module): 118 | 119 | def __init__(self, in_dim, out_dim, dropout): 120 | super().__init__() 121 | 122 | # write your code here 123 | self.layer = MyGraphLayer() 124 | 125 | def forward(self, g_in, h_in, e_in): 126 | 127 | # write your code here 128 | # write the dgl reduce and updates function call here 129 | 130 | return h_out 131 | 132 | def loss(self, pred, label): 133 | 134 | # write your loss function here 135 | 136 | return loss 137 | ``` 138 | 139 | Add a name `MyGNN` for the proposed new graph network class in `load_gnn.py` file in the `net/` directory. 140 | 141 | ``` 142 | from nets.my_gcn_net import MyGraphNetwork 143 | 144 | def MyGNN(net_params): 145 | return MyGraphNetwork(net_params) 146 | 147 | def gnn_model(MODEL_NAME, net_params): 148 | models = { 149 | 'MyGNN': MyGNN 150 | } 151 | return models[MODEL_NAME](net_params) 152 | ``` 153 | 154 | 155 | For example, `GCNNet()` in [nets/gcn_net.py](./nets/gcn_net.py) is given the GNN name `GCN` in [nets/load_net.py](./nets/load_net.py). 156 | 157 | 158 | ## Reference 159 | 160 | Benchmarking Graph Neural Networks [[paper](https://arxiv.org/pdf/2003.00982.pdf%3C/p%3E)] [[Github](https://github.com/graphdeeplearning/benchmarking-gnns)] 161 | https://codechina.csdn.net/mirrors/dmlc/dgl/-/tree/master/examples/pytorch -------------------------------------------------------------------------------- /configs/DD/TUs_graph_classification_DiffPool_DD_100k.json: -------------------------------------------------------------------------------- 1 | { 2 | "gpu": { 3 | "use": true, 4 | "id": 0 5 | }, 6 | 7 | "model": "DiffPool", 8 | "dataset": "DD", 9 | 10 | "out_dir": "out/TUs_graph_classification/", 11 | 12 | "params": { 13 | "seed": 41, 14 | "epochs": 1000, 15 | "batch_size": 10, 16 | "init_lr": 7e-4, 17 | "lr_reduce_factor": 0.5, 18 | "lr_schedule_patience": 25, 19 | "min_lr": 1e-6, 20 | "weight_decay": 0.0, 21 | "print_epoch_interval": 5, 22 | "max_time": 30 23 | }, 24 | 25 | "net_params": { 26 | "L": 4, 27 | "hidden_dim": 32, 28 | "out_dim": 32, 29 | "residual": true, 30 | "readout": "mean", 31 | "in_feat_dropout": 0.0, 32 | "dropout": 0.0, 33 | "batch_norm": true, 34 | "sage_aggregator": "maxpool", 35 | "self_loop": false, 36 | "edge_feat": false 37 | } 38 | } -------------------------------------------------------------------------------- /configs/DD/TUs_graph_classification_MEWISPool_DD_100k.json: -------------------------------------------------------------------------------- 1 | { 2 | "gpu": { 3 | "use": true, 4 | "id": 0 5 | }, 6 | 7 | "model": "MEWISPool", 8 | "dataset": "DD", 9 | 10 | "out_dir": "out/TUs_graph_classification/", 11 | 12 | "params": { 13 | "seed": 41, 14 | "epochs": 1000, 15 | "batch_size": 20, 16 | "init_lr": 7e-4, 17 | "lr_reduce_factor": 0.5, 18 | "lr_schedule_patience": 25, 19 | "min_lr": 1e-6, 20 | "weight_decay": 0.0, 21 | "print_epoch_interval": 5, 22 | "max_time": 30 23 | }, 24 | 25 | "net_params": { 26 | "L": 4, 27 | "hidden_dim": 146, 28 | "out_dim": 146, 29 | "residual": true, 30 | "readout": "mean", 31 | "in_feat_dropout": 0.0, 32 | "dropout": 0.0, 33 | "batch_norm": true, 34 | "self_loop": false, 35 | "edge_feat": false 36 | } 37 | } -------------------------------------------------------------------------------- /configs/FRANKENSTEIN/TUs_graph_classification_DiffPool_FRANKENSTEIN_100k.json: -------------------------------------------------------------------------------- 1 | { 2 | "gpu": { 3 | "use": true, 4 | "id": 0 5 | }, 6 | 7 | "model": "DiffPool", 8 | "dataset": "FRANKENSTEIN", 9 | 10 | "out_dir": "out/TUs_graph_classification/", 11 | 12 | "params": { 13 | "seed": 41, 14 | "epochs": 1000, 15 | "batch_size": 20, 16 | "init_lr": 7e-4, 17 | "lr_reduce_factor": 0.5, 18 | "lr_schedule_patience": 25, 19 | "min_lr": 1e-6, 20 | "weight_decay": 0.0, 21 | "print_epoch_interval": 5, 22 | "max_time": 30 23 | }, 24 | 25 | "net_params": { 26 | "L": 4, 27 | "hidden_dim": 86, 28 | "out_dim": 86, 29 | "residual": true, 30 | "readout": "mean", 31 | "in_feat_dropout": 0.0, 32 | "dropout": 0.0, 33 | "batch_norm": true, 34 | "sage_aggregator": "maxpool", 35 | "self_loop": false, 36 | "edge_feat": false 37 | } 38 | } -------------------------------------------------------------------------------- /configs/FRANKENSTEIN/TUs_graph_classification_GAT_FRANKENSTEIN_100k.json: -------------------------------------------------------------------------------- 1 | { 2 | "gpu": { 3 | "use": true, 4 | "id": 0 5 | }, 6 | 7 | "model": "GAT", 8 | "dataset": "FRANKENSTEIN", 9 | 10 | "out_dir": "out/TUs_graph_classification/", 11 | 12 | "params": { 13 | "seed": 41, 14 | "epochs": 1000, 15 | "batch_size": 20, 16 | "init_lr": 5e-5, 17 | "lr_reduce_factor": 0.5, 18 | "lr_schedule_patience": 25, 19 | "min_lr": 1e-6, 20 | "weight_decay": 0.0, 21 | "print_epoch_interval": 5, 22 | "max_time": 30 23 | }, 24 | 25 | "net_params": { 26 | "L": 4, 27 | "hidden_dim": 17, 28 | "out_dim": 136, 29 | "residual": true, 30 | "readout": "mean", 31 | "n_heads": 8, 32 | "in_feat_dropout": 0.0, 33 | "dropout": 0.0, 34 | "batch_norm": true, 35 | "self_loop": false, 36 | "edge_feat": false 37 | } 38 | } -------------------------------------------------------------------------------- /configs/FRANKENSTEIN/TUs_graph_classification_MEWISPool_FRANKENSTEIN_100k.json: -------------------------------------------------------------------------------- 1 | { 2 | "gpu": { 3 | "use": true, 4 | "id": 0 5 | }, 6 | 7 | "model": "MEWISPool", 8 | "dataset": "FRANKENSTEIN", 9 | 10 | "out_dir": "out/TUs_graph_classification/", 11 | 12 | "params": { 13 | "seed": 41, 14 | "epochs": 1000, 15 | "batch_size": 20, 16 | "init_lr": 7e-4, 17 | "lr_reduce_factor": 0.5, 18 | "lr_schedule_patience": 25, 19 | "min_lr": 1e-6, 20 | "weight_decay": 0.0, 21 | "print_epoch_interval": 5, 22 | "max_time": 30 23 | }, 24 | 25 | "net_params": { 26 | "L": 4, 27 | "hidden_dim": 146, 28 | "out_dim": 146, 29 | "residual": true, 30 | "readout": "mean", 31 | "in_feat_dropout": 0.0, 32 | "dropout": 0.0, 33 | "batch_norm": true, 34 | "self_loop": false, 35 | "edge_feat": false 36 | } 37 | } -------------------------------------------------------------------------------- /configs/MUTAG/TUs_graph_classification_DiffPool_MUTAG_100k.json: -------------------------------------------------------------------------------- 1 | { 2 | "gpu": { 3 | "use": true, 4 | "id": 0 5 | }, 6 | 7 | "model": "DiffPool", 8 | "dataset": "MUTAG", 9 | 10 | "out_dir": "out/TUs_graph_classification/", 11 | 12 | "params": { 13 | "seed": 41, 14 | "epochs": 1000, 15 | "batch_size": 20, 16 | "init_lr": 7e-4, 17 | "lr_reduce_factor": 0.5, 18 | "lr_schedule_patience": 25, 19 | "min_lr": 1e-6, 20 | "weight_decay": 0.0, 21 | "print_epoch_interval": 5, 22 | "max_time": 30 23 | }, 24 | 25 | "net_params": { 26 | "L": 4, 27 | "hidden_dim": 86, 28 | "out_dim": 86, 29 | "residual": true, 30 | "readout": "mean", 31 | "in_feat_dropout": 0.0, 32 | "dropout": 0.0, 33 | "batch_norm": true, 34 | "sage_aggregator": "maxpool", 35 | "self_loop": false, 36 | "edge_feat": false 37 | } 38 | } -------------------------------------------------------------------------------- /configs/MUTAG/TUs_graph_classification_GAT_MUTAG_100k.json: -------------------------------------------------------------------------------- 1 | { 2 | "gpu": { 3 | "use": true, 4 | "id": 0 5 | }, 6 | 7 | "model": "GAT", 8 | "dataset": "MUTAG", 9 | 10 | "out_dir": "out/TUs_graph_classification/", 11 | 12 | "params": { 13 | "seed": 41, 14 | "epochs": 1000, 15 | "batch_size": 20, 16 | "init_lr": 5e-5, 17 | "lr_reduce_factor": 0.5, 18 | "lr_schedule_patience": 25, 19 | "min_lr": 1e-6, 20 | "weight_decay": 0.0, 21 | "print_epoch_interval": 5, 22 | "max_time": 30 23 | }, 24 | 25 | "net_params": { 26 | "L": 4, 27 | "hidden_dim": 17, 28 | "out_dim": 136, 29 | "residual": true, 30 | "readout": "mean", 31 | "n_heads": 8, 32 | "in_feat_dropout": 0.0, 33 | "dropout": 0.0, 34 | "batch_norm": true, 35 | "self_loop": false, 36 | "edge_feat": false 37 | } 38 | } -------------------------------------------------------------------------------- /configs/MUTAG/TUs_graph_classification_MEWISPool_MUTAG_100k.json: -------------------------------------------------------------------------------- 1 | { 2 | "gpu": { 3 | "use": true, 4 | "id": 0 5 | }, 6 | 7 | "model": "MEWISPool", 8 | "dataset": "MUTAG", 9 | 10 | "out_dir": "out/TUs_graph_classification/", 11 | 12 | "params": { 13 | "seed": 41, 14 | "epochs": 1000, 15 | "batch_size": 20, 16 | "init_lr": 7e-4, 17 | "lr_reduce_factor": 0.5, 18 | "lr_schedule_patience": 25, 19 | "min_lr": 1e-6, 20 | "weight_decay": 0.0, 21 | "print_epoch_interval": 5, 22 | "max_time": 30 23 | }, 24 | 25 | "net_params": { 26 | "L": 4, 27 | "hidden_dim": 146, 28 | "out_dim": 146, 29 | "residual": true, 30 | "readout": "mean", 31 | "in_feat_dropout": 0.0, 32 | "dropout": 0.0, 33 | "batch_norm": true, 34 | "self_loop": false, 35 | "edge_feat": false 36 | } 37 | } -------------------------------------------------------------------------------- /configs/NCI1/TUs_graph_classification_DiffPool_NCI1_100k.json: -------------------------------------------------------------------------------- 1 | { 2 | "gpu": { 3 | "use": true, 4 | "id": 0 5 | }, 6 | 7 | "model": "DiffPool", 8 | "dataset": "NCI1", 9 | 10 | "out_dir": "out/TUs_graph_classification/", 11 | 12 | "params": { 13 | "seed": 41, 14 | "epochs": 1000, 15 | "batch_size": 20, 16 | "init_lr": 7e-4, 17 | "lr_reduce_factor": 0.5, 18 | "lr_schedule_patience": 25, 19 | "min_lr": 1e-6, 20 | "weight_decay": 0.0, 21 | "print_epoch_interval": 5, 22 | "max_time": 30 23 | }, 24 | 25 | "net_params": { 26 | "L": 4, 27 | "hidden_dim": 86, 28 | "out_dim": 86, 29 | "residual": true, 30 | "readout": "mean", 31 | "in_feat_dropout": 0.0, 32 | "dropout": 0.0, 33 | "batch_norm": true, 34 | "sage_aggregator": "maxpool", 35 | "self_loop": false, 36 | "edge_feat": false 37 | } 38 | } -------------------------------------------------------------------------------- /configs/NCI1/TUs_graph_classification_GAT_NCI1_100k.json: -------------------------------------------------------------------------------- 1 | { 2 | "gpu": { 3 | "use": true, 4 | "id": 0 5 | }, 6 | 7 | "model": "GAT", 8 | "dataset": "NCI1", 9 | 10 | "out_dir": "out/TUs_graph_classification/", 11 | 12 | "params": { 13 | "seed": 41, 14 | "epochs": 1000, 15 | "batch_size": 20, 16 | "init_lr": 5e-5, 17 | "lr_reduce_factor": 0.5, 18 | "lr_schedule_patience": 25, 19 | "min_lr": 1e-6, 20 | "weight_decay": 0.0, 21 | "print_epoch_interval": 5, 22 | "max_time": 30 23 | }, 24 | 25 | "net_params": { 26 | "L": 4, 27 | "hidden_dim": 17, 28 | "out_dim": 136, 29 | "residual": true, 30 | "readout": "mean", 31 | "n_heads": 8, 32 | "in_feat_dropout": 0.0, 33 | "dropout": 0.0, 34 | "batch_norm": true, 35 | "self_loop": false, 36 | "edge_feat": false 37 | } 38 | } -------------------------------------------------------------------------------- /configs/NCI1/TUs_graph_classification_MEWISPool_NCI1_100k.json: -------------------------------------------------------------------------------- 1 | { 2 | "gpu": { 3 | "use": true, 4 | "id": 0 5 | }, 6 | 7 | "model": "MEWISPool", 8 | "dataset": "NCI1", 9 | 10 | "out_dir": "out/TUs_graph_classification/", 11 | 12 | "params": { 13 | "seed": 41, 14 | "epochs": 1000, 15 | "batch_size": 20, 16 | "init_lr": 7e-4, 17 | "lr_reduce_factor": 0.5, 18 | "lr_schedule_patience": 25, 19 | "min_lr": 1e-6, 20 | "weight_decay": 0.0, 21 | "print_epoch_interval": 5, 22 | "max_time": 30 23 | }, 24 | 25 | "net_params": { 26 | "L": 4, 27 | "hidden_dim": 146, 28 | "out_dim": 146, 29 | "residual": true, 30 | "readout": "mean", 31 | "in_feat_dropout": 0.0, 32 | "dropout": 0.0, 33 | "batch_norm": true, 34 | "self_loop": false, 35 | "edge_feat": false 36 | } 37 | } -------------------------------------------------------------------------------- /configs/NCI109/TUs_graph_classification_DiffPool_NCI109_100k.json: -------------------------------------------------------------------------------- 1 | { 2 | "gpu": { 3 | "use": true, 4 | "id": 0 5 | }, 6 | 7 | "model": "DiffPool", 8 | "dataset": "NCI109", 9 | 10 | "out_dir": "out/TUs_graph_classification/", 11 | 12 | "params": { 13 | "seed": 41, 14 | "epochs": 1000, 15 | "batch_size": 20, 16 | "init_lr": 7e-4, 17 | "lr_reduce_factor": 0.5, 18 | "lr_schedule_patience": 25, 19 | "min_lr": 1e-6, 20 | "weight_decay": 0.0, 21 | "print_epoch_interval": 5, 22 | "max_time": 30 23 | }, 24 | 25 | "net_params": { 26 | "L": 4, 27 | "hidden_dim": 86, 28 | "out_dim": 86, 29 | "residual": true, 30 | "readout": "mean", 31 | "in_feat_dropout": 0.0, 32 | "dropout": 0.0, 33 | "batch_norm": true, 34 | "sage_aggregator": "maxpool", 35 | "self_loop": false, 36 | "edge_feat": false 37 | } 38 | } -------------------------------------------------------------------------------- /configs/NCI109/TUs_graph_classification_GAT_NCI109_100k.json: -------------------------------------------------------------------------------- 1 | { 2 | "gpu": { 3 | "use": true, 4 | "id": 0 5 | }, 6 | 7 | "model": "GAT", 8 | "dataset": "NCI109", 9 | 10 | "out_dir": "out/TUs_graph_classification/", 11 | 12 | "params": { 13 | "seed": 41, 14 | "epochs": 1000, 15 | "batch_size": 20, 16 | "init_lr": 5e-5, 17 | "lr_reduce_factor": 0.5, 18 | "lr_schedule_patience": 25, 19 | "min_lr": 1e-6, 20 | "weight_decay": 0.0, 21 | "print_epoch_interval": 5, 22 | "max_time": 30 23 | }, 24 | 25 | "net_params": { 26 | "L": 4, 27 | "hidden_dim": 17, 28 | "out_dim": 136, 29 | "residual": true, 30 | "readout": "mean", 31 | "n_heads": 8, 32 | "in_feat_dropout": 0.0, 33 | "dropout": 0.0, 34 | "batch_norm": true, 35 | "self_loop": false, 36 | "edge_feat": false 37 | } 38 | } -------------------------------------------------------------------------------- /configs/NCI109/TUs_graph_classification_MEWISPool_NCI109_100k.json: -------------------------------------------------------------------------------- 1 | { 2 | "gpu": { 3 | "use": true, 4 | "id": 0 5 | }, 6 | 7 | "model": "MEWISPool", 8 | "dataset": "NCI109", 9 | 10 | "out_dir": "out/TUs_graph_classification/", 11 | 12 | "params": { 13 | "seed": 41, 14 | "epochs": 1000, 15 | "batch_size": 20, 16 | "init_lr": 7e-4, 17 | "lr_reduce_factor": 0.5, 18 | "lr_schedule_patience": 25, 19 | "min_lr": 1e-6, 20 | "weight_decay": 0.0, 21 | "print_epoch_interval": 5, 22 | "max_time": 30 23 | }, 24 | 25 | "net_params": { 26 | "L": 4, 27 | "hidden_dim": 146, 28 | "out_dim": 146, 29 | "residual": true, 30 | "readout": "mean", 31 | "in_feat_dropout": 0.0, 32 | "dropout": 0.0, 33 | "batch_norm": true, 34 | "self_loop": false, 35 | "edge_feat": false 36 | } 37 | } -------------------------------------------------------------------------------- /configs/PROTEINS_full/TUs_graph_classification_DiffPool_PROTEINS_full_100k.json: -------------------------------------------------------------------------------- 1 | { 2 | "gpu": { 3 | "use": true, 4 | "id": 0 5 | }, 6 | 7 | "model": "DiffPool", 8 | "dataset": "PROTEINS_full", 9 | 10 | "out_dir": "out/TUs_graph_classification/", 11 | 12 | "params": { 13 | "seed": 41, 14 | "epochs": 1000, 15 | "batch_size": 20, 16 | "init_lr": 7e-4, 17 | "lr_reduce_factor": 0.5, 18 | "lr_schedule_patience": 25, 19 | "min_lr": 1e-6, 20 | "weight_decay": 0.0, 21 | "print_epoch_interval": 5, 22 | "max_time": 30 23 | }, 24 | 25 | "net_params": { 26 | "L": 4, 27 | "hidden_dim": 86, 28 | "out_dim": 86, 29 | "residual": true, 30 | "readout": "mean", 31 | "in_feat_dropout": 0.0, 32 | "dropout": 0.0, 33 | "batch_norm": true, 34 | "sage_aggregator": "maxpool", 35 | "self_loop": false, 36 | "edge_feat": false 37 | } 38 | } -------------------------------------------------------------------------------- /configs/PROTEINS_full/TUs_graph_classification_MEWISPool_PROTEINS_full_100k.json: -------------------------------------------------------------------------------- 1 | { 2 | "gpu": { 3 | "use": true, 4 | "id": 0 5 | }, 6 | 7 | "model": "MEWISPool", 8 | "dataset": "PROTEINS_full", 9 | 10 | "out_dir": "out/TUs_graph_classification/", 11 | 12 | "params": { 13 | "seed": 41, 14 | "epochs": 1000, 15 | "batch_size": 20, 16 | "init_lr": 7e-4, 17 | "lr_reduce_factor": 0.5, 18 | "lr_schedule_patience": 25, 19 | "min_lr": 1e-6, 20 | "weight_decay": 0.0, 21 | "print_epoch_interval": 5, 22 | "max_time": 30 23 | }, 24 | 25 | "net_params": { 26 | "L": 4, 27 | "hidden_dim": 146, 28 | "out_dim": 146, 29 | "residual": true, 30 | "readout": "mean", 31 | "in_feat_dropout": 0.0, 32 | "dropout": 0.0, 33 | "batch_norm": true, 34 | "self_loop": false, 35 | "edge_feat": false 36 | } 37 | } -------------------------------------------------------------------------------- /configs/TUs_graph_classification_3WLGNN_DD_100k.json: -------------------------------------------------------------------------------- 1 | { 2 | "gpu": { 3 | "use": true, 4 | "id": 0 5 | }, 6 | 7 | "model": "3WLGNN", 8 | "dataset": "DD", 9 | 10 | "out_dir": "out/TUs_graph_classification/", 11 | 12 | "params": { 13 | "seed": 41, 14 | "epochs": 1000, 15 | "batch_size": 4, 16 | "init_lr": 1e-4, 17 | "lr_reduce_factor": 0.5, 18 | "lr_schedule_patience": 25, 19 | "min_lr": 1e-6, 20 | "weight_decay": 0.0, 21 | "print_epoch_interval": 5, 22 | "max_time": 12 23 | }, 24 | 25 | "net_params": { 26 | "L": 3, 27 | "hidden_dim": 74, 28 | "depth_of_mlp": 2, 29 | "residual": false, 30 | "dropout": 0.0, 31 | "layer_norm": false 32 | } 33 | } -------------------------------------------------------------------------------- /configs/TUs_graph_classification_3WLGNN_ENZYMES_100k.json: -------------------------------------------------------------------------------- 1 | { 2 | "gpu": { 3 | "use": true, 4 | "id": 0 5 | }, 6 | 7 | "model": "3WLGNN", 8 | "dataset": "ENZYMES", 9 | 10 | "out_dir": "out/TUs_graph_classification/", 11 | 12 | "params": { 13 | "seed": 41, 14 | "epochs": 1000, 15 | "batch_size": 4, 16 | "init_lr": 1e-3, 17 | "lr_reduce_factor": 0.5, 18 | "lr_schedule_patience": 25, 19 | "min_lr": 1e-6, 20 | "weight_decay": 0.0, 21 | "print_epoch_interval": 5, 22 | "max_time": 12 23 | }, 24 | 25 | "net_params": { 26 | "L": 3, 27 | "hidden_dim": 80, 28 | "depth_of_mlp": 2, 29 | "residual": false, 30 | "dropout": 0.0, 31 | "layer_norm": false 32 | } 33 | } -------------------------------------------------------------------------------- /configs/TUs_graph_classification_3WLGNN_PROTEINS_full_100k.json: -------------------------------------------------------------------------------- 1 | { 2 | "gpu": { 3 | "use": true, 4 | "id": 0 5 | }, 6 | 7 | "model": "3WLGNN", 8 | "dataset": "PROTEINS_full", 9 | 10 | "out_dir": "out/TUs_graph_classification/", 11 | 12 | "params": { 13 | "seed": 41, 14 | "epochs": 1000, 15 | "batch_size": 4, 16 | "init_lr": 1e-3, 17 | "lr_reduce_factor": 0.5, 18 | "lr_schedule_patience": 25, 19 | "min_lr": 1e-5, 20 | "weight_decay": 0.0, 21 | "print_epoch_interval": 5, 22 | "max_time": 12 23 | }, 24 | 25 | "net_params": { 26 | "L": 3, 27 | "hidden_dim": 80, 28 | "depth_of_mlp": 2, 29 | "residual": false, 30 | "dropout": 0.0, 31 | "layer_norm": false 32 | } 33 | } -------------------------------------------------------------------------------- /configs/TUs_graph_classification_GAT_DD_100k.json: -------------------------------------------------------------------------------- 1 | { 2 | "gpu": { 3 | "use": true, 4 | "id": 0 5 | }, 6 | 7 | "model": "GAT", 8 | "dataset": "DD", 9 | 10 | "out_dir": "out/TUs_graph_classification/", 11 | 12 | "params": { 13 | "seed": 41, 14 | "epochs": 1000, 15 | "batch_size": 20, 16 | "init_lr": 5e-5, 17 | "lr_reduce_factor": 0.5, 18 | "lr_schedule_patience": 25, 19 | "min_lr": 1e-6, 20 | "weight_decay": 0.0, 21 | "print_epoch_interval": 5, 22 | "max_time": 12 23 | }, 24 | 25 | "net_params": { 26 | "L": 4, 27 | "hidden_dim": 17, 28 | "out_dim": 136, 29 | "residual": true, 30 | "readout": "mean", 31 | "n_heads": 8, 32 | "in_feat_dropout": 0.0, 33 | "dropout": 0.0, 34 | "batch_norm": true, 35 | "self_loop": false 36 | } 37 | } -------------------------------------------------------------------------------- /configs/TUs_graph_classification_GAT_ENZYMES_100k.json: -------------------------------------------------------------------------------- 1 | { 2 | "gpu": { 3 | "use": true, 4 | "id": 0 5 | }, 6 | 7 | "model": "GAT", 8 | "dataset": "ENZYMES", 9 | 10 | "out_dir": "out/TUs_graph_classification/", 11 | 12 | "params": { 13 | "seed": 41, 14 | "epochs": 1000, 15 | "batch_size": 20, 16 | "init_lr": 1e-3, 17 | "lr_reduce_factor": 0.5, 18 | "lr_schedule_patience": 25, 19 | "min_lr": 1e-6, 20 | "weight_decay": 0.0, 21 | "print_epoch_interval": 5, 22 | "max_time": 12 23 | }, 24 | 25 | "net_params": { 26 | "L": 4, 27 | "hidden_dim": 18, 28 | "out_dim": 144, 29 | "residual": true, 30 | "readout": "mean", 31 | "n_heads": 8, 32 | "in_feat_dropout": 0.0, 33 | "dropout": 0.0, 34 | "batch_norm": true, 35 | "self_loop": false 36 | } 37 | } -------------------------------------------------------------------------------- /configs/TUs_graph_classification_GAT_PROTEINS_full_100k.json: -------------------------------------------------------------------------------- 1 | { 2 | "gpu": { 3 | "use": true, 4 | "id": 0 5 | }, 6 | 7 | "model": "GAT", 8 | "dataset": "PROTEINS_full", 9 | 10 | "out_dir": "out/TUs_graph_classification/", 11 | 12 | "params": { 13 | "seed": 41, 14 | "epochs": 1000, 15 | "batch_size": 20, 16 | "init_lr": 1e-3, 17 | "lr_reduce_factor": 0.5, 18 | "lr_schedule_patience": 25, 19 | "min_lr": 1e-6, 20 | "weight_decay": 0.0, 21 | "print_epoch_interval": 5, 22 | "max_time": 12 23 | }, 24 | 25 | "net_params": { 26 | "L": 4, 27 | "hidden_dim": 18, 28 | "out_dim": 144, 29 | "residual": true, 30 | "readout": "mean", 31 | "n_heads": 8, 32 | "in_feat_dropout": 0.0, 33 | "dropout": 0.0, 34 | "batch_norm": true, 35 | "self_loop": false 36 | } 37 | } -------------------------------------------------------------------------------- /configs/TUs_graph_classification_GCN_DD_100k.json: -------------------------------------------------------------------------------- 1 | { 2 | "gpu": { 3 | "use": true, 4 | "id": 0 5 | }, 6 | 7 | "model": "GCN", 8 | "dataset": "DD", 9 | 10 | "out_dir": "out/TUs_graph_classification/", 11 | 12 | "params": { 13 | "seed": 41, 14 | "epochs": 1000, 15 | "batch_size": 20, 16 | "init_lr": 7e-4, 17 | "lr_reduce_factor": 0.5, 18 | "lr_schedule_patience": 25, 19 | "min_lr": 1e-6, 20 | "weight_decay": 0.0, 21 | "print_epoch_interval": 5, 22 | "max_time": 12 23 | }, 24 | 25 | "net_params": { 26 | "L": 4, 27 | "hidden_dim": 138, 28 | "out_dim": 138, 29 | "residual": true, 30 | "readout": "mean", 31 | "in_feat_dropout": 0.0, 32 | "dropout": 0.0, 33 | "batch_norm": true, 34 | "self_loop": false 35 | } 36 | } -------------------------------------------------------------------------------- /configs/TUs_graph_classification_GCN_ENZYMES_100k.json: -------------------------------------------------------------------------------- 1 | { 2 | "gpu": { 3 | "use": true, 4 | "id": 0 5 | }, 6 | 7 | "model": "GCN", 8 | "dataset": "ENZYMES", 9 | 10 | "out_dir": "out/TUs_graph_classification/", 11 | 12 | "params": { 13 | "seed": 41, 14 | "epochs": 1000, 15 | "batch_size": 20, 16 | "init_lr": 7e-4, 17 | "lr_reduce_factor": 0.5, 18 | "lr_schedule_patience": 25, 19 | "min_lr": 1e-6, 20 | "weight_decay": 0.0, 21 | "print_epoch_interval": 5, 22 | "max_time": 12 23 | }, 24 | 25 | "net_params": { 26 | "L": 4, 27 | "hidden_dim": 146, 28 | "out_dim": 146, 29 | "residual": true, 30 | "readout": "mean", 31 | "in_feat_dropout": 0.0, 32 | "dropout": 0.0, 33 | "batch_norm": true, 34 | "self_loop": false 35 | } 36 | } -------------------------------------------------------------------------------- /configs/TUs_graph_classification_GCN_PROTEINS_full_100k.json: -------------------------------------------------------------------------------- 1 | { 2 | "gpu": { 3 | "use": true, 4 | "id": 0 5 | }, 6 | 7 | "model": "GCN", 8 | "dataset": "PROTEINS_full", 9 | 10 | "out_dir": "out/TUs_graph_classification/", 11 | 12 | "params": { 13 | "seed": 41, 14 | "epochs": 1000, 15 | "batch_size": 20, 16 | "init_lr": 7e-4, 17 | "lr_reduce_factor": 0.5, 18 | "lr_schedule_patience": 25, 19 | "min_lr": 1e-6, 20 | "weight_decay": 0.0, 21 | "print_epoch_interval": 5, 22 | "max_time": 12 23 | }, 24 | 25 | "net_params": { 26 | "L": 4, 27 | "hidden_dim": 146, 28 | "out_dim": 146, 29 | "residual": true, 30 | "readout": "mean", 31 | "in_feat_dropout": 0.0, 32 | "dropout": 0.0, 33 | "batch_norm": true, 34 | "self_loop": false 35 | } 36 | } -------------------------------------------------------------------------------- /configs/TUs_graph_classification_GIN_DD_100k.json: -------------------------------------------------------------------------------- 1 | { 2 | "gpu": { 3 | "use": true, 4 | "id": 0 5 | }, 6 | 7 | "model": "GIN", 8 | "dataset": "DD", 9 | 10 | "out_dir": "out/TUs_graph_classification/", 11 | 12 | "params": { 13 | "seed": 41, 14 | "epochs": 1000, 15 | "batch_size": 20, 16 | "init_lr": 1e-3, 17 | "lr_reduce_factor": 0.5, 18 | "lr_schedule_patience": 25, 19 | "min_lr": 1e-6, 20 | "weight_decay": 0.0, 21 | "print_epoch_interval": 5, 22 | "max_time": 12 23 | }, 24 | 25 | "net_params": { 26 | "L": 4, 27 | "hidden_dim": 106, 28 | "residual": true, 29 | "readout": "sum", 30 | "n_mlp_GIN": 2, 31 | "learn_eps_GIN": true, 32 | "neighbor_aggr_GIN": "sum", 33 | "in_feat_dropout": 0.0, 34 | "dropout": 0.0, 35 | "batch_norm": true 36 | } 37 | } -------------------------------------------------------------------------------- /configs/TUs_graph_classification_GIN_ENZYMES_100k.json: -------------------------------------------------------------------------------- 1 | { 2 | "gpu": { 3 | "use": true, 4 | "id": 0 5 | }, 6 | 7 | "model": "GIN", 8 | "dataset": "ENZYMES", 9 | 10 | "out_dir": "out/TUs_graph_classification/", 11 | 12 | "params": { 13 | "seed": 41, 14 | "epochs": 1000, 15 | "batch_size": 20, 16 | "init_lr": 7e-3, 17 | "lr_reduce_factor": 0.5, 18 | "lr_schedule_patience": 25, 19 | "min_lr": 1e-6, 20 | "weight_decay": 0.0, 21 | "print_epoch_interval": 5, 22 | "max_time": 12 23 | }, 24 | 25 | "net_params": { 26 | "L": 4, 27 | "hidden_dim": 110, 28 | "residual": true, 29 | "readout": "sum", 30 | "n_mlp_GIN": 2, 31 | "learn_eps_GIN": true, 32 | "neighbor_aggr_GIN": "sum", 33 | "in_feat_dropout": 0.0, 34 | "dropout": 0.0, 35 | "batch_norm": true 36 | } 37 | } -------------------------------------------------------------------------------- /configs/TUs_graph_classification_GIN_PROTEINS_full_100k.json: -------------------------------------------------------------------------------- 1 | { 2 | "gpu": { 3 | "use": true, 4 | "id": 0 5 | }, 6 | 7 | "model": "GIN", 8 | "dataset": "PROTEINS_full", 9 | 10 | "out_dir": "out/TUs_graph_classification/", 11 | 12 | "params": { 13 | "seed": 41, 14 | "epochs": 1000, 15 | "batch_size": 20, 16 | "init_lr": 1e-4, 17 | "lr_reduce_factor": 0.5, 18 | "lr_schedule_patience": 25, 19 | "min_lr": 1e-6, 20 | "weight_decay": 0.0, 21 | "print_epoch_interval": 5, 22 | "max_time": 12 23 | }, 24 | 25 | "net_params": { 26 | "L": 4, 27 | "hidden_dim": 110, 28 | "residual": true, 29 | "readout": "sum", 30 | "n_mlp_GIN": 2, 31 | "learn_eps_GIN": true, 32 | "neighbor_aggr_GIN": "sum", 33 | "in_feat_dropout": 0.0, 34 | "dropout": 0.0, 35 | "batch_norm": true 36 | } 37 | } -------------------------------------------------------------------------------- /configs/TUs_graph_classification_GatedGCN_DD_100k.json: -------------------------------------------------------------------------------- 1 | { 2 | "gpu": { 3 | "use": true, 4 | "id": 0 5 | }, 6 | 7 | "model": "GatedGCN", 8 | "dataset": "DD", 9 | 10 | "out_dir": "out/TUs_graph_classification/", 11 | 12 | "params": { 13 | "seed": 41, 14 | "epochs": 1000, 15 | "batch_size": 20, 16 | "init_lr": 1e-5, 17 | "lr_reduce_factor": 0.5, 18 | "lr_schedule_patience": 25, 19 | "min_lr": 1e-6, 20 | "weight_decay": 0.0, 21 | "print_epoch_interval": 5, 22 | "max_time": 12 23 | }, 24 | 25 | "net_params": { 26 | "L": 4, 27 | "hidden_dim": 66, 28 | "out_dim": 66, 29 | "residual": true, 30 | "readout": "mean", 31 | "in_feat_dropout": 0.0, 32 | "dropout": 0.0, 33 | "batch_norm": true, 34 | "edge_feat": false 35 | } 36 | } -------------------------------------------------------------------------------- /configs/TUs_graph_classification_GatedGCN_ENZYMES_100k.json: -------------------------------------------------------------------------------- 1 | { 2 | "gpu": { 3 | "use": true, 4 | "id": 0 5 | }, 6 | 7 | "model": "GatedGCN", 8 | "dataset": "ENZYMES", 9 | 10 | "out_dir": "out/TUs_graph_classification/", 11 | 12 | "params": { 13 | "seed": 41, 14 | "epochs": 1000, 15 | "batch_size": 20, 16 | "init_lr": 7e-4, 17 | "lr_reduce_factor": 0.5, 18 | "lr_schedule_patience": 25, 19 | "min_lr": 1e-6, 20 | "weight_decay": 0.0, 21 | "print_epoch_interval": 5, 22 | "max_time": 12 23 | }, 24 | 25 | "net_params": { 26 | "L": 4, 27 | "hidden_dim": 69, 28 | "out_dim": 69, 29 | "residual": true, 30 | "readout": "mean", 31 | "in_feat_dropout": 0.0, 32 | "dropout": 0.0, 33 | "batch_norm": true, 34 | "edge_feat": false 35 | } 36 | } -------------------------------------------------------------------------------- /configs/TUs_graph_classification_GatedGCN_PROTEINS_full_100k.json: -------------------------------------------------------------------------------- 1 | { 2 | "gpu": { 3 | "use": true, 4 | "id": 0 5 | }, 6 | 7 | "model": "GatedGCN", 8 | "dataset": "PROTEINS_full", 9 | 10 | "out_dir": "out/TUs_graph_classification/", 11 | 12 | "params": { 13 | "seed": 41, 14 | "epochs": 1000, 15 | "batch_size": 20, 16 | "init_lr": 1e-4, 17 | "lr_reduce_factor": 0.5, 18 | "lr_schedule_patience": 25, 19 | "min_lr": 1e-6, 20 | "weight_decay": 0.0, 21 | "print_epoch_interval": 5, 22 | "max_time": 12 23 | }, 24 | 25 | "net_params": { 26 | "L": 4, 27 | "hidden_dim": 69, 28 | "out_dim": 69, 29 | "residual": true, 30 | "readout": "mean", 31 | "in_feat_dropout": 0.0, 32 | "dropout": 0.0, 33 | "batch_norm": true, 34 | "edge_feat": false 35 | } 36 | } -------------------------------------------------------------------------------- /configs/TUs_graph_classification_GraphSage_DD_100k.json: -------------------------------------------------------------------------------- 1 | { 2 | "gpu": { 3 | "use": true, 4 | "id": 0 5 | }, 6 | 7 | "model": "GraphSage", 8 | "dataset": "DD", 9 | 10 | "out_dir": "out/TUs_graph_classification/", 11 | 12 | "params": { 13 | "seed": 41, 14 | "epochs": 1000, 15 | "batch_size": 20, 16 | "init_lr": 1e-3, 17 | "lr_reduce_factor": 0.5, 18 | "lr_schedule_patience": 25, 19 | "min_lr": 1e-6, 20 | "weight_decay": 0.0, 21 | "print_epoch_interval": 5, 22 | "max_time": 12 23 | }, 24 | 25 | "net_params": { 26 | "L": 4, 27 | "hidden_dim": 86, 28 | "out_dim": 86, 29 | "residual": true, 30 | "readout": "mean", 31 | "in_feat_dropout": 0.0, 32 | "dropout": 0.0, 33 | "batch_norm": true, 34 | "sage_aggregator": "maxpool" 35 | } 36 | } -------------------------------------------------------------------------------- /configs/TUs_graph_classification_GraphSage_ENZYMES_100k.json: -------------------------------------------------------------------------------- 1 | { 2 | "gpu": { 3 | "use": true, 4 | "id": 0 5 | }, 6 | 7 | "model": "GraphSage", 8 | "dataset": "ENZYMES", 9 | 10 | "out_dir": "out/TUs_graph_classification/", 11 | 12 | "params": { 13 | "seed": 41, 14 | "epochs": 1000, 15 | "batch_size": 20, 16 | "init_lr": 7e-4, 17 | "lr_reduce_factor": 0.5, 18 | "lr_schedule_patience": 25, 19 | "min_lr": 1e-6, 20 | "weight_decay": 0.0, 21 | "print_epoch_interval": 5, 22 | "max_time": 12 23 | }, 24 | 25 | "net_params": { 26 | "L": 4, 27 | "hidden_dim": 90, 28 | "out_dim": 90, 29 | "residual": true, 30 | "readout": "mean", 31 | "in_feat_dropout": 0.0, 32 | "dropout": 0.0, 33 | "batch_norm": true, 34 | "sage_aggregator": "maxpool" 35 | } 36 | } -------------------------------------------------------------------------------- /configs/TUs_graph_classification_GraphSage_PROTEINS_full_100k.json: -------------------------------------------------------------------------------- 1 | { 2 | "gpu": { 3 | "use": true, 4 | "id": 0 5 | }, 6 | 7 | "model": "GraphSage", 8 | "dataset": "PROTEINS_full", 9 | 10 | "out_dir": "out/TUs_graph_classification/", 11 | 12 | "params": { 13 | "seed": 41, 14 | "epochs": 1000, 15 | "batch_size": 20, 16 | "init_lr": 7e-5, 17 | "lr_reduce_factor": 0.5, 18 | "lr_schedule_patience": 25, 19 | "min_lr": 1e-6, 20 | "weight_decay": 0.0, 21 | "print_epoch_interval": 5, 22 | "max_time": 12 23 | }, 24 | 25 | "net_params": { 26 | "L": 4, 27 | "hidden_dim": 88, 28 | "out_dim": 88, 29 | "residual": true, 30 | "readout": "mean", 31 | "in_feat_dropout": 0.0, 32 | "dropout": 0.0, 33 | "batch_norm": true, 34 | "sage_aggregator": "maxpool" 35 | } 36 | } -------------------------------------------------------------------------------- /configs/TUs_graph_classification_MLP_DD_100k.json: -------------------------------------------------------------------------------- 1 | { 2 | "gpu": { 3 | "use": true, 4 | "id": 0 5 | }, 6 | 7 | "model": "MLP", 8 | "dataset": "DD", 9 | 10 | "out_dir": "out/TUs_graph_classification/", 11 | 12 | "params": { 13 | "seed": 41, 14 | "epochs": 1000, 15 | "batch_size": 20, 16 | "init_lr": 1e-4, 17 | "lr_reduce_factor": 0.5, 18 | "lr_schedule_patience": 25, 19 | "min_lr": 1e-6, 20 | "weight_decay": 0.0, 21 | "print_epoch_interval": 5, 22 | "max_time": 12 23 | }, 24 | 25 | "net_params": { 26 | "L": 4, 27 | "hidden_dim": 154, 28 | "out_dim": 154, 29 | "readout": "mean", 30 | "gated": false, 31 | "in_feat_dropout": 0.0, 32 | "dropout": 0.0 33 | } 34 | } -------------------------------------------------------------------------------- /configs/TUs_graph_classification_MLP_ENZYMES_100k.json: -------------------------------------------------------------------------------- 1 | { 2 | "gpu": { 3 | "use": true, 4 | "id": 0 5 | }, 6 | 7 | "model": "MLP", 8 | "dataset": "ENZYMES", 9 | 10 | "out_dir": "out/TUs_graph_classification/", 11 | 12 | "params": { 13 | "seed": 41, 14 | "epochs": 1000, 15 | "batch_size": 20, 16 | "init_lr": 1e-3, 17 | "lr_reduce_factor": 0.5, 18 | "lr_schedule_patience": 25, 19 | "min_lr": 1e-6, 20 | "weight_decay": 0.0, 21 | "print_epoch_interval": 5, 22 | "max_time": 12 23 | }, 24 | 25 | "net_params": { 26 | "L": 4, 27 | "hidden_dim": 164, 28 | "out_dim": 164, 29 | "readout": "mean", 30 | "gated": false, 31 | "in_feat_dropout": 0.0, 32 | "dropout": 0.0 33 | } 34 | } -------------------------------------------------------------------------------- /configs/TUs_graph_classification_MLP_PROTEINS_full_100k.json: -------------------------------------------------------------------------------- 1 | { 2 | "gpu": { 3 | "use": true, 4 | "id": 0 5 | }, 6 | 7 | "model": "MLP", 8 | "dataset": "PROTEINS_full", 9 | 10 | "out_dir": "out/TUs_graph_classification/", 11 | 12 | "params": { 13 | "seed": 41, 14 | "epochs": 1000, 15 | "batch_size": 20, 16 | "init_lr": 1e-4, 17 | "lr_reduce_factor": 0.5, 18 | "lr_schedule_patience": 25, 19 | "min_lr": 1e-6, 20 | "weight_decay": 0.0, 21 | "print_epoch_interval": 5, 22 | "max_time": 12 23 | }, 24 | 25 | "net_params": { 26 | "L": 4, 27 | "hidden_dim": 162, 28 | "out_dim": 162, 29 | "readout": "mean", 30 | "gated": false, 31 | "in_feat_dropout": 0.0, 32 | "dropout": 0.0 33 | } 34 | } -------------------------------------------------------------------------------- /configs/TUs_graph_classification_MoNet_DD_100k.json: -------------------------------------------------------------------------------- 1 | { 2 | "gpu": { 3 | "use": true, 4 | "id": 0 5 | }, 6 | 7 | "model": "MoNet", 8 | "dataset": "DD", 9 | 10 | "out_dir": "out/TUs_graph_classification/", 11 | 12 | "params": { 13 | "seed": 41, 14 | "epochs": 1000, 15 | "batch_size": 20, 16 | "init_lr": 1e-5, 17 | "lr_reduce_factor": 0.5, 18 | "lr_schedule_patience": 25, 19 | "min_lr": 1e-6, 20 | "weight_decay": 0.0, 21 | "print_epoch_interval": 5, 22 | "max_time": 12 23 | }, 24 | 25 | "net_params": { 26 | "L": 4, 27 | "hidden_dim": 86, 28 | "out_dim": 86, 29 | "residual": true, 30 | "readout": "mean", 31 | "kernel": 3, 32 | "pseudo_dim_MoNet": 2, 33 | "in_feat_dropout": 0.0, 34 | "dropout": 0.0, 35 | "batch_norm": true 36 | } 37 | } -------------------------------------------------------------------------------- /configs/TUs_graph_classification_MoNet_ENZYMES_100k.json: -------------------------------------------------------------------------------- 1 | { 2 | "gpu": { 3 | "use": true, 4 | "id": 0 5 | }, 6 | 7 | "model": "MoNet", 8 | "dataset": "ENZYMES", 9 | 10 | "out_dir": "out/TUs_graph_classification/", 11 | 12 | "params": { 13 | "seed": 41, 14 | "epochs": 1000, 15 | "batch_size": 20, 16 | "init_lr": 1e-3, 17 | "lr_reduce_factor": 0.5, 18 | "lr_schedule_patience": 25, 19 | "min_lr": 1e-6, 20 | "weight_decay": 0.0, 21 | "print_epoch_interval": 5, 22 | "max_time": 12 23 | }, 24 | 25 | "net_params": { 26 | "L": 4, 27 | "hidden_dim": 90, 28 | "out_dim": 90, 29 | "residual": true, 30 | "readout": "mean", 31 | "kernel": 3, 32 | "pseudo_dim_MoNet": 2, 33 | "in_feat_dropout": 0.0, 34 | "dropout": 0.0, 35 | "batch_norm": true 36 | } 37 | } -------------------------------------------------------------------------------- /configs/TUs_graph_classification_MoNet_PROTEINS_full_100k.json: -------------------------------------------------------------------------------- 1 | { 2 | "gpu": { 3 | "use": true, 4 | "id": 0 5 | }, 6 | 7 | "model": "MoNet", 8 | "dataset": "PROTEINS_full", 9 | 10 | "out_dir": "out/TUs_graph_classification/", 11 | 12 | "params": { 13 | "seed": 41, 14 | "epochs": 1000, 15 | "batch_size": 20, 16 | "init_lr": 7e-5, 17 | "lr_reduce_factor": 0.5, 18 | "lr_schedule_patience": 25, 19 | "min_lr": 1e-6, 20 | "weight_decay": 0.0, 21 | "print_epoch_interval": 5, 22 | "max_time": 12 23 | }, 24 | 25 | "net_params": { 26 | "L": 4, 27 | "hidden_dim": 89, 28 | "out_dim": 89, 29 | "residual": true, 30 | "readout": "mean", 31 | "kernel": 3, 32 | "pseudo_dim_MoNet": 2, 33 | "in_feat_dropout": 0.0, 34 | "dropout": 0.0, 35 | "batch_norm": true 36 | } 37 | } -------------------------------------------------------------------------------- /configs/TUs_graph_classification_RingGNN_DD_100k.json: -------------------------------------------------------------------------------- 1 | { 2 | "gpu": { 3 | "use": true, 4 | "id": 0 5 | }, 6 | 7 | "model": "RingGNN", 8 | "dataset": "DD", 9 | 10 | "out_dir": "out/TUs_graph_classification/", 11 | 12 | "params": { 13 | "seed": 41, 14 | "epochs": 1000, 15 | "batch_size": 4, 16 | "init_lr": 1e-4, 17 | "lr_reduce_factor": 0.5, 18 | "lr_schedule_patience": 25, 19 | "min_lr": 1e-6, 20 | "weight_decay": 0.0, 21 | "print_epoch_interval": 5, 22 | "max_time": 12 23 | }, 24 | 25 | "net_params": { 26 | "L": 2, 27 | "hidden_dim": 20, 28 | "radius": 2, 29 | "residual": false, 30 | "dropout": 0.0, 31 | "layer_norm": false 32 | } 33 | } -------------------------------------------------------------------------------- /configs/TUs_graph_classification_RingGNN_ENZYMES_100k.json: -------------------------------------------------------------------------------- 1 | { 2 | "gpu": { 3 | "use": true, 4 | "id": 0 5 | }, 6 | 7 | "model": "RingGNN", 8 | "dataset": "ENZYMES", 9 | 10 | "out_dir": "out/TUs_graph_classification/", 11 | 12 | "params": { 13 | "seed": 41, 14 | "epochs": 1000, 15 | "batch_size": 4, 16 | "init_lr": 1e-4, 17 | "lr_reduce_factor": 0.5, 18 | "lr_schedule_patience": 25, 19 | "min_lr": 1e-6, 20 | "weight_decay": 0.0, 21 | "print_epoch_interval": 5, 22 | "max_time": 12 23 | }, 24 | 25 | "net_params": { 26 | "L": 2, 27 | "hidden_dim": 38, 28 | "radius": 2, 29 | "residual": false, 30 | "dropout": 0.0, 31 | "layer_norm": false 32 | } 33 | } -------------------------------------------------------------------------------- /configs/TUs_graph_classification_RingGNN_PROTEINS_full_100k.json: -------------------------------------------------------------------------------- 1 | { 2 | "gpu": { 3 | "use": true, 4 | "id": 0 5 | }, 6 | 7 | "model": "RingGNN", 8 | "dataset": "PROTEINS_full", 9 | 10 | "out_dir": "out/TUs_graph_classification/", 11 | 12 | "params": { 13 | "seed": 41, 14 | "epochs": 1000, 15 | "batch_size": 4, 16 | "init_lr": 7e-5, 17 | "lr_reduce_factor": 0.5, 18 | "lr_schedule_patience": 25, 19 | "min_lr": 1e-6, 20 | "weight_decay": 0.0, 21 | "print_epoch_interval": 5, 22 | "max_time": 12 23 | }, 24 | 25 | "net_params": { 26 | "L": 2, 27 | "hidden_dim": 35, 28 | "radius": 2, 29 | "residual": false, 30 | "dropout": 0.0, 31 | "layer_norm": false 32 | } 33 | } -------------------------------------------------------------------------------- /configs/Tox21_AhR_training/TUs_graph_classification_DiffPool_Tox21_AhR_training_100k.json: -------------------------------------------------------------------------------- 1 | { 2 | "gpu": { 3 | "use": true, 4 | "id": 0 5 | }, 6 | 7 | "model": "DiffPool", 8 | "dataset": "Tox21_AhR_training", 9 | 10 | "out_dir": "out/TUs_graph_classification/", 11 | 12 | "params": { 13 | "seed": 41, 14 | "epochs": 1000, 15 | "batch_size": 20, 16 | "init_lr": 7e-4, 17 | "lr_reduce_factor": 0.5, 18 | "lr_schedule_patience": 25, 19 | "min_lr": 1e-6, 20 | "weight_decay": 0.0, 21 | "print_epoch_interval": 5, 22 | "max_time": 30 23 | }, 24 | 25 | "net_params": { 26 | "L": 4, 27 | "hidden_dim": 86, 28 | "out_dim": 86, 29 | "residual": true, 30 | "readout": "mean", 31 | "in_feat_dropout": 0.0, 32 | "dropout": 0.0, 33 | "batch_norm": true, 34 | "sage_aggregator": "maxpool", 35 | "self_loop": false, 36 | "edge_feat": false 37 | } 38 | } -------------------------------------------------------------------------------- /configs/Tox21_AhR_training/TUs_graph_classification_GAT_Tox21_AhR_training_100k.json: -------------------------------------------------------------------------------- 1 | { 2 | "gpu": { 3 | "use": true, 4 | "id": 0 5 | }, 6 | 7 | "model": "GAT", 8 | "dataset": "Tox21_AhR_training", 9 | 10 | "out_dir": "out/TUs_graph_classification/", 11 | 12 | "params": { 13 | "seed": 41, 14 | "epochs": 1000, 15 | "batch_size": 20, 16 | "init_lr": 5e-5, 17 | "lr_reduce_factor": 0.5, 18 | "lr_schedule_patience": 25, 19 | "min_lr": 1e-6, 20 | "weight_decay": 0.0, 21 | "print_epoch_interval": 5, 22 | "max_time": 30 23 | }, 24 | 25 | "net_params": { 26 | "L": 4, 27 | "hidden_dim": 17, 28 | "out_dim": 136, 29 | "residual": true, 30 | "readout": "mean", 31 | "n_heads": 8, 32 | "in_feat_dropout": 0.0, 33 | "dropout": 0.0, 34 | "batch_norm": true, 35 | "self_loop": true, 36 | "edge_feat": false 37 | } 38 | } -------------------------------------------------------------------------------- /configs/Tox21_AhR_training/TUs_graph_classification_MEWISPool_Tox21_AhR_training_100k.json: -------------------------------------------------------------------------------- 1 | { 2 | "gpu": { 3 | "use": true, 4 | "id": 0 5 | }, 6 | 7 | "model": "MEWISPool", 8 | "dataset": "Tox21_AhR_training", 9 | 10 | "out_dir": "out/TUs_graph_classification/", 11 | 12 | "params": { 13 | "seed": 41, 14 | "epochs": 1000, 15 | "batch_size": 20, 16 | "init_lr": 7e-4, 17 | "lr_reduce_factor": 0.5, 18 | "lr_schedule_patience": 25, 19 | "min_lr": 1e-6, 20 | "weight_decay": 0.0, 21 | "print_epoch_interval": 5, 22 | "max_time": 30 23 | }, 24 | 25 | "net_params": { 26 | "L": 4, 27 | "hidden_dim": 146, 28 | "out_dim": 146, 29 | "residual": true, 30 | "readout": "mean", 31 | "in_feat_dropout": 0.0, 32 | "dropout": 0.0, 33 | "batch_norm": true, 34 | "self_loop": false, 35 | "edge_feat": false 36 | } 37 | } -------------------------------------------------------------------------------- /data/TUs.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pickle 3 | import torch.utils.data 4 | import time 5 | import os 6 | import numpy as np 7 | 8 | import csv 9 | 10 | import dgl 11 | from dgl.data import TUDataset 12 | from dgl.data import LegacyTUDataset 13 | 14 | import random 15 | random.seed(42) 16 | 17 | from sklearn.model_selection import StratifiedKFold, train_test_split 18 | 19 | 20 | class DGLFormDataset(torch.utils.data.Dataset): 21 | """ 22 | DGLFormDataset wrapping graph list and label list as per pytorch Dataset. 23 | *lists (list): lists of 'graphs' and 'labels' with same len(). 24 | """ 25 | def __init__(self, *lists): 26 | assert all(len(lists[0]) == len(li) for li in lists) 27 | self.lists = lists 28 | self.graph_lists = lists[0] 29 | self.graph_labels = lists[1] 30 | 31 | def __getitem__(self, index): 32 | return tuple(li[index] for li in self.lists) 33 | 34 | def __len__(self): 35 | return len(self.lists[0]) 36 | 37 | 38 | 39 | def self_loop(g): 40 | """ 41 | Utility function only, to be used only when necessary as per user self_loop flag 42 | : Overwriting the function dgl.transform.add_self_loop() to not miss ndata['feat'] and edata['feat'] 43 | 44 | 45 | This function is called inside a function in TUsDataset class. 46 | """ 47 | new_g = dgl.DGLGraph() 48 | new_g.add_nodes(g.number_of_nodes()) 49 | new_g.ndata['feat'] = g.ndata['feat'] 50 | 51 | src, dst = g.all_edges(order="eid") 52 | src = dgl.backend.zerocopy_to_numpy(src) 53 | dst = dgl.backend.zerocopy_to_numpy(dst) 54 | non_self_edges_idx = src != dst 55 | nodes = np.arange(g.number_of_nodes()) 56 | new_g.add_edges(src[non_self_edges_idx], dst[non_self_edges_idx]) 57 | new_g.add_edges(nodes, nodes) 58 | 59 | # This new edata is not used since this function gets called only for GCN, GAT 60 | # However, we need this for the generic requirement of ndata and edata 61 | new_g.edata['feat'] = torch.zeros(new_g.number_of_edges()) 62 | return new_g 63 | 64 | 65 | class TUsDataset(torch.utils.data.Dataset): 66 | def __init__(self, name): 67 | t0 = time.time() 68 | self.name = name 69 | 70 | #dataset = TUDataset(self.name, hidden_size=1) 71 | dataset = LegacyTUDataset(self.name, hidden_size=1) # dgl 4.0 72 | self.input_dim, self.label_dim, self.max_num_node = dataset.statistics() 73 | 74 | # frankenstein has labels 0 and 2; so correcting them as 0 and 1 75 | if self.name in ["FRANKENSTEIN", "MUTAG"]: 76 | dataset.graph_labels = np.array([1 if x==2 else x for x in dataset.graph_labels]) 77 | 78 | print("[!] Dataset: ", self.name) 79 | 80 | # this function splits data into train/val/test and returns the indices 81 | self.all_idx = self.get_all_split_idx(dataset) 82 | 83 | self.all = dataset 84 | self.train = [self.format_dataset([dataset[idx] for idx in self.all_idx['train'][split_num]]) for split_num in range(10)] 85 | self.val = [self.format_dataset([dataset[idx] for idx in self.all_idx['val'][split_num]]) for split_num in range(10)] 86 | self.test = [self.format_dataset([dataset[idx] for idx in self.all_idx['test'][split_num]]) for split_num in range(10)] 87 | 88 | print("Time taken: {:.4f}s".format(time.time()-t0)) 89 | 90 | def get_all_split_idx(self, dataset): 91 | """ 92 | - Split total number of graphs into 3 (train, val and test) in 80:10:10 93 | - Stratified split proportionate to original distribution of data with respect to classes 94 | - Using sklearn to perform the split and then save the indexes 95 | - Preparing 10 such combinations of indexes split to be used in Graph NNs 96 | - As with KFold, each of the 10 fold have unique test set. 97 | """ 98 | root_idx_dir = './data/TUs/' 99 | if not os.path.exists(root_idx_dir): 100 | os.makedirs(root_idx_dir) 101 | all_idx = {} 102 | 103 | # If there are no idx files, do the split and store the files 104 | if not (os.path.exists(root_idx_dir + dataset.name + '_train.index')): 105 | print("[!] Splitting the data into train/val/test ...") 106 | 107 | # Using 10-fold cross val to compare with benchmark papers 108 | k_splits = 10 109 | 110 | cross_val_fold = StratifiedKFold(n_splits=k_splits, shuffle=True) 111 | k_data_splits = [] 112 | 113 | # this is a temporary index assignment, to be used below for val splitting 114 | for i in range(len(dataset.graph_lists)): 115 | dataset[i][0].a = lambda: None 116 | setattr(dataset[i][0].a, 'index', i) 117 | 118 | for indexes in cross_val_fold.split(dataset.graph_lists, dataset.graph_labels): 119 | remain_index, test_index = indexes[0], indexes[1] 120 | 121 | remain_set = self.format_dataset([dataset[index] for index in remain_index]) 122 | 123 | # Gets final 'train' and 'val' 124 | train, val, _, __ = train_test_split(remain_set, 125 | range(len(remain_set.graph_lists)), 126 | test_size=0.111, 127 | stratify=remain_set.graph_labels) 128 | 129 | train, val = self.format_dataset(train), self.format_dataset(val) 130 | test = self.format_dataset([dataset[index] for index in test_index]) 131 | 132 | # Extracting only idx 133 | idx_train = [item[0].a.index for item in train] 134 | idx_val = [item[0].a.index for item in val] 135 | idx_test = [item[0].a.index for item in test] 136 | 137 | f_train_w = csv.writer(open(root_idx_dir + dataset.name + '_train.index', 'a+')) 138 | f_val_w = csv.writer(open(root_idx_dir + dataset.name + '_val.index', 'a+')) 139 | f_test_w = csv.writer(open(root_idx_dir + dataset.name + '_test.index', 'a+')) 140 | 141 | f_train_w.writerow(idx_train) 142 | f_val_w.writerow(idx_val) 143 | f_test_w.writerow(idx_test) 144 | 145 | print("[!] Splitting done!") 146 | 147 | # reading idx from the files 148 | for section in ['train', 'val', 'test']: 149 | with open(root_idx_dir + dataset.name + '_'+ section + '.index', 'r') as f: 150 | reader = csv.reader(f) 151 | all_idx[section] = [list(map(int, idx)) for idx in reader] 152 | return all_idx 153 | 154 | def format_dataset(self, dataset): 155 | """ 156 | Utility function to recover data, 157 | INTO-> dgl/pytorch compatible format 158 | """ 159 | graphs = [data[0] for data in dataset] 160 | labels = [data[1] for data in dataset] 161 | 162 | for graph in graphs: 163 | #graph.ndata['feat'] = torch.FloatTensor(graph.ndata['feat']) 164 | graph.ndata['feat'] = graph.ndata['feat'].float() # dgl 4.0 165 | # adding edge features for Residual Gated ConvNet, if not there 166 | if 'feat' not in graph.edata.keys(): 167 | edge_feat_dim = graph.ndata['feat'].shape[1] # dim same as node feature dim 168 | graph.edata['feat'] = torch.ones(graph.number_of_edges(), edge_feat_dim) 169 | 170 | return DGLFormDataset(graphs, labels) 171 | 172 | 173 | # form a mini batch from a given list of samples = [(graph, label) pairs] 174 | def collate(self, samples): 175 | # The input samples is a list of pairs (graph, label). 176 | graphs, labels = map(list, zip(*samples)) 177 | labels = torch.tensor(np.array(labels)) 178 | # tab_sizes_n = [ graphs[i].number_of_nodes() for i in range(len(graphs))] 179 | # tab_snorm_n = [ torch.FloatTensor(size,1).fill_(1./float(size)) for size in tab_sizes_n ] 180 | # snorm_n = torch.cat(tab_snorm_n).sqrt() 181 | # tab_sizes_e = [ graphs[i].number_of_edges() for i in range(len(graphs))] 182 | # tab_snorm_e = [ torch.FloatTensor(size,1).fill_(1./float(size)) for size in tab_sizes_e ] 183 | # snorm_e = torch.cat(tab_snorm_e).sqrt() 184 | batched_graph = dgl.batch(graphs) 185 | 186 | return batched_graph, labels 187 | 188 | 189 | # prepare dense tensors for GNNs using them; such as RingGNN, 3WLGNN 190 | def collate_dense_gnn(self, samples): 191 | # The input samples is a list of pairs (graph, label). 192 | graphs, labels = map(list, zip(*samples)) 193 | labels = torch.tensor(np.array(labels)) 194 | #tab_sizes_n = [ graphs[i].number_of_nodes() for i in range(len(graphs))] 195 | #tab_snorm_n = [ torch.FloatTensor(size,1).fill_(1./float(size)) for size in tab_sizes_n ] 196 | #snorm_n = tab_snorm_n[0][0].sqrt() 197 | 198 | #batched_graph = dgl.batch(graphs) 199 | 200 | g = graphs[0] 201 | adj = self._sym_normalize_adj(g.adjacency_matrix().to_dense()) 202 | """ 203 | Adapted from https://github.com/leichen2018/Ring-GNN/ 204 | Assigning node and edge feats:: 205 | we have the adjacency matrix in R^{n x n}, the node features in R^{d_n} and edge features R^{d_e}. 206 | Then we build a zero-initialized tensor, say T, in R^{(1 + d_n + d_e) x n x n}. T[0, :, :] is the adjacency matrix. 207 | The diagonal T[1:1+d_n, i, i], i = 0 to n-1, store the node feature of node i. 208 | The off diagonal T[1+d_n:, i, j] store edge features of edge(i, j). 209 | """ 210 | 211 | zero_adj = torch.zeros_like(adj) 212 | 213 | in_dim = g.ndata['feat'].shape[1] 214 | 215 | # use node feats to prepare adj 216 | adj_node_feat = torch.stack([zero_adj for j in range(in_dim)]) 217 | adj_node_feat = torch.cat([adj.unsqueeze(0), adj_node_feat], dim=0) 218 | 219 | for node, node_feat in enumerate(g.ndata['feat']): 220 | adj_node_feat[1:, node, node] = node_feat 221 | 222 | x_node_feat = adj_node_feat.unsqueeze(0) 223 | 224 | return x_node_feat, labels 225 | 226 | def _sym_normalize_adj(self, adj): 227 | deg = torch.sum(adj, dim = 0)#.squeeze() 228 | deg_inv = torch.where(deg>0, 1./torch.sqrt(deg), torch.zeros(deg.size())) 229 | deg_inv = torch.diag(deg_inv) 230 | return torch.mm(deg_inv, torch.mm(adj, deg_inv)) 231 | 232 | 233 | def _add_self_loops(self): 234 | 235 | # function for adding self loops 236 | # this function will be called only if self_loop flag is True 237 | for split_num in range(10): 238 | self.train[split_num].graph_lists = [self_loop(g) for g in self.train[split_num].graph_lists] 239 | self.val[split_num].graph_lists = [self_loop(g) for g in self.val[split_num].graph_lists] 240 | self.test[split_num].graph_lists = [self_loop(g) for g in self.test[split_num].graph_lists] 241 | 242 | for split_num in range(10): 243 | self.train[split_num] = DGLFormDataset(self.train[split_num].graph_lists, self.train[split_num].graph_labels) 244 | self.val[split_num] = DGLFormDataset(self.val[split_num].graph_lists, self.val[split_num].graph_labels) 245 | self.test[split_num] = DGLFormDataset(self.test[split_num].graph_lists, self.test[split_num].graph_labels) 246 | -------------------------------------------------------------------------------- /data/TUs/DD_test.index: -------------------------------------------------------------------------------- 1 | 25,36,44,53,54,60,66,68,81,95,99,113,115,128,132,146,155,185,190,195,196,197,209,223,225,226,244,254,272,301,308,311,312,315,331,340,349,366,375,397,411,414,418,433,435,438,447,467,472,477,478,479,497,526,529,542,566,576,605,608,609,620,623,628,633,640,642,650,657,689,697,698,699,715,723,725,744,753,760,764,787,799,807,810,831,838,863,869,870,874,900,903,925,932,943,950,978,992,994,997,1008,1009,1018,1026,1037,1040,1074,1082,1083,1087,1113,1119,1125,1132,1135,1142,1143,1151,1175 2 | 8,42,47,48,55,86,111,114,116,119,129,140,144,148,150,160,161,176,178,221,243,255,257,268,273,276,289,295,297,302,321,343,353,354,360,367,376,378,379,389,401,428,431,434,436,449,453,474,476,480,493,499,500,505,523,527,540,543,557,581,585,597,629,630,649,660,668,675,685,693,694,707,720,727,728,737,756,768,773,774,780,783,784,786,792,801,805,809,814,818,820,822,824,826,829,835,837,846,851,857,909,910,913,919,923,936,953,963,975,981,1101,1107,1108,1128,1130,1147,1165,1172 3 | 3,19,21,31,61,97,104,118,123,124,134,138,142,152,159,171,179,194,217,231,256,269,296,298,299,310,327,341,348,350,362,364,380,381,383,393,398,399,405,408,409,417,430,432,442,456,459,460,485,489,492,530,533,548,553,556,558,568,572,583,584,588,589,614,627,632,654,672,686,700,703,735,769,772,775,777,782,788,832,848,856,860,884,890,891,892,899,906,917,935,944,952,955,959,970,972,973,980,982,1014,1015,1016,1019,1029,1036,1050,1069,1078,1081,1088,1129,1139,1148,1153,1156,1163,1164,1174 4 | 0,14,18,35,38,49,69,72,74,90,93,102,127,131,135,156,157,182,189,216,218,220,222,251,261,282,284,317,328,334,363,370,385,413,415,422,452,464,488,491,504,510,522,544,547,554,567,569,582,587,590,593,594,600,604,607,616,618,634,635,641,651,653,656,662,665,677,680,682,692,695,712,717,718,734,745,762,794,797,798,811,812,827,834,843,861,886,897,905,915,929,937,941,942,962,964,966,993,1003,1020,1027,1033,1034,1035,1043,1053,1058,1068,1075,1096,1100,1110,1112,1118,1122,1123,1134,1141 5 | 33,51,75,76,77,82,83,85,87,88,91,98,100,121,122,147,154,167,173,205,208,245,249,260,277,287,300,316,319,326,330,332,346,347,357,361,368,374,390,416,426,427,429,448,466,471,483,487,495,496,512,517,519,536,541,551,552,565,570,573,586,601,606,610,611,624,645,676,684,710,721,729,740,742,754,779,785,793,796,836,842,864,880,885,889,896,912,918,920,931,954,957,960,971,983,984,989,1002,1004,1024,1049,1054,1062,1067,1071,1089,1120,1126,1140,1145,1146,1149,1152,1154,1162,1168,1173,1176 6 | 4,27,28,34,39,62,80,92,125,126,143,153,166,168,187,192,201,212,215,219,230,233,237,241,246,258,259,267,271,275,278,281,290,305,313,318,339,342,344,345,352,356,386,387,394,400,412,424,444,470,503,508,520,528,535,537,549,561,562,564,577,621,625,643,648,652,661,678,688,711,724,730,736,739,747,752,755,770,791,806,821,828,840,844,862,867,868,876,881,883,902,904,908,916,933,934,940,948,949,965,979,986,990,995,996,1011,1045,1048,1059,1064,1070,1091,1111,1115,1136,1138,1150,1157 7 | 5,9,13,23,26,29,30,41,50,58,59,63,64,67,70,78,117,130,137,165,169,170,175,177,183,184,188,199,203,210,214,236,248,253,266,306,314,320,322,323,324,337,338,358,369,371,388,395,407,450,451,465,481,490,498,502,507,555,595,596,598,613,622,638,659,667,679,681,690,704,738,758,766,771,776,778,781,789,802,803,813,833,841,853,858,859,872,879,888,898,911,924,951,961,976,987,988,998,1012,1013,1017,1021,1025,1038,1046,1047,1051,1072,1086,1090,1095,1098,1103,1104,1105,1106,1124,1167 8 | 10,12,15,22,46,56,57,89,103,105,109,110,120,151,164,180,193,213,227,235,238,247,252,263,264,274,280,286,294,303,304,309,335,351,396,403,419,425,445,455,468,469,501,506,509,521,525,532,539,546,550,563,571,578,579,580,591,602,603,619,636,637,639,644,646,655,669,673,674,709,713,731,733,741,748,749,751,761,790,808,815,839,845,847,849,850,852,873,921,928,938,939,946,956,958,968,969,974,985,991,1000,1005,1006,1010,1030,1031,1044,1055,1057,1066,1073,1077,1092,1109,1114,1131,1169 9 | 7,16,20,32,40,43,52,79,84,94,101,108,136,145,149,172,186,191,198,200,204,206,228,232,242,262,270,283,291,292,307,325,355,359,373,382,391,402,404,410,421,423,440,441,454,457,458,463,484,486,511,513,514,515,516,524,531,538,560,575,592,617,631,647,658,664,666,670,683,696,702,716,722,726,743,750,757,795,804,816,817,819,855,865,866,875,877,882,887,893,895,901,907,922,927,945,977,999,1007,1041,1042,1052,1056,1060,1063,1065,1076,1079,1093,1094,1102,1117,1121,1158,1160,1161,1166 10 | 1,2,6,11,17,24,37,45,65,71,73,96,106,107,112,133,139,141,158,162,163,174,181,202,207,211,224,229,234,239,240,250,265,279,285,288,293,329,333,336,365,372,377,384,392,406,420,437,439,443,446,461,462,473,475,482,494,518,534,545,559,574,599,612,615,626,663,671,687,691,701,705,706,708,714,719,732,746,759,763,765,767,800,823,825,830,854,871,878,894,914,926,930,947,967,1001,1022,1023,1028,1032,1039,1061,1080,1084,1085,1097,1099,1116,1127,1133,1137,1144,1155,1159,1170,1171,1177 11 | -------------------------------------------------------------------------------- /data/TUs/DD_val.index: -------------------------------------------------------------------------------- 1 | 443,729,804,373,710,741,1139,212,394,584,899,791,233,663,432,1012,813,16,112,538,639,268,763,738,127,579,749,783,803,728,423,917,731,702,238,611,974,158,485,942,1000,1141,150,701,455,889,569,554,984,462,85,415,774,850,43,422,1003,1098,221,11,680,550,263,204,1089,599,162,245,987,525,348,631,552,622,953,280,97,471,409,841,880,802,718,876,3,506,1120,1169,1069,716,775,638,32,30,1090,491,110,570,581,475,539,1045,344,896,319,752,1173,299,811,6,927,482,214,21,403,362,269,386 2 | 663,752,384,196,495,526,17,369,518,971,231,1100,716,575,811,443,152,46,703,1035,833,755,100,870,1070,115,859,533,993,1078,978,471,323,99,506,892,347,502,1047,821,475,1145,1002,957,341,854,402,199,264,4,723,486,122,118,606,977,77,20,529,680,750,188,881,1084,1056,1015,628,80,92,1025,239,831,955,206,1095,766,237,326,984,609,154,695,288,542,985,602,277,856,79,218,896,1124,612,782,1104,659,283,437,1148,232,558,1022,507,717,303,908,91,433,1053,274,1176,21,555,653,370,803,300,593 3 | 818,701,316,606,314,984,476,502,288,840,778,404,767,127,214,549,338,667,1008,926,1144,339,649,1056,347,29,1112,520,472,739,765,23,604,644,37,493,655,156,293,873,263,389,161,253,695,513,30,666,1011,939,908,354,773,181,685,731,1167,172,646,812,940,680,128,813,25,108,657,1169,861,587,1033,315,762,391,185,720,41,495,901,481,707,598,932,620,630,1068,69,192,59,1060,724,852,692,776,467,264,1123,796,1018,1001,511,186,824,507,877,708,895,33,143,1162,91,490,225,738,577,938,671,593 4 | 337,125,280,1025,551,949,810,1037,82,540,1078,7,633,721,80,921,733,359,945,609,1056,116,953,719,833,28,924,47,500,1069,664,1169,918,1152,809,162,518,631,200,351,689,589,96,148,754,1054,555,987,983,1109,394,299,755,585,666,606,454,1177,581,295,511,1015,188,648,643,973,477,523,467,297,329,155,1061,598,1175,357,137,503,698,787,644,913,647,638,674,455,744,969,628,432,1004,1080,158,957,1024,946,457,1038,424,1089,469,629,1008,259,1121,434,20,358,980,526,214,686,830,642,820,940,873,97 5 | 452,1086,1042,516,415,4,915,299,877,533,373,678,1085,1131,19,701,501,238,1,56,278,321,369,317,66,766,1123,1000,394,310,1156,612,1017,198,1081,746,470,250,698,575,0,196,455,994,538,1112,445,289,1097,638,351,161,375,292,505,53,20,847,112,718,970,814,244,8,648,359,945,86,884,1047,241,1005,759,479,901,65,203,274,1011,647,525,262,212,981,917,894,25,715,873,1023,675,739,851,273,549,1104,965,724,1010,192,674,1074,157,852,862,84,67,821,843,253,234,947,209,727,738,201,634,232 6 | 328,558,950,338,36,1096,248,1041,713,185,1055,984,839,268,639,607,882,532,358,239,1165,68,977,690,1088,1006,1176,705,11,378,1144,701,798,410,885,645,65,891,833,321,238,63,640,284,375,95,72,510,440,608,133,698,1086,283,362,357,150,106,666,486,953,9,751,861,519,732,330,206,808,728,848,670,842,880,294,935,493,824,513,277,501,800,814,710,637,487,605,205,436,282,114,60,837,1065,780,422,56,426,382,1090,35,754,363,3,871,700,340,204,574,70,1071,874,823,392,807,859,1018,368 7 | 877,724,1066,636,527,653,621,21,469,602,313,572,890,31,152,367,433,1074,1154,945,953,17,897,966,298,978,363,1069,352,723,16,763,1091,687,590,1019,1083,129,774,670,1144,228,470,305,856,100,647,548,1031,715,285,574,846,609,788,735,57,714,1119,739,860,1176,820,699,504,1003,247,271,168,1121,434,384,643,847,20,307,947,48,661,222,139,145,237,215,1110,747,1131,272,965,484,1071,1166,956,478,116,674,402,842,796,505,40,173,134,179,510,80,806,212,204,18,99,1045,594,665,335,878,258,932 8 | 932,1076,793,308,745,1042,1126,158,128,783,905,156,872,786,772,1024,402,581,421,311,399,95,1091,449,127,322,435,154,479,1120,957,737,1059,534,1054,99,705,552,560,337,1105,920,410,738,708,391,37,347,488,628,973,211,665,1033,408,629,721,291,217,462,1117,152,328,342,117,699,1003,385,312,915,478,348,47,828,5,814,119,446,595,1021,514,1072,148,517,190,1138,4,961,577,168,1107,801,323,989,53,710,1040,429,732,1047,504,822,706,395,897,31,1130,165,1087,1084,338,254,555,201,614,298,255,334 9 | 297,885,177,1104,75,797,383,464,1066,706,994,905,157,782,503,152,988,765,661,455,1070,481,747,1006,479,651,934,605,490,432,980,171,87,1087,334,1154,1164,818,1120,756,655,586,1109,606,540,129,494,156,212,42,71,1045,837,926,777,409,103,970,1080,97,869,168,997,533,85,1000,734,387,14,1097,238,253,123,1040,1037,11,584,261,243,221,148,962,565,316,1124,1168,786,167,815,330,309,453,990,53,708,159,134,187,589,1096,367,919,667,1122,197,215,142,955,491,891,258,438,1103,257,898,338,572,548 10 | 620,206,284,185,84,186,1148,587,826,36,690,63,1043,694,150,604,502,1079,1029,327,64,1112,648,655,693,851,1083,1164,1117,331,1113,943,465,778,405,201,886,286,789,454,312,1027,83,353,579,902,480,1110,531,490,93,302,1111,991,1045,1086,844,833,893,850,20,196,489,535,1106,776,123,488,38,960,584,233,530,156,25,1122,698,990,666,729,47,403,659,1105,471,1042,1114,23,457,289,117,161,832,841,622,1124,792,219,616,627,330,251,49,624,455,1152,790,379,1038,309,846,956,401,1173,595,866,505,674 11 | -------------------------------------------------------------------------------- /data/TUs/ENZYMES_test.index: -------------------------------------------------------------------------------- 1 | 0,13,14,20,25,33,39,63,68,85,102,113,117,129,144,161,167,192,194,195,221,223,241,244,267,273,277,281,282,287,316,332,333,341,344,361,368,379,391,393,405,411,414,416,420,456,467,472,487,491,500,504,505,508,514,526,533,536,580,589 2 | 32,47,54,57,60,65,67,86,91,99,106,120,124,126,159,160,166,170,173,191,246,258,259,260,264,268,269,278,284,297,321,326,331,342,348,357,364,370,378,386,404,406,424,437,438,447,448,455,478,496,518,528,544,548,558,569,575,588,596,599 3 | 5,12,21,26,55,71,72,92,93,94,101,110,123,127,130,155,175,180,181,199,201,205,207,213,232,233,256,295,296,299,300,305,319,327,337,345,352,365,388,399,430,434,435,452,474,475,482,484,485,489,522,523,529,530,549,552,560,567,572,590 4 | 2,3,6,19,27,28,46,59,89,90,116,121,136,165,168,176,179,184,186,193,202,204,215,225,245,254,257,289,290,294,313,314,318,335,349,373,375,380,389,392,402,421,433,442,462,471,476,492,497,499,501,510,534,554,556,557,565,570,581,584 5 | 8,15,31,44,45,48,81,82,84,95,103,114,125,139,142,147,163,177,190,198,203,214,217,219,229,234,247,250,252,262,301,302,310,346,351,354,366,369,395,398,401,408,415,425,429,439,444,446,470,488,515,517,519,531,540,550,559,563,571,592 6 | 18,23,24,29,35,36,38,73,77,79,122,134,138,141,143,148,169,182,187,189,212,224,227,230,239,242,249,261,271,293,304,311,315,324,328,329,339,340,371,385,400,403,412,417,432,445,464,469,479,493,509,520,521,524,543,562,578,579,587,595 7 | 30,34,43,49,56,64,66,69,88,98,132,145,151,152,153,158,164,172,174,185,206,209,210,226,251,253,263,274,276,285,303,306,312,323,353,355,359,367,374,381,407,409,426,453,454,457,459,463,477,498,525,527,538,545,551,561,574,576,583,586 8 | 1,4,41,50,62,74,75,76,83,96,109,112,115,119,131,154,156,157,162,196,211,220,222,228,240,243,248,270,280,298,320,322,334,336,343,358,377,384,394,397,410,428,436,440,443,460,461,486,490,494,506,512,513,553,555,564,566,593,594,597 9 | 7,9,16,40,42,51,52,53,70,87,107,111,118,128,135,137,140,178,183,188,208,218,235,237,255,272,275,283,288,291,307,325,330,347,350,356,360,363,372,396,413,422,431,441,449,458,466,473,480,483,502,507,532,535,537,539,541,577,585,598 10 | 10,11,17,22,37,58,61,78,80,97,100,104,105,108,133,146,149,150,171,197,200,216,231,236,238,265,266,279,286,292,308,309,317,338,362,376,382,383,387,390,418,419,423,427,450,451,465,468,481,495,503,511,516,542,546,547,568,573,582,591 11 | -------------------------------------------------------------------------------- /data/TUs/ENZYMES_val.index: -------------------------------------------------------------------------------- 1 | 328,431,464,451,9,78,153,236,4,80,182,441,225,257,83,199,245,88,426,327,513,390,126,463,157,301,170,304,594,554,540,204,401,380,324,330,497,578,178,310,134,567,384,296,572,547,97,17,565,545,50,400,65,276,492,235,168,252,218,171 2 | 319,531,435,393,385,581,12,372,339,110,460,49,276,343,107,142,174,74,271,51,461,398,175,266,564,58,467,506,306,204,436,469,7,30,407,92,257,273,101,526,240,598,38,138,380,237,507,495,395,33,123,210,504,256,464,111,592,454,562,155 3 | 165,77,22,331,69,119,202,240,393,45,497,573,315,193,222,28,56,397,162,8,509,317,373,408,163,192,442,234,84,525,355,188,289,496,504,219,301,441,133,269,142,4,288,445,339,292,440,505,524,279,18,469,536,562,423,343,554,416,117,557 4 | 228,69,237,170,244,488,112,406,97,164,326,291,191,64,56,466,238,541,574,520,397,33,100,493,309,347,354,563,22,503,310,583,549,430,319,31,98,321,317,53,276,78,118,438,585,230,542,207,457,423,135,139,288,419,129,261,390,490,101,517 5 | 73,525,36,496,378,206,59,566,154,367,221,431,174,162,150,167,159,402,197,209,141,228,134,35,386,280,11,584,474,452,565,467,244,430,309,594,233,362,288,72,164,537,92,389,508,486,521,316,208,66,23,359,490,242,65,538,341,328,494,523 6 | 210,309,377,104,590,565,388,240,491,576,500,117,391,76,247,527,406,273,552,467,132,176,7,248,37,449,156,450,363,147,473,14,82,250,56,599,418,214,383,43,1,322,443,302,170,396,85,181,151,364,559,413,287,135,253,280,526,546,6,472 7 | 247,334,188,262,455,341,548,31,96,248,521,203,11,268,54,195,423,495,298,180,40,92,336,537,7,491,170,432,479,313,166,591,131,549,587,510,292,223,392,340,384,534,228,39,469,1,378,442,149,530,544,147,106,431,62,319,317,227,130,496 8 | 210,535,332,402,524,173,596,335,598,585,216,423,400,374,453,559,16,372,309,275,174,376,338,287,99,544,256,179,236,323,30,439,3,499,205,182,357,548,446,108,5,496,509,429,237,107,47,147,92,134,71,144,12,325,403,186,207,234,504,84 9 | 432,105,572,203,38,271,273,524,499,582,589,301,172,163,368,193,560,13,200,412,130,421,108,72,401,554,547,386,450,277,84,513,553,407,404,355,119,494,220,467,233,364,323,590,20,294,239,319,329,150,246,104,59,376,181,49,73,25,357,18 10 | 119,490,226,48,291,272,274,564,303,557,474,124,82,85,320,242,377,203,483,459,380,112,111,98,86,441,401,79,62,597,99,453,354,388,430,121,198,38,561,164,153,344,509,329,539,438,350,540,49,165,235,569,553,447,148,224,375,500,211,297 11 | -------------------------------------------------------------------------------- /data/TUs/PROTEINS_full_test.index: -------------------------------------------------------------------------------- 1 | 3,6,19,36,42,45,53,55,64,69,82,84,89,97,122,150,167,174,183,192,193,202,203,211,229,233,243,249,251,264,297,303,308,309,313,353,356,363,374,382,389,391,393,398,408,417,441,446,448,455,460,461,499,512,518,532,533,547,566,569,579,598,602,605,611,652,661,665,668,673,676,679,684,692,698,715,723,744,753,756,765,772,782,783,801,828,830,851,864,877,892,893,896,899,904,921,930,939,948,949,953,959,962,975,983,994,1009,1043,1044,1069,1075,1080 2 | 17,18,25,44,66,68,71,86,98,100,114,121,124,127,145,152,154,169,170,176,180,184,194,215,223,236,244,255,276,279,293,307,310,312,349,367,375,394,395,403,406,421,430,439,464,472,473,501,510,528,530,550,558,562,567,572,577,580,586,594,599,609,615,618,631,636,645,686,694,697,699,705,706,713,720,734,743,745,755,762,769,786,788,793,809,811,821,850,856,866,876,884,903,920,925,929,931,977,992,1007,1020,1028,1042,1053,1056,1057,1063,1073,1077,1082,1089,1096 3 | 0,1,2,4,20,22,24,35,40,48,51,63,74,83,107,112,120,139,143,144,153,165,190,197,204,210,217,219,222,240,248,253,260,270,278,286,320,321,343,345,347,358,370,379,383,404,411,414,423,442,470,474,477,479,482,486,490,504,560,568,600,603,620,623,649,656,657,671,672,682,687,714,716,722,741,778,790,792,812,815,816,817,820,825,831,849,861,863,865,881,895,901,916,918,935,938,941,942,957,973,1003,1004,1005,1008,1017,1024,1025,1029,1046,1048,1078,1093 4 | 28,29,31,32,61,76,77,106,115,116,128,129,133,136,140,159,160,173,212,214,220,234,235,265,274,282,283,287,294,295,315,318,324,327,333,335,338,351,361,366,376,405,416,419,428,431,435,440,458,475,485,487,491,502,520,522,534,539,542,555,556,576,593,608,614,617,664,670,689,696,707,708,711,721,733,742,746,751,757,766,779,785,787,827,844,848,857,859,872,873,878,898,902,919,922,933,963,965,978,988,991,1000,1002,1027,1037,1040,1041,1054,1079,1094,1106 5 | 5,11,16,37,46,49,50,52,59,78,119,132,146,147,163,172,178,179,201,232,259,267,275,289,311,322,325,352,360,362,380,388,399,412,413,420,432,457,459,465,478,481,492,494,505,511,513,521,527,531,537,548,549,571,575,582,585,616,619,627,633,639,650,651,655,658,666,677,683,691,717,718,725,731,735,739,760,775,795,802,806,810,813,818,824,860,888,900,905,907,908,927,945,954,958,966,968,972,979,982,995,1006,1012,1051,1055,1074,1090,1102,1103,1104,1110 6 | 9,15,41,54,56,57,60,72,92,96,104,123,125,126,131,135,138,151,161,162,171,175,187,189,195,205,207,250,257,258,263,268,284,301,314,316,317,331,348,359,390,397,401,427,437,444,449,450,454,456,484,489,497,508,516,517,523,544,557,563,570,573,588,621,646,648,667,703,704,727,728,730,747,749,771,776,780,781,803,814,822,829,853,868,874,880,897,912,913,926,944,947,960,964,974,976,980,990,999,1032,1045,1050,1060,1061,1068,1070,1076,1084,1091,1097,1108 7 | 8,23,34,39,47,65,73,85,91,95,101,105,141,148,155,156,164,177,181,185,186,196,216,218,225,237,241,246,247,256,261,269,273,277,280,285,291,334,336,342,350,364,371,377,384,407,410,415,424,443,445,493,503,509,525,564,589,596,601,606,612,624,632,635,638,659,669,680,685,688,693,695,701,709,712,738,740,750,759,784,797,798,845,855,858,862,890,923,932,936,937,951,952,956,961,970,985,1010,1014,1021,1022,1033,1047,1052,1058,1062,1081,1085,1105,1107,1112 8 | 10,13,21,26,30,75,88,94,102,108,109,113,142,168,191,198,208,213,227,228,231,239,252,302,305,306,328,355,357,369,373,381,392,400,402,409,425,426,429,452,453,467,469,471,476,483,495,498,500,536,551,552,553,559,578,587,595,610,622,628,641,643,644,653,654,660,674,678,710,724,736,737,752,761,768,794,804,839,843,846,847,852,854,870,871,875,910,915,924,928,943,967,969,987,996,997,998,1031,1036,1039,1049,1059,1064,1065,1072,1083,1087,1092,1095,1101,1111 9 | 7,14,33,58,70,80,90,93,111,117,130,149,166,188,199,206,209,221,230,242,254,266,271,272,281,288,290,292,296,319,323,330,332,337,368,372,378,385,387,396,418,422,434,462,463,468,515,519,524,535,540,541,543,561,574,581,584,590,591,597,604,629,630,637,640,642,663,681,690,748,758,764,774,777,789,791,805,807,808,823,826,832,835,836,837,838,842,869,885,886,887,906,909,914,917,934,971,981,986,989,993,1011,1013,1015,1018,1019,1035,1038,1086,1098,1100 10 | 12,27,38,43,62,67,79,81,87,99,103,110,118,134,137,157,158,182,200,224,226,238,245,262,298,299,300,304,326,329,339,340,341,344,346,354,365,386,433,436,438,447,451,466,480,488,496,506,507,514,526,529,538,545,546,554,565,583,592,607,613,625,626,634,647,662,675,700,702,719,726,729,732,754,763,767,770,773,796,799,800,819,833,834,840,841,867,879,882,883,889,891,894,911,940,946,950,955,984,1001,1016,1023,1026,1030,1034,1066,1067,1071,1088,1099,1109 11 | -------------------------------------------------------------------------------- /data/TUs/PROTEINS_full_val.index: -------------------------------------------------------------------------------- 1 | 179,29,726,285,501,1005,780,790,738,1104,259,775,789,956,540,403,232,854,815,694,468,147,134,95,1112,331,616,539,184,778,1092,205,258,226,469,436,933,66,695,867,634,239,479,1036,526,641,796,745,48,168,549,112,434,632,873,320,794,626,462,1033,27,599,261,856,596,741,131,482,360,488,31,875,387,804,367,72,437,580,527,1026,761,1089,766,642,895,846,16,123,728,30,344,369,336,770,879,885,971,1,83,629,720,476,998,729,100,615,806,260,333,888,10,901 2 | 1001,1006,716,38,117,444,535,313,292,930,346,843,1031,527,372,965,829,663,613,398,801,476,167,291,344,318,569,806,650,216,452,72,1090,687,629,242,1081,9,649,985,675,1032,161,563,963,228,571,348,973,831,199,896,1014,241,841,1092,587,628,1040,245,747,287,898,369,708,728,750,867,736,434,685,1095,122,368,319,80,387,524,270,871,139,534,153,288,988,1038,30,150,168,511,94,357,410,297,551,928,1049,975,471,208,196,777,574,90,460,1002,937,909,386,197,826,97 3 | 680,106,1094,754,1033,400,629,993,85,569,272,265,739,49,1034,262,335,87,732,312,638,685,424,1075,91,706,230,191,99,186,1069,681,1106,384,1053,401,991,237,513,824,10,231,937,931,786,854,914,127,608,182,611,780,105,522,52,1009,77,480,333,870,689,348,562,227,1090,315,161,719,247,903,902,50,519,667,1002,758,1111,517,1021,956,823,542,832,920,156,663,555,521,983,999,468,434,924,516,923,619,141,494,90,229,192,175,808,661,501,371,573,394,70,124,126,635 4 | 903,166,826,1111,663,1004,227,1056,19,1069,937,797,585,1101,208,256,72,65,413,436,1003,382,132,247,653,211,814,762,602,4,311,97,1038,1029,444,589,811,528,786,433,778,319,334,641,486,293,1044,155,403,237,296,192,998,149,38,425,18,99,983,363,940,798,392,258,1108,277,720,592,1070,3,348,1021,563,355,188,1030,418,175,337,889,784,645,544,526,1051,996,148,1010,1092,157,196,40,701,974,402,1007,1033,1016,286,1055,134,739,788,759,559,774,183,239,853,64,412,540 5 | 458,895,623,706,987,934,479,234,628,914,263,105,522,112,716,766,823,1005,345,819,308,975,142,1086,872,1075,386,17,1046,547,967,361,1099,327,744,771,93,1087,796,539,550,19,1111,40,227,430,703,485,253,64,745,174,84,45,74,428,1031,123,826,742,250,87,789,215,898,638,1082,1048,409,682,858,596,787,1069,942,378,192,204,439,205,264,851,480,125,970,422,356,307,406,769,922,177,792,614,1011,570,440,365,330,302,350,221,1018,593,162,136,66,1065,394,334,711,246 6 | 283,320,581,40,107,796,363,922,148,286,1008,461,298,613,1067,373,85,17,153,408,357,904,1044,115,720,227,1081,38,719,247,59,36,838,946,672,700,480,1011,821,925,3,26,567,1043,524,655,394,696,302,1017,948,1090,334,622,729,238,1052,686,322,170,155,439,494,324,593,50,276,718,654,519,716,629,927,958,101,942,45,475,109,329,1004,1087,438,981,865,82,333,799,597,663,694,811,1023,1082,274,147,446,920,409,105,992,380,124,687,128,1096,929,398,1051,396,413,243 7 | 121,301,697,993,184,552,1039,425,226,1084,478,853,579,636,1024,434,90,1088,403,399,929,962,1064,592,1104,887,634,182,647,431,288,637,149,305,142,40,231,194,708,598,518,946,700,718,117,220,769,104,912,720,631,973,719,968,347,143,442,618,1051,353,992,307,471,864,1063,599,433,396,77,528,474,777,728,1025,488,1092,832,457,55,397,88,945,950,297,903,696,1108,234,1056,374,130,536,771,439,1002,860,611,114,27,268,665,420,479,668,679,646,607,990,608,192,376,917 8 | 1050,439,703,950,197,615,939,938,124,352,248,627,573,722,300,694,581,269,82,760,853,14,1103,877,840,130,490,122,819,71,896,535,849,488,280,1033,458,284,279,781,640,177,497,390,1081,925,676,1051,708,98,1085,226,972,1053,78,519,482,693,68,353,404,919,335,816,391,199,485,472,407,304,624,667,360,651,769,746,958,235,258,39,663,580,807,477,1088,388,188,491,1029,812,154,234,356,1082,246,569,325,116,19,792,450,650,1009,898,86,827,1017,1019,1002,303,35,657 9 | 486,458,653,238,1007,279,624,929,27,237,85,773,772,298,293,918,617,593,761,202,1025,677,487,962,722,18,945,363,957,725,351,1014,169,204,717,155,825,586,570,697,488,411,529,657,847,394,979,1088,527,1001,852,181,325,24,299,71,215,36,562,481,861,28,447,820,695,276,633,927,643,902,724,1042,201,1068,938,311,683,824,261,11,560,522,498,244,616,1055,804,526,995,546,343,104,457,648,65,510,456,850,1045,16,855,706,844,747,868,460,878,273,453,109,606,1096 10 | 757,999,756,250,926,644,824,464,434,281,26,476,1027,343,165,1076,629,568,922,132,254,385,906,236,486,356,289,1053,1044,440,826,59,16,169,860,713,593,500,1013,896,167,608,311,885,475,127,282,930,694,82,1068,112,901,393,408,652,996,521,1022,936,980,44,571,806,248,1105,186,1064,1075,780,705,187,28,363,146,697,271,168,572,898,588,704,51,904,163,893,1058,720,192,302,351,516,175,960,306,557,673,285,864,858,1039,69,561,56,998,272,654,508,442,251,74,944 11 | -------------------------------------------------------------------------------- /data/TUs/README.md: -------------------------------------------------------------------------------- 1 | ### Details: 2 | - Split total number of graphs into 3 (train, val and test) in 80:10:10 3 | - Stratified split proportionate to original distribution of data with respect to classes 4 | - Using sklearn to perform the split and then save the indexes 5 | - Preparing 10 such combinations of indexes split to be used in Graph NNs 6 | - As with KFold, each of the 10 fold have unique test set. -------------------------------------------------------------------------------- /data/data.py: -------------------------------------------------------------------------------- 1 | """ 2 | File to load dataset based on user control from main file 3 | """ 4 | from data.TUs import TUsDataset 5 | 6 | 7 | def LoadData(DATASET_NAME): 8 | """ 9 | This function is called in the main.py file 10 | returns: 11 | ; dataset object 12 | """ 13 | 14 | # handling for the TU Datasets 15 | TU_DATASETS = ['ENZYMES', 'DD', 'PROTEINS_full'] 16 | if DATASET_NAME in TU_DATASETS: 17 | return TUsDataset(DATASET_NAME) 18 | -------------------------------------------------------------------------------- /layers/diffpool_layer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import functional as F 4 | import numpy as np 5 | from scipy.linalg import block_diag 6 | 7 | from torch.autograd import Function 8 | 9 | """ 10 | DIFFPOOL: 11 | Z. Ying, J. You, C. Morris, X. Ren, W. Hamilton, and J. Leskovec, 12 | Hierarchical graph representation learning with differentiable pooling (NeurIPS 2018) 13 | https://arxiv.org/pdf/1806.08804.pdf 14 | 15 | ! code started from dgl diffpool examples dir 16 | """ 17 | 18 | from layers.graphsage_layer import GraphSageLayer, DenseGraphSage 19 | 20 | 21 | def masked_softmax(matrix, mask, dim=-1, memory_efficient=True, 22 | mask_fill_value=-1e32): 23 | ''' 24 | masked_softmax for dgl batch graph 25 | code snippet contributed by AllenNLP (https://github.com/allenai/allennlp) 26 | ''' 27 | if mask is None: 28 | result = torch.nn.functional.softmax(matrix, dim=dim) 29 | else: 30 | mask = mask.float() 31 | while mask.dim() < matrix.dim(): 32 | mask = mask.unsqueeze(1) 33 | if not memory_efficient: 34 | result = torch.nn.functional.softmax(matrix * mask, dim=dim) 35 | result = result * mask 36 | result = result / (result.sum(dim=dim, keepdim=True) + 1e-13) 37 | else: 38 | masked_matrix = matrix.masked_fill((1 - mask).byte(), 39 | mask_fill_value) 40 | result = torch.nn.functional.softmax(masked_matrix, dim=dim) 41 | return result 42 | 43 | 44 | class EntropyLoss(nn.Module): 45 | # Return Scalar 46 | # loss used in diffpool 47 | def forward(self, adj, anext, s_l): 48 | entropy = (torch.distributions.Categorical( 49 | probs=s_l).entropy()).sum(-1).mean(-1) 50 | assert not torch.isnan(entropy) 51 | return entropy 52 | 53 | 54 | class DiffPoolLayer(nn.Module): 55 | 56 | def __init__(self, input_dim, assign_dim, output_feat_dim, 57 | activation, dropout, aggregator_type, link_pred, batch_norm): 58 | super().__init__() 59 | self.embedding_dim = input_dim 60 | self.assign_dim = assign_dim 61 | self.hidden_dim = output_feat_dim 62 | self.link_pred = link_pred 63 | self.feat_gc = GraphSageLayer( 64 | input_dim, 65 | output_feat_dim, 66 | activation, 67 | dropout, 68 | aggregator_type, 69 | batch_norm) 70 | self.pool_gc = GraphSageLayer( 71 | input_dim, 72 | assign_dim, 73 | activation, 74 | dropout, 75 | aggregator_type, 76 | batch_norm) 77 | self.reg_loss = nn.ModuleList([]) 78 | self.loss_log = {} 79 | self.reg_loss.append(EntropyLoss()) 80 | 81 | def forward(self, g, h, e=None): 82 | feat, e = self.feat_gc(g, h, e) 83 | assign_tensor, e = self.pool_gc(g, h, e) 84 | # print(assign_tensor.shape) 85 | device = feat.device 86 | assign_tensor_masks = [] 87 | batch_size = len(g.batch_num_nodes()) 88 | for g_n_nodes in g.batch_num_nodes(): 89 | mask = torch.ones((g_n_nodes, 90 | int(assign_tensor.size()[1] / batch_size))) 91 | assign_tensor_masks.append(mask) 92 | # print(assign_tensor_masks) 93 | """ 94 | The first pooling layer is computed on batched graph. 95 | We first take the adjacency matrix of the batched graph, which is block-wise diagonal. 96 | We then compute the assignment matrix for the whole batch graph, which will also be block diagonal 97 | """ 98 | mask = torch.FloatTensor( 99 | block_diag( 100 | * 101 | assign_tensor_masks)).to( 102 | device=device) 103 | 104 | assign_tensor = masked_softmax(assign_tensor, mask, 105 | memory_efficient=False) 106 | # print(assign_tensor.shape) 107 | h = torch.matmul(torch.t(assign_tensor), feat) # equation (3) of DIFFPOOL paper 108 | adj = g.adjacency_matrix(ctx=device) 109 | 110 | adj_new = torch.sparse.mm(adj, assign_tensor) 111 | adj_new = torch.mm(torch.t(assign_tensor), adj_new) # equation (4) of DIFFPOOL paper 112 | 113 | if self.link_pred: 114 | current_lp_loss = torch.norm(adj.to_dense() - 115 | torch.mm(assign_tensor, torch.t(assign_tensor))) / np.power(g.number_of_nodes(), 2) 116 | self.loss_log['LinkPredLoss'] = current_lp_loss 117 | 118 | for loss_layer in self.reg_loss: 119 | loss_name = str(type(loss_layer).__name__) 120 | self.loss_log[loss_name] = loss_layer(adj, adj_new, assign_tensor) 121 | return adj_new, h 122 | 123 | 124 | class LinkPredLoss(nn.Module): 125 | # loss used in diffpool 126 | def forward(self, adj, anext, s_l): 127 | link_pred_loss = ( 128 | adj - s_l.matmul(s_l.transpose(-1, -2))).norm(dim=(1, 2)) 129 | link_pred_loss = link_pred_loss / (adj.size(1) * adj.size(2)) 130 | return link_pred_loss.mean() 131 | 132 | 133 | class DenseDiffPool(nn.Module): 134 | def __init__(self, nfeat, nnext, nhid, link_pred=False, entropy=True): 135 | super().__init__() 136 | self.link_pred = link_pred 137 | self.log = {} 138 | self.link_pred_layer = self.LinkPredLoss() 139 | self.embed = DenseGraphSage(nfeat, nhid, use_bn=True) 140 | self.assign = DiffPoolAssignment(nfeat, nnext) 141 | self.reg_loss = nn.ModuleList([]) 142 | self.loss_log = {} 143 | if link_pred: 144 | self.reg_loss.append(LinkPredLoss()) 145 | if entropy: 146 | self.reg_loss.append(EntropyLoss()) 147 | 148 | def forward(self, x, adj, log=False): 149 | z_l = self.embed(x, adj) 150 | s_l = self.assign(x, adj) 151 | if log: 152 | self.log['s'] = s_l.cpu().numpy() 153 | xnext = torch.matmul(s_l.transpose(-1, -2), z_l) 154 | anext = (s_l.transpose(-1, -2)).matmul(adj).matmul(s_l) 155 | 156 | for loss_layer in self.reg_loss: 157 | loss_name = str(type(loss_layer).__name__) 158 | self.loss_log[loss_name] = loss_layer(adj, anext, s_l) 159 | if log: 160 | self.log['a'] = anext.cpu().numpy() 161 | return xnext, anext 162 | 163 | class DiffPoolAssignment(nn.Module): 164 | def __init__(self, nfeat, nnext): 165 | super().__init__() 166 | self.assign_mat = DenseGraphSage(nfeat, nnext, use_bn=True) 167 | 168 | def forward(self, x, adj, log=False): 169 | s_l_init = self.assign_mat(x, adj) 170 | s_l = F.softmax(s_l_init, dim=-1) 171 | return s_l 172 | -------------------------------------------------------------------------------- /layers/gat_layer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from dgl.nn.pytorch import GATConv 6 | 7 | """ 8 | GAT: Graph Attention Network 9 | Graph Attention Networks (Veličković et al., ICLR 2018) 10 | https://arxiv.org/abs/1710.10903 11 | """ 12 | 13 | class GATLayer(nn.Module): 14 | """ 15 | Parameters 16 | ---------- 17 | in_dim : 18 | Number of input features. 19 | out_dim : 20 | Number of output features. 21 | num_heads : int 22 | Number of heads in Multi-Head Attention. 23 | dropout : 24 | Required for dropout of attn and feat in GATConv 25 | batch_norm : 26 | boolean flag for batch_norm layer. 27 | residual : 28 | If True, use residual connection inside this layer. Default: ``False``. 29 | activation : callable activation function/layer or None, optional. 30 | If not None, applies an activation function to the updated node features. 31 | 32 | Using dgl builtin GATConv by default: 33 | https://github.com/graphdeeplearning/benchmarking-gnns/commit/206e888ecc0f8d941c54e061d5dffcc7ae2142fc 34 | """ 35 | def __init__(self, in_dim, out_dim, num_heads, dropout, batch_norm, residual=False, activation=F.elu): 36 | super().__init__() 37 | self.residual = residual 38 | self.activation = activation 39 | self.batch_norm = batch_norm 40 | 41 | if in_dim != (out_dim*num_heads): 42 | self.residual = False 43 | 44 | self.gatconv = GATConv(in_dim, out_dim, num_heads, dropout, dropout) 45 | 46 | if self.batch_norm: 47 | self.batchnorm_h = nn.BatchNorm1d(out_dim * num_heads) 48 | 49 | def forward(self, g, h): 50 | h_in = h # for residual connection 51 | 52 | h = self.gatconv(g, h).flatten(1) 53 | 54 | if self.batch_norm: 55 | h = self.batchnorm_h(h) 56 | 57 | if self.activation: 58 | h = self.activation(h) 59 | 60 | if self.residual: 61 | h = h_in + h # residual connection 62 | 63 | return h 64 | 65 | 66 | ############################################################## 67 | # 68 | # Additional layers for edge feature/representation analysis 69 | # 70 | ############################################################## 71 | 72 | 73 | class CustomGATHeadLayer(nn.Module): 74 | def __init__(self, in_dim, out_dim, dropout, batch_norm): 75 | super().__init__() 76 | self.dropout = dropout 77 | self.batch_norm = batch_norm 78 | 79 | self.fc = nn.Linear(in_dim, out_dim, bias=False) 80 | self.attn_fc = nn.Linear(2 * out_dim, 1, bias=False) 81 | self.batchnorm_h = nn.BatchNorm1d(out_dim) 82 | 83 | def edge_attention(self, edges): 84 | z2 = torch.cat([edges.src['z'], edges.dst['z']], dim=1) 85 | a = self.attn_fc(z2) 86 | return {'e': F.leaky_relu(a)} 87 | 88 | def message_func(self, edges): 89 | return {'z': edges.src['z'], 'e': edges.data['e']} 90 | 91 | def reduce_func(self, nodes): 92 | alpha = F.softmax(nodes.mailbox['e'], dim=1) 93 | alpha = F.dropout(alpha, self.dropout, training=self.training) 94 | h = torch.sum(alpha * nodes.mailbox['z'], dim=1) 95 | return {'h': h} 96 | 97 | def forward(self, g, h): 98 | z = self.fc(h) 99 | g.ndata['z'] = z 100 | g.apply_edges(self.edge_attention) 101 | g.update_all(self.message_func, self.reduce_func) 102 | h = g.ndata['h'] 103 | 104 | if self.batch_norm: 105 | h = self.batchnorm_h(h) 106 | 107 | h = F.elu(h) 108 | 109 | h = F.dropout(h, self.dropout, training=self.training) 110 | 111 | return h 112 | 113 | 114 | class CustomGATLayer(nn.Module): 115 | """ 116 | Param: [in_dim, out_dim, n_heads] 117 | """ 118 | def __init__(self, in_dim, out_dim, num_heads, dropout, batch_norm, residual=True): 119 | super().__init__() 120 | 121 | self.in_channels = in_dim 122 | self.out_channels = out_dim 123 | self.num_heads = num_heads 124 | self.residual = residual 125 | 126 | if in_dim != (out_dim*num_heads): 127 | self.residual = False 128 | 129 | self.heads = nn.ModuleList() 130 | for i in range(num_heads): 131 | self.heads.append(CustomGATHeadLayer(in_dim, out_dim, dropout, batch_norm)) 132 | self.merge = 'cat' 133 | 134 | def forward(self, g, h, e): 135 | h_in = h # for residual connection 136 | 137 | head_outs = [attn_head(g, h) for attn_head in self.heads] 138 | 139 | if self.merge == 'cat': 140 | h = torch.cat(head_outs, dim=1) 141 | else: 142 | h = torch.mean(torch.stack(head_outs)) 143 | 144 | if self.residual: 145 | h = h_in + h # residual connection 146 | 147 | return h, e 148 | 149 | def __repr__(self): 150 | return '{}(in_channels={}, out_channels={}, heads={}, residual={})'.format(self.__class__.__name__, 151 | self.in_channels, 152 | self.out_channels, self.num_heads, self.residual) 153 | 154 | 155 | ############################################################## 156 | 157 | 158 | class CustomGATHeadLayerEdgeReprFeat(nn.Module): 159 | def __init__(self, in_dim, out_dim, dropout, batch_norm): 160 | super().__init__() 161 | self.dropout = dropout 162 | self.batch_norm = batch_norm 163 | 164 | self.fc_h = nn.Linear(in_dim, out_dim, bias=False) 165 | self.fc_e = nn.Linear(in_dim, out_dim, bias=False) 166 | self.fc_proj = nn.Linear(3* out_dim, out_dim) 167 | self.attn_fc = nn.Linear(3* out_dim, 1, bias=False) 168 | self.batchnorm_h = nn.BatchNorm1d(out_dim) 169 | self.batchnorm_e = nn.BatchNorm1d(out_dim) 170 | 171 | def edge_attention(self, edges): 172 | z = torch.cat([edges.data['z_e'], edges.src['z_h'], edges.dst['z_h']], dim=1) 173 | e_proj = self.fc_proj(z) 174 | attn = F.leaky_relu(self.attn_fc(z)) 175 | return {'attn': attn, 'e_proj': e_proj} 176 | 177 | def message_func(self, edges): 178 | return {'z': edges.src['z_h'], 'attn': edges.data['attn']} 179 | 180 | def reduce_func(self, nodes): 181 | alpha = F.softmax(nodes.mailbox['attn'], dim=1) 182 | h = torch.sum(alpha * nodes.mailbox['z'], dim=1) 183 | return {'h': h} 184 | 185 | def forward(self, g, h, e): 186 | z_h = self.fc_h(h) 187 | z_e = self.fc_e(e) 188 | g.ndata['z_h'] = z_h 189 | g.edata['z_e'] = z_e 190 | 191 | g.apply_edges(self.edge_attention) 192 | 193 | g.update_all(self.message_func, self.reduce_func) 194 | 195 | h = g.ndata['h'] 196 | e = g.edata['e_proj'] 197 | 198 | if self.batch_norm: 199 | h = self.batchnorm_h(h) 200 | e = self.batchnorm_e(e) 201 | 202 | h = F.elu(h) 203 | e = F.elu(e) 204 | 205 | h = F.dropout(h, self.dropout, training=self.training) 206 | e = F.dropout(e, self.dropout, training=self.training) 207 | 208 | return h, e 209 | 210 | 211 | class CustomGATLayerEdgeReprFeat(nn.Module): 212 | """ 213 | Param: [in_dim, out_dim, n_heads] 214 | """ 215 | def __init__(self, in_dim, out_dim, num_heads, dropout, batch_norm, residual=True): 216 | super().__init__() 217 | 218 | self.in_channels = in_dim 219 | self.out_channels = out_dim 220 | self.num_heads = num_heads 221 | self.residual = residual 222 | 223 | if in_dim != (out_dim*num_heads): 224 | self.residual = False 225 | 226 | self.heads = nn.ModuleList() 227 | for i in range(num_heads): 228 | self.heads.append(CustomGATHeadLayerEdgeReprFeat(in_dim, out_dim, dropout, batch_norm)) 229 | self.merge = 'cat' 230 | 231 | def forward(self, g, h, e): 232 | h_in = h # for residual connection 233 | e_in = e 234 | 235 | head_outs_h = [] 236 | head_outs_e = [] 237 | for attn_head in self.heads: 238 | h_temp, e_temp = attn_head(g, h, e) 239 | head_outs_h.append(h_temp) 240 | head_outs_e.append(e_temp) 241 | 242 | if self.merge == 'cat': 243 | h = torch.cat(head_outs_h, dim=1) 244 | e = torch.cat(head_outs_e, dim=1) 245 | else: 246 | raise NotImplementedError 247 | 248 | if self.residual: 249 | h = h_in + h # residual connection 250 | e = e_in + e 251 | 252 | return h, e 253 | 254 | def __repr__(self): 255 | return '{}(in_channels={}, out_channels={}, heads={}, residual={})'.format(self.__class__.__name__, 256 | self.in_channels, 257 | self.out_channels, self.num_heads, self.residual) 258 | 259 | 260 | ############################################################## 261 | 262 | 263 | class CustomGATHeadLayerIsotropic(nn.Module): 264 | def __init__(self, in_dim, out_dim, dropout, batch_norm): 265 | super().__init__() 266 | self.dropout = dropout 267 | self.batch_norm = batch_norm 268 | 269 | self.fc = nn.Linear(in_dim, out_dim, bias=False) 270 | self.batchnorm_h = nn.BatchNorm1d(out_dim) 271 | 272 | def message_func(self, edges): 273 | return {'z': edges.src['z']} 274 | 275 | def reduce_func(self, nodes): 276 | h = torch.sum(nodes.mailbox['z'], dim=1) 277 | return {'h': h} 278 | 279 | def forward(self, g, h): 280 | z = self.fc(h) 281 | g.ndata['z'] = z 282 | g.update_all(self.message_func, self.reduce_func) 283 | h = g.ndata['h'] 284 | 285 | if self.batch_norm: 286 | h = self.batchnorm_h(h) 287 | 288 | h = F.elu(h) 289 | 290 | h = F.dropout(h, self.dropout, training=self.training) 291 | 292 | return h 293 | 294 | 295 | class CustomGATLayerIsotropic(nn.Module): 296 | """ 297 | Param: [in_dim, out_dim, n_heads] 298 | """ 299 | def __init__(self, in_dim, out_dim, num_heads, dropout, batch_norm, residual=True): 300 | super().__init__() 301 | 302 | self.in_channels = in_dim 303 | self.out_channels = out_dim 304 | self.num_heads = num_heads 305 | self.residual = residual 306 | 307 | if in_dim != (out_dim*num_heads): 308 | self.residual = False 309 | 310 | self.heads = nn.ModuleList() 311 | for i in range(num_heads): 312 | self.heads.append(CustomGATHeadLayerIsotropic(in_dim, out_dim, dropout, batch_norm)) 313 | self.merge = 'cat' 314 | 315 | def forward(self, g, h, e): 316 | h_in = h # for residual connection 317 | 318 | head_outs = [attn_head(g, h) for attn_head in self.heads] 319 | 320 | if self.merge == 'cat': 321 | h = torch.cat(head_outs, dim=1) 322 | else: 323 | h = torch.mean(torch.stack(head_outs)) 324 | 325 | if self.residual: 326 | h = h_in + h # residual connection 327 | 328 | return h, e 329 | 330 | def __repr__(self): 331 | return '{}(in_channels={}, out_channels={}, heads={}, residual={})'.format(self.__class__.__name__, 332 | self.in_channels, 333 | self.out_channels, self.num_heads, self.residual) 334 | -------------------------------------------------------------------------------- /layers/gated_gcn_layer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import dgl.function as fn 5 | 6 | """ 7 | ResGatedGCN: Residual Gated Graph ConvNets 8 | An Experimental Study of Neural Networks for Variable Graphs (Xavier Bresson and Thomas Laurent, ICLR 2018) 9 | https://arxiv.org/pdf/1711.07553v2.pdf 10 | """ 11 | 12 | class GatedGCNLayer(nn.Module): 13 | """ 14 | Param: [] 15 | """ 16 | def __init__(self, input_dim, output_dim, dropout, batch_norm, residual=False): 17 | super().__init__() 18 | self.in_channels = input_dim 19 | self.out_channels = output_dim 20 | self.dropout = dropout 21 | self.batch_norm = batch_norm 22 | self.residual = residual 23 | 24 | if input_dim != output_dim: 25 | self.residual = False 26 | 27 | self.A = nn.Linear(input_dim, output_dim, bias=True) 28 | self.B = nn.Linear(input_dim, output_dim, bias=True) 29 | self.C = nn.Linear(input_dim, output_dim, bias=True) 30 | self.D = nn.Linear(input_dim, output_dim, bias=True) 31 | self.E = nn.Linear(input_dim, output_dim, bias=True) 32 | self.bn_node_h = nn.BatchNorm1d(output_dim) 33 | self.bn_node_e = nn.BatchNorm1d(output_dim) 34 | 35 | def forward(self, g, h, e): 36 | 37 | h_in = h # for residual connection 38 | e_in = e # for residual connection 39 | 40 | g.ndata['h'] = h 41 | g.ndata['Ah'] = self.A(h) 42 | g.ndata['Bh'] = self.B(h) 43 | g.ndata['Dh'] = self.D(h) 44 | g.ndata['Eh'] = self.E(h) 45 | g.edata['e'] = e 46 | g.edata['Ce'] = self.C(e) 47 | 48 | g.apply_edges(fn.u_add_v('Dh', 'Eh', 'DEh')) 49 | g.edata['e'] = g.edata['DEh'] + g.edata['Ce'] 50 | g.edata['sigma'] = torch.sigmoid(g.edata['e']) 51 | g.update_all(fn.u_mul_e('Bh', 'sigma', 'm'), fn.sum('m', 'sum_sigma_h')) 52 | g.update_all(fn.copy_e('sigma', 'm'), fn.sum('m', 'sum_sigma')) 53 | g.ndata['h'] = g.ndata['Ah'] + g.ndata['sum_sigma_h'] / (g.ndata['sum_sigma'] + 1e-6) 54 | #g.update_all(self.message_func,self.reduce_func) 55 | h = g.ndata['h'] # result of graph convolution 56 | e = g.edata['e'] # result of graph convolution 57 | 58 | if self.batch_norm: 59 | h = self.bn_node_h(h) # batch normalization 60 | e = self.bn_node_e(e) # batch normalization 61 | 62 | h = F.relu(h) # non-linear activation 63 | e = F.relu(e) # non-linear activation 64 | 65 | if self.residual: 66 | h = h_in + h # residual connection 67 | e = e_in + e # residual connection 68 | 69 | h = F.dropout(h, self.dropout, training=self.training) 70 | e = F.dropout(e, self.dropout, training=self.training) 71 | 72 | return h, e 73 | 74 | def __repr__(self): 75 | return '{}(in_channels={}, out_channels={})'.format(self.__class__.__name__, 76 | self.in_channels, 77 | self.out_channels) 78 | 79 | 80 | ############################################################## 81 | # 82 | # Additional layers for edge feature/representation analysis 83 | # 84 | ############################################################## 85 | 86 | 87 | class GatedGCNLayerEdgeFeatOnly(nn.Module): 88 | """ 89 | Param: [] 90 | """ 91 | def __init__(self, input_dim, output_dim, dropout, batch_norm, residual=False): 92 | super().__init__() 93 | self.in_channels = input_dim 94 | self.out_channels = output_dim 95 | self.dropout = dropout 96 | self.batch_norm = batch_norm 97 | self.residual = residual 98 | 99 | if input_dim != output_dim: 100 | self.residual = False 101 | 102 | self.A = nn.Linear(input_dim, output_dim, bias=True) 103 | self.B = nn.Linear(input_dim, output_dim, bias=True) 104 | self.D = nn.Linear(input_dim, output_dim, bias=True) 105 | self.E = nn.Linear(input_dim, output_dim, bias=True) 106 | self.bn_node_h = nn.BatchNorm1d(output_dim) 107 | 108 | 109 | def forward(self, g, h, e): 110 | 111 | h_in = h # for residual connection 112 | 113 | g.ndata['h'] = h 114 | g.ndata['Ah'] = self.A(h) 115 | g.ndata['Bh'] = self.B(h) 116 | g.ndata['Dh'] = self.D(h) 117 | g.ndata['Eh'] = self.E(h) 118 | #g.update_all(self.message_func,self.reduce_func) 119 | g.apply_edges(fn.u_add_v('Dh', 'Eh', 'e')) 120 | g.edata['sigma'] = torch.sigmoid(g.edata['e']) 121 | g.update_all(fn.u_mul_e('Bh', 'sigma', 'm'), fn.sum('m', 'sum_sigma_h')) 122 | g.update_all(fn.copy_e('sigma', 'm'), fn.sum('m', 'sum_sigma')) 123 | g.ndata['h'] = g.ndata['Ah'] + g.ndata['sum_sigma_h'] / (g.ndata['sum_sigma'] + 1e-6) 124 | h = g.ndata['h'] # result of graph convolution 125 | 126 | if self.batch_norm: 127 | h = self.bn_node_h(h) # batch normalization 128 | 129 | h = F.relu(h) # non-linear activation 130 | 131 | if self.residual: 132 | h = h_in + h # residual connection 133 | 134 | h = F.dropout(h, self.dropout, training=self.training) 135 | 136 | return h, e 137 | 138 | def __repr__(self): 139 | return '{}(in_channels={}, out_channels={})'.format(self.__class__.__name__, 140 | self.in_channels, 141 | self.out_channels) 142 | 143 | 144 | ############################################################## 145 | 146 | 147 | class GatedGCNLayerIsotropic(nn.Module): 148 | """ 149 | Param: [] 150 | """ 151 | def __init__(self, input_dim, output_dim, dropout, batch_norm, residual=False): 152 | super().__init__() 153 | self.in_channels = input_dim 154 | self.out_channels = output_dim 155 | self.dropout = dropout 156 | self.batch_norm = batch_norm 157 | self.residual = residual 158 | 159 | if input_dim != output_dim: 160 | self.residual = False 161 | 162 | self.A = nn.Linear(input_dim, output_dim, bias=True) 163 | self.B = nn.Linear(input_dim, output_dim, bias=True) 164 | self.bn_node_h = nn.BatchNorm1d(output_dim) 165 | 166 | 167 | def forward(self, g, h, e): 168 | 169 | h_in = h # for residual connection 170 | 171 | g.ndata['h'] = h 172 | g.ndata['Ah'] = self.A(h) 173 | g.ndata['Bh'] = self.B(h) 174 | #g.update_all(self.message_func,self.reduce_func) 175 | g.update_all(fn.copy_u('Bh', 'm'), fn.sum('m', 'sum_h')) 176 | g.ndata['h'] = g.ndata['Ah'] + g.ndata['sum_h'] 177 | h = g.ndata['h'] # result of graph convolution 178 | 179 | if self.batch_norm: 180 | h = self.bn_node_h(h) # batch normalization 181 | 182 | h = F.relu(h) # non-linear activation 183 | 184 | if self.residual: 185 | h = h_in + h # residual connection 186 | 187 | h = F.dropout(h, self.dropout, training=self.training) 188 | 189 | return h, e 190 | 191 | def __repr__(self): 192 | return '{}(in_channels={}, out_channels={})'.format(self.__class__.__name__, 193 | self.in_channels, 194 | self.out_channels) 195 | 196 | -------------------------------------------------------------------------------- /layers/gcn_layer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | import dgl 6 | import dgl.function as fn 7 | from dgl.nn.pytorch import GraphConv 8 | 9 | """ 10 | GCN: Graph Convolutional Networks 11 | Thomas N. Kipf, Max Welling, Semi-Supervised Classification with Graph Convolutional Networks (ICLR 2017) 12 | http://arxiv.org/abs/1609.02907 13 | """ 14 | 15 | # Sends a message of node feature h 16 | # Equivalent to => return {'m': edges.src['h']} 17 | msg = fn.copy_src(src='h', out='m') 18 | reduce = fn.mean('m', 'h') 19 | 20 | class NodeApplyModule(nn.Module): 21 | # Update node feature h_v with (Wh_v+b) 22 | def __init__(self, in_dim, out_dim): 23 | super().__init__() 24 | self.linear = nn.Linear(in_dim, out_dim) 25 | 26 | def forward(self, node): 27 | h = self.linear(node.data['h']) 28 | return {'h': h} 29 | 30 | class GCNLayer(nn.Module): 31 | """ 32 | Param: [in_dim, out_dim] 33 | """ 34 | def __init__(self, in_dim, out_dim, activation, dropout, batch_norm, residual=False, dgl_builtin=False): 35 | super().__init__() 36 | self.in_channels = in_dim 37 | self.out_channels = out_dim 38 | self.batch_norm = batch_norm 39 | self.residual = residual 40 | self.dgl_builtin = dgl_builtin 41 | 42 | if in_dim != out_dim: 43 | self.residual = False 44 | 45 | self.batchnorm_h = nn.BatchNorm1d(out_dim) 46 | self.activation = activation 47 | self.dropout = nn.Dropout(dropout) 48 | if self.dgl_builtin == False: 49 | self.apply_mod = NodeApplyModule(in_dim, out_dim) 50 | elif dgl.__version__ < "0.5": 51 | self.conv = GraphConv(in_dim, out_dim) 52 | else: 53 | self.conv = GraphConv(in_dim, out_dim, allow_zero_in_degree=True) 54 | 55 | 56 | def forward(self, g, feature): 57 | h_in = feature # to be used for residual connection 58 | 59 | if self.dgl_builtin == False: 60 | g.ndata['h'] = feature 61 | g.update_all(msg, reduce) 62 | g.apply_nodes(func=self.apply_mod) 63 | h = g.ndata['h'] # result of graph convolution 64 | else: 65 | h = self.conv(g, feature) 66 | 67 | if self.batch_norm: 68 | h = self.batchnorm_h(h) # batch normalization 69 | 70 | if self.activation: 71 | h = self.activation(h) 72 | 73 | if self.residual: 74 | h = h_in + h # residual connection 75 | 76 | h = self.dropout(h) 77 | return h 78 | 79 | def __repr__(self): 80 | return '{}(in_channels={}, out_channels={}, residual={})'.format(self.__class__.__name__, 81 | self.in_channels, 82 | self.out_channels, self.residual) -------------------------------------------------------------------------------- /layers/gin_layer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import dgl.function as fn 5 | 6 | """ 7 | GIN: Graph Isomorphism Networks 8 | HOW POWERFUL ARE GRAPH NEURAL NETWORKS? (Keyulu Xu, Weihua Hu, Jure Leskovec and Stefanie Jegelka, ICLR 2019) 9 | https://arxiv.org/pdf/1810.00826.pdf 10 | """ 11 | 12 | class GINLayer(nn.Module): 13 | """ 14 | [!] code adapted from dgl implementation of GINConv 15 | 16 | Parameters 17 | ---------- 18 | apply_func : callable activation function/layer or None 19 | If not None, apply this function to the updated node feature, 20 | the :math:`f_\Theta` in the formula. 21 | aggr_type : 22 | Aggregator type to use (``sum``, ``max`` or ``mean``). 23 | out_dim : 24 | Rquired for batch norm layer; should match out_dim of apply_func if not None. 25 | dropout : 26 | Required for dropout of output features. 27 | batch_norm : 28 | boolean flag for batch_norm layer. 29 | residual : 30 | boolean flag for using residual connection. 31 | init_eps : optional 32 | Initial :math:`\epsilon` value, default: ``0``. 33 | learn_eps : bool, optional 34 | If True, :math:`\epsilon` will be a learnable parameter. 35 | 36 | """ 37 | def __init__(self, apply_func, aggr_type, dropout, batch_norm, residual=False, init_eps=0, learn_eps=False): 38 | super().__init__() 39 | self.apply_func = apply_func 40 | 41 | if aggr_type == 'sum': 42 | self._reducer = fn.sum 43 | elif aggr_type == 'max': 44 | self._reducer = fn.max 45 | elif aggr_type == 'mean': 46 | self._reducer = fn.mean 47 | else: 48 | raise KeyError('Aggregator type {} not recognized.'.format(aggr_type)) 49 | 50 | self.batch_norm = batch_norm 51 | self.residual = residual 52 | self.dropout = dropout 53 | 54 | in_dim = apply_func.mlp.input_dim 55 | out_dim = apply_func.mlp.output_dim 56 | 57 | if in_dim != out_dim: 58 | self.residual = False 59 | 60 | # to specify whether eps is trainable or not. 61 | if learn_eps: 62 | self.eps = torch.nn.Parameter(torch.FloatTensor([init_eps])) 63 | else: 64 | self.register_buffer('eps', torch.FloatTensor([init_eps])) 65 | 66 | self.bn_node_h = nn.BatchNorm1d(out_dim) 67 | 68 | def forward(self, g, h): 69 | h_in = h # for residual connection 70 | 71 | g = g.local_var() 72 | g.ndata['h'] = h 73 | g.update_all(fn.copy_u('h', 'm'), self._reducer('m', 'neigh')) 74 | h = (1 + self.eps) * h + g.ndata['neigh'] 75 | if self.apply_func is not None: 76 | h = self.apply_func(h) 77 | 78 | if self.batch_norm: 79 | h = self.bn_node_h(h) # batch normalization 80 | 81 | h = F.relu(h) # non-linear activation 82 | 83 | if self.residual: 84 | h = h_in + h # residual connection 85 | 86 | h = F.dropout(h, self.dropout, training=self.training) 87 | 88 | return h 89 | 90 | 91 | class ApplyNodeFunc(nn.Module): 92 | """ 93 | This class is used in class GINNet 94 | Update the node feature hv with MLP 95 | """ 96 | def __init__(self, mlp): 97 | super().__init__() 98 | self.mlp = mlp 99 | 100 | def forward(self, h): 101 | h = self.mlp(h) 102 | return h 103 | 104 | 105 | class MLP(nn.Module): 106 | """MLP with linear output""" 107 | def __init__(self, num_layers, input_dim, hidden_dim, output_dim): 108 | 109 | super().__init__() 110 | self.linear_or_not = True # default is linear model 111 | self.num_layers = num_layers 112 | self.output_dim = output_dim 113 | self.input_dim = input_dim 114 | 115 | if num_layers < 1: 116 | raise ValueError("number of layers should be positive!") 117 | elif num_layers == 1: 118 | # Linear model 119 | self.linear = nn.Linear(input_dim, output_dim) 120 | else: 121 | # Multi-layer model 122 | self.linear_or_not = False 123 | self.linears = torch.nn.ModuleList() 124 | self.batch_norms = torch.nn.ModuleList() 125 | 126 | self.linears.append(nn.Linear(input_dim, hidden_dim)) 127 | for layer in range(num_layers - 2): 128 | self.linears.append(nn.Linear(hidden_dim, hidden_dim)) 129 | self.linears.append(nn.Linear(hidden_dim, output_dim)) 130 | 131 | for layer in range(num_layers - 1): 132 | self.batch_norms.append(nn.BatchNorm1d((hidden_dim))) 133 | 134 | def forward(self, x): 135 | if self.linear_or_not: 136 | # If linear model 137 | return self.linear(x) 138 | else: 139 | # If MLP 140 | h = x 141 | for i in range(self.num_layers - 1): 142 | h = F.relu(self.batch_norms[i](self.linears[i](h))) 143 | return self.linears[-1](h) 144 | -------------------------------------------------------------------------------- /layers/gmm_layer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn import init 5 | import dgl.function as fn 6 | 7 | """ 8 | GMM: Gaussian Mixture Model Convolution layer 9 | Geometric Deep Learning on Graphs and Manifolds using Mixture Model CNNs (Federico Monti et al., CVPR 2017) 10 | https://arxiv.org/pdf/1611.08402.pdf 11 | """ 12 | 13 | class GMMLayer(nn.Module): 14 | """ 15 | [!] code adapted from dgl implementation of GMMConv 16 | 17 | Parameters 18 | ---------- 19 | in_dim : 20 | Number of input features. 21 | out_dim : 22 | Number of output features. 23 | dim : 24 | Dimensionality of pseudo-coordinte. 25 | kernel : 26 | Number of kernels :math:`K`. 27 | aggr_type : 28 | Aggregator type (``sum``, ``mean``, ``max``). 29 | dropout : 30 | Required for dropout of output features. 31 | batch_norm : 32 | boolean flag for batch_norm layer. 33 | residual : 34 | If True, use residual connection inside this layer. Default: ``False``. 35 | bias : 36 | If True, adds a learnable bias to the output. Default: ``True``. 37 | 38 | """ 39 | def __init__(self, in_dim, out_dim, dim, kernel, aggr_type, dropout, 40 | batch_norm, residual=False, bias=True): 41 | super().__init__() 42 | 43 | self.in_dim = in_dim 44 | self.out_dim = out_dim 45 | self.dim = dim 46 | self.kernel = kernel 47 | self.batch_norm = batch_norm 48 | self.residual = residual 49 | self.dropout = dropout 50 | 51 | if aggr_type == 'sum': 52 | self._reducer = fn.sum 53 | elif aggr_type == 'mean': 54 | self._reducer = fn.mean 55 | elif aggr_type == 'max': 56 | self._reducer = fn.max 57 | else: 58 | raise KeyError("Aggregator type {} not recognized.".format(aggr_type)) 59 | 60 | self.mu = nn.Parameter(torch.Tensor(kernel, dim)) 61 | self.inv_sigma = nn.Parameter(torch.Tensor(kernel, dim)) 62 | self.fc = nn.Linear(in_dim, kernel * out_dim, bias=False) 63 | 64 | self.bn_node_h = nn.BatchNorm1d(out_dim) 65 | 66 | if in_dim != out_dim: 67 | self.residual = False 68 | 69 | if bias: 70 | self.bias = nn.Parameter(torch.Tensor(out_dim)) 71 | else: 72 | self.register_buffer('bias', None) 73 | self.reset_parameters() 74 | 75 | def reset_parameters(self): 76 | """Reinitialize learnable parameters.""" 77 | gain = init.calculate_gain('relu') 78 | init.xavier_normal_(self.fc.weight, gain=gain) 79 | init.normal_(self.mu.data, 0, 0.1) 80 | init.constant_(self.inv_sigma.data, 1) 81 | if self.bias is not None: 82 | init.zeros_(self.bias.data) 83 | 84 | def forward(self, g, h, pseudo): 85 | h_in = h # for residual connection 86 | 87 | g = g.local_var() 88 | g.ndata['h'] = self.fc(h).view(-1, self.kernel, self.out_dim) 89 | E = g.number_of_edges() 90 | 91 | # compute gaussian weight 92 | gaussian = -0.5 * ((pseudo.view(E, 1, self.dim) - 93 | self.mu.view(1, self.kernel, self.dim)) ** 2) 94 | gaussian = gaussian * (self.inv_sigma.view(1, self.kernel, self.dim) ** 2) 95 | gaussian = torch.exp(gaussian.sum(dim=-1, keepdim=True)) # (E, K, 1) 96 | g.edata['w'] = gaussian 97 | g.update_all(fn.u_mul_e('h', 'w', 'm'), self._reducer('m', 'h')) 98 | h = g.ndata['h'].sum(1) 99 | 100 | if self.batch_norm: 101 | h = self.bn_node_h(h) # batch normalization 102 | 103 | h = F.relu(h) # non-linear activation 104 | 105 | if self.residual: 106 | h = h_in + h # residual connection 107 | 108 | if self.bias is not None: 109 | h = h + self.bias 110 | 111 | h = F.dropout(h, self.dropout, training=self.training) 112 | 113 | return h 114 | -------------------------------------------------------------------------------- /layers/mlp_readout_layer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | """ 6 | MLP Layer used after graph vector representation 7 | """ 8 | 9 | class MLPReadout(nn.Module): 10 | 11 | def __init__(self, input_dim, output_dim, L=2): #L=nb_hidden_layers 12 | super().__init__() 13 | list_FC_layers = [ nn.Linear( input_dim//2**l , input_dim//2**(l+1) , bias=True ) for l in range(L) ] 14 | list_FC_layers.append(nn.Linear( input_dim//2**L , output_dim , bias=True )) 15 | self.FC_layers = nn.ModuleList(list_FC_layers) 16 | self.L = L 17 | 18 | def forward(self, x): 19 | y = x 20 | for l in range(self.L): 21 | y = self.FC_layers[l](y) 22 | y = F.relu(y) 23 | y = self.FC_layers[self.L](y) 24 | return y -------------------------------------------------------------------------------- /layers/ring_gnn_equiv_layer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | """ 6 | Ring-GNN equi 2 to 2 layer file 7 | On the equivalence between graph isomorphism testing and function approximation with GNNs (Chen et al, 2019) 8 | https://arxiv.org/pdf/1905.12560v1.pdf 9 | 10 | CODE ADPATED FROM https://github.com/leichen2018/Ring-GNN/ 11 | """ 12 | 13 | class RingGNNEquivLayer(nn.Module): 14 | def __init__(self, device, input_dim, output_dim, layer_norm, residual, dropout, 15 | normalization='inf', normalization_val=1.0, radius=2, k2_init = 0.1): 16 | super().__init__() 17 | self.device = device 18 | basis_dimension = 15 19 | self.radius = radius 20 | self.layer_norm = layer_norm 21 | self.residual = residual 22 | self.dropout = dropout 23 | 24 | coeffs_values = lambda i, j, k: torch.randn([i, j, k]) * torch.sqrt(2. / (i + j).float()) 25 | self.diag_bias_list = nn.ParameterList([]) 26 | 27 | for i in range(radius): 28 | for j in range(i+1): 29 | self.diag_bias_list.append(nn.Parameter(torch.zeros(1, output_dim, 1, 1))) 30 | 31 | self.all_bias = nn.Parameter(torch.zeros(1, output_dim, 1, 1)) 32 | self.coeffs_list = nn.ParameterList([]) 33 | 34 | for i in range(radius): 35 | for j in range(i+1): 36 | self.coeffs_list.append(nn.Parameter(coeffs_values(input_dim, output_dim, basis_dimension))) 37 | 38 | self.switch = nn.ParameterList([nn.Parameter(torch.FloatTensor([1])), nn.Parameter(torch.FloatTensor([k2_init]))]) 39 | self.output_dim = output_dim 40 | 41 | self.normalization = normalization 42 | self.normalization_val = normalization_val 43 | 44 | if self.layer_norm: 45 | self.ln_x = LayerNorm(output_dim.item()) 46 | 47 | if self.residual: 48 | self.res_x = nn.Linear(input_dim.item(), output_dim.item()) 49 | 50 | def forward(self, inputs): 51 | m = inputs.size()[3] 52 | 53 | ops_out = ops_2_to_2(inputs, m, normalization=self.normalization) 54 | ops_out = torch.stack(ops_out, dim = 2) 55 | 56 | 57 | output_list = [] 58 | 59 | for i in range(self.radius): 60 | for j in range(i+1): 61 | output_i = torch.einsum('dsb,ndbij->nsij', self.coeffs_list[i*(i+1)//2 + j], ops_out) 62 | 63 | mat_diag_bias = torch.eye(inputs.size()[3]).unsqueeze(0).unsqueeze(0).to(self.device) * self.diag_bias_list[i*(i+1)//2 + j] 64 | # mat_diag_bias = torch.eye(inputs.size()[3]).to('cuda:0').unsqueeze(0).unsqueeze(0) * self.diag_bias_list[i*(i+1)//2 + j] 65 | if j == 0: 66 | output = output_i + mat_diag_bias 67 | else: 68 | output = torch.einsum('abcd,abde->abce', output_i, output) 69 | output_list.append(output) 70 | 71 | output = 0 72 | for i in range(self.radius): 73 | output += output_list[i] * self.switch[i] 74 | 75 | output = output + self.all_bias 76 | 77 | if self.layer_norm: 78 | # Now, changing shapes from [1xdxnxn] to [nxnxd] for BN 79 | output = output.permute(3,2,1,0).squeeze() 80 | 81 | # output = self.bn_x(output.reshape(m*m, self.output_dim.item())) # batch normalization 82 | output = self.ln_x(output) # layer normalization 83 | 84 | # Returning output back to original shape 85 | output = output.reshape(m, m, self.output_dim.item()) 86 | output = output.permute(2,1,0).unsqueeze(0) 87 | 88 | output = F.relu(output) # non-linear activation 89 | 90 | if self.residual: 91 | # Now, changing shapes from [1xdxnxn] to [nxnxd] for Linear() layer 92 | inputs, output = inputs.permute(3,2,1,0).squeeze(), output.permute(3,2,1,0).squeeze() 93 | 94 | residual_ = self.res_x(inputs) 95 | output = residual_ + output # residual connection 96 | 97 | # Returning output back to original shape 98 | output = output.permute(2,1,0).unsqueeze(0) 99 | 100 | output = F.dropout(output, self.dropout, training=self.training) 101 | 102 | return output 103 | 104 | 105 | def ops_2_to_2(inputs, dim, normalization='inf', normalization_val=1.0): # N x D x m x m 106 | # input: N x D x m x m 107 | diag_part = torch.diagonal(inputs, dim1 = 2, dim2 = 3) # N x D x m 108 | sum_diag_part = torch.sum(diag_part, dim=2, keepdim = True) # N x D x 1 109 | sum_of_rows = torch.sum(inputs, dim=3) # N x D x m 110 | sum_of_cols = torch.sum(inputs, dim=2) # N x D x m 111 | sum_all = torch.sum(sum_of_rows, dim=2) # N x D 112 | 113 | # op1 - (1234) - extract diag 114 | op1 = torch.diag_embed(diag_part) # N x D x m x m 115 | 116 | # op2 - (1234) + (12)(34) - place sum of diag on diag 117 | op2 = torch.diag_embed(sum_diag_part.repeat(1, 1, dim)) 118 | 119 | # op3 - (1234) + (123)(4) - place sum of row i on diag ii 120 | op3 = torch.diag_embed(sum_of_rows) 121 | 122 | # op4 - (1234) + (124)(3) - place sum of col i on diag ii 123 | op4 = torch.diag_embed(sum_of_cols) 124 | 125 | # op5 - (1234) + (124)(3) + (123)(4) + (12)(34) + (12)(3)(4) - place sum of all entries on diag 126 | op5 = torch.diag_embed(sum_all.unsqueeze(2).repeat(1, 1, dim)) 127 | 128 | # op6 - (14)(23) + (13)(24) + (24)(1)(3) + (124)(3) + (1234) - place sum of col i on row i 129 | op6 = sum_of_cols.unsqueeze(3).repeat(1, 1, 1, dim) 130 | 131 | # op7 - (14)(23) + (23)(1)(4) + (234)(1) + (123)(4) + (1234) - place sum of row i on row i 132 | op7 = sum_of_rows.unsqueeze(3).repeat(1, 1, 1, dim) 133 | 134 | # op8 - (14)(2)(3) + (134)(2) + (14)(23) + (124)(3) + (1234) - place sum of col i on col i 135 | op8 = sum_of_cols.unsqueeze(2).repeat(1, 1, dim, 1) 136 | 137 | # op9 - (13)(24) + (13)(2)(4) + (134)(2) + (123)(4) + (1234) - place sum of row i on col i 138 | op9 = sum_of_rows.unsqueeze(2).repeat(1, 1, dim, 1) 139 | 140 | # op10 - (1234) + (14)(23) - identity 141 | op10 = inputs 142 | 143 | # op11 - (1234) + (13)(24) - transpose 144 | op11 = torch.transpose(inputs, -2, -1) 145 | 146 | # op12 - (1234) + (234)(1) - place ii element in row i 147 | op12 = diag_part.unsqueeze(3).repeat(1, 1, 1, dim) 148 | 149 | # op13 - (1234) + (134)(2) - place ii element in col i 150 | op13 = diag_part.unsqueeze(2).repeat(1, 1, dim, 1) 151 | 152 | # op14 - (34)(1)(2) + (234)(1) + (134)(2) + (1234) + (12)(34) - place sum of diag in all entries 153 | op14 = sum_diag_part.unsqueeze(3).repeat(1, 1, dim, dim) 154 | 155 | # op15 - sum of all ops - place sum of all entries in all entries 156 | op15 = sum_all.unsqueeze(2).unsqueeze(3).repeat(1, 1, dim, dim) 157 | 158 | #A_2 = torch.einsum('abcd,abde->abce', inputs, inputs) 159 | #A_4 = torch.einsum('abcd,abde->abce', A_2, A_2) 160 | #op16 = torch.where(A_4>1, torch.ones(A_4.size()), A_4) 161 | 162 | if normalization is not None: 163 | float_dim = float(dim) 164 | if normalization is 'inf': 165 | op2 = torch.div(op2, float_dim) 166 | op3 = torch.div(op3, float_dim) 167 | op4 = torch.div(op4, float_dim) 168 | op5 = torch.div(op5, float_dim**2) 169 | op6 = torch.div(op6, float_dim) 170 | op7 = torch.div(op7, float_dim) 171 | op8 = torch.div(op8, float_dim) 172 | op9 = torch.div(op9, float_dim) 173 | op14 = torch.div(op14, float_dim) 174 | op15 = torch.div(op15, float_dim**2) 175 | 176 | #return [op1, op2, op3, op4, op5, op6, op7, op8, op9, op10, op11, op12, op13, op14, op15, op16] 177 | ''' 178 | l = [op1, op2, op3, op4, op5, op6, op7, op8, op9, op10, op11, op12, op13, op14, op15] 179 | for i, ls in enumerate(l): 180 | print(i+1) 181 | print(torch.sum(ls)) 182 | print("$%^&*(*&^%$#$%^&*(*&^%$%^&*(*&^%$%^&*(") 183 | ''' 184 | return [op1, op2, op3, op4, op5, op6, op7, op8, op9, op10, op11, op12, op13, op14, op15] 185 | 186 | 187 | class LayerNorm(nn.Module): 188 | def __init__(self, d): 189 | super().__init__() 190 | self.a = nn.Parameter(torch.ones(d).unsqueeze(0).unsqueeze(0)) # shape is 1 x 1 x d 191 | self.b = nn.Parameter(torch.zeros(d).unsqueeze(0).unsqueeze(0)) # shape is 1 x 1 x d 192 | 193 | def forward(self, x): 194 | # x tensor of the shape n x n x d 195 | mean = x.mean(dim=(0,1), keepdim=True) 196 | var = x.var(dim=(0,1), keepdim=True, unbiased=False) 197 | x = self.a * (x - mean) / torch.sqrt(var + 1e-6) + self.b # shape is n x n x d 198 | return x 199 | 200 | -------------------------------------------------------------------------------- /layers/three_wl_gnn_layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | """ 6 | Layers used for 7 | 3WLGNN 8 | Provably Powerful Graph Networks (Maron et al., 2019) 9 | https://papers.nips.cc/paper/8488-provably-powerful-graph-networks.pdf 10 | 11 | CODE adapted from https://github.com/hadarser/ProvablyPowerfulGraphNetworks_torch/ 12 | """ 13 | 14 | class RegularBlock(nn.Module): 15 | """ 16 | Imputs: N x input_depth x m x m 17 | Take the input through 2 parallel MLP routes, multiply the result, and add a skip-connection at the end. 18 | At the skip-connection, reduce the dimension back to output_depth 19 | """ 20 | def __init__(self, depth_of_mlp, in_features, out_features, residual=False): 21 | super().__init__() 22 | 23 | self.residual = residual 24 | 25 | self.mlp1 = MlpBlock(in_features, out_features, depth_of_mlp) 26 | self.mlp2 = MlpBlock(in_features, out_features, depth_of_mlp) 27 | 28 | self.skip = SkipConnection(in_features+out_features, out_features) 29 | 30 | if self.residual: 31 | self.res_x = nn.Linear(in_features, out_features) 32 | 33 | def forward(self, inputs): 34 | mlp1 = self.mlp1(inputs) 35 | mlp2 = self.mlp2(inputs) 36 | 37 | mult = torch.matmul(mlp1, mlp2) 38 | 39 | out = self.skip(in1=inputs, in2=mult) 40 | 41 | if self.residual: 42 | # Now, changing shapes from [1xdxnxn] to [nxnxd] for Linear() layer 43 | inputs, out = inputs.permute(3,2,1,0).squeeze(), out.permute(3,2,1,0).squeeze() 44 | 45 | residual_ = self.res_x(inputs) 46 | out = residual_ + out # residual connection 47 | 48 | # Returning output back to original shape 49 | out = out.permute(2,1,0).unsqueeze(0) 50 | 51 | return out 52 | 53 | 54 | class MlpBlock(nn.Module): 55 | """ 56 | Block of MLP layers with activation function after each (1x1 conv layers). 57 | """ 58 | def __init__(self, in_features, out_features, depth_of_mlp, activation_fn=nn.functional.relu): 59 | super().__init__() 60 | self.activation = activation_fn 61 | self.convs = nn.ModuleList() 62 | for i in range(depth_of_mlp): 63 | self.convs.append(nn.Conv2d(in_features, out_features, kernel_size=1, padding=0, bias=True)) 64 | _init_weights(self.convs[-1]) 65 | in_features = out_features 66 | 67 | def forward(self, inputs): 68 | out = inputs 69 | for conv_layer in self.convs: 70 | out = self.activation(conv_layer(out)) 71 | 72 | return out 73 | 74 | 75 | class SkipConnection(nn.Module): 76 | """ 77 | Connects the two given inputs with concatenation 78 | :param in1: earlier input tensor of shape N x d1 x m x m 79 | :param in2: later input tensor of shape N x d2 x m x m 80 | :param in_features: d1+d2 81 | :param out_features: output num of features 82 | :return: Tensor of shape N x output_depth x m x m 83 | """ 84 | def __init__(self, in_features, out_features): 85 | super().__init__() 86 | self.conv = nn.Conv2d(in_features, out_features, kernel_size=1, padding=0, bias=True) 87 | _init_weights(self.conv) 88 | 89 | def forward(self, in1, in2): 90 | # in1: N x d1 x m x m 91 | # in2: N x d2 x m x m 92 | out = torch.cat((in1, in2), dim=1) 93 | out = self.conv(out) 94 | return out 95 | 96 | 97 | class FullyConnected(nn.Module): 98 | def __init__(self, in_features, out_features, activation_fn=nn.functional.relu): 99 | super().__init__() 100 | 101 | self.fc = nn.Linear(in_features, out_features) 102 | _init_weights(self.fc) 103 | 104 | self.activation = activation_fn 105 | 106 | def forward(self, input): 107 | out = self.fc(input) 108 | if self.activation is not None: 109 | out = self.activation(out) 110 | 111 | return out 112 | 113 | 114 | def diag_offdiag_maxpool(input): 115 | N = input.shape[-1] 116 | 117 | max_diag = torch.max(torch.diagonal(input, dim1=-2, dim2=-1), dim=2)[0] # BxS 118 | 119 | # with torch.no_grad(): 120 | max_val = torch.max(max_diag) 121 | min_val = torch.max(-1 * input) 122 | val = torch.abs(torch.add(max_val, min_val)) 123 | 124 | min_mat = torch.mul(val, torch.eye(N, device=input.device)).view(1, 1, N, N) 125 | 126 | max_offdiag = torch.max(torch.max(input - min_mat, dim=3)[0], dim=2)[0] # BxS 127 | 128 | return torch.cat((max_diag, max_offdiag), dim=1) # output Bx2S 129 | 130 | def _init_weights(layer): 131 | """ 132 | Init weights of the layer 133 | :param layer: 134 | :return: 135 | """ 136 | nn.init.xavier_uniform_(layer.weight) 137 | # nn.init.xavier_normal_(layer.weight) 138 | if layer.bias is not None: 139 | nn.init.zeros_(layer.bias) 140 | 141 | 142 | class LayerNorm(nn.Module): 143 | def __init__(self, d): 144 | super().__init__() 145 | self.a = nn.Parameter(torch.ones(d).unsqueeze(0).unsqueeze(0)) # shape is 1 x 1 x d 146 | self.b = nn.Parameter(torch.zeros(d).unsqueeze(0).unsqueeze(0)) # shape is 1 x 1 x d 147 | 148 | def forward(self, x): 149 | # x tensor of the shape n x n x d 150 | mean = x.mean(dim=(0,1), keepdim=True) 151 | var = x.var(dim=(0,1), keepdim=True, unbiased=False) 152 | x = self.a * (x - mean) / torch.sqrt(var + 1e-6) + self.b # shape is n x n x d 153 | return x 154 | 155 | -------------------------------------------------------------------------------- /metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from sklearn.metrics import confusion_matrix 6 | from sklearn.metrics import f1_score 7 | import numpy as np 8 | 9 | 10 | def MAE(scores, targets): 11 | MAE = F.l1_loss(scores, targets) 12 | MAE = MAE.detach().item() 13 | return MAE 14 | 15 | 16 | def accuracy_TU(scores, targets): 17 | scores = scores.detach().argmax(dim=1) 18 | acc = (scores==targets).float().sum().item() 19 | return acc 20 | 21 | 22 | def accuracy_MNIST_CIFAR(scores, targets): 23 | scores = scores.detach().argmax(dim=1) 24 | acc = (scores==targets).float().sum().item() 25 | return acc 26 | 27 | def accuracy_CITATION_GRAPH(scores, targets): 28 | scores = scores.detach().argmax(dim=1) 29 | acc = (scores==targets).float().sum().item() 30 | acc = acc / len(targets) 31 | return acc 32 | 33 | 34 | def accuracy_SBM(scores, targets): 35 | S = targets.cpu().numpy() 36 | C = np.argmax( torch.nn.Softmax(dim=1)(scores).cpu().detach().numpy() , axis=1 ) 37 | CM = confusion_matrix(S,C).astype(np.float32) 38 | nb_classes = CM.shape[0] 39 | targets = targets.cpu().detach().numpy() 40 | nb_non_empty_classes = 0 41 | pr_classes = np.zeros(nb_classes) 42 | for r in range(nb_classes): 43 | cluster = np.where(targets==r)[0] 44 | if cluster.shape[0] != 0: 45 | pr_classes[r] = CM[r,r]/ float(cluster.shape[0]) 46 | if CM[r,r]>0: 47 | nb_non_empty_classes += 1 48 | else: 49 | pr_classes[r] = 0.0 50 | acc = 100.* np.sum(pr_classes)/ float(nb_classes) 51 | return acc 52 | 53 | 54 | def binary_f1_score(scores, targets): 55 | """Computes the F1 score using scikit-learn for binary class labels. 56 | 57 | Returns the F1 score for the positive class, i.e. labelled '1'. 58 | """ 59 | y_true = targets.cpu().numpy() 60 | y_pred = scores.argmax(dim=1).cpu().numpy() 61 | return f1_score(y_true, y_pred, average='binary') 62 | 63 | 64 | def accuracy_VOC(scores, targets): 65 | scores = scores.detach().argmax(dim=1).cpu() 66 | targets = targets.cpu().detach().numpy() 67 | acc = f1_score(scores, targets, average='weighted') 68 | return acc 69 | -------------------------------------------------------------------------------- /nets/gat_net.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | import dgl 6 | 7 | """ 8 | GAT: Graph Attention Network 9 | Graph Attention Networks (Veličković et al., ICLR 2018) 10 | https://arxiv.org/abs/1710.10903 11 | """ 12 | from layers.gat_layer import GATLayer 13 | from layers.mlp_readout_layer import MLPReadout 14 | 15 | class GATNet(nn.Module): 16 | def __init__(self, net_params): 17 | super().__init__() 18 | in_dim = net_params['in_dim'] 19 | hidden_dim = net_params['hidden_dim'] 20 | out_dim = net_params['out_dim'] 21 | n_classes = net_params['n_classes'] 22 | num_heads = net_params['n_heads'] 23 | in_feat_dropout = net_params['in_feat_dropout'] 24 | dropout = net_params['dropout'] 25 | n_layers = net_params['L'] 26 | self.readout = net_params['readout'] 27 | self.batch_norm = net_params['batch_norm'] 28 | self.residual = net_params['residual'] 29 | 30 | self.dropout = dropout 31 | 32 | self.embedding_h = nn.Linear(in_dim, hidden_dim * num_heads) 33 | self.in_feat_dropout = nn.Dropout(in_feat_dropout) 34 | 35 | self.layers = nn.ModuleList([GATLayer(hidden_dim * num_heads, hidden_dim, num_heads, 36 | dropout, self.batch_norm, self.residual) for _ in range(n_layers-1)]) 37 | self.layers.append(GATLayer(hidden_dim * num_heads, out_dim, 1, dropout, self.batch_norm, self.residual)) 38 | self.MLP_layer = MLPReadout(out_dim, n_classes) 39 | 40 | def forward(self, g, h, e): 41 | h = self.embedding_h(h) 42 | h = self.in_feat_dropout(h) 43 | for conv in self.layers: 44 | h = conv(g, h) 45 | g.ndata['h'] = h 46 | 47 | if self.readout == "sum": 48 | hg = dgl.sum_nodes(g, 'h') 49 | elif self.readout == "max": 50 | hg = dgl.max_nodes(g, 'h') 51 | elif self.readout == "mean": 52 | hg = dgl.mean_nodes(g, 'h') 53 | else: 54 | hg = dgl.mean_nodes(g, 'h') # default readout is mean nodes 55 | 56 | return self.MLP_layer(hg) 57 | 58 | def loss(self, pred, label): 59 | criterion = nn.CrossEntropyLoss() 60 | loss = criterion(pred, label) 61 | return loss 62 | 63 | 64 | -------------------------------------------------------------------------------- /nets/gated_gcn_net.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | import dgl 6 | 7 | """ 8 | ResGatedGCN: Residual Gated Graph ConvNets 9 | An Experimental Study of Neural Networks for Variable Graphs (Xavier Bresson and Thomas Laurent, ICLR 2018) 10 | https://arxiv.org/pdf/1711.07553v2.pdf 11 | """ 12 | from layers.gated_gcn_layer import GatedGCNLayer 13 | from layers.mlp_readout_layer import MLPReadout 14 | 15 | class GatedGCNNet(nn.Module): 16 | 17 | def __init__(self, net_params): 18 | super().__init__() 19 | in_dim = net_params['in_dim'] 20 | hidden_dim = net_params['hidden_dim'] 21 | out_dim = net_params['out_dim'] 22 | n_classes = net_params['n_classes'] 23 | dropout = net_params['dropout'] 24 | n_layers = net_params['L'] 25 | self.readout = net_params['readout'] 26 | self.batch_norm = net_params['batch_norm'] 27 | self.residual = net_params['residual'] 28 | 29 | self.embedding_h = nn.Linear(in_dim, hidden_dim) 30 | self.embedding_e = nn.Linear(in_dim, hidden_dim) 31 | self.layers = nn.ModuleList([ GatedGCNLayer(hidden_dim, hidden_dim, dropout, 32 | self.batch_norm, self.residual) for _ in range(n_layers-1) ]) 33 | self.layers.append(GatedGCNLayer(hidden_dim, out_dim, dropout, self.batch_norm, self.residual)) 34 | self.MLP_layer = MLPReadout(out_dim, n_classes) 35 | 36 | def forward(self, g, h, e): 37 | h = self.embedding_h(h) 38 | e = self.embedding_e(e) 39 | 40 | # convnets 41 | for conv in self.layers: 42 | h, e = conv(g, h, e) 43 | g.ndata['h'] = h 44 | 45 | if self.readout == "sum": 46 | hg = dgl.sum_nodes(g, 'h') 47 | elif self.readout == "max": 48 | hg = dgl.max_nodes(g, 'h') 49 | elif self.readout == "mean": 50 | hg = dgl.mean_nodes(g, 'h') 51 | else: 52 | hg = dgl.mean_nodes(g, 'h') # default readout is mean nodes 53 | 54 | return self.MLP_layer(hg) 55 | 56 | def loss(self, pred, label): 57 | criterion = nn.CrossEntropyLoss() 58 | loss = criterion(pred, label) 59 | return loss 60 | 61 | 62 | 63 | -------------------------------------------------------------------------------- /nets/gcn_net.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | import dgl 6 | 7 | """ 8 | GCN: Graph Convolutional Networks 9 | Thomas N. Kipf, Max Welling, Semi-Supervised Classification with Graph Convolutional Networks (ICLR 2017) 10 | http://arxiv.org/abs/1609.02907 11 | """ 12 | from layers.gcn_layer import GCNLayer 13 | from layers.mlp_readout_layer import MLPReadout 14 | 15 | class GCNNet(nn.Module): 16 | def __init__(self, net_params): 17 | super().__init__() 18 | in_dim = net_params['in_dim'] 19 | hidden_dim = net_params['hidden_dim'] 20 | out_dim = net_params['out_dim'] 21 | n_classes = net_params['n_classes'] 22 | in_feat_dropout = net_params['in_feat_dropout'] 23 | dropout = net_params['dropout'] 24 | n_layers = net_params['L'] 25 | self.readout = net_params['readout'] 26 | self.batch_norm = net_params['batch_norm'] 27 | self.residual = net_params['residual'] 28 | 29 | self.embedding_h = nn.Linear(in_dim, hidden_dim) 30 | self.in_feat_dropout = nn.Dropout(in_feat_dropout) 31 | 32 | self.layers = nn.ModuleList([GCNLayer(hidden_dim, hidden_dim, F.relu, dropout, 33 | self.batch_norm, self.residual) for _ in range(n_layers-1)]) 34 | self.layers.append(GCNLayer(hidden_dim, out_dim, F.relu, dropout, self.batch_norm, self.residual)) 35 | self.MLP_layer = MLPReadout(out_dim, n_classes) 36 | 37 | def forward(self, g, h, e): 38 | h = self.embedding_h(h) 39 | h = self.in_feat_dropout(h) 40 | for conv in self.layers: 41 | h = conv(g, h) 42 | g.ndata['h'] = h 43 | 44 | if self.readout == "sum": 45 | hg = dgl.sum_nodes(g, 'h') 46 | elif self.readout == "max": 47 | hg = dgl.max_nodes(g, 'h') 48 | elif self.readout == "mean": 49 | hg = dgl.mean_nodes(g, 'h') 50 | else: 51 | hg = dgl.mean_nodes(g, 'h') # default readout is mean nodes 52 | 53 | return self.MLP_layer(hg) 54 | 55 | def loss(self, pred, label): 56 | criterion = nn.CrossEntropyLoss() 57 | loss = criterion(pred, label) 58 | return loss 59 | 60 | -------------------------------------------------------------------------------- /nets/gin_net.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | import dgl 6 | from dgl.nn.pytorch.glob import SumPooling, AvgPooling, MaxPooling 7 | 8 | """ 9 | GIN: Graph Isomorphism Networks 10 | HOW POWERFUL ARE GRAPH NEURAL NETWORKS? (Keyulu Xu, Weihua Hu, Jure Leskovec and Stefanie Jegelka, ICLR 2019) 11 | https://arxiv.org/pdf/1810.00826.pdf 12 | """ 13 | 14 | from layers.gin_layer import GINLayer, ApplyNodeFunc, MLP 15 | 16 | class GINNet(nn.Module): 17 | 18 | def __init__(self, net_params): 19 | super().__init__() 20 | in_dim = net_params['in_dim'] 21 | hidden_dim = net_params['hidden_dim'] 22 | n_classes = net_params['n_classes'] 23 | dropout = net_params['dropout'] 24 | self.n_layers = net_params['L'] 25 | n_mlp_layers = net_params['n_mlp_GIN'] # GIN 26 | learn_eps = net_params['learn_eps_GIN'] # GIN 27 | neighbor_aggr_type = net_params['neighbor_aggr_GIN'] # GIN 28 | readout = net_params['readout'] # this is graph_pooling_type 29 | batch_norm = net_params['batch_norm'] 30 | residual = net_params['residual'] 31 | 32 | # List of MLPs 33 | self.ginlayers = torch.nn.ModuleList() 34 | 35 | self.embedding_h = nn.Linear(in_dim, hidden_dim) 36 | 37 | for layer in range(self.n_layers): 38 | mlp = MLP(n_mlp_layers, hidden_dim, hidden_dim, hidden_dim) 39 | 40 | self.ginlayers.append(GINLayer(ApplyNodeFunc(mlp), neighbor_aggr_type, 41 | dropout, batch_norm, residual, 0, learn_eps)) 42 | 43 | # Linear function for graph poolings (readout) of output of each layer 44 | # which maps the output of different layers into a prediction score 45 | self.linears_prediction = torch.nn.ModuleList() 46 | 47 | for layer in range(self.n_layers+1): 48 | self.linears_prediction.append(nn.Linear(hidden_dim, n_classes)) 49 | 50 | if readout == 'sum': 51 | self.pool = SumPooling() 52 | elif readout == 'mean': 53 | self.pool = AvgPooling() 54 | elif readout == 'max': 55 | self.pool = MaxPooling() 56 | else: 57 | raise NotImplementedError 58 | 59 | def forward(self, g, h, e): 60 | 61 | h = self.embedding_h(h) 62 | 63 | # list of hidden representation at each layer (including input) 64 | hidden_rep = [h] 65 | 66 | for i in range(self.n_layers): 67 | h = self.ginlayers[i](g, h) 68 | hidden_rep.append(h) 69 | 70 | score_over_layer = 0 71 | 72 | # perform pooling over all nodes in each graph in every layer 73 | for i, h in enumerate(hidden_rep): 74 | pooled_h = self.pool(g, h) 75 | score_over_layer += self.linears_prediction[i](pooled_h) 76 | 77 | return score_over_layer 78 | 79 | def loss(self, pred, label): 80 | criterion = nn.CrossEntropyLoss() 81 | loss = criterion(pred, label) 82 | return loss -------------------------------------------------------------------------------- /nets/graphsage_net.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | import dgl 6 | 7 | """ 8 | GraphSAGE: 9 | William L. Hamilton, Rex Ying, Jure Leskovec, Inductive Representation Learning on Large Graphs (NeurIPS 2017) 10 | https://cs.stanford.edu/people/jure/pubs/graphsage-nips17.pdf 11 | """ 12 | 13 | from layers.graphsage_layer import GraphSageLayer 14 | from layers.mlp_readout_layer import MLPReadout 15 | 16 | class GraphSageNet(nn.Module): 17 | """ 18 | Grahpsage network with multiple GraphSageLayer layers 19 | """ 20 | def __init__(self, net_params): 21 | super().__init__() 22 | in_dim = net_params['in_dim'] 23 | hidden_dim = net_params['hidden_dim'] 24 | out_dim = net_params['out_dim'] 25 | n_classes = net_params['n_classes'] 26 | in_feat_dropout = net_params['in_feat_dropout'] 27 | dropout = net_params['dropout'] 28 | aggregator_type = net_params['sage_aggregator'] 29 | n_layers = net_params['L'] 30 | batch_norm = net_params['batch_norm'] 31 | residual = net_params['residual'] 32 | self.readout = net_params['readout'] 33 | 34 | self.embedding_h = nn.Linear(in_dim, hidden_dim) 35 | self.in_feat_dropout = nn.Dropout(in_feat_dropout) 36 | 37 | self.layers = nn.ModuleList([GraphSageLayer(hidden_dim, hidden_dim, F.relu, 38 | dropout, aggregator_type, batch_norm, residual) for _ in range(n_layers-1)]) 39 | self.layers.append(GraphSageLayer(hidden_dim, out_dim, F.relu, dropout, aggregator_type, batch_norm, residual)) 40 | self.MLP_layer = MLPReadout(out_dim, n_classes) 41 | 42 | def forward(self, g, h, e): 43 | h = self.embedding_h(h) 44 | h = self.in_feat_dropout(h) 45 | for conv in self.layers: 46 | h = conv(g, h) 47 | g.ndata['h'] = h 48 | 49 | if self.readout == "sum": 50 | hg = dgl.sum_nodes(g, 'h') 51 | elif self.readout == "max": 52 | hg = dgl.max_nodes(g, 'h') 53 | elif self.readout == "mean": 54 | hg = dgl.mean_nodes(g, 'h') 55 | else: 56 | hg = dgl.mean_nodes(g, 'h') # default readout is mean nodes 57 | 58 | return self.MLP_layer(hg) 59 | 60 | 61 | def loss(self, pred, label): 62 | criterion = nn.CrossEntropyLoss() 63 | loss = criterion(pred, label) 64 | return loss 65 | -------------------------------------------------------------------------------- /nets/load_net.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utility file to select GraphNN model as 3 | selected by the user 4 | """ 5 | 6 | from nets.gated_gcn_net import GatedGCNNet 7 | from nets.gcn_net import GCNNet 8 | from nets.gat_net import GATNet 9 | from nets.graphsage_net import GraphSageNet 10 | from nets.gin_net import GINNet 11 | from nets.mo_net import MoNet as MoNet_ 12 | from nets.mlp_net import MLPNet 13 | from nets.ring_gnn_net import RingGNNNet 14 | from nets.three_wl_gnn_net import ThreeWLGNNNet 15 | 16 | def GatedGCN(net_params): 17 | return GatedGCNNet(net_params) 18 | 19 | def GCN(net_params): 20 | return GCNNet(net_params) 21 | 22 | def GAT(net_params): 23 | return GATNet(net_params) 24 | 25 | def GraphSage(net_params): 26 | return GraphSageNet(net_params) 27 | 28 | def GIN(net_params): 29 | return GINNet(net_params) 30 | 31 | def MoNet(net_params): 32 | return MoNet_(net_params) 33 | 34 | def MLP(net_params): 35 | return MLPNet(net_params) 36 | 37 | def RingGNN(net_params): 38 | return RingGNNNet(net_params) 39 | 40 | def ThreeWLGNN(net_params): 41 | return ThreeWLGNNNet(net_params) 42 | 43 | 44 | def gnn_model(MODEL_NAME, net_params): 45 | models = { 46 | 'GatedGCN': GatedGCN, 47 | 'GCN': GCN, 48 | 'GAT': GAT, 49 | 'GraphSage': GraphSage, 50 | 'GIN': GIN, 51 | 'MoNet': MoNet_, 52 | 'MLP': MLP, 53 | 'RingGNN': RingGNN, 54 | '3WLGNN': ThreeWLGNN 55 | } 56 | 57 | return models[MODEL_NAME](net_params) -------------------------------------------------------------------------------- /nets/mlp_net.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | import dgl 6 | 7 | from layers.mlp_readout_layer import MLPReadout 8 | 9 | class MLPNet(nn.Module): 10 | def __init__(self, net_params): 11 | super().__init__() 12 | in_dim = net_params['in_dim'] 13 | hidden_dim = net_params['hidden_dim'] 14 | n_classes = net_params['n_classes'] 15 | in_feat_dropout = net_params['in_feat_dropout'] 16 | dropout = net_params['dropout'] 17 | n_layers = net_params['L'] 18 | self.gated = net_params['gated'] 19 | 20 | self.in_feat_dropout = nn.Dropout(in_feat_dropout) 21 | 22 | feat_mlp_modules = [ 23 | nn.Linear(in_dim, hidden_dim, bias=True), 24 | nn.ReLU(), 25 | nn.Dropout(dropout), 26 | ] 27 | for _ in range(n_layers-1): 28 | feat_mlp_modules.append(nn.Linear(hidden_dim, hidden_dim, bias=True)) 29 | feat_mlp_modules.append(nn.ReLU()) 30 | feat_mlp_modules.append(nn.Dropout(dropout)) 31 | self.feat_mlp = nn.Sequential(*feat_mlp_modules) 32 | 33 | if self.gated: 34 | self.gates = nn.Linear(hidden_dim, hidden_dim, bias=True) 35 | 36 | self.readout_mlp = MLPReadout(hidden_dim, n_classes) 37 | 38 | def forward(self, g, h, e): 39 | h = self.in_feat_dropout(h) 40 | h = self.feat_mlp(h) 41 | if self.gated: 42 | h = torch.sigmoid(self.gates(h)) * h 43 | g.ndata['h'] = h 44 | hg = dgl.sum_nodes(g, 'h') 45 | # hg = torch.cat( 46 | # ( 47 | # dgl.sum_nodes(g, 'h'), 48 | # dgl.max_nodes(g, 'h') 49 | # ), 50 | # dim=1 51 | # ) 52 | 53 | else: 54 | g.ndata['h'] = h 55 | hg = dgl.mean_nodes(g, 'h') 56 | 57 | return self.readout_mlp(hg) 58 | 59 | 60 | def loss(self, pred, label): 61 | criterion = nn.CrossEntropyLoss() 62 | loss = criterion(pred, label) 63 | return loss 64 | -------------------------------------------------------------------------------- /nets/mo_net.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | import dgl 6 | 7 | import numpy as np 8 | 9 | """ 10 | GMM: Gaussian Mixture Model Convolution layer 11 | Geometric Deep Learning on Graphs and Manifolds using Mixture Model CNNs (Federico Monti et al., CVPR 2017) 12 | https://arxiv.org/pdf/1611.08402.pdf 13 | """ 14 | 15 | from layers.gmm_layer import GMMLayer 16 | from layers.mlp_readout_layer import MLPReadout 17 | 18 | class MoNet(nn.Module): 19 | def __init__(self, net_params): 20 | super().__init__() 21 | self.name = 'MoNet' 22 | in_dim = net_params['in_dim'] 23 | hidden_dim = net_params['hidden_dim'] 24 | out_dim = net_params['out_dim'] 25 | kernel = net_params['kernel'] # for MoNet 26 | dim = net_params['pseudo_dim_MoNet'] # for MoNet 27 | n_classes = net_params['n_classes'] 28 | dropout = net_params['dropout'] 29 | n_layers = net_params['L'] 30 | self.readout = net_params['readout'] 31 | batch_norm = net_params['batch_norm'] 32 | residual = net_params['residual'] 33 | self.device = net_params['device'] 34 | 35 | aggr_type = "sum" # default for MoNet 36 | 37 | self.embedding_h = nn.Linear(in_dim, hidden_dim) 38 | 39 | self.layers = nn.ModuleList() 40 | self.pseudo_proj = nn.ModuleList() 41 | 42 | # Hidden layer 43 | for _ in range(n_layers-1): 44 | self.layers.append(GMMLayer(hidden_dim, hidden_dim, dim, kernel, aggr_type, 45 | dropout, batch_norm, residual)) 46 | self.pseudo_proj.append(nn.Sequential(nn.Linear(2, dim), nn.Tanh())) 47 | 48 | # Output layer 49 | self.layers.append(GMMLayer(hidden_dim, out_dim, dim, kernel, aggr_type, 50 | dropout, batch_norm, residual)) 51 | self.pseudo_proj.append(nn.Sequential(nn.Linear(2, dim), nn.Tanh())) 52 | 53 | self.MLP_layer = MLPReadout(out_dim, n_classes) 54 | 55 | def forward(self, g, h, e): 56 | h = self.embedding_h(h) 57 | 58 | # computing the 'pseudo' named tensor which depends on node degrees 59 | g.ndata['deg'] = g.in_degrees() 60 | g.apply_edges(self.compute_pseudo) 61 | pseudo = g.edata['pseudo'].to(self.device).float() 62 | 63 | for i in range(len(self.layers)): 64 | h = self.layers[i](g, h, self.pseudo_proj[i](pseudo)) 65 | g.ndata['h'] = h 66 | 67 | if self.readout == "sum": 68 | hg = dgl.sum_nodes(g, 'h') 69 | elif self.readout == "max": 70 | hg = dgl.max_nodes(g, 'h') 71 | elif self.readout == "mean": 72 | hg = dgl.mean_nodes(g, 'h') 73 | else: 74 | hg = dgl.mean_nodes(g, 'h') # default readout is mean nodes 75 | 76 | return self.MLP_layer(hg) 77 | 78 | def compute_pseudo(self, edges): 79 | # compute pseudo edge features for MoNet 80 | # to avoid zero division in case in_degree is 0, we add constant '1' in all node degrees denoting self-loop 81 | srcs = 1/np.sqrt(edges.src['deg']+1) 82 | dsts = 1/np.sqrt(edges.dst['deg']+1) 83 | pseudo = torch.cat((srcs.unsqueeze(-1), dsts.unsqueeze(-1)), dim=1) 84 | return {'pseudo': pseudo} 85 | 86 | def loss(self, pred, label): 87 | criterion = nn.CrossEntropyLoss() 88 | loss = criterion(pred, label) 89 | return loss -------------------------------------------------------------------------------- /nets/ring_gnn_net.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | import dgl 6 | import time 7 | 8 | """ 9 | Ring-GNN 10 | On the equivalence between graph isomorphism testing and function approximation with GNNs (Chen et al, 2019) 11 | https://arxiv.org/pdf/1905.12560v1.pdf 12 | """ 13 | from layers.ring_gnn_equiv_layer import RingGNNEquivLayer 14 | from layers.mlp_readout_layer import MLPReadout 15 | 16 | class RingGNNNet(nn.Module): 17 | def __init__(self, net_params): 18 | super().__init__() 19 | self.in_dim_node = net_params['in_dim'] 20 | avg_node_num = net_params['avg_node_num'] 21 | radius = net_params['radius'] 22 | hidden_dim = net_params['hidden_dim'] 23 | n_classes = net_params['n_classes'] 24 | dropout = net_params['dropout'] 25 | n_layers = net_params['L'] 26 | self.layer_norm = net_params['layer_norm'] 27 | self.residual = net_params['residual'] 28 | self.device = net_params['device'] 29 | 30 | self.depth = [torch.LongTensor([1+self.in_dim_node])] + [torch.LongTensor([hidden_dim])] * n_layers 31 | 32 | self.equi_modulelist = nn.ModuleList([RingGNNEquivLayer(self.device, m, n, 33 | layer_norm=self.layer_norm, 34 | residual=self.residual, 35 | dropout=dropout, 36 | radius=radius, 37 | k2_init=0.5/avg_node_num) for m, n in zip(self.depth[:-1], self.depth[1:])]) 38 | 39 | self.prediction = MLPReadout(torch.sum(torch.stack(self.depth)).item(), n_classes) 40 | 41 | def forward(self, x): 42 | """ 43 | CODE ADPATED FROM https://github.com/leichen2018/Ring-GNN/ 44 | """ 45 | 46 | # this x is the tensor with all info available => adj, node feat 47 | 48 | x_list = [x] 49 | for layer in self.equi_modulelist: 50 | x = layer(x) 51 | x_list.append(x) 52 | 53 | # # readout 54 | x_list = [torch.sum(torch.sum(x, dim=3), dim=2) for x in x_list] 55 | x_list = torch.cat(x_list, dim=1) 56 | 57 | x_out = self.prediction(x_list) 58 | 59 | return x_out 60 | 61 | def loss(self, pred, label): 62 | criterion = nn.CrossEntropyLoss() 63 | loss = criterion(pred, label) 64 | return loss 65 | 66 | -------------------------------------------------------------------------------- /nets/three_wl_gnn_net.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | import dgl 6 | import time 7 | 8 | """ 9 | 3WLGNN / ThreeWLGNN 10 | Provably Powerful Graph Networks (Maron et al., 2019) 11 | https://papers.nips.cc/paper/8488-provably-powerful-graph-networks.pdf 12 | 13 | CODE adapted from https://github.com/hadarser/ProvablyPowerfulGraphNetworks_torch/ 14 | """ 15 | 16 | from layers.three_wl_gnn_layers import RegularBlock, MlpBlock, SkipConnection, FullyConnected, diag_offdiag_maxpool 17 | from layers.mlp_readout_layer import MLPReadout 18 | 19 | class ThreeWLGNNNet(nn.Module): 20 | def __init__(self, net_params): 21 | super().__init__() 22 | 23 | self.in_dim_node = net_params['in_dim'] 24 | depth_of_mlp = net_params['depth_of_mlp'] 25 | hidden_dim = net_params['hidden_dim'] 26 | n_classes = net_params['n_classes'] 27 | dropout = net_params['dropout'] 28 | n_layers = net_params['L'] 29 | self.layer_norm = net_params['layer_norm'] 30 | self.residual = net_params['residual'] 31 | self.device = net_params['device'] 32 | self.diag_pool_readout = True # if True, uses the new_suffix readout from original code 33 | 34 | block_features = [hidden_dim] * n_layers # L here is the block number 35 | 36 | original_features_num = self.in_dim_node + 1 # Number of features of the input 37 | 38 | # sequential mlp blocks 39 | last_layer_features = original_features_num 40 | self.reg_blocks = nn.ModuleList() 41 | for layer, next_layer_features in enumerate(block_features): 42 | mlp_block = RegularBlock(depth_of_mlp, last_layer_features, next_layer_features, self.residual) 43 | self.reg_blocks.append(mlp_block) 44 | last_layer_features = next_layer_features 45 | 46 | if self.diag_pool_readout: 47 | self.fc_layers = nn.ModuleList() 48 | for output_features in block_features: 49 | # each block's output will be pooled (thus have 2*output_features), and pass through a fully connected 50 | fc = FullyConnected(2*output_features, n_classes, activation_fn=None) 51 | self.fc_layers.append(fc) 52 | else: 53 | self.mlp_prediction = MLPReadout(sum(block_features)+original_features_num, n_classes) 54 | 55 | def forward(self, x): 56 | if self.diag_pool_readout: 57 | scores = torch.tensor(0, device=self.device, dtype=x.dtype) 58 | else: 59 | x_list = [x] 60 | 61 | for i, block in enumerate(self.reg_blocks): 62 | 63 | x = block(x) 64 | if self.diag_pool_readout: 65 | scores = self.fc_layers[i](diag_offdiag_maxpool(x)) + scores 66 | else: 67 | x_list.append(x) 68 | 69 | if self.diag_pool_readout: 70 | return scores 71 | else: 72 | # readout like RingGNN 73 | x_list = [torch.sum(torch.sum(x, dim=3), dim=2) for x in x_list] 74 | x_list = torch.cat(x_list, dim=1) 75 | 76 | x_out = self.mlp_prediction(x_list) 77 | return x_out 78 | 79 | 80 | def loss(self, pred, label): 81 | criterion = nn.CrossEntropyLoss() 82 | loss = criterion(pred, label) 83 | return loss 84 | -------------------------------------------------------------------------------- /pooling/diffpool.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import init 4 | import torch.nn.functional as F 5 | 6 | import time 7 | import numpy as np 8 | from scipy.linalg import block_diag 9 | 10 | import dgl 11 | 12 | """ 13 | 14 | 15 | DIFFPOOL: 16 | Z. Ying, J. You, C. Morris, X. Ren, W. Hamilton, and J. Leskovec, 17 | Hierarchical graph representation learning with differentiable pooling (NeurIPS 2018) 18 | https://arxiv.org/pdf/1806.08804.pdf 19 | 20 | ! code started from dgl diffpool examples dir 21 | """ 22 | 23 | from layers.graphsage_layer import GraphSageLayer, DenseGraphSage # this is GraphSageLayer, DiffPoolBatchedGraphLayer 24 | from layers.diffpool_layer import DiffPoolLayer, DenseDiffPool # this is GraphSageLayer, DiffPoolBatchedGraphLayer 25 | # from .graphsage_net import GraphSageNet # this is GraphSage 26 | # replace BatchedDiffPool with DenseDiffPool and BatchedGraphSAGE with DenseGraphSage 27 | 28 | 29 | class DiffPoolNet(nn.Module): 30 | """ 31 | DiffPool Fuse with GNN layers and pooling layers in sequence 32 | """ 33 | 34 | def __init__(self, net_params, pool_ratio=0.5): 35 | 36 | super().__init__() 37 | input_dim = net_params['in_dim'] 38 | hidden_dim = net_params['hidden_dim'] 39 | embedding_dim = net_params['hidden_dim'] # net_params['embedding_dim'] 40 | label_dim = net_params['n_classes'] 41 | activation = F.relu 42 | n_layers = net_params['L'] # this is the gnn_per_block param 43 | dropout = net_params['dropout'] 44 | self.batch_norm = net_params['batch_norm'] 45 | self.residual = net_params['residual'] 46 | aggregator_type = net_params['sage_aggregator'] 47 | # pool_ratio = net_params['pool_ratio'] 48 | 49 | self.device = net_params['device'] 50 | self.link_pred = True # net_params['linkpred'] 51 | self.concat = False # net_params['cat'] 52 | self.n_pooling = 1 # net_params['num_pool'] 53 | self.batch_size = net_params['batch_size'] 54 | self.e_feat = net_params['edge_feat'] 55 | self.link_pred_loss = [] 56 | self.entropy_loss = [] 57 | 58 | self.embedding_h = nn.Linear(input_dim, hidden_dim) 59 | 60 | # list of GNN modules before the first diffpool operation 61 | self.gc_before_pool = nn.ModuleList() 62 | 63 | self.assign_dim = int(net_params['max_num_node'] * pool_ratio) * self.batch_size # net_params['assign_dim'] 64 | self.bn = True 65 | self.num_aggs = 1 66 | 67 | # constructing layers 68 | # layers before diffpool 69 | assert n_layers >= 3, "n_layers too few" 70 | self.gc_before_pool.append(GraphSageLayer(hidden_dim, hidden_dim, activation, 71 | dropout, aggregator_type, self.residual, self.bn)) 72 | 73 | for _ in range(n_layers - 2): 74 | self.gc_before_pool.append(GraphSageLayer(hidden_dim, hidden_dim, activation, 75 | dropout, aggregator_type, self.residual, self.bn)) 76 | 77 | self.gc_before_pool.append(GraphSageLayer(hidden_dim, embedding_dim, None, dropout, aggregator_type, self.residual)) 78 | 79 | 80 | assign_dims = [] 81 | assign_dims.append(self.assign_dim) 82 | if self.concat: 83 | # diffpool layer receive pool_emedding_dim node feature tensor 84 | # and return pool_embedding_dim node embedding 85 | pool_embedding_dim = hidden_dim * (n_layers - 1) + embedding_dim 86 | else: 87 | 88 | pool_embedding_dim = embedding_dim 89 | 90 | self.first_diffpool_layer = DiffPoolLayer(pool_embedding_dim, self.assign_dim, hidden_dim, 91 | activation, dropout, aggregator_type, self.link_pred, self.batch_norm) 92 | gc_after_per_pool = nn.ModuleList() 93 | 94 | # list of list of GNN modules, each list after one diffpool operation 95 | self.gc_after_pool = nn.ModuleList() 96 | 97 | for _ in range(n_layers - 1): 98 | gc_after_per_pool.append(DenseGraphSage(hidden_dim, hidden_dim, self.residual)) 99 | gc_after_per_pool.append(DenseGraphSage(hidden_dim, embedding_dim, self.residual)) 100 | self.gc_after_pool.append(gc_after_per_pool) 101 | 102 | self.assign_dim = int(self.assign_dim * pool_ratio) 103 | 104 | self.diffpool_layers = nn.ModuleList() 105 | # each pooling module 106 | for _ in range(self.n_pooling - 1): 107 | self.diffpool_layers.append(DenseDiffPool(pool_embedding_dim, self.assign_dim, hidden_dim, self.link_pred)) 108 | 109 | gc_after_per_pool = nn.ModuleList() 110 | 111 | for _ in range(n_layers - 1): 112 | gc_after_per_pool.append(DenseGraphSage(hidden_dim, hidden_dim, self.residual)) 113 | gc_after_per_pool.append(DenseGraphSage(hidden_dim, embedding_dim, self.residual)) 114 | self.gc_after_pool.append(gc_after_per_pool) 115 | 116 | assign_dims.append(self.assign_dim) 117 | self.assign_dim = int(self.assign_dim * pool_ratio) 118 | 119 | # predicting layer 120 | if self.concat: 121 | self.pred_input_dim = pool_embedding_dim * \ 122 | self.num_aggs * (self.n_pooling + 1) 123 | else: 124 | self.pred_input_dim = embedding_dim * self.num_aggs 125 | self.pred_layer = nn.Linear(self.pred_input_dim, label_dim) 126 | 127 | # weight initialization 128 | for m in self.modules(): 129 | if isinstance(m, nn.Linear): 130 | m.weight.data = init.xavier_uniform_(m.weight.data, 131 | gain=nn.init.calculate_gain('relu')) 132 | if m.bias is not None: 133 | m.bias.data = init.constant_(m.bias.data, 0.0) 134 | 135 | def gcn_forward(self, g, h, e, gc_layers, cat=False): 136 | """ 137 | Return gc_layer embedding cat. 138 | """ 139 | block_readout = [] 140 | for gc_layer in gc_layers[:-1]: 141 | h, e = gc_layer(g, h, e) 142 | block_readout.append(h) 143 | h, e = gc_layers[-1](g, h, e) 144 | block_readout.append(h) 145 | if cat: 146 | block = torch.cat(block_readout, dim=1) # N x F, F = F1 + F2 + ... 147 | else: 148 | block = h 149 | return block 150 | 151 | def gcn_forward_tensorized(self, h, adj, gc_layers, cat=False): 152 | block_readout = [] 153 | for gc_layer in gc_layers: 154 | h = gc_layer(h, adj) 155 | block_readout.append(h) 156 | if cat: 157 | block = torch.cat(block_readout, dim=2) # N x F, F = F1 + F2 + ... 158 | else: 159 | block = h 160 | return block 161 | 162 | def forward(self, g, h, e): 163 | self.link_pred_loss = [] 164 | self.entropy_loss = [] 165 | 166 | # node feature for assignment matrix computation is the same as the 167 | # original node feature 168 | h = self.embedding_h(h) 169 | h_a = h 170 | 171 | out_all = [] 172 | 173 | # we use GCN blocks to get an embedding first 174 | g_embedding = self.gcn_forward(g, h, e, self.gc_before_pool, self.concat) 175 | 176 | g.ndata['h'] = g_embedding 177 | 178 | readout = dgl.sum_nodes(g, 'h') 179 | out_all.append(readout) 180 | if self.num_aggs == 2: 181 | readout = dgl.max_nodes(g, 'h') 182 | out_all.append(readout) 183 | 184 | adj, h = self.first_diffpool_layer(g, g_embedding) 185 | node_per_pool_graph = int(adj.size()[0] / self.batch_size) 186 | 187 | h, adj = self.batch2tensor(adj, h, node_per_pool_graph) 188 | h = self.gcn_forward_tensorized(h, adj, self.gc_after_pool[0], self.concat) 189 | 190 | readout = torch.sum(h, dim=1) 191 | out_all.append(readout) 192 | if self.num_aggs == 2: 193 | readout, _ = torch.max(h, dim=1) 194 | out_all.append(readout) 195 | 196 | for i, diffpool_layer in enumerate(self.diffpool_layers): 197 | h, adj = diffpool_layer(h, adj) 198 | h = self.gcn_forward_tensorized(h, adj, self.gc_after_pool[i + 1], self.concat) 199 | 200 | readout = torch.sum(h, dim=1) 201 | out_all.append(readout) 202 | 203 | if self.num_aggs == 2: 204 | readout, _ = torch.max(h, dim=1) 205 | out_all.append(readout) 206 | 207 | if self.concat or self.num_aggs > 1: 208 | final_readout = torch.cat(out_all, dim=1) 209 | else: 210 | final_readout = readout 211 | ypred = self.pred_layer(final_readout) 212 | return ypred 213 | 214 | def batch2tensor(self, batch_adj, batch_feat, node_per_pool_graph): 215 | """ 216 | transform a batched graph to batched adjacency tensor and node feature tensor 217 | """ 218 | batch_size = int(batch_adj.size()[0] / node_per_pool_graph) 219 | adj_list = [] 220 | feat_list = [] 221 | 222 | for i in range(batch_size): 223 | start = i * node_per_pool_graph 224 | end = (i + 1) * node_per_pool_graph 225 | 226 | # 1/sqrt(V) normalization 227 | snorm_n = torch.FloatTensor(node_per_pool_graph, 1).fill_(1./float(node_per_pool_graph)).sqrt().to(self.device) 228 | 229 | adj_list.append(batch_adj[start:end, start:end]) 230 | feat_list.append((batch_feat[start:end, :])*snorm_n) 231 | adj_list = list(map(lambda x: torch.unsqueeze(x, 0), adj_list)) 232 | feat_list = list(map(lambda x: torch.unsqueeze(x, 0), feat_list)) 233 | adj = torch.cat(adj_list, dim=0) 234 | feat = torch.cat(feat_list, dim=0) 235 | 236 | return feat, adj 237 | 238 | def loss(self, pred, label): 239 | ''' 240 | loss function 241 | ''' 242 | #softmax + CE 243 | criterion = nn.CrossEntropyLoss() 244 | loss = criterion(pred, label) 245 | for diffpool_layer in self.diffpool_layers: 246 | for key, value in diffpool_layer.loss_log.items(): 247 | loss += value 248 | return loss 249 | 250 | -------------------------------------------------------------------------------- /pooling/edge_sparse_max.py: -------------------------------------------------------------------------------- 1 | """ 2 | An original implementation of sparsemax (Martins & Astudillo, 2016) is available at 3 | https://github.com/OpenNMT/OpenNMT-py/blob/master/onmt/modules/sparse_activations.py. 4 | See `From Softmax to Sparsemax: A Sparse Model of Attention and Multi-Label Classification, ICML 2016` 5 | for detailed description. 6 | 7 | Here we implement a graph-edge version of sparsemax where we perform sparsemax for all edges 8 | with the same node as end-node in graphs. 9 | """ 10 | import dgl 11 | import torch 12 | from dgl.backend import astype 13 | from dgl.base import ALL, is_all 14 | from dgl.heterograph_index import HeteroGraphIndex 15 | from dgl.sparse import _gsddmm, _gspmm 16 | from torch import Tensor 17 | from torch.autograd import Function 18 | 19 | 20 | def _neighbor_sort(scores:Tensor, end_n_ids:Tensor, in_degrees:Tensor, cum_in_degrees:Tensor): 21 | """Sort edge scores for each node""" 22 | num_nodes, max_in_degree = in_degrees.size(0), int(in_degrees.max().item()) 23 | 24 | # Compute the index for dense score matrix with size (N x D_{max}) 25 | # Note that the end_n_ids here is the end_node tensor in dgl graph, 26 | # which is not grouped by its node id (i.e. in this form: 0,0,1,1,1,...,N,N). 27 | # Thus here we first sort the end_node tensor to make it easier to compute 28 | # indexs in dense edge score matrix. Since we will need the original order 29 | # for following gspmm and gsddmm operations, we also keep the reverse mapping 30 | # (the reverse_perm) here. 31 | end_n_ids, perm = torch.sort(end_n_ids) 32 | scores = scores[perm] 33 | _, reverse_perm = torch.sort(perm) 34 | 35 | index = torch.arange(end_n_ids.size(0), dtype=torch.long, device=scores.device) 36 | index = (index - cum_in_degrees[end_n_ids]) + (end_n_ids * max_in_degree) 37 | index = index.long() 38 | 39 | dense_scores = scores.new_full((num_nodes * max_in_degree, ), torch.finfo(scores.dtype).min) 40 | dense_scores[index] = scores 41 | dense_scores = dense_scores.view(num_nodes, max_in_degree) 42 | 43 | sorted_dense_scores, dense_reverse_perm = dense_scores.sort(dim=-1, descending=True) 44 | _, dense_reverse_perm = torch.sort(dense_reverse_perm, dim=-1) 45 | dense_reverse_perm = dense_reverse_perm + cum_in_degrees.view(-1, 1) 46 | dense_reverse_perm = dense_reverse_perm.view(-1) 47 | cumsum_sorted_dense_scores = sorted_dense_scores.cumsum(dim=-1).view(-1) 48 | sorted_dense_scores = sorted_dense_scores.view(-1) 49 | arange_vec = torch.arange(1, max_in_degree + 1, dtype=torch.long, device=end_n_ids.device) 50 | arange_vec = torch.repeat_interleave(arange_vec.view(1, -1), num_nodes, dim=0).view(-1) 51 | 52 | valid_mask = (sorted_dense_scores != torch.finfo(scores.dtype).min) 53 | sorted_scores = sorted_dense_scores[valid_mask] 54 | cumsum_sorted_scores = cumsum_sorted_dense_scores[valid_mask] 55 | arange_vec = arange_vec[valid_mask] 56 | dense_reverse_perm = dense_reverse_perm[valid_mask].long() 57 | 58 | return sorted_scores, cumsum_sorted_scores, arange_vec, reverse_perm, dense_reverse_perm 59 | 60 | 61 | def _threshold_and_support_graph(gidx:HeteroGraphIndex, scores:Tensor, end_n_ids:Tensor): 62 | """Find the threshold for each node and its edges""" 63 | in_degrees = _gspmm(gidx, "copy_rhs", "sum", None, torch.ones_like(scores))[0] 64 | cum_in_degrees = torch.cat([in_degrees.new_zeros(1), in_degrees.cumsum(dim=0)[:-1]], dim=0) 65 | 66 | # perform sort on edges for each node 67 | sorted_scores, cumsum_scores, rhos, reverse_perm, dense_reverse_perm = _neighbor_sort(scores, end_n_ids, 68 | in_degrees, cum_in_degrees) 69 | cumsum_scores = cumsum_scores - 1. 70 | support = rhos * sorted_scores > cumsum_scores 71 | support = support[dense_reverse_perm] # from sorted order to unsorted order 72 | support = support[reverse_perm] # from src-dst order to eid order 73 | 74 | support_size = _gspmm(gidx, "copy_rhs", "sum", None, support.float())[0] 75 | support_size = support_size.long() 76 | idx = support_size + cum_in_degrees - 1 77 | 78 | # mask invalid index, for example, if batch is not start from 0 or not continuous, it may result in negative index 79 | mask = idx < 0 80 | idx[mask] = 0 81 | tau = cumsum_scores.gather(0, idx.long()) 82 | tau /= support_size.to(scores.dtype) 83 | 84 | return tau, support_size 85 | 86 | 87 | class EdgeSparsemaxFunction(Function): 88 | r""" 89 | Description 90 | ----------- 91 | Pytorch Auto-Grad Function for edge sparsemax. 92 | 93 | We define this auto-grad function here since 94 | sparsemax involves sort and select, which are 95 | not derivative. 96 | """ 97 | @staticmethod 98 | def forward(ctx, gidx:HeteroGraphIndex, scores:Tensor, 99 | eids:Tensor, end_n_ids:Tensor, norm_by:str): 100 | if not is_all(eids): 101 | gidx = gidx.edge_subgraph([eids], True).graph 102 | if norm_by == "src": 103 | gidx = gidx.reverse() 104 | 105 | # use feat - max(feat) for numerical stability. 106 | scores = scores.float() 107 | scores_max = _gspmm(gidx, "copy_rhs", "max", None, scores)[0] 108 | scores = _gsddmm(gidx, "sub", scores, scores_max, "e", "v") 109 | 110 | # find threshold for each node and perform ReLU(u-t(u)) operation. 111 | tau, supp_size = _threshold_and_support_graph(gidx, scores, end_n_ids) 112 | out = torch.clamp(_gsddmm(gidx, "sub", scores, tau, "e", "v"), min=0) 113 | ctx.backward_cache = gidx 114 | ctx.save_for_backward(supp_size, out) 115 | torch.cuda.empty_cache() 116 | return out 117 | 118 | @staticmethod 119 | def backward(ctx, grad_out): 120 | gidx = ctx.backward_cache 121 | supp_size, out = ctx.saved_tensors 122 | grad_in = grad_out.clone() 123 | 124 | # grad for ReLU 125 | grad_in[out == 0] = 0 126 | 127 | # dL/dv_i = dL/do_i - 1/k \sum_{j=1}^k dL/do_j 128 | v_hat = _gspmm(gidx, "copy_rhs", "sum", None, grad_in)[0] / supp_size.to(out.dtype) 129 | grad_in_modify = _gsddmm(gidx, "sub", grad_in, v_hat, "e", "v") 130 | grad_in = torch.where(out != 0, grad_in_modify, grad_in) 131 | del gidx 132 | torch.cuda.empty_cache() 133 | 134 | return None, grad_in, None, None, None 135 | 136 | 137 | def edge_sparsemax(graph:dgl.DGLGraph, logits, eids=ALL, norm_by="dst"): 138 | r""" 139 | Description 140 | ----------- 141 | Compute edge sparsemax. For a node :math:`i`, edge sparsemax is an operation that computes 142 | 143 | .. math:: 144 | a_{ij} = \text{ReLU}(z_{ij} - \tau(\z_{i,:})) 145 | 146 | where :math:`z_{ij}` is a signal of edge :math:`j\rightarrow i`, also 147 | called logits in the context of sparsemax. :math:`\tau` is a function 148 | that can be found at the `From Softmax to Sparsemax ` 149 | paper. 150 | 151 | NOTE: currently only homogeneous graphs are supported. 152 | 153 | Parameters 154 | ---------- 155 | graph : DGLGraph 156 | The graph to perform edge sparsemax on. 157 | logits : torch.Tensor 158 | The input edge feature. 159 | eids : torch.Tensor or ALL, optional 160 | A tensor of edge index on which to apply edge sparsemax. If ALL, apply edge 161 | sparsemax on all edges in the graph. Default: ALL. 162 | norm_by : str, could be 'src' or 'dst' 163 | Normalized by source nodes of destination nodes. Default: `dst`. 164 | 165 | Returns 166 | ------- 167 | Tensor 168 | Sparsemax value. 169 | """ 170 | # we get edge index tensors here since it is 171 | # hard to get edge index with HeteroGraphIndex 172 | # object without other information like edge_type. 173 | row, col = graph.all_edges(order="eid") 174 | assert norm_by in ["dst", "src"] 175 | end_n_ids = col if norm_by == "dst" else row 176 | if not is_all(eids): 177 | eids = astype(eids, graph.idtype) 178 | end_n_ids = end_n_ids[eids] 179 | return EdgeSparsemaxFunction.apply(graph._graph, logits, 180 | eids, end_n_ids, norm_by) 181 | 182 | 183 | class EdgeSparsemax(torch.nn.Module): 184 | r""" 185 | Description 186 | ----------- 187 | Compute edge sparsemax. For a node :math:`i`, edge sparsemax is an operation that computes 188 | 189 | .. math:: 190 | a_{ij} = \text{ReLU}(z_{ij} - \tau(\z_{i,:})) 191 | 192 | where :math:`z_{ij}` is a signal of edge :math:`j\rightarrow i`, also 193 | called logits in the context of sparsemax. :math:`\tau` is a function 194 | that can be found at the `From Softmax to Sparsemax ` 195 | paper. 196 | 197 | Parameters 198 | ---------- 199 | graph : DGLGraph 200 | The graph to perform edge sparsemax on. 201 | logits : torch.Tensor 202 | The input edge feature. 203 | eids : torch.Tensor or ALL, optional 204 | A tensor of edge index on which to apply edge sparsemax. If ALL, apply edge 205 | sparsemax on all edges in the graph. Default: ALL. 206 | norm_by : str, could be 'src' or 'dst' 207 | Normalized by source nodes of destination nodes. Default: `dst`. 208 | 209 | NOTE: currently only homogeneous graphs are supported. 210 | 211 | Returns 212 | ------- 213 | Tensor 214 | Sparsemax value. 215 | """ 216 | def __init__(self): 217 | super(EdgeSparsemax, self).__init__() 218 | 219 | def forward(self, graph, logits, eids=ALL, norm_by="dst"): 220 | return edge_sparsemax(graph, logits, eids, norm_by) 221 | -------------------------------------------------------------------------------- /pooling/mewispool.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import networkx as nx 5 | import dgl 6 | from dgl.nn import GINConv 7 | from scipy import sparse 8 | from dgl.nn.pytorch.glob import SumPooling, AvgPooling, MaxPooling 9 | 10 | 11 | class MLP(nn.Module): 12 | def __init__(self, input_dim, hidden_dim, output_dim, enhance=False): 13 | super(MLP, self).__init__() 14 | 15 | self.enhance = enhance 16 | 17 | self.fc1 = nn.Linear(in_features=input_dim, out_features=hidden_dim) 18 | self.fc2 = nn.Linear(in_features=hidden_dim, out_features=hidden_dim) 19 | self.fc3 = nn.Linear(in_features=hidden_dim, out_features=output_dim) 20 | 21 | if enhance: 22 | self.bn1 = nn.BatchNorm1d(hidden_dim) 23 | self.bn2 = nn.BatchNorm1d(hidden_dim) 24 | self.dropout = nn.Dropout(0.2) 25 | 26 | def forward(self, x): 27 | x = self.fc1(x) 28 | if self.enhance: 29 | x = self.bn1(x) 30 | x = torch.relu(x) 31 | if self.enhance: 32 | x = self.dropout(x) 33 | 34 | x = self.fc2(x) 35 | if self.enhance: 36 | x = self.bn2(x) 37 | x = torch.relu(x) 38 | if self.enhance: 39 | x = self.dropout(x) 40 | 41 | x = self.fc3(x) 42 | 43 | return x 44 | 45 | 46 | class MEWISPool(nn.Module): 47 | def __init__(self, hidden_dim): 48 | super(MEWISPool, self).__init__() 49 | 50 | self.gc1 = GINConv(MLP(1, hidden_dim, hidden_dim), 'sum') 51 | self.gc2 = GINConv(MLP(hidden_dim, hidden_dim, hidden_dim), 'sum') 52 | self.gc3 = GINConv(MLP(hidden_dim, hidden_dim, 1), 'sum') 53 | 54 | def forward(self, g, x, batch): 55 | # computing the graph laplacian and adjacency matrix 56 | batch_nodes = x.size(0) 57 | if g.num_edges() != 0: 58 | # L_indices, L_values = get_laplacian(edge_index) 59 | # L = torch.sparse.FloatTensor(L_indices, L_values, torch.Size([batch_nodes, batch_nodes])) 60 | # A = torch.diag(torch.diag(L.to_dense())) - L.to_dense() 61 | 62 | num_nodes = g.number_of_nodes() 63 | A = g.adjacency_matrix_scipy(return_edge_ids=False).astype(float) 64 | L = sparse.eye(num_nodes) - A 65 | 66 | A = self.matrix2tensor(A, device=x.device) 67 | L = self.matrix2tensor(L, device=x.device) 68 | 69 | # entropy computation 70 | entropies = self.compute_entropy(x, L, A, batch) # Eq. (8) 71 | else: 72 | A = torch.zeros([batch_nodes, batch_nodes]) 73 | norm = torch.norm(x, dim=1).unsqueeze(-1) 74 | entropies = norm / norm 75 | 76 | # graph convolution and probability scores 77 | probabilities = self.gc1(g, entropies) 78 | probabilities = self.gc2(g, probabilities) 79 | probabilities = self.gc3(g, probabilities) 80 | probabilities = torch.sigmoid(probabilities) 81 | 82 | # conditional expectation; Algorithm 1 83 | gamma = entropies.sum() 84 | loss = self.loss_fn(entropies, probabilities, A, gamma) # Eq. (9) 85 | 86 | mewis = self.conditional_expectation(entropies, probabilities, A, loss, gamma) 87 | 88 | # graph reconstruction; Eq. (10) 89 | g, h, adj_pooled = self.graph_reconstruction(mewis, x, A, device=x.device) 90 | edge_index_pooled, batch_pooled = self.to_edge_index(adj_pooled, mewis, batch) 91 | 92 | return g, h, loss, batch_pooled 93 | 94 | @staticmethod 95 | def matrix2tensor(csr_matrix, device): 96 | coo_matrix = csr_matrix.tocoo() 97 | values = coo_matrix.data 98 | indices = np.vstack((coo_matrix.row, coo_matrix.col)) 99 | 100 | i = torch.tensor(indices, device=device).long() 101 | v = torch.tensor(values, device=device) 102 | shape = coo_matrix.shape 103 | 104 | return torch.sparse.FloatTensor(i, v, torch.Size(shape)).to_dense().to(torch.float32) 105 | 106 | @staticmethod 107 | def compute_entropy(x, L, A, batch): 108 | # computing local variations; Eq. (5) 109 | V = x * torch.matmul(L, x) - x * torch.matmul(A, x) + torch.matmul(A, x * x) 110 | V = torch.norm(V, dim=1) 111 | 112 | # computing the probability distributions based on the local variations; Eq. (7) 113 | P = torch.cat([torch.softmax(V[batch == i], dim=0) for i in torch.unique(batch)]) 114 | P[P == 0.] += 1 115 | # computing the entropies; Eq. (8) 116 | H = -P * torch.log(P) 117 | 118 | return H.unsqueeze(-1) 119 | 120 | @staticmethod 121 | def loss_fn(entropies, probabilities, A, gamma): 122 | term1 = -torch.matmul(entropies.t(), probabilities)[0, 0] 123 | 124 | term2 = torch.matmul(torch.matmul(probabilities.t(), A), probabilities).sum() 125 | 126 | return gamma + term1 + term2 127 | 128 | def conditional_expectation(self, entropies, probabilities, A, threshold, gamma): 129 | sorted_probabilities = torch.sort(probabilities, descending=True, dim=0) 130 | 131 | dummy_probabilities = probabilities.detach().clone() 132 | selected = set() 133 | rejected = set() 134 | 135 | for i in range(sorted_probabilities.values.size(0)): 136 | node_index = sorted_probabilities.indices[i].item() 137 | neighbors = torch.where(A[node_index] == 1)[0] 138 | if len(neighbors) == 0: 139 | selected.add(node_index) 140 | continue 141 | if node_index not in rejected and node_index not in selected: 142 | s = dummy_probabilities.clone() 143 | s[node_index] = 1 144 | s[neighbors] = 0 145 | 146 | loss = self.loss_fn(entropies, s, A, gamma) 147 | 148 | if loss <= threshold: 149 | selected.add(node_index) 150 | for n in neighbors.tolist(): 151 | rejected.add(n) 152 | 153 | dummy_probabilities[node_index] = 1 154 | dummy_probabilities[neighbors] = 0 155 | 156 | mewis = list(selected) 157 | mewis = sorted(mewis) 158 | 159 | return mewis 160 | 161 | def graph_reconstruction(self, mewis, x, A, device): 162 | x_pooled = x[mewis] 163 | 164 | A2 = torch.matmul(A, A) 165 | A3 = torch.matmul(A2, A) 166 | 167 | A2 = A2[mewis][:, mewis] 168 | A3 = A3[mewis][:, mewis] 169 | 170 | I = torch.eye(len(mewis), device=device) 171 | one = torch.ones([len(mewis), len(mewis)], device=device) 172 | 173 | adj_pooled = (one - I) * torch.clamp(A2 + A3, min=0, max=1) 174 | 175 | src, dst = np.nonzero(adj_pooled.cpu().numpy()) 176 | g = dgl.graph((src, dst), num_nodes=adj_pooled.size(0)).to(device) 177 | g.ndata['h'] = x_pooled 178 | 179 | return g, x_pooled, adj_pooled 180 | 181 | @staticmethod 182 | def to_edge_index(adj_pooled, mewis, batch): 183 | row1, row2 = torch.where(adj_pooled > 0) 184 | edge_index_pooled = torch.cat([row1.unsqueeze(0), row2.unsqueeze(0)], dim=0) 185 | batch_pooled = batch[mewis] 186 | 187 | return edge_index_pooled, batch_pooled 188 | 189 | 190 | class MEWISPoolNet(nn.Module): 191 | def __init__(self, net_params, readout='sum'): 192 | super(MEWISPoolNet, self).__init__() 193 | 194 | input_dim = net_params['in_dim'] 195 | hidden_dim = net_params['hidden_dim'] 196 | n_classes = net_params['n_classes'] 197 | self.e_feat = net_params['edge_feat'] 198 | 199 | self.gc1 = GINConv(MLP(input_dim=input_dim, hidden_dim=hidden_dim, output_dim=hidden_dim, enhance=True), 'sum') 200 | self.pool1 = MEWISPool(hidden_dim=hidden_dim) 201 | self.gc2 = GINConv(MLP(input_dim=hidden_dim, hidden_dim=hidden_dim, output_dim=hidden_dim, enhance=True), 'sum') 202 | self.pool2 = MEWISPool(hidden_dim=hidden_dim) 203 | self.gc3 = GINConv(MLP(input_dim=hidden_dim, hidden_dim=hidden_dim, output_dim=hidden_dim, enhance=True), 'sum') 204 | self.fc1 = nn.Linear(in_features=hidden_dim, out_features=hidden_dim) 205 | self.fc2 = nn.Linear(in_features=hidden_dim, out_features=n_classes) 206 | 207 | if readout == 'sum': 208 | self.pool = SumPooling() 209 | elif readout == 'mean': 210 | self.pool = AvgPooling() 211 | elif readout == 'max': 212 | self.pool = MaxPooling() 213 | else: 214 | raise NotImplementedError 215 | 216 | def forward(self, g, h, e): 217 | batch = [] 218 | for i, g_n_nodes in enumerate(g.batch_num_nodes()): 219 | batch += [i] * g_n_nodes 220 | batch = torch.tensor(batch, device=h.device) 221 | 222 | x = torch.relu(self.gc1(g, h)) 223 | g_pooled1, x_pooled1, loss1, batch_pooled1 = self.pool1(g, x, batch) 224 | x_pooled1 = torch.relu(self.gc2(g_pooled1, x_pooled1)) 225 | g_pooled2, x_pooled2, loss2, batch_pooled2 = self.pool2(g_pooled1, x_pooled1, batch_pooled1) 226 | x_pooled2 = self.gc3(g_pooled2, x_pooled2) 227 | 228 | # readout = torch.cat([self.pool(g, x), self.pool(g_pooled1, x_pooled1), self.pool(g_pooled2, x_pooled2)], dim=1) 229 | # readout = self.pool(g_pooled2, x_pooled2) 230 | readout = torch.cat([x_pooled2[batch_pooled2 == i].mean(0).unsqueeze(0) for i in torch.unique(batch_pooled2)], dim=0) 231 | 232 | out = torch.relu(self.fc1(readout)) 233 | out = self.fc2(out) 234 | 235 | return torch.log_softmax(out, dim=-1), loss1 + loss2 236 | 237 | def loss(self, pred, label, loss_pool): 238 | criterion = nn.CrossEntropyLoss() 239 | loss_classification = criterion(pred, label.view(-1)) 240 | loss = loss_classification + 0.01 * loss_pool 241 | return loss 242 | 243 | 244 | class Net2(nn.Module): 245 | def __init__(self, input_dim, hidden_dim, num_classes): 246 | super(Net2, self).__init__() 247 | 248 | self.gc1 = GINConv(MLP(input_dim=input_dim, hidden_dim=hidden_dim, output_dim=hidden_dim, enhance=True)) 249 | self.gc2 = GINConv(MLP(input_dim=hidden_dim, hidden_dim=hidden_dim, output_dim=hidden_dim, enhance=True)) 250 | self.gc3 = GINConv(MLP(input_dim=hidden_dim, hidden_dim=hidden_dim, output_dim=hidden_dim, enhance=True)) 251 | self.pool1 = MEWISPool(hidden_dim=hidden_dim) 252 | self.gc4 = GINConv(MLP(input_dim=hidden_dim, hidden_dim=hidden_dim, output_dim=hidden_dim, enhance=True)) 253 | self.fc1 = nn.Linear(in_features=hidden_dim, out_features=hidden_dim) 254 | self.fc2 = nn.Linear(in_features=hidden_dim, out_features=num_classes) 255 | 256 | def forward(self, x, edge_index, batch): 257 | x = self.gc1(x, edge_index) 258 | x = torch.relu(x) 259 | 260 | x = self.gc2(x, edge_index) 261 | x = torch.relu(x) 262 | 263 | x = self.gc3(x, edge_index) 264 | x = torch.relu(x) 265 | readout2 = torch.cat([x[batch == i].mean(0).unsqueeze(0) for i in torch.unique(batch)], dim=0) 266 | 267 | x_pooled1, edge_index_pooled1, batch_pooled1, loss1, mewis = self.pool1(x, edge_index, batch) 268 | 269 | x_pooled1 = self.gc4(x_pooled1, edge_index_pooled1) 270 | 271 | readout = torch.cat([x_pooled1[batch_pooled1 == i].mean(0).unsqueeze(0) for i in torch.unique(batch_pooled1)], 272 | dim=0) 273 | 274 | out = readout2 + readout 275 | 276 | out = self.fc1(out) 277 | out = torch.relu(out) 278 | out = self.fc2(out) 279 | 280 | return torch.log_softmax(out, dim=-1), loss1 281 | 282 | 283 | class Net3(nn.Module): 284 | def __init__(self, input_dim, hidden_dim, num_classes): 285 | super(Net3, self).__init__() 286 | 287 | self.gc1 = GINConv(MLP(input_dim=input_dim, hidden_dim=hidden_dim, output_dim=hidden_dim, enhance=True)) 288 | self.gc2 = GINConv(MLP(input_dim=hidden_dim, hidden_dim=hidden_dim, output_dim=hidden_dim, enhance=True)) 289 | self.pool1 = MEWISPool(hidden_dim=hidden_dim) 290 | self.gc3 = GINConv(MLP(input_dim=hidden_dim, hidden_dim=hidden_dim, output_dim=hidden_dim, enhance=True)) 291 | self.gc4 = GINConv(MLP(input_dim=hidden_dim, hidden_dim=hidden_dim, output_dim=hidden_dim, enhance=True)) 292 | self.pool2 = MEWISPool(hidden_dim=hidden_dim) 293 | self.gc5 = GINConv(MLP(input_dim=hidden_dim, hidden_dim=hidden_dim, output_dim=hidden_dim, enhance=True)) 294 | self.fc1 = nn.Linear(in_features=hidden_dim, out_features=hidden_dim) 295 | self.fc2 = nn.Linear(in_features=hidden_dim, out_features=num_classes) 296 | 297 | def forward(self, x, edge_index, batch): 298 | x = self.gc1(x, edge_index) 299 | x = torch.relu(x) 300 | 301 | x = self.gc2(x, edge_index) 302 | x = torch.relu(x) 303 | 304 | x_pooled1, edge_index_pooled1, batch_pooled1, loss1, mewis1 = self.pool1(x, edge_index, batch) 305 | 306 | x_pooled1 = self.gc3(x_pooled1, edge_index_pooled1) 307 | x_pooled1 = torch.relu(x_pooled1) 308 | 309 | x_pooled1 = self.gc4(x_pooled1, edge_index_pooled1) 310 | x_pooled1 = torch.relu(x_pooled1) 311 | 312 | x_pooled2, edge_index_pooled2, batch_pooled2, loss2, mewis2 = self.pool2(x_pooled1, edge_index_pooled1, 313 | batch_pooled1) 314 | 315 | x_pooled2 = self.gc5(x_pooled2, edge_index_pooled2) 316 | x_pooled2 = torch.relu(x_pooled2) 317 | 318 | readout = torch.cat([x_pooled2[batch_pooled2 == i].mean(0).unsqueeze(0) for i in torch.unique(batch_pooled2)], 319 | dim=0) 320 | 321 | out = self.fc1(readout) 322 | out = torch.relu(out) 323 | out = self.fc2(out) 324 | 325 | return torch.log_softmax(out, dim=-1), loss1 + loss2 -------------------------------------------------------------------------------- /pooling/sagpool.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn 3 | import torch.nn.functional as F 4 | import dgl 5 | from dgl.nn import GraphConv, AvgPooling, MaxPooling 6 | from pooling.topkpool import TopKPooling, get_batch_id 7 | from layers.gcn_layer import GCNLayer 8 | from layers.mlp_readout_layer import MLPReadout 9 | 10 | 11 | class SAGPool(torch.nn.Module): 12 | """The Self-Attention Pooling layer in paper 13 | `Self Attention Graph Pooling ` 14 | 15 | Args: 16 | in_dim (int): The dimension of node feature. 17 | ratio (float, optional): The pool ratio which determines the amount of nodes 18 | remain after pooling. (default: :obj:`0.5`) 19 | conv_op (torch.nn.Module, optional): The graph convolution layer in dgl used to 20 | compute scale for each node. (default: :obj:`dgl.nn.GraphConv`) 21 | non_linearity (Callable, optional): The non-linearity function, a pytorch function. 22 | (default: :obj:`torch.tanh`) 23 | """ 24 | def __init__(self, in_dim:int, ratio=0.5, non_linearity=torch.tanh): 25 | super(SAGPool, self).__init__() 26 | self.in_dim = in_dim 27 | self.ratio = ratio 28 | if dgl.__version__ < "0.5": 29 | self.score_layer = GraphConv(in_dim, 1) 30 | else: 31 | self.score_layer = GraphConv(in_dim, 1, allow_zero_in_degree=True) 32 | self.non_linearity = non_linearity 33 | self.softmax = torch.nn.Softmax() 34 | 35 | def forward(self, graph:dgl.DGLGraph, feature:torch.Tensor, e_feat=None): 36 | score = self.score_layer(graph, feature).squeeze() 37 | perm, next_batch_num_nodes = TopKPooling(score, self.ratio, get_batch_id(graph.batch_num_nodes()), graph.batch_num_nodes()) 38 | feature = feature[perm] * self.non_linearity(score[perm]).view(-1, 1) 39 | graph = dgl.node_subgraph(graph, perm) 40 | 41 | # node_subgraph currently does not support batch-graph, 42 | # the 'batch_num_nodes' of the result subgraph is None. 43 | # So we manually set the 'batch_num_nodes' here. 44 | # Since global pooling has nothing to do with 'batch_num_edges', 45 | # we can leave it to be None or unchanged. 46 | graph.set_batch_num_nodes(next_batch_num_nodes) 47 | 48 | score = self.softmax(score) 49 | if torch.nonzero(torch.isnan(score)).size(0) > 0: 50 | print(score[torch.nonzero(torch.isnan(score))]) 51 | raise KeyError 52 | 53 | if e_feat is not None: 54 | e_feat = graph.edata['feat'].unsqueeze(-1) 55 | 56 | return graph, feature, perm, score, e_feat 57 | 58 | 59 | class SAGPoolBlock(torch.nn.Module): 60 | """A combination of GCN layer and SAGPool layer, 61 | followed by a concatenated (mean||sum) readout operation. 62 | return the graph embedding with the same dimension of node embedding. 63 | """ 64 | def __init__(self, in_dim:int, pool_ratio=0.5): 65 | super(SAGPoolBlock, self).__init__() 66 | self.pool = SAGPool(in_dim, ratio=pool_ratio) 67 | self.avgpool = AvgPooling() 68 | self.maxpool = MaxPooling() 69 | self.mlp = torch.nn.Linear(in_dim * 2, in_dim) 70 | 71 | def forward(self, graph, feature): 72 | graph, out, _, _, _ = self.pool(graph, feature) 73 | g_out = torch.cat([self.avgpool(graph, out), self.maxpool(graph, out)], dim=-1) 74 | g_out = F.relu(self.mlp(g_out)) 75 | return graph, out, g_out 76 | 77 | 78 | class SAGPoolReadout(torch.nn.Module): 79 | """A combination of GCN layer and SAGPool layer, 80 | followed by a concatenated (mean||sum) readout operation. 81 | """ 82 | def __init__(self, net_params, pool_ratio=0.5, pool=True): 83 | super(SAGPoolReadout, self).__init__() 84 | in_dim = net_params['in_dim'] 85 | out_dim = net_params['out_dim'] 86 | dropout = net_params['dropout'] 87 | n_classes = net_params['n_classes'] 88 | self.batch_norm = net_params['batch_norm'] 89 | self.residual = net_params['residual'] 90 | self.e_feat = net_params['edge_feat'] 91 | self.conv = GCNLayer(in_dim, out_dim, F.relu, dropout, self.batch_norm, self.residual, e_feat=self.e_feat) 92 | self.use_pool = pool 93 | self.pool = SAGPool(out_dim, ratio=pool_ratio) 94 | self.avgpool = AvgPooling() 95 | self.maxpool = MaxPooling() 96 | self.MLP_layer = MLPReadout(out_dim * 2, n_classes) 97 | 98 | def forward(self, graph, feature, e_feat=None): 99 | out, e_feat = self.conv(graph, feature, e_feat) 100 | if self.use_pool: 101 | graph, out, _, _, e_feat = self.pool(graph, out, e_feat) 102 | hg = torch.cat([self.avgpool(graph, out), self.maxpool(graph, out)], dim=-1) 103 | scores = self.MLP_layer(hg) 104 | return scores 105 | 106 | def loss(self, pred, label, cluster=False): 107 | 108 | criterion = torch.nn.CrossEntropyLoss() 109 | loss = criterion(pred, label) 110 | 111 | return loss 112 | -------------------------------------------------------------------------------- /pooling/topkpool.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import logging 3 | from scipy.stats import t 4 | import math 5 | 6 | 7 | def get_batch_id(num_nodes:torch.Tensor): 8 | """Convert the num_nodes array obtained from batch graph to batch_id array 9 | for each node. 10 | 11 | Args: 12 | num_nodes (torch.Tensor): The tensor whose element is the number of nodes 13 | in each graph in the batch graph. 14 | """ 15 | batch_size = num_nodes.size(0) 16 | batch_ids = [] 17 | for i in range(batch_size): 18 | item = torch.full((num_nodes[i],), i, dtype=torch.long, device=num_nodes.device) 19 | batch_ids.append(item) 20 | return torch.cat(batch_ids) 21 | 22 | 23 | def TopKPooling(x:torch.Tensor, ratio:float, batch_id:torch.Tensor, num_nodes:torch.Tensor): 24 | """The top-k pooling method. Given a graph batch, this method will pool out some 25 | nodes from input node feature tensor for each graph according to the given ratio. 26 | 27 | Args: 28 | x (torch.Tensor): The input node feature batch-tensor to be pooled. 29 | ratio (float): the pool ratio. For example if :obj:`ratio=0.5` then half of the input 30 | tensor will be pooled out. 31 | batch_id (torch.Tensor): The batch_id of each element in the input tensor. 32 | num_nodes (torch.Tensor): The number of nodes of each graph in batch. 33 | 34 | Returns: 35 | perm (torch.Tensor): The index in batch to be kept. 36 | k (torch.Tensor): The remaining number of nodes for each graph. 37 | """ 38 | batch_size, max_num_nodes = num_nodes.size(0), num_nodes.max().item() 39 | 40 | cum_num_nodes = torch.cat( 41 | [num_nodes.new_zeros(1), 42 | num_nodes.cumsum(dim=0)[:-1]], dim=0) 43 | 44 | index = torch.arange(batch_id.size(0), dtype=torch.long, device=x.device) 45 | index = (index - cum_num_nodes[batch_id]) + (batch_id * max_num_nodes) 46 | 47 | dense_x = x.new_full((batch_size * max_num_nodes, ), torch.finfo(x.dtype).min) 48 | dense_x[index] = x 49 | dense_x = dense_x.view(batch_size, max_num_nodes) 50 | 51 | _, perm = dense_x.sort(dim=-1, descending=True) 52 | perm = perm + cum_num_nodes.view(-1, 1) 53 | perm = perm.view(-1) 54 | 55 | k = (ratio * num_nodes.to(torch.float)).ceil().to(torch.long) 56 | mask = [ 57 | torch.arange(k[i], dtype=torch.long, device=x.device) + 58 | i * max_num_nodes for i in range(batch_size)] 59 | 60 | mask = torch.cat(mask, dim=0) 61 | perm = perm[mask] 62 | 63 | return perm, k 64 | -------------------------------------------------------------------------------- /train_TUs_graph_classification.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utility functions for training one epoch 3 | and evaluating one epoch 4 | """ 5 | import torch 6 | import torch.nn as nn 7 | import math 8 | 9 | from metrics import accuracy_TU as accuracy 10 | 11 | """ 12 | For GCNs 13 | """ 14 | def train_epoch_sparse(model, optimizer, device, data_loader, epoch): 15 | model.train() 16 | epoch_loss = 0 17 | epoch_train_acc = 0 18 | nb_data = 0 19 | gpu_mem = 0 20 | for iter, (batch_graphs, batch_labels) in enumerate(data_loader): 21 | batch_graphs = batch_graphs.to(device) 22 | batch_x = batch_graphs.ndata['feat'].to(device) # num x feat 23 | batch_e = batch_graphs.edata['feat'].to(device) 24 | batch_labels = batch_labels.to(device) 25 | optimizer.zero_grad() 26 | 27 | batch_scores = model.forward(batch_graphs, batch_x, batch_e) 28 | loss = model.loss(batch_scores, batch_labels) 29 | loss.backward() 30 | optimizer.step() 31 | epoch_loss += loss.detach().item() 32 | epoch_train_acc += accuracy(batch_scores, batch_labels) 33 | nb_data += batch_labels.size(0) 34 | epoch_loss /= (iter + 1) 35 | epoch_train_acc /= nb_data 36 | 37 | return epoch_loss, epoch_train_acc, optimizer 38 | 39 | def evaluate_network_sparse(model, device, data_loader, epoch): 40 | model.eval() 41 | epoch_test_loss = 0 42 | epoch_test_acc = 0 43 | nb_data = 0 44 | with torch.no_grad(): 45 | for iter, (batch_graphs, batch_labels) in enumerate(data_loader): 46 | batch_graphs = batch_graphs.to(device) 47 | batch_x = batch_graphs.ndata['feat'].to(device) 48 | batch_e = batch_graphs.edata['feat'].to(device) 49 | batch_labels = batch_labels.to(device) 50 | 51 | batch_scores = model.forward(batch_graphs, batch_x, batch_e) 52 | loss = model.loss(batch_scores, batch_labels) 53 | epoch_test_loss += loss.detach().item() 54 | epoch_test_acc += accuracy(batch_scores, batch_labels) 55 | nb_data += batch_labels.size(0) 56 | epoch_test_loss /= (iter + 1) 57 | epoch_test_acc /= nb_data 58 | 59 | return epoch_test_loss, epoch_test_acc 60 | 61 | 62 | 63 | 64 | """ 65 | For WL-GNNs 66 | """ 67 | def train_epoch_dense(model, optimizer, device, data_loader, epoch, batch_size): 68 | model.train() 69 | epoch_loss = 0 70 | epoch_train_acc = 0 71 | nb_data = 0 72 | gpu_mem = 0 73 | optimizer.zero_grad() 74 | for iter, (x_with_node_feat, labels) in enumerate(data_loader): 75 | x_with_node_feat = x_with_node_feat.to(device) 76 | labels = labels.to(device) 77 | 78 | scores = model.forward(x_with_node_feat) 79 | loss = model.loss(scores, labels) 80 | loss.backward() 81 | 82 | if not (iter%batch_size): 83 | optimizer.step() 84 | optimizer.zero_grad() 85 | 86 | epoch_loss += loss.detach().item() 87 | epoch_train_acc += accuracy(scores, labels) 88 | nb_data += labels.size(0) 89 | epoch_loss /= (iter + 1) 90 | epoch_train_acc /= nb_data 91 | 92 | return epoch_loss, epoch_train_acc, optimizer 93 | 94 | def evaluate_network_dense(model, device, data_loader, epoch): 95 | model.eval() 96 | epoch_test_loss = 0 97 | epoch_test_acc = 0 98 | nb_data = 0 99 | with torch.no_grad(): 100 | for iter, (x_with_node_feat, labels) in enumerate(data_loader): 101 | x_with_node_feat = x_with_node_feat.to(device) 102 | labels = labels.to(device) 103 | 104 | scores = model.forward(x_with_node_feat) 105 | loss = model.loss(scores, labels) 106 | epoch_test_loss += loss.detach().item() 107 | epoch_test_acc += accuracy(scores, labels) 108 | nb_data += labels.size(0) 109 | epoch_test_loss /= (iter + 1) 110 | epoch_test_acc /= nb_data 111 | 112 | return epoch_test_loss, epoch_test_acc 113 | 114 | 115 | 116 | 117 | def check_patience(all_losses, best_loss, best_epoch, curr_loss, curr_epoch, counter): 118 | if curr_loss < best_loss: 119 | counter = 0 120 | best_loss = curr_loss 121 | best_epoch = curr_epoch 122 | else: 123 | counter += 1 124 | return best_loss, best_epoch, counter --------------------------------------------------------------------------------