├── .gitignore ├── README.md ├── config └── hparams_testdb.yml ├── data ├── fig │ ├── ModelStructure.png │ └── ReportedResults.png └── gnn_enzymes_source_20190905 │ └── ENZYMES │ ├── ENZYMES_A.txt │ ├── ENZYMES_graph_indicator.txt │ ├── ENZYMES_graph_labels.txt │ ├── ENZYMES_node_attributes.txt │ ├── ENZYMES_node_labels.txt │ └── README.txt ├── gnn_hpool ├── __init__.py ├── bin │ ├── __init__.py │ └── train_eval.py ├── layers │ ├── __init__.py │ ├── gcn_layer.py │ └── hierarchical_diff_pooling.py ├── models │ ├── __init__.py │ ├── gcn_hpool_encoder.py │ └── gcn_hpool_submodel.py └── utils │ ├── __init__.py │ ├── common_utils.py │ ├── evaluate.py │ ├── get_loss.py │ ├── global_variables.py │ ├── hparam.py │ ├── hparams_lib.py │ ├── load_data.py │ └── load_data_test.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | # tmp results 107 | results/* 108 | .idea* 109 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Hierarchical Graph Representation Learning 2 | 3 | # Requirements 4 | * Python==3.6.x 5 | * PyTorch==1.1.0 6 | * NumPy>=1.16.3 7 | * matplotlib>=3.0.3 8 | * networkx==2.4 9 | * tensorboardX==1.8 10 | 11 | # DataSet & Task 12 | ENZYMES dataset is used, which includes 600 molecule structures. The edge feature denotes whether there is a connection between two molecules and the node feature denotes what kind of element for a node. 13 | * Classify the type of enzyme 14 | * 600 samples (training: 540, testing: 60) 15 | * 6 types; 100 samples per type 16 | * the number of nodes per graph: 2 ~ 125 (median value ~ 30) 17 | * dimension of node features: 3 18 | 19 | # Model Structure 20 | ![](./data/fig/ModelStructure.png) 21 | 22 | # Usage 23 | ```shell 24 | python train.py --hparam_path=./config/hparams_testdb.yml # or other config files you defined 25 | ``` 26 | 27 | # Results 28 | ## Reported Results 29 | ![](./data/fig/ReportedResults.png) 30 | ## Replication 31 | 32 | Best val result: 0.6133 @ epoch 765 33 | 34 | 35 | # Reference 36 | [1] Ying, Zhitao, et al. "Hierarchical graph representation learning with differentiable pooling." Advances in Neural Information Processing Systems. 2018. 37 | 38 | [2] Huang, Gao, et al. "Densely connected convolutional networks." Proceedings of the IEEE conference on computer vision and pattern recognition. 2017. 39 | -------------------------------------------------------------------------------- /config/hparams_testdb.yml: -------------------------------------------------------------------------------- 1 | device: 'cuda' 2 | cuda_visible_devices: '2' 3 | datadir: 'data/gnn_enzymes_source_20190905' 4 | model_save_path: 'results' 5 | dataname: 'ENZYMES' 6 | fold_num: 10 7 | max_num_nodes: 1000 8 | channel_list: [3, 30, 30, 30, 30] 9 | node_list: [3, 10, 10] 10 | batch_size: 20 11 | epoch: 1000 12 | timestamp: 1567969359 13 | learning_rate: 0.001 14 | grad_clip: 2 15 | -------------------------------------------------------------------------------- /data/fig/ModelStructure.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/murphyyhuang/gnn_hierarchical_pooling/5c875d5821b49cacba3ac052fccd2a05a2274716/data/fig/ModelStructure.png -------------------------------------------------------------------------------- /data/fig/ReportedResults.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/murphyyhuang/gnn_hierarchical_pooling/5c875d5821b49cacba3ac052fccd2a05a2274716/data/fig/ReportedResults.png -------------------------------------------------------------------------------- /data/gnn_enzymes_source_20190905/ENZYMES/ENZYMES_graph_labels.txt: -------------------------------------------------------------------------------- 1 | 6 2 | 6 3 | 6 4 | 6 5 | 6 6 | 6 7 | 6 8 | 6 9 | 6 10 | 6 11 | 6 12 | 6 13 | 6 14 | 6 15 | 6 16 | 6 17 | 6 18 | 6 19 | 6 20 | 6 21 | 6 22 | 6 23 | 6 24 | 6 25 | 6 26 | 6 27 | 6 28 | 6 29 | 6 30 | 6 31 | 6 32 | 6 33 | 6 34 | 6 35 | 6 36 | 6 37 | 6 38 | 6 39 | 6 40 | 6 41 | 6 42 | 6 43 | 6 44 | 6 45 | 6 46 | 6 47 | 6 48 | 6 49 | 6 50 | 6 51 | 6 52 | 6 53 | 6 54 | 6 55 | 6 56 | 6 57 | 6 58 | 6 59 | 6 60 | 6 61 | 6 62 | 6 63 | 6 64 | 6 65 | 6 66 | 6 67 | 6 68 | 6 69 | 6 70 | 6 71 | 6 72 | 6 73 | 6 74 | 6 75 | 6 76 | 6 77 | 6 78 | 6 79 | 6 80 | 6 81 | 6 82 | 6 83 | 6 84 | 6 85 | 6 86 | 6 87 | 6 88 | 6 89 | 6 90 | 6 91 | 6 92 | 6 93 | 6 94 | 6 95 | 6 96 | 6 97 | 6 98 | 6 99 | 6 100 | 6 101 | 5 102 | 5 103 | 5 104 | 5 105 | 5 106 | 5 107 | 5 108 | 5 109 | 5 110 | 5 111 | 5 112 | 5 113 | 5 114 | 5 115 | 5 116 | 5 117 | 5 118 | 5 119 | 5 120 | 5 121 | 5 122 | 5 123 | 5 124 | 5 125 | 5 126 | 5 127 | 5 128 | 5 129 | 5 130 | 5 131 | 5 132 | 5 133 | 5 134 | 5 135 | 5 136 | 5 137 | 5 138 | 5 139 | 5 140 | 5 141 | 5 142 | 5 143 | 5 144 | 5 145 | 5 146 | 5 147 | 5 148 | 5 149 | 5 150 | 5 151 | 5 152 | 5 153 | 5 154 | 5 155 | 5 156 | 5 157 | 5 158 | 5 159 | 5 160 | 5 161 | 5 162 | 5 163 | 5 164 | 5 165 | 5 166 | 5 167 | 5 168 | 5 169 | 5 170 | 5 171 | 5 172 | 5 173 | 5 174 | 5 175 | 5 176 | 5 177 | 5 178 | 5 179 | 5 180 | 5 181 | 5 182 | 5 183 | 5 184 | 5 185 | 5 186 | 5 187 | 5 188 | 5 189 | 5 190 | 5 191 | 5 192 | 5 193 | 5 194 | 5 195 | 5 196 | 5 197 | 5 198 | 5 199 | 5 200 | 5 201 | 1 202 | 1 203 | 1 204 | 1 205 | 1 206 | 1 207 | 1 208 | 1 209 | 1 210 | 1 211 | 1 212 | 1 213 | 1 214 | 1 215 | 1 216 | 1 217 | 1 218 | 1 219 | 1 220 | 1 221 | 1 222 | 1 223 | 1 224 | 1 225 | 1 226 | 1 227 | 1 228 | 1 229 | 1 230 | 1 231 | 1 232 | 1 233 | 1 234 | 1 235 | 1 236 | 1 237 | 1 238 | 1 239 | 1 240 | 1 241 | 1 242 | 1 243 | 1 244 | 1 245 | 1 246 | 1 247 | 1 248 | 1 249 | 1 250 | 1 251 | 1 252 | 1 253 | 1 254 | 1 255 | 1 256 | 1 257 | 1 258 | 1 259 | 1 260 | 1 261 | 1 262 | 1 263 | 1 264 | 1 265 | 1 266 | 1 267 | 1 268 | 1 269 | 1 270 | 1 271 | 1 272 | 1 273 | 1 274 | 1 275 | 1 276 | 1 277 | 1 278 | 1 279 | 1 280 | 1 281 | 1 282 | 1 283 | 1 284 | 1 285 | 1 286 | 1 287 | 1 288 | 1 289 | 1 290 | 1 291 | 1 292 | 1 293 | 1 294 | 1 295 | 1 296 | 1 297 | 1 298 | 1 299 | 1 300 | 1 301 | 2 302 | 2 303 | 2 304 | 2 305 | 2 306 | 2 307 | 2 308 | 2 309 | 2 310 | 2 311 | 2 312 | 2 313 | 2 314 | 2 315 | 2 316 | 2 317 | 2 318 | 2 319 | 2 320 | 2 321 | 2 322 | 2 323 | 2 324 | 2 325 | 2 326 | 2 327 | 2 328 | 2 329 | 2 330 | 2 331 | 2 332 | 2 333 | 2 334 | 2 335 | 2 336 | 2 337 | 2 338 | 2 339 | 2 340 | 2 341 | 2 342 | 2 343 | 2 344 | 2 345 | 2 346 | 2 347 | 2 348 | 2 349 | 2 350 | 2 351 | 2 352 | 2 353 | 2 354 | 2 355 | 2 356 | 2 357 | 2 358 | 2 359 | 2 360 | 2 361 | 2 362 | 2 363 | 2 364 | 2 365 | 2 366 | 2 367 | 2 368 | 2 369 | 2 370 | 2 371 | 2 372 | 2 373 | 2 374 | 2 375 | 2 376 | 2 377 | 2 378 | 2 379 | 2 380 | 2 381 | 2 382 | 2 383 | 2 384 | 2 385 | 2 386 | 2 387 | 2 388 | 2 389 | 2 390 | 2 391 | 2 392 | 2 393 | 2 394 | 2 395 | 2 396 | 2 397 | 2 398 | 2 399 | 2 400 | 2 401 | 3 402 | 3 403 | 3 404 | 3 405 | 3 406 | 3 407 | 3 408 | 3 409 | 3 410 | 3 411 | 3 412 | 3 413 | 3 414 | 3 415 | 3 416 | 3 417 | 3 418 | 3 419 | 3 420 | 3 421 | 3 422 | 3 423 | 3 424 | 3 425 | 3 426 | 3 427 | 3 428 | 3 429 | 3 430 | 3 431 | 3 432 | 3 433 | 3 434 | 3 435 | 3 436 | 3 437 | 3 438 | 3 439 | 3 440 | 3 441 | 3 442 | 3 443 | 3 444 | 3 445 | 3 446 | 3 447 | 3 448 | 3 449 | 3 450 | 3 451 | 3 452 | 3 453 | 3 454 | 3 455 | 3 456 | 3 457 | 3 458 | 3 459 | 3 460 | 3 461 | 3 462 | 3 463 | 3 464 | 3 465 | 3 466 | 3 467 | 3 468 | 3 469 | 3 470 | 3 471 | 3 472 | 3 473 | 3 474 | 3 475 | 3 476 | 3 477 | 3 478 | 3 479 | 3 480 | 3 481 | 3 482 | 3 483 | 3 484 | 3 485 | 3 486 | 3 487 | 3 488 | 3 489 | 3 490 | 3 491 | 3 492 | 3 493 | 3 494 | 3 495 | 3 496 | 3 497 | 3 498 | 3 499 | 3 500 | 3 501 | 4 502 | 4 503 | 4 504 | 4 505 | 4 506 | 4 507 | 4 508 | 4 509 | 4 510 | 4 511 | 4 512 | 4 513 | 4 514 | 4 515 | 4 516 | 4 517 | 4 518 | 4 519 | 4 520 | 4 521 | 4 522 | 4 523 | 4 524 | 4 525 | 4 526 | 4 527 | 4 528 | 4 529 | 4 530 | 4 531 | 4 532 | 4 533 | 4 534 | 4 535 | 4 536 | 4 537 | 4 538 | 4 539 | 4 540 | 4 541 | 4 542 | 4 543 | 4 544 | 4 545 | 4 546 | 4 547 | 4 548 | 4 549 | 4 550 | 4 551 | 4 552 | 4 553 | 4 554 | 4 555 | 4 556 | 4 557 | 4 558 | 4 559 | 4 560 | 4 561 | 4 562 | 4 563 | 4 564 | 4 565 | 4 566 | 4 567 | 4 568 | 4 569 | 4 570 | 4 571 | 4 572 | 4 573 | 4 574 | 4 575 | 4 576 | 4 577 | 4 578 | 4 579 | 4 580 | 4 581 | 4 582 | 4 583 | 4 584 | 4 585 | 4 586 | 4 587 | 4 588 | 4 589 | 4 590 | 4 591 | 4 592 | 4 593 | 4 594 | 4 595 | 4 596 | 4 597 | 4 598 | 4 599 | 4 600 | 4 601 | -------------------------------------------------------------------------------- /data/gnn_enzymes_source_20190905/ENZYMES/README.txt: -------------------------------------------------------------------------------- 1 | README for dataset ENZYMES 2 | 3 | 4 | === Usage === 5 | 6 | This folder contains the following comma separated text files 7 | (replace DS by the name of the dataset): 8 | 9 | n = total number of nodes 10 | m = total number of edges 11 | N = number of graphs 12 | 13 | (1) DS_A.txt (m lines) 14 | sparse (block diagonal) adjacency matrix for all graphs, 15 | each line corresponds to (row, col) resp. (node_id, node_id) 16 | 17 | (2) DS_graph_indicator.txt (n lines) 18 | column vector of graph identifiers for all nodes of all graphs, 19 | the value in the i-th line is the graph_id of the node with node_id i 20 | 21 | (3) DS_graph_labels.txt (N lines) 22 | class labels for all graphs in the dataset, 23 | the value in the i-th line is the class label of the graph with graph_id i 24 | 25 | (4) DS_node_labels.txt (n lines) 26 | column vector of node labels, 27 | the value in the i-th line corresponds to the node with node_id i 28 | 29 | There are OPTIONAL files if the respective information is available: 30 | 31 | (5) DS_edge_labels.txt (m lines; same size as DS_A_sparse.txt) 32 | labels for the edges in DS_A_sparse.txt 33 | 34 | (6) DS_edge_attributes.txt (m lines; same size as DS_A.txt) 35 | attributes for the edges in DS_A.txt 36 | 37 | (7) DS_node_attributes.txt (n lines) 38 | matrix of node attributes, 39 | the comma seperated values in the i-th line is the attribute vector of the node with node_id i 40 | 41 | (8) DS_graph_attributes.txt (N lines) 42 | regression values for all graphs in the dataset, 43 | the value in the i-th line is the attribute of the graph with graph_id i 44 | 45 | 46 | === Description === 47 | 48 | ENZYMES is a dataset of protein tertiary structures obtained from (Borgwardt et al., 2005) 49 | consisting of 600 enzymes from the BRENDA enzyme database (Schomburg et al., 2004). 50 | In this case the task is to correctly assign each enzyme to one of the 6 EC top-level 51 | classes. 52 | 53 | 54 | === Previous Use of the Dataset === 55 | 56 | Feragen, A., Kasenburg, N., Petersen, J., de Bruijne, M., Borgwardt, K.M.: Scalable 57 | kernels for graphs with continuous attributes. In: C.J.C. Burges, L. Bottou, Z. Ghahra- 58 | mani, K.Q. Weinberger (eds.) NIPS, pp. 216-224 (2013) 59 | 60 | Neumann, M., Garnett R., Bauckhage Ch., Kersting K.: Propagation Kernels: Efficient Graph 61 | Kernels from Propagated Information. Under review at MLJ. 62 | 63 | 64 | === References === 65 | 66 | K. M. Borgwardt, C. S. Ong, S. Schoenauer, S. V. N. Vishwanathan, A. J. Smola, and H. P. 67 | Kriegel. Protein function prediction via graph kernels. Bioinformatics, 21(Suppl 1):i47–i56, 68 | Jun 2005. 69 | 70 | I. Schomburg, A. Chang, C. Ebeling, M. Gremse, C. Heldt, G. Huhn, and D. Schomburg. Brenda, 71 | the enzyme database: updates and major new developments. Nucleic Acids Research, 32D:431–433, 2004. 72 | -------------------------------------------------------------------------------- /gnn_hpool/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/murphyyhuang/gnn_hierarchical_pooling/5c875d5821b49cacba3ac052fccd2a05a2274716/gnn_hpool/__init__.py -------------------------------------------------------------------------------- /gnn_hpool/bin/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/murphyyhuang/gnn_hierarchical_pooling/5c875d5821b49cacba3ac052fccd2a05a2274716/gnn_hpool/bin/__init__.py -------------------------------------------------------------------------------- /gnn_hpool/bin/train_eval.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | import os 4 | import time 5 | import logging 6 | import matplotlib 7 | 8 | try: 9 | import matplotlib.pyplot as plt 10 | except ModuleNotFoundError: 11 | matplotlib.use('Agg') 12 | import matplotlib.pyplot as plt 13 | 14 | import numpy as np 15 | import networkx as nx 16 | 17 | import torch 18 | import tensorboardX 19 | from gnn_hpool.utils import get_loss 20 | from gnn_hpool.utils import common_utils 21 | from gnn_hpool.utils.global_variables import * 22 | from gnn_hpool.utils.evaluate import evaluate 23 | from gnn_hpool.utils import load_data 24 | from gnn_hpool.models import gcn_hpool_encoder 25 | 26 | 27 | def train_eval(hparams): 28 | data_loader = load_data.GraphDataLoaderWrapper(hparams) 29 | 30 | all_vals = [] 31 | for val_idx in range(hparams.fold_num): 32 | logging.warning('* validation index: {}'.format(val_idx)) 33 | training_loader, validation_loader = data_loader.get_loader(val_idx) 34 | summary_writer = tensorboardX.SummaryWriter( 35 | logdir=os.path.join(hparams.model_save_path, str(hparams.timestamp) + '/val_{}'.format(val_idx)) 36 | ) 37 | 38 | model = gcn_hpool_encoder.GcnHpoolEncoder(hparams).to(torch.device(hparams.device)) 39 | _, val_accs = train_eval_iter(model, training_loader, validation_loader, summary_writer, hparams) 40 | all_vals.append(np.array(val_accs)) 41 | 42 | all_vals = np.vstack(all_vals) 43 | all_vals = np.mean(all_vals, axis=0) 44 | logging.warning('* all of the validation results: '.format(all_vals)) 45 | logging.warning('* the best validation results & its id: {} @ {}'.format(np.max(all_vals), np.argmax(all_vals))) 46 | 47 | 48 | def train_eval_iter(model, train_dataset, eval_dataset, writer, hparams): 49 | optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=hparams.learning_rate) 50 | 51 | best_val_result = { 52 | 'epoch': 0, 53 | 'loss': 0, 54 | 'acc': 0} 55 | train_accs = [] 56 | train_epochs = [] 57 | best_val_accs = [] 58 | best_val_epochs = [] 59 | val_accs = [] 60 | 61 | for epoch in range(hparams.epoch): 62 | 63 | if not epoch % 100: 64 | logging.info('* Start the {}_th epoch'.format(epoch)) 65 | 66 | total_time = 0 67 | avg_loss = 0.0 68 | model.train() 69 | 70 | for batch_idx, graph_data in enumerate(train_dataset): 71 | 72 | begin_time = time.time() 73 | optimizer.zero_grad() 74 | 75 | # run model 76 | ypred = model(graph_data) 77 | loss = get_loss.cross_entropy(ypred, graph_data[g_key.y]) 78 | loss.backward() 79 | torch.nn.utils.clip_grad_norm_(model.parameters(), hparams.grad_clip) 80 | optimizer.step() 81 | 82 | # record 83 | avg_loss += loss 84 | elapsed = time.time() - begin_time 85 | total_time += elapsed 86 | 87 | # log once per XX epochs 88 | if epoch % 10 == 0 and batch_idx == len(train_dataset) // 2 and writer is not None: 89 | log_assignment(model.gcn_hpool_layer.pool_tensor, writer, epoch, writer_batch_idx) 90 | log_graph(graph_data[g_key.adj_mat], graph_data[g_key.node_num], writer, epoch, writer_batch_idx, model.gcn_hpool_layer.pool_tensor) 91 | 92 | avg_loss /= batch_idx + 1 93 | if writer is not None: 94 | writer.add_scalar('loss/avg_loss', avg_loss, epoch) 95 | 96 | result = evaluate(train_dataset, model, hparams, max_num_examples=100) 97 | train_accs.append(result['acc']) 98 | train_epochs.append(epoch) 99 | 100 | val_result = evaluate(eval_dataset, model, hparams) 101 | val_accs.append(val_result['acc']) 102 | if val_result['acc'] > best_val_result['acc'] - 1e-7: 103 | best_val_result['acc'] = val_result['acc'] 104 | best_val_result['epoch'] = epoch 105 | best_val_result['loss'] = avg_loss 106 | 107 | logging.warning('Best val result: {:.4f} @ epoch {}'.format(best_val_result['acc'], best_val_result['epoch'])) 108 | if writer is not None: 109 | writer.add_scalar('acc/train_acc', result['acc'], epoch) 110 | writer.add_scalar('acc/val_acc', val_result['acc'], epoch) 111 | writer.add_scalar('loss/best_val_loss', best_val_result['loss'], epoch) 112 | 113 | best_val_epochs.append(best_val_result['epoch']) 114 | best_val_accs.append(best_val_result['acc']) 115 | 116 | matplotlib.style.use('seaborn') 117 | plt.switch_backend('agg') 118 | plt.figure() 119 | plt.plot(train_epochs, common_utils.exp_moving_avg(train_accs, 0.85), '-', lw=1) 120 | 121 | plt.plot(best_val_epochs, best_val_accs, 'bo') 122 | plt.legend(['train', 'val']) 123 | plt.savefig(os.path.join(hparams.model_save_path, str(hparams.timestamp) + '.png'), dpi=600) 124 | plt.close() 125 | matplotlib.style.use('default') 126 | 127 | return model, val_accs 128 | 129 | 130 | def log_assignment(assign_tensor, writer, epoch, batch_idx): 131 | plt.switch_backend('agg') 132 | fig = plt.figure(figsize=(8, 6), dpi=300) 133 | 134 | # has to be smaller than args.batch_size 135 | for i in range(len(batch_idx)): 136 | plt.subplot(2, 2, i + 1) 137 | plt.imshow(assign_tensor.cpu().data.numpy()[batch_idx[i]], cmap=plt.get_cmap('BuPu')) 138 | cbar = plt.colorbar() 139 | cbar.solids.set_edgecolor("face") 140 | plt.tight_layout() 141 | fig.canvas.draw() 142 | 143 | # data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='') 144 | # data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) 145 | data = tensorboardX.utils.figure_to_image(fig) 146 | writer.add_image('assignment', data, epoch) 147 | 148 | 149 | def log_graph(adj, batch_num_nodes, writer, epoch, batch_idx, assign_tensor=None): 150 | plt.switch_backend('agg') 151 | fig = plt.figure(figsize=(8, 6), dpi=300) 152 | 153 | for i in range(len(batch_idx)): 154 | ax = plt.subplot(2, 2, i + 1) 155 | num_nodes = batch_num_nodes[batch_idx[i]] 156 | adj_matrix = adj[batch_idx[i], :num_nodes, :num_nodes].cpu().data.numpy() 157 | G = nx.from_numpy_matrix(adj_matrix) 158 | nx.draw(G, pos=nx.spring_layout(G), with_labels=True, node_color='#336699', 159 | edge_color='grey', width=0.5, node_size=300, 160 | alpha=0.7) 161 | ax.xaxis.set_visible(False) 162 | 163 | plt.tight_layout() 164 | fig.canvas.draw() 165 | 166 | data = tensorboardX.utils.figure_to_image(fig) 167 | writer.add_image('graphs', data, epoch) 168 | 169 | assignment = assign_tensor.cpu().data.numpy() 170 | fig = plt.figure(figsize=(8, 6), dpi=300) 171 | 172 | num_clusters = assignment.shape[2] 173 | all_colors = np.array(range(num_clusters)) 174 | 175 | for i in range(len(batch_idx)): 176 | ax = plt.subplot(2, 2, i + 1) 177 | num_nodes = batch_num_nodes[batch_idx[i]] 178 | adj_matrix = adj[batch_idx[i], :num_nodes, :num_nodes].cpu().data.numpy() 179 | 180 | label = np.argmax(assignment[batch_idx[i]], axis=1).astype(int) 181 | label = label[: batch_num_nodes[batch_idx[i]]] 182 | node_colors = all_colors[label] 183 | 184 | G = nx.from_numpy_matrix(adj_matrix) 185 | nx.draw(G, pos=nx.spring_layout(G), with_labels=False, node_color=node_colors, 186 | edge_color='grey', width=0.4, node_size=50, cmap=plt.get_cmap('Set1'), 187 | vmin=0, vmax=num_clusters - 1, 188 | alpha=0.8) 189 | 190 | plt.tight_layout() 191 | fig.canvas.draw() 192 | 193 | # data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='') 194 | # data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) 195 | data = tensorboardX.utils.figure_to_image(fig) 196 | writer.add_image('graphs_colored', data, epoch) 197 | -------------------------------------------------------------------------------- /gnn_hpool/layers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/murphyyhuang/gnn_hierarchical_pooling/5c875d5821b49cacba3ac052fccd2a05a2274716/gnn_hpool/layers/__init__.py -------------------------------------------------------------------------------- /gnn_hpool/layers/gcn_layer.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | import math 4 | 5 | import torch 6 | 7 | from torch.nn.parameter import Parameter 8 | from torch.nn.modules.module import Module 9 | 10 | 11 | class GraphConvolution(Module): 12 | """ 13 | Simple GCN layer, similar to https://arxiv.org/abs/1609.02907 14 | """ 15 | 16 | def __init__(self, in_features, out_features, hparams, bias=True): 17 | super(GraphConvolution, self).__init__() 18 | self._hparams = hparams 19 | self.in_features = in_features 20 | self.out_features = out_features 21 | self.weight = Parameter(torch.FloatTensor(in_features, out_features)) 22 | if bias: 23 | self.bias = Parameter(torch.FloatTensor(out_features)) 24 | else: 25 | self.register_parameter('bias', None) 26 | self.reset_parameters() 27 | 28 | def reset_parameters(self): 29 | stdv = 1. / math.sqrt(self.weight.size(1)) 30 | self.weight.data.uniform_(-stdv, stdv) 31 | if self.bias is not None: 32 | self.bias.data.uniform_(-stdv, stdv) 33 | 34 | def norm(self, adj): 35 | node_num = adj.shape[-1] 36 | # add remaining self-loops 37 | self_loop = torch.eye(node_num).to(self._hparams.device) 38 | self_loop = self_loop.reshape((1, node_num, node_num)) 39 | self_loop = self_loop.repeat(adj.shape[0], 1, 1) 40 | adj_post = adj + self_loop 41 | # signed adjacent matrix 42 | deg_abs = torch.sum(torch.abs(adj_post), dim=-1) 43 | deg_abs_sqrt = deg_abs.pow(-0.5) 44 | diag_deg = torch.diag_embed(deg_abs_sqrt, dim1=-2, dim2=-1) 45 | 46 | norm_adj = torch.matmul(torch.matmul(diag_deg, adj_post), diag_deg) 47 | return norm_adj 48 | 49 | def forward(self, input, adj): 50 | support = torch.matmul(input, self.weight) 51 | adj_norm = self.norm(adj) 52 | output = torch.matmul(support.transpose(1, 2), adj_norm.transpose(1, 2)) 53 | output = output.transpose(1, 2) 54 | if self.bias is not None: 55 | return output + self.bias 56 | else: 57 | return output 58 | 59 | def __repr__(self): 60 | return self.__class__.__name__ + ' (' \ 61 | + str(self.in_features) + ' -> ' \ 62 | + str(self.out_features) + ')' 63 | -------------------------------------------------------------------------------- /gnn_hpool/layers/hierarchical_diff_pooling.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | import torch 4 | from gnn_hpool.utils.global_variables import * 5 | 6 | 7 | def dense_diff_pool(x, adj, s, mask=None): 8 | r"""Differentiable pooling operator from the `"Hierarchical Graph 9 | Representation Learning with Differentiable Pooling" 10 | `_ paper 11 | 12 | Directly use the implementation of torch_geometric 13 | https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#module-torch_geometric.nn.dense.diff_pool 14 | .. math:: 15 | \mathbf{X}^{\prime} &= {\mathrm{softmax}(\mathbf{S})}^{\top} \cdot 16 | \mathbf{X} 17 | 18 | \mathbf{A}^{\prime} &= {\mathrm{softmax}(\mathbf{S})}^{\top} \cdot 19 | \mathbf{A} \cdot \mathrm{softmax}(\mathbf{S}) 20 | 21 | based on dense learned assignments :math:`\mathbf{S} \in \mathbb{R}^{B 22 | \times N \times C}`. 23 | Returns pooled node feature matrix, coarsened adjacency matrix and the 24 | auxiliary link prediction objective :math:`\| \mathbf{A} - 25 | \mathrm{softmax}(\mathbf{S}) \cdot {\mathrm{softmax}(\mathbf{S})}^{\top} 26 | \|_F`. 27 | 28 | Args: 29 | x (Tensor): Node feature tensor :math:`\mathbf{X} \in \mathbb{R}^{B 30 | \times N \times F}` with batch-size :math:`B`, (maximum) 31 | number of nodes :math:`N` for each graph, and feature dimension 32 | :math:`F`. 33 | adj (Tensor): Adjacency tensor :math:`\mathbf{A} \in \mathbb{R}^{B 34 | \times N \times N}`. 35 | s (Tensor): Assignment tensor :math:`\mathbf{S} \in \mathbb{R}^{B 36 | \times N \times C}` with number of clusters :math:`C`. The softmax 37 | does not have to be applied beforehand, since it is executed 38 | within this method. 39 | mask (ByteTensor, optional): Mask matrix 40 | :math:`\mathbf{M} \in {\{ 0, 1 \}}^{B \times N}` indicating 41 | the valid nodes for each graph. (default: :obj:`None`) 42 | 43 | :rtype: (:class:`Tensor`, :class:`Tensor`, :class:`Tensor`, 44 | :class:`Tensor`) 45 | """ 46 | 47 | x = x.unsqueeze(0) if x.dim() == 2 else x 48 | adj = adj.unsqueeze(0) if adj.dim() == 2 else adj 49 | s = s.unsqueeze(0) if s.dim() == 2 else s 50 | 51 | batch_size, num_nodes, _ = x.size() 52 | 53 | if mask is not None: 54 | mask = mask.view(batch_size, num_nodes, 1).to(x.dtype) 55 | x, s = x * mask, s * mask 56 | 57 | out = torch.matmul(s.transpose(1, 2), x) 58 | out_adj = torch.matmul(torch.matmul(s.transpose(1, 2), adj), s) 59 | 60 | link_loss = adj - torch.matmul(s, s.transpose(1, 2)) + EPS 61 | link_loss = torch.norm(link_loss, p=2) 62 | link_loss = link_loss / adj.numel() 63 | 64 | ent_loss = (-s * torch.log(s + EPS)).sum(dim=-1).mean() 65 | 66 | return out, out_adj, link_loss, ent_loss 67 | -------------------------------------------------------------------------------- /gnn_hpool/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/murphyyhuang/gnn_hierarchical_pooling/5c875d5821b49cacba3ac052fccd2a05a2274716/gnn_hpool/models/__init__.py -------------------------------------------------------------------------------- /gnn_hpool/models/gcn_hpool_encoder.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | import torch 4 | from torch.nn.modules.module import Module 5 | import torch.nn.functional as F 6 | 7 | from gnn_hpool.utils.global_variables import g_key 8 | from gnn_hpool.utils import hparams_lib 9 | from gnn_hpool.models.gcn_hpool_submodel import GcnHpoolSubmodel 10 | from gnn_hpool.layers import gcn_layer 11 | 12 | 13 | class GcnHpoolEncoder(Module): 14 | 15 | def __init__(self, hparams): 16 | super(GcnHpoolEncoder, self).__init__() 17 | 18 | self._hparams = hparams_lib.copy_hparams(hparams) 19 | self.build_graph() 20 | self.reset_parameters() 21 | 22 | self._device = torch.device(self._hparams.device) 23 | 24 | def reset_parameters(self): 25 | for m in self.modules(): 26 | if isinstance(m, gcn_layer.GraphConvolution): 27 | m.weight.data = torch.nn.init.xavier_uniform_(m.weight.data, gain=torch.nn.init.calculate_gain('relu')) 28 | if m.bias is not None: 29 | m.bias.data = torch.nn.init.constant_(m.bias.data, 0.0) 30 | 31 | def build_graph(self): 32 | 33 | # entry GCN 34 | self.entry_conv_first = gcn_layer.GraphConvolution( 35 | in_features=self._hparams.channel_list[0], 36 | out_features=self._hparams.channel_list[1], 37 | hparams=self._hparams, 38 | ) 39 | self.entry_conv_block = gcn_layer.GraphConvolution( 40 | in_features=self._hparams.channel_list[1], 41 | out_features=self._hparams.channel_list[1], 42 | hparams=self._hparams, 43 | ) 44 | self.entry_conv_last = gcn_layer.GraphConvolution( 45 | in_features=self._hparams.channel_list[1], 46 | out_features=self._hparams.channel_list[2], 47 | hparams=self._hparams, 48 | ) 49 | 50 | self.gcn_hpool_layer = GcnHpoolSubmodel( 51 | self._hparams.channel_list[2] * 3, self._hparams.channel_list[3], self._hparams.channel_list[4], 52 | self._hparams.node_list[0], self._hparams.node_list[1], self._hparams.node_list[2], 53 | self._hparams 54 | ) 55 | 56 | self.pred_model = torch.nn.Sequential( 57 | torch.nn.Linear(2 * 3 * self._hparams.channel_list[-3], self._hparams.channel_list[-2]), 58 | torch.nn.ReLU(), 59 | torch.nn.Linear(self._hparams.channel_list[-2], self._hparams.channel_list[-1]) 60 | ) 61 | 62 | def forward(self, graph_input): 63 | 64 | node_feature = graph_input[g_key.x] 65 | adjacency_mat = graph_input[g_key.adj_mat] 66 | batch_num_nodes = graph_input[g_key.node_num] 67 | 68 | # input mask 69 | max_num_nodes = adjacency_mat.size()[1] 70 | embedding_mask = self.construct_mask(max_num_nodes, batch_num_nodes) 71 | 72 | # entry embedding gcn 73 | embedding_tensor_1 = self.gcn_forward( 74 | node_feature, adjacency_mat, 75 | self.entry_conv_first, self.entry_conv_block, self.entry_conv_last, 76 | embedding_mask 77 | ) 78 | output_1, _ = torch.max(embedding_tensor_1, dim=1) 79 | 80 | # hpool layer 81 | output_2, _, _, _ = self.gcn_hpool_layer( 82 | embedding_tensor_1, node_feature, adjacency_mat, embedding_mask 83 | ) 84 | 85 | output = torch.cat([output_1, output_2], dim=1) 86 | ypred = self.pred_model(output) 87 | 88 | return ypred 89 | 90 | def gcn_forward(self, x, adj, conv_first, conv_block, conv_last, embedding_mask=None): 91 | out_all = [] 92 | 93 | layer_out_1 = F.relu(conv_first(x, adj)) 94 | layer_out_1 = self.apply_bn(layer_out_1) 95 | out_all.append(layer_out_1) 96 | 97 | layer_out_2 = F.relu(conv_block(layer_out_1, adj)) 98 | layer_out_2 = self.apply_bn(layer_out_2) 99 | out_all.append(layer_out_2) 100 | 101 | layer_out_3 = conv_last(layer_out_2, adj) 102 | out_all.append(layer_out_3) 103 | out_all = torch.cat(out_all, dim=2) 104 | if embedding_mask is not None: 105 | out_all = out_all * embedding_mask 106 | 107 | return out_all 108 | 109 | def apply_bn(self, x): 110 | ''' Batch normalization of 3D tensor x 111 | ''' 112 | bn_module = torch.nn.BatchNorm1d(x.size()[1]).to(self._device) 113 | return bn_module(x) 114 | 115 | def construct_mask(self, max_nodes, batch_num_nodes): 116 | ''' For each num_nodes in batch_num_nodes, the first num_nodes entries of the 117 | corresponding column are 1's, and the rest are 0's (to be masked out). 118 | Dimension of mask: [batch_size x max_nodes x 1] 119 | ''' 120 | # masks 121 | packed_masks = [torch.ones(int(num)) for num in batch_num_nodes] 122 | batch_size = len(batch_num_nodes) 123 | out_tensor = torch.zeros(batch_size, max_nodes) 124 | for i, mask in enumerate(packed_masks): 125 | out_tensor[i, :batch_num_nodes[i]] = mask 126 | return out_tensor.unsqueeze(2).to(self._device) 127 | -------------------------------------------------------------------------------- /gnn_hpool/models/gcn_hpool_submodel.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | import torch 4 | from torch.nn.modules.module import Module 5 | import torch.nn.functional as F 6 | 7 | from gnn_hpool.layers.hierarchical_diff_pooling import dense_diff_pool 8 | from gnn_hpool.utils import hparams_lib 9 | from gnn_hpool.layers import gcn_layer 10 | 11 | 12 | class GcnHpoolSubmodel(Module): 13 | def __init__(self, in_feature, hidden_feature, out_feature, in_node, hidden_node, out_node, hparams): 14 | super(GcnHpoolSubmodel, self).__init__() 15 | 16 | self._hparams = hparams_lib.copy_hparams(hparams) 17 | self.build_graph(in_feature, hidden_feature, out_feature, in_node, hidden_node, out_node) 18 | self.reset_parameters() 19 | 20 | self._device = torch.device(self._hparams.device) 21 | self.pool_tensor = None 22 | 23 | def reset_parameters(self): 24 | for m in self.modules(): 25 | if isinstance(m, gcn_layer.GraphConvolution): 26 | m.weight.data = torch.nn.init.xavier_uniform_(m.weight.data, gain=torch.nn.init.calculate_gain('relu')) 27 | if m.bias is not None: 28 | m.bias.data = torch.nn.init.constant_(m.bias.data, 0.0) 29 | 30 | def build_graph(self, in_feature, hidden_feature, out_feature, in_node, hidden_node, out_node): 31 | 32 | # embedding blocks 33 | 34 | self.embed_conv_first = gcn_layer.GraphConvolution( 35 | in_features=in_feature, 36 | out_features=hidden_feature, 37 | hparams=self._hparams, 38 | ) 39 | self.embed_conv_block = gcn_layer.GraphConvolution( 40 | in_features=hidden_feature, 41 | out_features=hidden_feature, 42 | hparams=self._hparams, 43 | ) 44 | self.embed_conv_last = gcn_layer.GraphConvolution( 45 | in_features=hidden_feature, 46 | out_features=out_feature, 47 | hparams=self._hparams, 48 | ) 49 | 50 | # pooling blocks 51 | 52 | self.pool_conv_first = gcn_layer.GraphConvolution( 53 | in_features=in_node, 54 | out_features=hidden_node, 55 | hparams=self._hparams, 56 | ) 57 | self.pool_conv_block = gcn_layer.GraphConvolution( 58 | in_features=hidden_node, 59 | out_features=hidden_node, 60 | hparams=self._hparams, 61 | ) 62 | self.pool_conv_last = gcn_layer.GraphConvolution( 63 | in_features=hidden_node, 64 | out_features=out_node, 65 | hparams=self._hparams, 66 | ) 67 | 68 | self.pool_linear = torch.nn.Linear(hidden_node * 2 + out_node, out_node) 69 | 70 | def forward(self, embedding_tensor, pool_x_tensor, adj, embedding_mask): 71 | 72 | pooling_tensor = self.gcn_forward( 73 | pool_x_tensor, adj, 74 | self.pool_conv_first, self.pool_conv_block, self.pool_conv_last, 75 | embedding_mask 76 | ) 77 | pooling_tensor = F.softmax(self.pool_linear(pooling_tensor), dim=-1) 78 | if embedding_mask is not None: 79 | pooling_tensor = pooling_tensor * embedding_mask 80 | 81 | x_pool, adj_pool, _, _ = dense_diff_pool(embedding_tensor, adj, pooling_tensor) 82 | 83 | embedding_tensor = self.gcn_forward( 84 | x_pool, adj_pool, 85 | self.embed_conv_first, self.embed_conv_block, self.embed_conv_last, 86 | ) 87 | 88 | output, _ = torch.max(embedding_tensor, dim=1) 89 | 90 | self.pool_tensor = pooling_tensor 91 | return output, adj_pool, x_pool, embedding_tensor 92 | 93 | def gcn_forward(self, x, adj, conv_first, conv_block, conv_last, embedding_mask=None): 94 | out_all = [] 95 | 96 | layer_out_1 = F.relu(conv_first(x, adj)) 97 | layer_out_1 = self.apply_bn(layer_out_1) 98 | out_all.append(layer_out_1) 99 | 100 | layer_out_2 = F.relu(conv_block(layer_out_1, adj)) 101 | layer_out_2 = self.apply_bn(layer_out_2) 102 | out_all.append(layer_out_2) 103 | 104 | layer_out_3 = conv_last(layer_out_2, adj) 105 | out_all.append(layer_out_3) 106 | out_all = torch.cat(out_all, dim=2) 107 | if embedding_mask is not None: 108 | out_all = out_all * embedding_mask 109 | 110 | return out_all 111 | 112 | def apply_bn(self, x): 113 | ''' Batch normalization of 3D tensor x 114 | ''' 115 | bn_module = torch.nn.BatchNorm1d(x.size()[1]).to(self._device) 116 | return bn_module(x) 117 | -------------------------------------------------------------------------------- /gnn_hpool/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/murphyyhuang/gnn_hierarchical_pooling/5c875d5821b49cacba3ac052fccd2a05a2274716/gnn_hpool/utils/__init__.py -------------------------------------------------------------------------------- /gnn_hpool/utils/common_utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | 4 | def exp_moving_avg(x, decay=0.9): 5 | shadow = x[0] 6 | a = [shadow] 7 | for v in x[1:]: 8 | shadow -= (1 - decay) * (shadow - v) 9 | a.append(shadow) 10 | return a 11 | -------------------------------------------------------------------------------- /gnn_hpool/utils/evaluate.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | import torch 4 | import numpy as np 5 | import sklearn.metrics as metrics 6 | 7 | from gnn_hpool.utils.global_variables import * 8 | 9 | 10 | def evaluate(dataset, model, hparams, max_num_examples=None): 11 | model.eval() 12 | 13 | labels = [] 14 | preds = [] 15 | for batch_idx, graph in enumerate(dataset): 16 | 17 | ypred = model(graph) 18 | _, indices = torch.max(ypred, 1) 19 | preds.append(indices.cpu().detach().numpy()) 20 | labels.append(graph[g_key.y].cpu().detach().numpy()) 21 | 22 | if max_num_examples is not None: 23 | if (batch_idx + 1) * hparams.batch_size > max_num_examples: 24 | break 25 | 26 | labels = np.hstack(labels) 27 | preds = np.hstack(preds) 28 | 29 | result = {'prec': metrics.precision_score(labels, preds, average='macro'), 30 | 'recall': metrics.recall_score(labels, preds, average='macro'), 31 | 'acc': metrics.accuracy_score(labels, preds), 32 | 'F1': metrics.f1_score(labels, preds, average="micro")} 33 | return result 34 | -------------------------------------------------------------------------------- /gnn_hpool/utils/get_loss.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | 4 | import torch.nn.functional as F 5 | 6 | 7 | def cross_entropy(prediction, reference): 8 | return F.cross_entropy(prediction, reference, size_average=True) 9 | -------------------------------------------------------------------------------- /gnn_hpool/utils/global_variables.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | 4 | class GKey(object): 5 | 6 | def __init__(self): 7 | self.adj_mat = 'adj_mat' 8 | self.x = 'x' 9 | self.y = 'y' 10 | self.node_num = 'node_num' 11 | 12 | 13 | g_key = GKey() 14 | EPS = 1e-30 15 | writer_batch_idx = [0, 3, 6, 9] 16 | -------------------------------------------------------------------------------- /gnn_hpool/utils/hparam.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | """Train and evaluate.""" 4 | from __future__ import absolute_import 5 | from __future__ import division 6 | from __future__ import print_function 7 | 8 | import six 9 | import numbers 10 | from ruamel import yaml 11 | 12 | 13 | def _cast_to_type_if_compatible(name, param_type, value): 14 | """Cast hparam to the provided type, if compatible. 15 | Args: 16 | name: Name of the hparam to be cast. 17 | param_type: The type of the hparam. 18 | value: The value to be cast, if compatible. 19 | Returns: 20 | The result of casting `value` to `param_type`. 21 | Raises: 22 | ValueError: If the type of `value` is not compatible with param_type. 23 | * If `param_type` is a string type, but `value` is not. 24 | * If `param_type` is a boolean, but `value` is not, or vice versa. 25 | * If `param_type` is an integer type, but `value` is not. 26 | * If `param_type` is a float type, but `value` is not a numeric type. 27 | """ 28 | fail_msg = ( 29 | "Could not cast hparam '%s' of type '%s' from value %r" % 30 | (name, param_type, value)) 31 | 32 | # Some callers use None, for which we can't do any casting/checking. :( 33 | if issubclass(param_type, type(None)): 34 | return value 35 | 36 | # Avoid converting a non-string type to a string. 37 | if (issubclass(param_type, (six.string_types, six.binary_type)) and 38 | not isinstance(value, (six.string_types, six.binary_type))): 39 | raise ValueError(fail_msg) 40 | 41 | # Avoid converting a number or string type to a boolean or vice versa. 42 | if issubclass(param_type, bool) != isinstance(value, bool): 43 | raise ValueError(fail_msg) 44 | 45 | # Avoid converting float to an integer (the reverse is fine). 46 | if (issubclass(param_type, numbers.Integral) and 47 | not isinstance(value, numbers.Integral)): 48 | raise ValueError(fail_msg) 49 | 50 | # Avoid converting a non-numeric type to a numeric type. 51 | if (issubclass(param_type, numbers.Number) and 52 | not isinstance(value, numbers.Number)): 53 | raise ValueError(fail_msg) 54 | 55 | return param_type(value) 56 | 57 | 58 | class HParams(object): 59 | """ 60 | Class to hold a set of hyper-parameters as name-value paris. 61 | """ 62 | _HAS_DYNAMIC_ATTRIBUTES = True # Required for pytype checks. 63 | 64 | def __init__(self, **kwargs): 65 | self._hparam_types = {} 66 | for name, value in six.iteritems(kwargs): 67 | self.add_hparam(name, value) 68 | 69 | def add_hparam(self, name, value): 70 | """Adds {name, value} pair to hyperparameters. 71 | Args: 72 | name: Name of the hyperparameter. 73 | value: Value of the hyperparameter. 74 | Raises: 75 | ValueError: if one of the arguments is invalid. 76 | """ 77 | if getattr(self, name, None) is not None: 78 | raise ValueError('Hyperparameter name is reserved: %s' % name) 79 | if isinstance(value, (list, tuple)): 80 | if not value: 81 | raise ValueError('Multi-valued hyperparameters cannot be empty: %s' % name) 82 | self._hparam_types[name] = (type(value[0]), True) 83 | else: 84 | self._hparam_types[name] = (type(value), False) 85 | setattr(self, name, value) 86 | 87 | def set_hparam(self, name, value): 88 | """Set the value of an existing hyperparameter. 89 | This function verifies that the type of the value matches the type of the 90 | existing hyperparameter. 91 | Args: 92 | name: Name of the hyperparameter. 93 | value: New value of the hyperparameter. 94 | Raises: 95 | KeyError: If the hyperparameter doesn't exist. 96 | ValueError: If there is a type mismatch. 97 | """ 98 | param_type, is_list = self._hparam_types[name] 99 | if isinstance(value, list): 100 | if not is_list: 101 | raise ValueError( 102 | 'Must not pass a list for single-valued parameter: %s' % name) 103 | setattr(self, name, [ 104 | _cast_to_type_if_compatible(name, param_type, v) for v in value]) 105 | else: 106 | if is_list: 107 | raise ValueError( 108 | 'Must pass a list for multi-valued parameter: %s.' % name) 109 | setattr(self, name, _cast_to_type_if_compatible(name, param_type, value)) 110 | 111 | def del_hparam(self, name): 112 | """Removes the hyperparameter with key 'name'. 113 | Does nothing if it isn't present. 114 | Args: 115 | name: Name of the hyperparameter. 116 | """ 117 | if hasattr(self, name): 118 | delattr(self, name) 119 | del self._hparam_types[name] 120 | 121 | def to_yaml(self, save_dir): 122 | def remove_callables(x): 123 | """Omit callable elements from input with arbitrary nesting.""" 124 | if isinstance(x, dict): 125 | return {k: remove_callables(v) for k, v in six.iteritems(x) 126 | if not callable(v)} 127 | elif isinstance(x, list): 128 | return [remove_callables(i) for i in x if not callable(i)] 129 | return x 130 | with open(save_dir, 'w') as yml_writer: 131 | yaml.dump(remove_callables(self.values()), yml_writer) 132 | 133 | def from_yaml(self, read_dir): 134 | with open(read_dir, 'r') as yml_reader: 135 | hparams_dict = yaml.load(yml_reader, Loader=yaml.SafeLoader) 136 | 137 | for name, value in hparams_dict.items(): 138 | self.add_hparam(name, value) 139 | 140 | def values(self): 141 | """Return the hyperparameter values as a Python dictionary. 142 | Returns: 143 | A dictionary with hyperparameter names as keys. The values are the 144 | hyperparameter values. 145 | """ 146 | return {n: getattr(self, n) for n in self._hparam_types.keys()} 147 | 148 | def __str__(self): 149 | return str(sorted(self.values().items())) 150 | 151 | def __contains__(self, key): 152 | return key in self._hparam_types 153 | 154 | def __repr__(self): 155 | return '%s(%s)' % (type(self).__name__, self.__str__()) 156 | -------------------------------------------------------------------------------- /gnn_hpool/utils/hparams_lib.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | """ HParams handling.""" 4 | 5 | from __future__ import absolute_import 6 | from __future__ import division 7 | from __future__ import print_function 8 | 9 | from gnn_hpool.utils import hparam 10 | 11 | 12 | def copy_hparams(hparams): 13 | hp_vals = hparams.values() 14 | new_hparams = hparam.HParams(**hp_vals) 15 | return new_hparams 16 | 17 | 18 | def create_hparams(config_dir): 19 | hparams = hparam.HParams() 20 | hparams.from_yaml(config_dir) 21 | 22 | return hparams 23 | -------------------------------------------------------------------------------- /gnn_hpool/utils/load_data.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | import networkx as nx 4 | import numpy as np 5 | import scipy as sc 6 | import os 7 | import re 8 | import random 9 | import logging 10 | 11 | import torch 12 | from torch.utils.data import Dataset, DataLoader 13 | 14 | from gnn_hpool.utils import hparams_lib 15 | from gnn_hpool.utils.global_variables import * 16 | 17 | 18 | # follow a discussion here: https://github.com/RexYing/diffpool/issues/17 19 | # no train-test split here 20 | 21 | 22 | class GraphDataset(Dataset): 23 | 24 | def __init__(self, hparams, graph_list): 25 | self._hparams = hparams_lib.copy_hparams(hparams) 26 | self._device = torch.device(self._hparams.device) 27 | self.graph_list = [] 28 | self.processed_graph_list = self.preprocess_graph(graph_list) 29 | 30 | def preprocess_graph(self, graph_list): 31 | processed_graph_list = [] 32 | 33 | for graph in graph_list: 34 | graph_tmp_dict = {} 35 | 36 | # adjacency matrix 37 | adj = np.array(nx.to_numpy_matrix(graph)) 38 | 39 | # node features 40 | node_tmp_feature = np.zeros((self._hparams.max_num_nodes, self._hparams.channel_list[0])) 41 | for index, feature in enumerate(graph.nodes()): 42 | # use the node label as input features 43 | # change the key here if the real features of node is wanted 44 | node_tmp_feature[index, :] = graph.node[index]['label'] 45 | 46 | num_nodes = adj.shape[0] 47 | graph_tmp_dict[g_key.x] = torch.tensor(node_tmp_feature, dtype=torch.float32).to(self._device) 48 | graph_tmp_dict[g_key.y] = torch.tensor(graph.graph['label'], dtype=torch.long).to(self._device) 49 | graph_tmp_dict[g_key.node_num] = torch.tensor(num_nodes, dtype=torch.int16).to(self._device) 50 | graph_tmp_dict[g_key.adj_mat] = torch.zeros(self._hparams.max_num_nodes, self._hparams.max_num_nodes).to(self._device) 51 | graph_tmp_dict[g_key.adj_mat][:num_nodes, :num_nodes] = torch.tensor(adj, dtype=torch.float32).to(self._device) 52 | 53 | processed_graph_list.append(graph_tmp_dict) 54 | 55 | return processed_graph_list 56 | 57 | def __len__(self): 58 | return len(self.processed_graph_list) 59 | 60 | def __getitem__(self, idx): 61 | return self.processed_graph_list[idx] 62 | 63 | 64 | class GraphDataLoaderWrapper(object): 65 | 66 | def __init__(self, hparams): 67 | 68 | self._hparams = hparams_lib.copy_hparams(hparams) 69 | self.graph_nx = read_graphfile(self._hparams.datadir, self._hparams.dataname, self._hparams.max_num_nodes) 70 | self.graph_count = len(self.graph_nx) 71 | self.val_size = self.graph_count // self._hparams.fold_num 72 | 73 | def get_loader(self, val_idx): 74 | graph_tmp = self.graph_nx 75 | random.shuffle(graph_tmp) 76 | 77 | train_graphs = graph_tmp[:val_idx * self.val_size] 78 | if val_idx < 9: 79 | train_graphs = train_graphs + graph_tmp[(val_idx+1) * self.val_size:] 80 | val_graphs = graph_tmp[val_idx * self.val_size: (val_idx + 1) * self.val_size] 81 | 82 | logging.info('\n * the length of training sets is {}; \n * the length of validation sets is {}' 83 | .format(len(train_graphs), len(val_graphs))) 84 | 85 | training_set = GraphDataset(self._hparams, train_graphs) 86 | validation_set = GraphDataset(self._hparams, val_graphs) 87 | 88 | training_loader = DataLoader( 89 | training_set, 90 | batch_size=self._hparams.batch_size, 91 | shuffle=True, 92 | ) 93 | 94 | validation_loader = DataLoader( 95 | validation_set, 96 | batch_size=self._hparams.batch_size, 97 | shuffle=False, 98 | ) 99 | 100 | return training_loader, validation_loader 101 | 102 | 103 | def read_graphfile(datadir, dataname, max_nodes=None): 104 | ''' Read data from https://ls11-www.cs.tu-dortmund.de/staff/morris/graphkerneldatasets 105 | graph index starts with 1 in file 106 | Returns: 107 | List of networkx objects with graph and node labels 108 | ''' 109 | prefix = os.path.join(datadir, dataname, dataname) 110 | filename_graph_indic = prefix + '_graph_indicator.txt' 111 | # index of graphs that a given node belongs to 112 | graph_indic = {} 113 | with open(filename_graph_indic) as f: 114 | i = 1 115 | for line in f: 116 | line = line.strip("\n") 117 | graph_indic[i] = int(line) 118 | i += 1 119 | 120 | filename_nodes = prefix + '_node_labels.txt' 121 | node_labels = [] 122 | try: 123 | with open(filename_nodes) as f: 124 | for line in f: 125 | line = line.strip("\n") 126 | node_labels += [int(line) - 1] 127 | num_unique_node_labels = max(node_labels) + 1 128 | except IOError: 129 | print('No node labels') 130 | 131 | filename_node_attrs = prefix + '_node_attributes.txt' 132 | node_attrs = [] 133 | try: 134 | with open(filename_node_attrs) as f: 135 | for line in f: 136 | line = line.strip("\s\n") 137 | attrs = [float(attr) for attr in re.split("[,\s]+", line) if not attr == ''] 138 | node_attrs.append(np.array(attrs)) 139 | except IOError: 140 | print('No node attributes') 141 | 142 | label_has_zero = False 143 | filename_graphs = prefix + '_graph_labels.txt' 144 | graph_labels = [] 145 | 146 | # assume that all graph labels appear in the dataset 147 | # (set of labels don't have to be consecutive) 148 | label_vals = [] 149 | with open(filename_graphs) as f: 150 | for line in f: 151 | line = line.strip("\n") 152 | val = int(line) 153 | # if val == 0: 154 | # label_has_zero = True 155 | if val not in label_vals: 156 | label_vals.append(val) 157 | graph_labels.append(val) 158 | # graph_labels = np.array(graph_labels) 159 | label_map_to_int = {val: i for i, val in enumerate(label_vals)} 160 | graph_labels = np.array([label_map_to_int[l] for l in graph_labels]) 161 | # if label_has_zero: 162 | # graph_labels += 1 163 | 164 | filename_adj = prefix + '_A.txt' 165 | adj_list = {i: [] for i in range(1, len(graph_labels) + 1)} 166 | index_graph = {i: [] for i in range(1, len(graph_labels) + 1)} 167 | num_edges = 0 168 | with open(filename_adj) as f: 169 | for line in f: 170 | line = line.strip("\n").split(",") 171 | e0, e1 = (int(line[0].strip(" ")), int(line[1].strip(" "))) 172 | adj_list[graph_indic[e0]].append((e0, e1)) 173 | index_graph[graph_indic[e0]] += [e0, e1] 174 | num_edges += 1 175 | for k in index_graph.keys(): 176 | index_graph[k] = [u - 1 for u in set(index_graph[k])] 177 | 178 | graphs = [] 179 | for i in range(1, 1 + len(adj_list)): 180 | # indexed from 1 here 181 | G = nx.from_edgelist(adj_list[i]) 182 | if max_nodes is not None and G.number_of_nodes() > max_nodes: 183 | continue 184 | 185 | # add features and labels 186 | G.graph['label'] = graph_labels[i - 1] 187 | for u in G.nodes(): 188 | if len(node_labels) > 0: 189 | node_label_one_hot = [0] * num_unique_node_labels 190 | node_label = node_labels[u - 1] 191 | node_label_one_hot[node_label] = 1 192 | G.node[u]['label'] = node_label_one_hot 193 | if len(node_attrs) > 0: 194 | G.node[u]['feat'] = node_attrs[u - 1] 195 | if len(node_attrs) > 0: 196 | G.graph['feat_dim'] = node_attrs[0].shape[0] 197 | 198 | # relabeling 199 | mapping = {} 200 | it = 0 201 | if float(nx.__version__) < 2.0: 202 | for n in G.nodes(): 203 | mapping[n] = it 204 | it += 1 205 | else: 206 | for n in G.nodes: 207 | mapping[n] = it 208 | it += 1 209 | 210 | # indexed from 0 211 | graphs.append(nx.relabel_nodes(G, mapping)) 212 | return graphs 213 | -------------------------------------------------------------------------------- /gnn_hpool/utils/load_data_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | from gnn_hpool.utils.load_data import read_graphfile 4 | 5 | 6 | def main(): 7 | datadir = '/home/murphyhuang/dev/src/github.com/EstelleHuang666/gnn_hierarchical_pooling/data/gnn_enzymes_source_20190905' 8 | dataname = 'ENZYMES' 9 | 10 | data_list = read_graphfile(datadir, dataname) 11 | 12 | print(data_list) 13 | 14 | 15 | if __name__ == '__main__': 16 | main() 17 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | import os 4 | import torch 5 | import numpy as np 6 | import logging 7 | import argparse 8 | 9 | from gnn_hpool.utils import hparam 10 | 11 | 12 | def main(args): 13 | 14 | hparams = hparam.HParams() 15 | hparams.from_yaml(args.hparam_path) 16 | 17 | # reproducibility 18 | if hparams.device == 'cuda': 19 | torch.backends.cudnn.deterministic = True 20 | torch.backends.cudnn.benchmark = False 21 | torch.manual_seed(1024) 22 | np.random.seed(1024) 23 | 24 | # set default GPU 25 | os.environ['CUDA_VISIBLE_DEVICES'] = hparams.cuda_visible_devices 26 | 27 | from gnn_hpool.bin import train_eval 28 | train_eval.train_eval(hparams) 29 | 30 | 31 | if __name__ == '__main__': 32 | logging.getLogger().setLevel(logging.INFO) 33 | parser = argparse.ArgumentParser(description='Parameters for the training of GNN') 34 | parser.add_argument('--hparam_path', nargs='?', type=str, 35 | default='./config/hparams_testdb.yml', 36 | help='The path to .yml file which contains all the hyperparameters.' 37 | ) 38 | 39 | args = parser.parse_args() 40 | main(args) 41 | --------------------------------------------------------------------------------