├── .DS_Store ├── README.md ├── data ├── Cora │ ├── processed │ │ ├── data.pt │ │ ├── pre_filter.pt │ │ └── pre_transform.pt │ └── raw │ │ ├── ind.cora.allx │ │ ├── ind.cora.ally │ │ ├── ind.cora.graph │ │ ├── ind.cora.test.index │ │ ├── ind.cora.tx │ │ ├── ind.cora.ty │ │ ├── ind.cora.x │ │ ├── ind.cora.y │ │ ├── trans.cora.graph │ │ ├── trans.cora.tx │ │ ├── trans.cora.ty │ │ ├── trans.cora.x │ │ └── trans.cora.y └── IMDB-MULTI │ ├── processed │ ├── data.pt │ ├── pre_filter.pt │ └── pre_transform.pt │ └── raw │ ├── IMDB-MULTI_A.txt │ ├── IMDB-MULTI_graph_indicator.txt │ └── IMDB-MULTI_graph_labels.txt ├── fig ├── .DS_Store ├── graph-classification.png └── node-classification.png ├── layers.py ├── main_graph_classification.py ├── main_node_classification.py ├── models.py ├── sparse_softmax.py └── utils.py /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cszhangzhen/MVPool/69e81573af2c9838dd1f661e846fc1b093c1e345/.DS_Store -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MVPool 2 | Hierarchical Multi-View Graph Pooling with Structure Learning ([paper](https://ieeexplore.ieee.org/abstract/document/9460814)). 3 | 4 | This is a PyTorch implementation of the MVPool algorithm, which is accepted by TKDE. The proposed MVPool conducts pooling operation via mulit-view information. Then, a structure learning layer is stacked on the pooling operation, which aims to learn a refined graph structure that can best preserve the essential topological information. It's a general operator that can be used in various architectures, including node-level representation learning and graph-level representation learning. 5 | 6 | ## Requirements 7 | * python3.6 8 | * pytorch==1.3.0 9 | * torch-scatter==1.4.0 10 | * torch-sparse==0.4.3 11 | * torch-cluster==1.4.5 12 | * torch-geometric==1.3.2 13 | 14 | Note: 15 | An older version of torch-sparse is needed, lower than 0.4.4. This code repository is heavily built on [pytorch_geometric](https://github.com/rusty1s/pytorch_geometric), which is a Geometric Deep Learning Extension Library for PyTorch. Please refer [here](https://pytorch-geometric.readthedocs.io/en/latest/) for how to install and utilize the library. 16 | 17 | ## Node Classification Datasets 18 | The input contains: 19 | * x, the feature vectors of the labeled training instances 20 | * y, the one-hot labels of the labeled training instances 21 | * allx, the feature vectors of both labeled and unlabeled training instances (a superset of x) 22 | * graph, a dict in the format {index: [index_of_neighbor_nodes]}. 23 | 24 | Let n be the number of both labeled and unlabeled training instances. These n instances should be indexed from 0 to n - 1 in graph with the same order as in allx. 25 | 26 | In addition to x, y, allx, and graph as described above, the preprocessed datasets also include: 27 | * tx, the feature vectors of the test instances 28 | * ty, the one-hot labels of the test instances 29 | * test.index, the indices of test instances in graph, for the inductive setting 30 | * ally, the labels for instances in allx. 31 | 32 | The indices of test instances in graph for the transductive setting are from #x to #x + #tx - 1, with the same order as in tx. 33 | 34 | You can use cPickle.load(open(filename)) to load the numpy/scipy objects x, y, tx, ty, allx, ally, and graph. test.index is stored as a text file. More details can be found at [here](https://github.com/kimiyoung/planetoid). 35 | 36 | ### Node Classification 37 | 38 | ![](https://github.com/cszhangzhen/MVPool/blob/main/fig/node-classification.png) 39 | 40 | Just execuate the following command for node classification task: 41 | ``` 42 | python main_node_classification.py 43 | ``` 44 | ### Parameter settings for node classification 45 | | Datasets | lr | weight_decay | batch_size | pool_ratio | lambda | net_layers | 46 | | ------------- | --------- | -------------- | -------- | -------- | -------- | ---------- | 47 | | Cora | 0.01 | 0.01 | Full | 0.5/0.5/0.8/0.5 | 0.9 | 4 | 48 | | Citeseer | 0.005 | 0.1 | Full | 0.7 | 0.0 | 1 | 49 | | Pubmed | 0.01 | 0.001 | Full | 0.05/0.6/0.5/0.9 | 1.0 | 4 | 50 | | CS | 0.01 | 0.01 | Full | 0.05/0.5/0.5/0.5 | 0.0 | 4 | 51 | | Physics | 0.01 | 0.01 | Full | 0.05/0.8/0.8/0.8 | 0.0 | 4 | 52 | 53 | 54 | ## Graph Classification Datasets 55 | Graph classification benchmarks are publicly available at [here](https://ls11-www.cs.tu-dortmund.de/staff/morris/graphkerneldatasets). 56 | 57 | This folder contains the following comma separated text files (replace DS by the name of the dataset): 58 | 59 | **n = total number of nodes** 60 | 61 | **m = total number of edges** 62 | 63 | **N = number of graphs** 64 | 65 | **(1) DS_A.txt (m lines)** 66 | 67 | *sparse (block diagonal) adjacency matrix for all graphs, each line corresponds to (row, col) resp. (node_id, node_id)* 68 | 69 | **(2) DS_graph_indicator.txt (n lines)** 70 | 71 | *column vector of graph identifiers for all nodes of all graphs, the value in the i-th line is the graph_id of the node with node_id i* 72 | 73 | **(3) DS_graph_labels.txt (N lines)** 74 | 75 | *class labels for all graphs in the dataset, the value in the i-th line is the class label of the graph with graph_id i* 76 | 77 | **(4) DS_node_labels.txt (n lines)** 78 | 79 | *column vector of node labels, the value in the i-th line corresponds to the node with node_id i* 80 | 81 | There are OPTIONAL files if the respective information is available: 82 | 83 | **(5) DS_edge_labels.txt (m lines; same size as DS_A_sparse.txt)** 84 | 85 | *labels for the edges in DS_A_sparse.txt* 86 | 87 | **(6) DS_edge_attributes.txt (m lines; same size as DS_A.txt)** 88 | 89 | *attributes for the edges in DS_A.txt* 90 | 91 | **(7) DS_node_attributes.txt (n lines)** 92 | 93 | *matrix of node attributes, the comma seperated values in the i-th line is the attribute vector of the node with node_id i* 94 | 95 | **(8) DS_graph_attributes.txt (N lines)** 96 | 97 | *regression values for all graphs in the dataset, the value in the i-th line is the attribute of the graph with graph_id i* 98 | 99 | 100 | ### Run Graph Classification 101 | 102 | ![](https://github.com/cszhangzhen/MVPool/blob/main/fig/graph-classification.png) 103 | 104 | Just execuate the following command for graph classification task: 105 | ``` 106 | python main_graph_classification.py 107 | ``` 108 | 109 | ## Citing 110 | If you find MVPool useful for your research, please consider citing the following paper: 111 | ``` 112 | @article{zhang2021hierarchical, 113 | title={Hierarchical Multi-View Graph Pooling with Structure Learning}, 114 | author={Zhang, Zhen and Bu, Jiajun and Ester, Martin and Zhang, Jianfeng and Li, Zhao and Yao, Chengwei and Huifen, Dai and Yu, Zhi and Wang, Can}, 115 | journal={IEEE Transactions on Knowledge and Data Engineering}, 116 | year={2021}, 117 | publisher={IEEE} 118 | } 119 | ``` 120 | -------------------------------------------------------------------------------- /data/Cora/processed/data.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cszhangzhen/MVPool/69e81573af2c9838dd1f661e846fc1b093c1e345/data/Cora/processed/data.pt -------------------------------------------------------------------------------- /data/Cora/processed/pre_filter.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cszhangzhen/MVPool/69e81573af2c9838dd1f661e846fc1b093c1e345/data/Cora/processed/pre_filter.pt -------------------------------------------------------------------------------- /data/Cora/processed/pre_transform.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cszhangzhen/MVPool/69e81573af2c9838dd1f661e846fc1b093c1e345/data/Cora/processed/pre_transform.pt -------------------------------------------------------------------------------- /data/Cora/raw/ind.cora.allx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cszhangzhen/MVPool/69e81573af2c9838dd1f661e846fc1b093c1e345/data/Cora/raw/ind.cora.allx -------------------------------------------------------------------------------- /data/Cora/raw/ind.cora.ally: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cszhangzhen/MVPool/69e81573af2c9838dd1f661e846fc1b093c1e345/data/Cora/raw/ind.cora.ally -------------------------------------------------------------------------------- /data/Cora/raw/ind.cora.graph: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cszhangzhen/MVPool/69e81573af2c9838dd1f661e846fc1b093c1e345/data/Cora/raw/ind.cora.graph -------------------------------------------------------------------------------- /data/Cora/raw/ind.cora.test.index: -------------------------------------------------------------------------------- 1 | 2692 2 | 2532 3 | 2050 4 | 1715 5 | 2362 6 | 2609 7 | 2622 8 | 1975 9 | 2081 10 | 1767 11 | 2263 12 | 1725 13 | 2588 14 | 2259 15 | 2357 16 | 1998 17 | 2574 18 | 2179 19 | 2291 20 | 2382 21 | 1812 22 | 1751 23 | 2422 24 | 1937 25 | 2631 26 | 2510 27 | 2378 28 | 2589 29 | 2345 30 | 1943 31 | 1850 32 | 2298 33 | 1825 34 | 2035 35 | 2507 36 | 2313 37 | 1906 38 | 1797 39 | 2023 40 | 2159 41 | 2495 42 | 1886 43 | 2122 44 | 2369 45 | 2461 46 | 1925 47 | 2565 48 | 1858 49 | 2234 50 | 2000 51 | 1846 52 | 2318 53 | 1723 54 | 2559 55 | 2258 56 | 1763 57 | 1991 58 | 1922 59 | 2003 60 | 2662 61 | 2250 62 | 2064 63 | 2529 64 | 1888 65 | 2499 66 | 2454 67 | 2320 68 | 2287 69 | 2203 70 | 2018 71 | 2002 72 | 2632 73 | 2554 74 | 2314 75 | 2537 76 | 1760 77 | 2088 78 | 2086 79 | 2218 80 | 2605 81 | 1953 82 | 2403 83 | 1920 84 | 2015 85 | 2335 86 | 2535 87 | 1837 88 | 2009 89 | 1905 90 | 2636 91 | 1942 92 | 2193 93 | 2576 94 | 2373 95 | 1873 96 | 2463 97 | 2509 98 | 1954 99 | 2656 100 | 2455 101 | 2494 102 | 2295 103 | 2114 104 | 2561 105 | 2176 106 | 2275 107 | 2635 108 | 2442 109 | 2704 110 | 2127 111 | 2085 112 | 2214 113 | 2487 114 | 1739 115 | 2543 116 | 1783 117 | 2485 118 | 2262 119 | 2472 120 | 2326 121 | 1738 122 | 2170 123 | 2100 124 | 2384 125 | 2152 126 | 2647 127 | 2693 128 | 2376 129 | 1775 130 | 1726 131 | 2476 132 | 2195 133 | 1773 134 | 1793 135 | 2194 136 | 2581 137 | 1854 138 | 2524 139 | 1945 140 | 1781 141 | 1987 142 | 2599 143 | 1744 144 | 2225 145 | 2300 146 | 1928 147 | 2042 148 | 2202 149 | 1958 150 | 1816 151 | 1916 152 | 2679 153 | 2190 154 | 1733 155 | 2034 156 | 2643 157 | 2177 158 | 1883 159 | 1917 160 | 1996 161 | 2491 162 | 2268 163 | 2231 164 | 2471 165 | 1919 166 | 1909 167 | 2012 168 | 2522 169 | 1865 170 | 2466 171 | 2469 172 | 2087 173 | 2584 174 | 2563 175 | 1924 176 | 2143 177 | 1736 178 | 1966 179 | 2533 180 | 2490 181 | 2630 182 | 1973 183 | 2568 184 | 1978 185 | 2664 186 | 2633 187 | 2312 188 | 2178 189 | 1754 190 | 2307 191 | 2480 192 | 1960 193 | 1742 194 | 1962 195 | 2160 196 | 2070 197 | 2553 198 | 2433 199 | 1768 200 | 2659 201 | 2379 202 | 2271 203 | 1776 204 | 2153 205 | 1877 206 | 2027 207 | 2028 208 | 2155 209 | 2196 210 | 2483 211 | 2026 212 | 2158 213 | 2407 214 | 1821 215 | 2131 216 | 2676 217 | 2277 218 | 2489 219 | 2424 220 | 1963 221 | 1808 222 | 1859 223 | 2597 224 | 2548 225 | 2368 226 | 1817 227 | 2405 228 | 2413 229 | 2603 230 | 2350 231 | 2118 232 | 2329 233 | 1969 234 | 2577 235 | 2475 236 | 2467 237 | 2425 238 | 1769 239 | 2092 240 | 2044 241 | 2586 242 | 2608 243 | 1983 244 | 2109 245 | 2649 246 | 1964 247 | 2144 248 | 1902 249 | 2411 250 | 2508 251 | 2360 252 | 1721 253 | 2005 254 | 2014 255 | 2308 256 | 2646 257 | 1949 258 | 1830 259 | 2212 260 | 2596 261 | 1832 262 | 1735 263 | 1866 264 | 2695 265 | 1941 266 | 2546 267 | 2498 268 | 2686 269 | 2665 270 | 1784 271 | 2613 272 | 1970 273 | 2021 274 | 2211 275 | 2516 276 | 2185 277 | 2479 278 | 2699 279 | 2150 280 | 1990 281 | 2063 282 | 2075 283 | 1979 284 | 2094 285 | 1787 286 | 2571 287 | 2690 288 | 1926 289 | 2341 290 | 2566 291 | 1957 292 | 1709 293 | 1955 294 | 2570 295 | 2387 296 | 1811 297 | 2025 298 | 2447 299 | 2696 300 | 2052 301 | 2366 302 | 1857 303 | 2273 304 | 2245 305 | 2672 306 | 2133 307 | 2421 308 | 1929 309 | 2125 310 | 2319 311 | 2641 312 | 2167 313 | 2418 314 | 1765 315 | 1761 316 | 1828 317 | 2188 318 | 1972 319 | 1997 320 | 2419 321 | 2289 322 | 2296 323 | 2587 324 | 2051 325 | 2440 326 | 2053 327 | 2191 328 | 1923 329 | 2164 330 | 1861 331 | 2339 332 | 2333 333 | 2523 334 | 2670 335 | 2121 336 | 1921 337 | 1724 338 | 2253 339 | 2374 340 | 1940 341 | 2545 342 | 2301 343 | 2244 344 | 2156 345 | 1849 346 | 2551 347 | 2011 348 | 2279 349 | 2572 350 | 1757 351 | 2400 352 | 2569 353 | 2072 354 | 2526 355 | 2173 356 | 2069 357 | 2036 358 | 1819 359 | 1734 360 | 1880 361 | 2137 362 | 2408 363 | 2226 364 | 2604 365 | 1771 366 | 2698 367 | 2187 368 | 2060 369 | 1756 370 | 2201 371 | 2066 372 | 2439 373 | 1844 374 | 1772 375 | 2383 376 | 2398 377 | 1708 378 | 1992 379 | 1959 380 | 1794 381 | 2426 382 | 2702 383 | 2444 384 | 1944 385 | 1829 386 | 2660 387 | 2497 388 | 2607 389 | 2343 390 | 1730 391 | 2624 392 | 1790 393 | 1935 394 | 1967 395 | 2401 396 | 2255 397 | 2355 398 | 2348 399 | 1931 400 | 2183 401 | 2161 402 | 2701 403 | 1948 404 | 2501 405 | 2192 406 | 2404 407 | 2209 408 | 2331 409 | 1810 410 | 2363 411 | 2334 412 | 1887 413 | 2393 414 | 2557 415 | 1719 416 | 1732 417 | 1986 418 | 2037 419 | 2056 420 | 1867 421 | 2126 422 | 1932 423 | 2117 424 | 1807 425 | 1801 426 | 1743 427 | 2041 428 | 1843 429 | 2388 430 | 2221 431 | 1833 432 | 2677 433 | 1778 434 | 2661 435 | 2306 436 | 2394 437 | 2106 438 | 2430 439 | 2371 440 | 2606 441 | 2353 442 | 2269 443 | 2317 444 | 2645 445 | 2372 446 | 2550 447 | 2043 448 | 1968 449 | 2165 450 | 2310 451 | 1985 452 | 2446 453 | 1982 454 | 2377 455 | 2207 456 | 1818 457 | 1913 458 | 1766 459 | 1722 460 | 1894 461 | 2020 462 | 1881 463 | 2621 464 | 2409 465 | 2261 466 | 2458 467 | 2096 468 | 1712 469 | 2594 470 | 2293 471 | 2048 472 | 2359 473 | 1839 474 | 2392 475 | 2254 476 | 1911 477 | 2101 478 | 2367 479 | 1889 480 | 1753 481 | 2555 482 | 2246 483 | 2264 484 | 2010 485 | 2336 486 | 2651 487 | 2017 488 | 2140 489 | 1842 490 | 2019 491 | 1890 492 | 2525 493 | 2134 494 | 2492 495 | 2652 496 | 2040 497 | 2145 498 | 2575 499 | 2166 500 | 1999 501 | 2434 502 | 1711 503 | 2276 504 | 2450 505 | 2389 506 | 2669 507 | 2595 508 | 1814 509 | 2039 510 | 2502 511 | 1896 512 | 2168 513 | 2344 514 | 2637 515 | 2031 516 | 1977 517 | 2380 518 | 1936 519 | 2047 520 | 2460 521 | 2102 522 | 1745 523 | 2650 524 | 2046 525 | 2514 526 | 1980 527 | 2352 528 | 2113 529 | 1713 530 | 2058 531 | 2558 532 | 1718 533 | 1864 534 | 1876 535 | 2338 536 | 1879 537 | 1891 538 | 2186 539 | 2451 540 | 2181 541 | 2638 542 | 2644 543 | 2103 544 | 2591 545 | 2266 546 | 2468 547 | 1869 548 | 2582 549 | 2674 550 | 2361 551 | 2462 552 | 1748 553 | 2215 554 | 2615 555 | 2236 556 | 2248 557 | 2493 558 | 2342 559 | 2449 560 | 2274 561 | 1824 562 | 1852 563 | 1870 564 | 2441 565 | 2356 566 | 1835 567 | 2694 568 | 2602 569 | 2685 570 | 1893 571 | 2544 572 | 2536 573 | 1994 574 | 1853 575 | 1838 576 | 1786 577 | 1930 578 | 2539 579 | 1892 580 | 2265 581 | 2618 582 | 2486 583 | 2583 584 | 2061 585 | 1796 586 | 1806 587 | 2084 588 | 1933 589 | 2095 590 | 2136 591 | 2078 592 | 1884 593 | 2438 594 | 2286 595 | 2138 596 | 1750 597 | 2184 598 | 1799 599 | 2278 600 | 2410 601 | 2642 602 | 2435 603 | 1956 604 | 2399 605 | 1774 606 | 2129 607 | 1898 608 | 1823 609 | 1938 610 | 2299 611 | 1862 612 | 2420 613 | 2673 614 | 1984 615 | 2204 616 | 1717 617 | 2074 618 | 2213 619 | 2436 620 | 2297 621 | 2592 622 | 2667 623 | 2703 624 | 2511 625 | 1779 626 | 1782 627 | 2625 628 | 2365 629 | 2315 630 | 2381 631 | 1788 632 | 1714 633 | 2302 634 | 1927 635 | 2325 636 | 2506 637 | 2169 638 | 2328 639 | 2629 640 | 2128 641 | 2655 642 | 2282 643 | 2073 644 | 2395 645 | 2247 646 | 2521 647 | 2260 648 | 1868 649 | 1988 650 | 2324 651 | 2705 652 | 2541 653 | 1731 654 | 2681 655 | 2707 656 | 2465 657 | 1785 658 | 2149 659 | 2045 660 | 2505 661 | 2611 662 | 2217 663 | 2180 664 | 1904 665 | 2453 666 | 2484 667 | 1871 668 | 2309 669 | 2349 670 | 2482 671 | 2004 672 | 1965 673 | 2406 674 | 2162 675 | 1805 676 | 2654 677 | 2007 678 | 1947 679 | 1981 680 | 2112 681 | 2141 682 | 1720 683 | 1758 684 | 2080 685 | 2330 686 | 2030 687 | 2432 688 | 2089 689 | 2547 690 | 1820 691 | 1815 692 | 2675 693 | 1840 694 | 2658 695 | 2370 696 | 2251 697 | 1908 698 | 2029 699 | 2068 700 | 2513 701 | 2549 702 | 2267 703 | 2580 704 | 2327 705 | 2351 706 | 2111 707 | 2022 708 | 2321 709 | 2614 710 | 2252 711 | 2104 712 | 1822 713 | 2552 714 | 2243 715 | 1798 716 | 2396 717 | 2663 718 | 2564 719 | 2148 720 | 2562 721 | 2684 722 | 2001 723 | 2151 724 | 2706 725 | 2240 726 | 2474 727 | 2303 728 | 2634 729 | 2680 730 | 2055 731 | 2090 732 | 2503 733 | 2347 734 | 2402 735 | 2238 736 | 1950 737 | 2054 738 | 2016 739 | 1872 740 | 2233 741 | 1710 742 | 2032 743 | 2540 744 | 2628 745 | 1795 746 | 2616 747 | 1903 748 | 2531 749 | 2567 750 | 1946 751 | 1897 752 | 2222 753 | 2227 754 | 2627 755 | 1856 756 | 2464 757 | 2241 758 | 2481 759 | 2130 760 | 2311 761 | 2083 762 | 2223 763 | 2284 764 | 2235 765 | 2097 766 | 1752 767 | 2515 768 | 2527 769 | 2385 770 | 2189 771 | 2283 772 | 2182 773 | 2079 774 | 2375 775 | 2174 776 | 2437 777 | 1993 778 | 2517 779 | 2443 780 | 2224 781 | 2648 782 | 2171 783 | 2290 784 | 2542 785 | 2038 786 | 1855 787 | 1831 788 | 1759 789 | 1848 790 | 2445 791 | 1827 792 | 2429 793 | 2205 794 | 2598 795 | 2657 796 | 1728 797 | 2065 798 | 1918 799 | 2427 800 | 2573 801 | 2620 802 | 2292 803 | 1777 804 | 2008 805 | 1875 806 | 2288 807 | 2256 808 | 2033 809 | 2470 810 | 2585 811 | 2610 812 | 2082 813 | 2230 814 | 1915 815 | 1847 816 | 2337 817 | 2512 818 | 2386 819 | 2006 820 | 2653 821 | 2346 822 | 1951 823 | 2110 824 | 2639 825 | 2520 826 | 1939 827 | 2683 828 | 2139 829 | 2220 830 | 1910 831 | 2237 832 | 1900 833 | 1836 834 | 2197 835 | 1716 836 | 1860 837 | 2077 838 | 2519 839 | 2538 840 | 2323 841 | 1914 842 | 1971 843 | 1845 844 | 2132 845 | 1802 846 | 1907 847 | 2640 848 | 2496 849 | 2281 850 | 2198 851 | 2416 852 | 2285 853 | 1755 854 | 2431 855 | 2071 856 | 2249 857 | 2123 858 | 1727 859 | 2459 860 | 2304 861 | 2199 862 | 1791 863 | 1809 864 | 1780 865 | 2210 866 | 2417 867 | 1874 868 | 1878 869 | 2116 870 | 1961 871 | 1863 872 | 2579 873 | 2477 874 | 2228 875 | 2332 876 | 2578 877 | 2457 878 | 2024 879 | 1934 880 | 2316 881 | 1841 882 | 1764 883 | 1737 884 | 2322 885 | 2239 886 | 2294 887 | 1729 888 | 2488 889 | 1974 890 | 2473 891 | 2098 892 | 2612 893 | 1834 894 | 2340 895 | 2423 896 | 2175 897 | 2280 898 | 2617 899 | 2208 900 | 2560 901 | 1741 902 | 2600 903 | 2059 904 | 1747 905 | 2242 906 | 2700 907 | 2232 908 | 2057 909 | 2147 910 | 2682 911 | 1792 912 | 1826 913 | 2120 914 | 1895 915 | 2364 916 | 2163 917 | 1851 918 | 2391 919 | 2414 920 | 2452 921 | 1803 922 | 1989 923 | 2623 924 | 2200 925 | 2528 926 | 2415 927 | 1804 928 | 2146 929 | 2619 930 | 2687 931 | 1762 932 | 2172 933 | 2270 934 | 2678 935 | 2593 936 | 2448 937 | 1882 938 | 2257 939 | 2500 940 | 1899 941 | 2478 942 | 2412 943 | 2107 944 | 1746 945 | 2428 946 | 2115 947 | 1800 948 | 1901 949 | 2397 950 | 2530 951 | 1912 952 | 2108 953 | 2206 954 | 2091 955 | 1740 956 | 2219 957 | 1976 958 | 2099 959 | 2142 960 | 2671 961 | 2668 962 | 2216 963 | 2272 964 | 2229 965 | 2666 966 | 2456 967 | 2534 968 | 2697 969 | 2688 970 | 2062 971 | 2691 972 | 2689 973 | 2154 974 | 2590 975 | 2626 976 | 2390 977 | 1813 978 | 2067 979 | 1952 980 | 2518 981 | 2358 982 | 1789 983 | 2076 984 | 2049 985 | 2119 986 | 2013 987 | 2124 988 | 2556 989 | 2105 990 | 2093 991 | 1885 992 | 2305 993 | 2354 994 | 2135 995 | 2601 996 | 1770 997 | 1995 998 | 2504 999 | 1749 1000 | 2157 1001 | -------------------------------------------------------------------------------- /data/Cora/raw/ind.cora.tx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cszhangzhen/MVPool/69e81573af2c9838dd1f661e846fc1b093c1e345/data/Cora/raw/ind.cora.tx -------------------------------------------------------------------------------- /data/Cora/raw/ind.cora.ty: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cszhangzhen/MVPool/69e81573af2c9838dd1f661e846fc1b093c1e345/data/Cora/raw/ind.cora.ty -------------------------------------------------------------------------------- /data/Cora/raw/ind.cora.x: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cszhangzhen/MVPool/69e81573af2c9838dd1f661e846fc1b093c1e345/data/Cora/raw/ind.cora.x -------------------------------------------------------------------------------- /data/Cora/raw/ind.cora.y: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cszhangzhen/MVPool/69e81573af2c9838dd1f661e846fc1b093c1e345/data/Cora/raw/ind.cora.y -------------------------------------------------------------------------------- /data/Cora/raw/trans.cora.graph: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cszhangzhen/MVPool/69e81573af2c9838dd1f661e846fc1b093c1e345/data/Cora/raw/trans.cora.graph -------------------------------------------------------------------------------- /data/Cora/raw/trans.cora.tx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cszhangzhen/MVPool/69e81573af2c9838dd1f661e846fc1b093c1e345/data/Cora/raw/trans.cora.tx -------------------------------------------------------------------------------- /data/Cora/raw/trans.cora.ty: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cszhangzhen/MVPool/69e81573af2c9838dd1f661e846fc1b093c1e345/data/Cora/raw/trans.cora.ty -------------------------------------------------------------------------------- /data/Cora/raw/trans.cora.x: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cszhangzhen/MVPool/69e81573af2c9838dd1f661e846fc1b093c1e345/data/Cora/raw/trans.cora.x -------------------------------------------------------------------------------- /data/Cora/raw/trans.cora.y: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cszhangzhen/MVPool/69e81573af2c9838dd1f661e846fc1b093c1e345/data/Cora/raw/trans.cora.y -------------------------------------------------------------------------------- /data/IMDB-MULTI/processed/data.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cszhangzhen/MVPool/69e81573af2c9838dd1f661e846fc1b093c1e345/data/IMDB-MULTI/processed/data.pt -------------------------------------------------------------------------------- /data/IMDB-MULTI/processed/pre_filter.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cszhangzhen/MVPool/69e81573af2c9838dd1f661e846fc1b093c1e345/data/IMDB-MULTI/processed/pre_filter.pt -------------------------------------------------------------------------------- /data/IMDB-MULTI/processed/pre_transform.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cszhangzhen/MVPool/69e81573af2c9838dd1f661e846fc1b093c1e345/data/IMDB-MULTI/processed/pre_transform.pt -------------------------------------------------------------------------------- /data/IMDB-MULTI/raw/IMDB-MULTI_graph_labels.txt: -------------------------------------------------------------------------------- 1 | 1 2 | 1 3 | 1 4 | 1 5 | 1 6 | 1 7 | 1 8 | 1 9 | 1 10 | 1 11 | 1 12 | 1 13 | 1 14 | 1 15 | 1 16 | 1 17 | 1 18 | 1 19 | 1 20 | 1 21 | 1 22 | 1 23 | 1 24 | 1 25 | 1 26 | 1 27 | 1 28 | 1 29 | 1 30 | 1 31 | 1 32 | 1 33 | 1 34 | 1 35 | 1 36 | 1 37 | 1 38 | 1 39 | 1 40 | 1 41 | 1 42 | 1 43 | 1 44 | 1 45 | 1 46 | 1 47 | 1 48 | 1 49 | 1 50 | 1 51 | 1 52 | 1 53 | 1 54 | 1 55 | 1 56 | 1 57 | 1 58 | 1 59 | 1 60 | 1 61 | 1 62 | 1 63 | 1 64 | 1 65 | 1 66 | 1 67 | 1 68 | 1 69 | 1 70 | 1 71 | 1 72 | 1 73 | 1 74 | 1 75 | 1 76 | 1 77 | 1 78 | 1 79 | 1 80 | 1 81 | 1 82 | 1 83 | 1 84 | 1 85 | 1 86 | 1 87 | 1 88 | 1 89 | 1 90 | 1 91 | 1 92 | 1 93 | 1 94 | 1 95 | 1 96 | 1 97 | 1 98 | 1 99 | 1 100 | 1 101 | 1 102 | 1 103 | 1 104 | 1 105 | 1 106 | 1 107 | 1 108 | 1 109 | 1 110 | 1 111 | 1 112 | 1 113 | 1 114 | 1 115 | 1 116 | 1 117 | 1 118 | 1 119 | 1 120 | 1 121 | 1 122 | 1 123 | 1 124 | 1 125 | 1 126 | 1 127 | 1 128 | 1 129 | 1 130 | 1 131 | 1 132 | 1 133 | 1 134 | 1 135 | 1 136 | 1 137 | 1 138 | 1 139 | 1 140 | 1 141 | 1 142 | 1 143 | 1 144 | 1 145 | 1 146 | 1 147 | 1 148 | 1 149 | 1 150 | 1 151 | 1 152 | 1 153 | 1 154 | 1 155 | 1 156 | 1 157 | 1 158 | 1 159 | 1 160 | 1 161 | 1 162 | 1 163 | 1 164 | 1 165 | 1 166 | 1 167 | 1 168 | 1 169 | 1 170 | 1 171 | 1 172 | 1 173 | 1 174 | 1 175 | 1 176 | 1 177 | 1 178 | 1 179 | 1 180 | 1 181 | 1 182 | 1 183 | 1 184 | 1 185 | 1 186 | 1 187 | 1 188 | 1 189 | 1 190 | 1 191 | 1 192 | 1 193 | 1 194 | 1 195 | 1 196 | 1 197 | 1 198 | 1 199 | 1 200 | 1 201 | 1 202 | 1 203 | 1 204 | 1 205 | 1 206 | 1 207 | 1 208 | 1 209 | 1 210 | 1 211 | 1 212 | 1 213 | 1 214 | 1 215 | 1 216 | 1 217 | 1 218 | 1 219 | 1 220 | 1 221 | 1 222 | 1 223 | 1 224 | 1 225 | 1 226 | 1 227 | 1 228 | 1 229 | 1 230 | 1 231 | 1 232 | 1 233 | 1 234 | 1 235 | 1 236 | 1 237 | 1 238 | 1 239 | 1 240 | 1 241 | 1 242 | 1 243 | 1 244 | 1 245 | 1 246 | 1 247 | 1 248 | 1 249 | 1 250 | 1 251 | 1 252 | 1 253 | 1 254 | 1 255 | 1 256 | 1 257 | 1 258 | 1 259 | 1 260 | 1 261 | 1 262 | 1 263 | 1 264 | 1 265 | 1 266 | 1 267 | 1 268 | 1 269 | 1 270 | 1 271 | 1 272 | 1 273 | 1 274 | 1 275 | 1 276 | 1 277 | 1 278 | 1 279 | 1 280 | 1 281 | 1 282 | 1 283 | 1 284 | 1 285 | 1 286 | 1 287 | 1 288 | 1 289 | 1 290 | 1 291 | 1 292 | 1 293 | 1 294 | 1 295 | 1 296 | 1 297 | 1 298 | 1 299 | 1 300 | 1 301 | 1 302 | 1 303 | 1 304 | 1 305 | 1 306 | 1 307 | 1 308 | 1 309 | 1 310 | 1 311 | 1 312 | 1 313 | 1 314 | 1 315 | 1 316 | 1 317 | 1 318 | 1 319 | 1 320 | 1 321 | 1 322 | 1 323 | 1 324 | 1 325 | 1 326 | 1 327 | 1 328 | 1 329 | 1 330 | 1 331 | 1 332 | 1 333 | 1 334 | 1 335 | 1 336 | 1 337 | 1 338 | 1 339 | 1 340 | 1 341 | 1 342 | 1 343 | 1 344 | 1 345 | 1 346 | 1 347 | 1 348 | 1 349 | 1 350 | 1 351 | 1 352 | 1 353 | 1 354 | 1 355 | 1 356 | 1 357 | 1 358 | 1 359 | 1 360 | 1 361 | 1 362 | 1 363 | 1 364 | 1 365 | 1 366 | 1 367 | 1 368 | 1 369 | 1 370 | 1 371 | 1 372 | 1 373 | 1 374 | 1 375 | 1 376 | 1 377 | 1 378 | 1 379 | 1 380 | 1 381 | 1 382 | 1 383 | 1 384 | 1 385 | 1 386 | 1 387 | 1 388 | 1 389 | 1 390 | 1 391 | 1 392 | 1 393 | 1 394 | 1 395 | 1 396 | 1 397 | 1 398 | 1 399 | 1 400 | 1 401 | 1 402 | 1 403 | 1 404 | 1 405 | 1 406 | 1 407 | 1 408 | 1 409 | 1 410 | 1 411 | 1 412 | 1 413 | 1 414 | 1 415 | 1 416 | 1 417 | 1 418 | 1 419 | 1 420 | 1 421 | 1 422 | 1 423 | 1 424 | 1 425 | 1 426 | 1 427 | 1 428 | 1 429 | 1 430 | 1 431 | 1 432 | 1 433 | 1 434 | 1 435 | 1 436 | 1 437 | 1 438 | 1 439 | 1 440 | 1 441 | 1 442 | 1 443 | 1 444 | 1 445 | 1 446 | 1 447 | 1 448 | 1 449 | 1 450 | 1 451 | 1 452 | 1 453 | 1 454 | 1 455 | 1 456 | 1 457 | 1 458 | 1 459 | 1 460 | 1 461 | 1 462 | 1 463 | 1 464 | 1 465 | 1 466 | 1 467 | 1 468 | 1 469 | 1 470 | 1 471 | 1 472 | 1 473 | 1 474 | 1 475 | 1 476 | 1 477 | 1 478 | 1 479 | 1 480 | 1 481 | 1 482 | 1 483 | 1 484 | 1 485 | 1 486 | 1 487 | 1 488 | 1 489 | 1 490 | 1 491 | 1 492 | 1 493 | 1 494 | 1 495 | 1 496 | 1 497 | 1 498 | 1 499 | 1 500 | 1 501 | 2 502 | 2 503 | 2 504 | 2 505 | 2 506 | 2 507 | 2 508 | 2 509 | 2 510 | 2 511 | 2 512 | 2 513 | 2 514 | 2 515 | 2 516 | 2 517 | 2 518 | 2 519 | 2 520 | 2 521 | 2 522 | 2 523 | 2 524 | 2 525 | 2 526 | 2 527 | 2 528 | 2 529 | 2 530 | 2 531 | 2 532 | 2 533 | 2 534 | 2 535 | 2 536 | 2 537 | 2 538 | 2 539 | 2 540 | 2 541 | 2 542 | 2 543 | 2 544 | 2 545 | 2 546 | 2 547 | 2 548 | 2 549 | 2 550 | 2 551 | 2 552 | 2 553 | 2 554 | 2 555 | 2 556 | 2 557 | 2 558 | 2 559 | 2 560 | 2 561 | 2 562 | 2 563 | 2 564 | 2 565 | 2 566 | 2 567 | 2 568 | 2 569 | 2 570 | 2 571 | 2 572 | 2 573 | 2 574 | 2 575 | 2 576 | 2 577 | 2 578 | 2 579 | 2 580 | 2 581 | 2 582 | 2 583 | 2 584 | 2 585 | 2 586 | 2 587 | 2 588 | 2 589 | 2 590 | 2 591 | 2 592 | 2 593 | 2 594 | 2 595 | 2 596 | 2 597 | 2 598 | 2 599 | 2 600 | 2 601 | 2 602 | 2 603 | 2 604 | 2 605 | 2 606 | 2 607 | 2 608 | 2 609 | 2 610 | 2 611 | 2 612 | 2 613 | 2 614 | 2 615 | 2 616 | 2 617 | 2 618 | 2 619 | 2 620 | 2 621 | 2 622 | 2 623 | 2 624 | 2 625 | 2 626 | 2 627 | 2 628 | 2 629 | 2 630 | 2 631 | 2 632 | 2 633 | 2 634 | 2 635 | 2 636 | 2 637 | 2 638 | 2 639 | 2 640 | 2 641 | 2 642 | 2 643 | 2 644 | 2 645 | 2 646 | 2 647 | 2 648 | 2 649 | 2 650 | 2 651 | 2 652 | 2 653 | 2 654 | 2 655 | 2 656 | 2 657 | 2 658 | 2 659 | 2 660 | 2 661 | 2 662 | 2 663 | 2 664 | 2 665 | 2 666 | 2 667 | 2 668 | 2 669 | 2 670 | 2 671 | 2 672 | 2 673 | 2 674 | 2 675 | 2 676 | 2 677 | 2 678 | 2 679 | 2 680 | 2 681 | 2 682 | 2 683 | 2 684 | 2 685 | 2 686 | 2 687 | 2 688 | 2 689 | 2 690 | 2 691 | 2 692 | 2 693 | 2 694 | 2 695 | 2 696 | 2 697 | 2 698 | 2 699 | 2 700 | 2 701 | 2 702 | 2 703 | 2 704 | 2 705 | 2 706 | 2 707 | 2 708 | 2 709 | 2 710 | 2 711 | 2 712 | 2 713 | 2 714 | 2 715 | 2 716 | 2 717 | 2 718 | 2 719 | 2 720 | 2 721 | 2 722 | 2 723 | 2 724 | 2 725 | 2 726 | 2 727 | 2 728 | 2 729 | 2 730 | 2 731 | 2 732 | 2 733 | 2 734 | 2 735 | 2 736 | 2 737 | 2 738 | 2 739 | 2 740 | 2 741 | 2 742 | 2 743 | 2 744 | 2 745 | 2 746 | 2 747 | 2 748 | 2 749 | 2 750 | 2 751 | 2 752 | 2 753 | 2 754 | 2 755 | 2 756 | 2 757 | 2 758 | 2 759 | 2 760 | 2 761 | 2 762 | 2 763 | 2 764 | 2 765 | 2 766 | 2 767 | 2 768 | 2 769 | 2 770 | 2 771 | 2 772 | 2 773 | 2 774 | 2 775 | 2 776 | 2 777 | 2 778 | 2 779 | 2 780 | 2 781 | 2 782 | 2 783 | 2 784 | 2 785 | 2 786 | 2 787 | 2 788 | 2 789 | 2 790 | 2 791 | 2 792 | 2 793 | 2 794 | 2 795 | 2 796 | 2 797 | 2 798 | 2 799 | 2 800 | 2 801 | 2 802 | 2 803 | 2 804 | 2 805 | 2 806 | 2 807 | 2 808 | 2 809 | 2 810 | 2 811 | 2 812 | 2 813 | 2 814 | 2 815 | 2 816 | 2 817 | 2 818 | 2 819 | 2 820 | 2 821 | 2 822 | 2 823 | 2 824 | 2 825 | 2 826 | 2 827 | 2 828 | 2 829 | 2 830 | 2 831 | 2 832 | 2 833 | 2 834 | 2 835 | 2 836 | 2 837 | 2 838 | 2 839 | 2 840 | 2 841 | 2 842 | 2 843 | 2 844 | 2 845 | 2 846 | 2 847 | 2 848 | 2 849 | 2 850 | 2 851 | 2 852 | 2 853 | 2 854 | 2 855 | 2 856 | 2 857 | 2 858 | 2 859 | 2 860 | 2 861 | 2 862 | 2 863 | 2 864 | 2 865 | 2 866 | 2 867 | 2 868 | 2 869 | 2 870 | 2 871 | 2 872 | 2 873 | 2 874 | 2 875 | 2 876 | 2 877 | 2 878 | 2 879 | 2 880 | 2 881 | 2 882 | 2 883 | 2 884 | 2 885 | 2 886 | 2 887 | 2 888 | 2 889 | 2 890 | 2 891 | 2 892 | 2 893 | 2 894 | 2 895 | 2 896 | 2 897 | 2 898 | 2 899 | 2 900 | 2 901 | 2 902 | 2 903 | 2 904 | 2 905 | 2 906 | 2 907 | 2 908 | 2 909 | 2 910 | 2 911 | 2 912 | 2 913 | 2 914 | 2 915 | 2 916 | 2 917 | 2 918 | 2 919 | 2 920 | 2 921 | 2 922 | 2 923 | 2 924 | 2 925 | 2 926 | 2 927 | 2 928 | 2 929 | 2 930 | 2 931 | 2 932 | 2 933 | 2 934 | 2 935 | 2 936 | 2 937 | 2 938 | 2 939 | 2 940 | 2 941 | 2 942 | 2 943 | 2 944 | 2 945 | 2 946 | 2 947 | 2 948 | 2 949 | 2 950 | 2 951 | 2 952 | 2 953 | 2 954 | 2 955 | 2 956 | 2 957 | 2 958 | 2 959 | 2 960 | 2 961 | 2 962 | 2 963 | 2 964 | 2 965 | 2 966 | 2 967 | 2 968 | 2 969 | 2 970 | 2 971 | 2 972 | 2 973 | 2 974 | 2 975 | 2 976 | 2 977 | 2 978 | 2 979 | 2 980 | 2 981 | 2 982 | 2 983 | 2 984 | 2 985 | 2 986 | 2 987 | 2 988 | 2 989 | 2 990 | 2 991 | 2 992 | 2 993 | 2 994 | 2 995 | 2 996 | 2 997 | 2 998 | 2 999 | 2 1000 | 2 1001 | 3 1002 | 3 1003 | 3 1004 | 3 1005 | 3 1006 | 3 1007 | 3 1008 | 3 1009 | 3 1010 | 3 1011 | 3 1012 | 3 1013 | 3 1014 | 3 1015 | 3 1016 | 3 1017 | 3 1018 | 3 1019 | 3 1020 | 3 1021 | 3 1022 | 3 1023 | 3 1024 | 3 1025 | 3 1026 | 3 1027 | 3 1028 | 3 1029 | 3 1030 | 3 1031 | 3 1032 | 3 1033 | 3 1034 | 3 1035 | 3 1036 | 3 1037 | 3 1038 | 3 1039 | 3 1040 | 3 1041 | 3 1042 | 3 1043 | 3 1044 | 3 1045 | 3 1046 | 3 1047 | 3 1048 | 3 1049 | 3 1050 | 3 1051 | 3 1052 | 3 1053 | 3 1054 | 3 1055 | 3 1056 | 3 1057 | 3 1058 | 3 1059 | 3 1060 | 3 1061 | 3 1062 | 3 1063 | 3 1064 | 3 1065 | 3 1066 | 3 1067 | 3 1068 | 3 1069 | 3 1070 | 3 1071 | 3 1072 | 3 1073 | 3 1074 | 3 1075 | 3 1076 | 3 1077 | 3 1078 | 3 1079 | 3 1080 | 3 1081 | 3 1082 | 3 1083 | 3 1084 | 3 1085 | 3 1086 | 3 1087 | 3 1088 | 3 1089 | 3 1090 | 3 1091 | 3 1092 | 3 1093 | 3 1094 | 3 1095 | 3 1096 | 3 1097 | 3 1098 | 3 1099 | 3 1100 | 3 1101 | 3 1102 | 3 1103 | 3 1104 | 3 1105 | 3 1106 | 3 1107 | 3 1108 | 3 1109 | 3 1110 | 3 1111 | 3 1112 | 3 1113 | 3 1114 | 3 1115 | 3 1116 | 3 1117 | 3 1118 | 3 1119 | 3 1120 | 3 1121 | 3 1122 | 3 1123 | 3 1124 | 3 1125 | 3 1126 | 3 1127 | 3 1128 | 3 1129 | 3 1130 | 3 1131 | 3 1132 | 3 1133 | 3 1134 | 3 1135 | 3 1136 | 3 1137 | 3 1138 | 3 1139 | 3 1140 | 3 1141 | 3 1142 | 3 1143 | 3 1144 | 3 1145 | 3 1146 | 3 1147 | 3 1148 | 3 1149 | 3 1150 | 3 1151 | 3 1152 | 3 1153 | 3 1154 | 3 1155 | 3 1156 | 3 1157 | 3 1158 | 3 1159 | 3 1160 | 3 1161 | 3 1162 | 3 1163 | 3 1164 | 3 1165 | 3 1166 | 3 1167 | 3 1168 | 3 1169 | 3 1170 | 3 1171 | 3 1172 | 3 1173 | 3 1174 | 3 1175 | 3 1176 | 3 1177 | 3 1178 | 3 1179 | 3 1180 | 3 1181 | 3 1182 | 3 1183 | 3 1184 | 3 1185 | 3 1186 | 3 1187 | 3 1188 | 3 1189 | 3 1190 | 3 1191 | 3 1192 | 3 1193 | 3 1194 | 3 1195 | 3 1196 | 3 1197 | 3 1198 | 3 1199 | 3 1200 | 3 1201 | 3 1202 | 3 1203 | 3 1204 | 3 1205 | 3 1206 | 3 1207 | 3 1208 | 3 1209 | 3 1210 | 3 1211 | 3 1212 | 3 1213 | 3 1214 | 3 1215 | 3 1216 | 3 1217 | 3 1218 | 3 1219 | 3 1220 | 3 1221 | 3 1222 | 3 1223 | 3 1224 | 3 1225 | 3 1226 | 3 1227 | 3 1228 | 3 1229 | 3 1230 | 3 1231 | 3 1232 | 3 1233 | 3 1234 | 3 1235 | 3 1236 | 3 1237 | 3 1238 | 3 1239 | 3 1240 | 3 1241 | 3 1242 | 3 1243 | 3 1244 | 3 1245 | 3 1246 | 3 1247 | 3 1248 | 3 1249 | 3 1250 | 3 1251 | 3 1252 | 3 1253 | 3 1254 | 3 1255 | 3 1256 | 3 1257 | 3 1258 | 3 1259 | 3 1260 | 3 1261 | 3 1262 | 3 1263 | 3 1264 | 3 1265 | 3 1266 | 3 1267 | 3 1268 | 3 1269 | 3 1270 | 3 1271 | 3 1272 | 3 1273 | 3 1274 | 3 1275 | 3 1276 | 3 1277 | 3 1278 | 3 1279 | 3 1280 | 3 1281 | 3 1282 | 3 1283 | 3 1284 | 3 1285 | 3 1286 | 3 1287 | 3 1288 | 3 1289 | 3 1290 | 3 1291 | 3 1292 | 3 1293 | 3 1294 | 3 1295 | 3 1296 | 3 1297 | 3 1298 | 3 1299 | 3 1300 | 3 1301 | 3 1302 | 3 1303 | 3 1304 | 3 1305 | 3 1306 | 3 1307 | 3 1308 | 3 1309 | 3 1310 | 3 1311 | 3 1312 | 3 1313 | 3 1314 | 3 1315 | 3 1316 | 3 1317 | 3 1318 | 3 1319 | 3 1320 | 3 1321 | 3 1322 | 3 1323 | 3 1324 | 3 1325 | 3 1326 | 3 1327 | 3 1328 | 3 1329 | 3 1330 | 3 1331 | 3 1332 | 3 1333 | 3 1334 | 3 1335 | 3 1336 | 3 1337 | 3 1338 | 3 1339 | 3 1340 | 3 1341 | 3 1342 | 3 1343 | 3 1344 | 3 1345 | 3 1346 | 3 1347 | 3 1348 | 3 1349 | 3 1350 | 3 1351 | 3 1352 | 3 1353 | 3 1354 | 3 1355 | 3 1356 | 3 1357 | 3 1358 | 3 1359 | 3 1360 | 3 1361 | 3 1362 | 3 1363 | 3 1364 | 3 1365 | 3 1366 | 3 1367 | 3 1368 | 3 1369 | 3 1370 | 3 1371 | 3 1372 | 3 1373 | 3 1374 | 3 1375 | 3 1376 | 3 1377 | 3 1378 | 3 1379 | 3 1380 | 3 1381 | 3 1382 | 3 1383 | 3 1384 | 3 1385 | 3 1386 | 3 1387 | 3 1388 | 3 1389 | 3 1390 | 3 1391 | 3 1392 | 3 1393 | 3 1394 | 3 1395 | 3 1396 | 3 1397 | 3 1398 | 3 1399 | 3 1400 | 3 1401 | 3 1402 | 3 1403 | 3 1404 | 3 1405 | 3 1406 | 3 1407 | 3 1408 | 3 1409 | 3 1410 | 3 1411 | 3 1412 | 3 1413 | 3 1414 | 3 1415 | 3 1416 | 3 1417 | 3 1418 | 3 1419 | 3 1420 | 3 1421 | 3 1422 | 3 1423 | 3 1424 | 3 1425 | 3 1426 | 3 1427 | 3 1428 | 3 1429 | 3 1430 | 3 1431 | 3 1432 | 3 1433 | 3 1434 | 3 1435 | 3 1436 | 3 1437 | 3 1438 | 3 1439 | 3 1440 | 3 1441 | 3 1442 | 3 1443 | 3 1444 | 3 1445 | 3 1446 | 3 1447 | 3 1448 | 3 1449 | 3 1450 | 3 1451 | 3 1452 | 3 1453 | 3 1454 | 3 1455 | 3 1456 | 3 1457 | 3 1458 | 3 1459 | 3 1460 | 3 1461 | 3 1462 | 3 1463 | 3 1464 | 3 1465 | 3 1466 | 3 1467 | 3 1468 | 3 1469 | 3 1470 | 3 1471 | 3 1472 | 3 1473 | 3 1474 | 3 1475 | 3 1476 | 3 1477 | 3 1478 | 3 1479 | 3 1480 | 3 1481 | 3 1482 | 3 1483 | 3 1484 | 3 1485 | 3 1486 | 3 1487 | 3 1488 | 3 1489 | 3 1490 | 3 1491 | 3 1492 | 3 1493 | 3 1494 | 3 1495 | 3 1496 | 3 1497 | 3 1498 | 3 1499 | 3 1500 | 3 1501 | -------------------------------------------------------------------------------- /fig/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cszhangzhen/MVPool/69e81573af2c9838dd1f661e846fc1b093c1e345/fig/.DS_Store -------------------------------------------------------------------------------- /fig/graph-classification.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cszhangzhen/MVPool/69e81573af2c9838dd1f661e846fc1b093c1e345/fig/graph-classification.png -------------------------------------------------------------------------------- /fig/node-classification.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cszhangzhen/MVPool/69e81573af2c9838dd1f661e846fc1b093c1e345/fig/node-classification.png -------------------------------------------------------------------------------- /layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from sparse_softmax import Sparsemax 5 | from torch.nn import Parameter 6 | from torch_geometric.data import Data 7 | from torch_geometric.utils import degree 8 | from torch_geometric.nn.conv import MessagePassing 9 | from torch_geometric.nn.pool.topk_pool import filter_adj, topk 10 | from torch_geometric.utils import softmax, dense_to_sparse, add_remaining_self_loops 11 | from torch_scatter import scatter_add 12 | from torch_sparse import spspmm, coalesce 13 | 14 | 15 | class TwoHopNeighborhood(object): 16 | def __call__(self, data): 17 | edge_index, edge_attr = data.edge_index, data.edge_attr 18 | n = data.num_nodes 19 | 20 | fill = 1e16 21 | value = edge_index.new_full((edge_index.size(1),), fill, dtype=torch.float) 22 | 23 | index, value = spspmm(edge_index, value, edge_index, value, n, n, n, True) 24 | 25 | edge_index = torch.cat([edge_index, index], dim=1) 26 | if edge_attr is None: 27 | data.edge_index, _ = coalesce(edge_index, None, n, n) 28 | else: 29 | value = value.view(-1, *[1 for _ in range(edge_attr.dim() - 1)]) 30 | value = value.expand(-1, *list(edge_attr.size())[1:]) 31 | edge_attr = torch.cat([edge_attr, value], dim=0) 32 | data.edge_index, edge_attr = coalesce(edge_index, edge_attr, n, n, op='min', fill_value=fill) 33 | edge_attr[edge_attr >= fill] = 0 34 | data.edge_attr = edge_attr 35 | 36 | return data 37 | 38 | def __repr__(self): 39 | return '{}()'.format(self.__class__.__name__) 40 | 41 | 42 | class GCN(MessagePassing): 43 | def __init__(self, in_channels, out_channels, cached=False, bias=True, **kwargs): 44 | super(GCN, self).__init__(aggr='add', **kwargs) 45 | 46 | self.in_channels = in_channels 47 | self.out_channels = out_channels 48 | self.cached = cached 49 | self.cached_result = None 50 | self.cached_num_edges = None 51 | 52 | self.weight = Parameter(torch.Tensor(in_channels, out_channels)) 53 | nn.init.xavier_uniform_(self.weight.data) 54 | 55 | if bias: 56 | self.bias = Parameter(torch.Tensor(out_channels)) 57 | nn.init.zeros_(self.bias.data) 58 | else: 59 | self.register_parameter('bias', None) 60 | 61 | self.reset_parameters() 62 | 63 | def reset_parameters(self): 64 | self.cached_result = None 65 | self.cached_num_edges = None 66 | 67 | @staticmethod 68 | def norm(edge_index, num_nodes, edge_weight, dtype=None): 69 | if edge_weight is None: 70 | edge_weight = torch.ones((edge_index.size(1),), dtype=dtype, device=edge_index.device) 71 | 72 | row, col = edge_index 73 | deg = scatter_add(edge_weight, row, dim=0, dim_size=num_nodes) 74 | deg_inv_sqrt = deg.pow(-0.5) 75 | deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0 76 | edge_weight = deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col] 77 | 78 | return edge_index, edge_weight 79 | 80 | def forward(self, x, edge_index, edge_weight=None): 81 | x = torch.matmul(x, self.weight) 82 | 83 | if self.cached and self.cached_result is not None: 84 | if edge_index.size(1) != self.cached_num_edges: 85 | raise RuntimeError( 86 | 'Cached {} number of edges, but found {}'.format(self.cached_num_edges, edge_index.size(1))) 87 | 88 | if not self.cached or self.cached_result is None: 89 | self.cached_num_edges = edge_index.size(1) 90 | edge_index, norm = self.norm(edge_index, x.size(0), edge_weight, x.dtype) 91 | self.cached_result = edge_index, norm 92 | 93 | edge_index, norm = self.cached_result 94 | 95 | return self.propagate(edge_index, x=x, norm=norm) 96 | 97 | def message(self, x_j, norm): 98 | return norm.view(-1, 1) * x_j 99 | 100 | def update(self, aggr_out): 101 | if self.bias is not None: 102 | aggr_out = aggr_out + self.bias 103 | return aggr_out 104 | 105 | def __repr__(self): 106 | return '{}({}, {})'.format(self.__class__.__name__, self.in_channels, self.out_channels) 107 | 108 | 109 | class PageRankScore(MessagePassing): 110 | def __init__(self, channels, k=10, alpha=0.1, **kwargs): 111 | super(PageRankScore, self).__init__(aggr='add', **kwargs) 112 | self.channels = channels 113 | self.k = k 114 | self.alpha = alpha 115 | self.gnn = GCN(channels, 1) 116 | 117 | def forward(self, x, edge_index, edge_weight=None): 118 | edge_index, norm = GCN.norm(edge_index, x.size(0), edge_weight, dtype=x.dtype) 119 | 120 | x = self.gnn(x, edge_index, edge_weight) 121 | hidden = x 122 | for k in range(self.k): 123 | x = self.propagate(edge_index, x=x, norm=norm) 124 | x = x * (1 - self.alpha) 125 | x = x + self.alpha * hidden 126 | 127 | return x 128 | 129 | def message(self, x_j, norm): 130 | return norm.view(-1, 1) * x_j 131 | 132 | 133 | class MVPool(torch.nn.Module): 134 | def __init__(self, in_channels, ratio, args, negative_slop=0.2): 135 | super(MVPool, self).__init__() 136 | self.in_channels = in_channels 137 | self.ratio = ratio 138 | self.sample = args.sample_neighbor 139 | self.sparse = args.sparse_attention 140 | self.sl = args.structure_learning 141 | self.hc = args.hop_connection 142 | self.h_hop = args.hop 143 | self.lamb = args.lamb 144 | self.negative_slop = negative_slop 145 | 146 | self.att = Parameter(torch.Tensor(1, self.in_channels * 2)) 147 | nn.init.xavier_uniform_(self.att.data) 148 | self.weight = Parameter(torch.Tensor(1, in_channels)) 149 | nn.init.xavier_uniform_(self.weight.data) 150 | self.view_att = Parameter(torch.Tensor(3, 3)) 151 | nn.init.xavier_uniform_(self.view_att.data) 152 | self.view_bias = Parameter(torch.Tensor(3)) 153 | nn.init.zeros_(self.view_bias.data) 154 | self.alpha = Parameter(torch.Tensor(1)) 155 | nn.init.ones_(self.alpha.data) 156 | self.beta = Parameter(torch.Tensor(1)) 157 | nn.init.ones_(self.beta.data) 158 | self.sparse_attention = Sparsemax() 159 | self.neighbor_augment = TwoHopNeighborhood() 160 | self.calc_pagerank_score = PageRankScore(in_channels) 161 | 162 | def forward(self, x, edge_index, edge_attr, batch=None): 163 | if batch is None: 164 | batch = edge_index.new_zeros(x.size(0)) 165 | 166 | row, col = edge_index 167 | score1 = torch.sigmoid(self.alpha * torch.log(degree(row, num_nodes=x.size(0)) + 1e-16) + self.beta).view(-1, 1) 168 | x_score2 = (x * self.weight).sum(dim=-1) 169 | score2 = torch.sigmoid(x_score2 / self.weight.norm(p=2, dim=-1)).view(-1, 1) 170 | x_score3 = self.calc_pagerank_score(x, edge_index, edge_attr) 171 | score3 = torch.sigmoid(x_score3).view(-1, 1) 172 | 173 | score_cat = torch.cat([score1, score2, score3], dim=1) 174 | max_value, _ = torch.max(torch.abs(score_cat), dim=0) 175 | score_cat = score_cat / max_value 176 | score_weight = torch.sigmoid(torch.matmul(score_cat, self.view_att) + self.view_bias) 177 | score_weight = torch.softmax(score_weight, dim=1) 178 | score = torch.sigmoid(torch.sum(score_cat * score_weight, dim=1)) 179 | # score = score2.view(-1) 180 | 181 | # Graph Pooling 182 | original_x = x 183 | perm = topk(score, self.ratio, batch) 184 | x = x[perm] * score[perm].view(-1, 1) 185 | batch = batch[perm] 186 | induced_edge_index, induced_edge_attr = filter_adj(edge_index, edge_attr, perm, num_nodes=score.size(0)) 187 | 188 | # Discard structure learning layer, directly return 189 | if self.sl is False: 190 | return x, induced_edge_index, induced_edge_attr, batch, perm 191 | 192 | # Structure Learning 193 | if self.sample: 194 | # A fast mode for large graphs. 195 | # In large graphs, learning the possible edge weights between each pair of nodes is time consuming. 196 | # To accelerate this process, we sample it's K-Hop neighbors for each node and then learn the 197 | # edge weights between them. 198 | if edge_attr is None: 199 | edge_attr = torch.ones((edge_index.size(1),), dtype=torch.float, device=edge_index.device) 200 | 201 | if self.h_hop >= 2: 202 | hop_data = Data(x=original_x, edge_index=edge_index, edge_attr=edge_attr) 203 | for _ in range(self.h_hop - 1): 204 | hop_data = self.neighbor_augment(hop_data) 205 | hop_edge_index = hop_data.edge_index 206 | hop_edge_attr = hop_data.edge_attr 207 | new_edge_index, new_edge_attr = filter_adj(hop_edge_index, hop_edge_attr, perm, num_nodes=score.size(0)) 208 | if self.hc is True: 209 | return x, new_edge_index, new_edge_attr, batch, perm 210 | else: 211 | new_edge_index = induced_edge_index 212 | new_edge_attr = induced_edge_attr 213 | if self.hc is True: 214 | return x, new_edge_index, new_edge_attr, batch, perm 215 | 216 | new_edge_index, new_edge_attr = add_remaining_self_loops(new_edge_index, new_edge_attr, 0, x.size(0)) 217 | row, col = new_edge_index 218 | weights = (torch.cat([x[row], x[col]], dim=1) * self.att).sum(dim=-1) 219 | weights = F.leaky_relu(weights, self.negative_slop) + new_edge_attr * self.lamb 220 | adj = torch.zeros((x.size(0), x.size(0)), dtype=torch.float, device=x.device) 221 | adj[row, col] = weights 222 | new_edge_index, weights = dense_to_sparse(adj) 223 | row, col = new_edge_index 224 | if self.sparse: 225 | new_edge_attr = self.sparse_attention(weights, row) 226 | else: 227 | new_edge_attr = softmax(weights, row, x.size(0)) 228 | # filter out zero weight edges 229 | adj[row, col] = new_edge_attr 230 | new_edge_index, new_edge_attr = dense_to_sparse(adj) 231 | # release gpu memory 232 | del adj 233 | torch.cuda.empty_cache() 234 | else: 235 | # Learning the possible edge weights between each pair of nodes in the pooled subgraph, relative slower. 236 | if edge_attr is None: 237 | induced_edge_attr = torch.ones((induced_edge_index.size(1),), dtype=x.dtype, 238 | device=induced_edge_index.device) 239 | num_nodes = scatter_add(batch.new_ones(x.size(0)), batch, dim=0) 240 | shift_cum_num_nodes = torch.cat([num_nodes.new_zeros(1), num_nodes.cumsum(dim=0)[:-1]], dim=0) 241 | cum_num_nodes = num_nodes.cumsum(dim=0) 242 | adj = torch.zeros((x.size(0), x.size(0)), dtype=torch.float, device=x.device) 243 | # Construct batch fully connected graph in block diagonal matirx format 244 | for idx_i, idx_j in zip(shift_cum_num_nodes, cum_num_nodes): 245 | adj[idx_i:idx_j, idx_i:idx_j] = 1.0 246 | new_edge_index, _ = dense_to_sparse(adj) 247 | row, col = new_edge_index 248 | 249 | weights = (torch.cat([x[row], x[col]], dim=1) * self.att).sum(dim=-1) 250 | weights = F.leaky_relu(weights, self.negative_slop) 251 | adj[row, col] = weights 252 | induced_row, induced_col = induced_edge_index 253 | 254 | adj[induced_row, induced_col] += induced_edge_attr * self.lamb 255 | weights = adj[row, col] 256 | if self.sparse: 257 | new_edge_attr = self.sparse_attention(weights, row) 258 | else: 259 | new_edge_attr = softmax(weights, row, x.size(0)) 260 | # filter out zero weight edges 261 | adj[row, col] = new_edge_attr 262 | new_edge_index, new_edge_attr = dense_to_sparse(adj) 263 | # release gpu memory 264 | del adj 265 | torch.cuda.empty_cache() 266 | 267 | return x, new_edge_index, new_edge_attr, batch, perm 268 | -------------------------------------------------------------------------------- /main_graph_classification.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import glob 3 | import os 4 | import time 5 | 6 | import torch 7 | import torch.nn.functional as F 8 | from models import GraphClassificationModel 9 | from torch.utils.data import random_split 10 | from torch_geometric.data import DataLoader 11 | from torch_geometric.datasets import TUDataset 12 | from torch_geometric.transforms import OneHotDegree 13 | from torch_geometric.utils import degree 14 | 15 | parser = argparse.ArgumentParser() 16 | 17 | parser.add_argument('--seed', type=int, default=777, help='random seed') 18 | parser.add_argument('--batch_size', type=int, default=512, help='batch size') 19 | parser.add_argument('--lr', type=float, default=0.001, help='learning rate') 20 | parser.add_argument('--weight_decay', type=float, default=0.001, help='weight decay') 21 | parser.add_argument('--nhid', type=int, default=128, help='hidden size') 22 | parser.add_argument('--sample_neighbor', type=bool, default=True, help='whether sample neighbors') 23 | parser.add_argument('--sparse_attention', type=bool, default=True, help='whether use sparse attention') 24 | parser.add_argument('--structure_learning', type=bool, default=False, help='whether perform structure learning') 25 | parser.add_argument('--hop_connection', type=bool, default=False, help='whether directly connect node within h-hops') 26 | parser.add_argument('--hop', type=int, default=3, help='h-hops') 27 | parser.add_argument('--pooling_ratio', type=float, default=0.8, help='pooling ratio') 28 | parser.add_argument('--dropout_ratio', type=float, default=0.0, help='dropout ratio') 29 | parser.add_argument('--lamb', type=float, default=2.0, help='trade-off parameter') 30 | parser.add_argument('--dataset', type=str, default='IMDB-MULTI', help='DD/PROTEINS/NCI1/NCI109/Mutagenicity/ENZYMES') 31 | parser.add_argument('--device', type=str, default='cuda:1', help='specify cuda devices') 32 | parser.add_argument('--epochs', type=int, default=1000, help='maximum number of epochs') 33 | parser.add_argument('--patience', type=int, default=100, help='patience for early stopping') 34 | 35 | args = parser.parse_args() 36 | torch.manual_seed(args.seed) 37 | if torch.cuda.is_available(): 38 | torch.cuda.manual_seed(args.seed) 39 | 40 | if args.dataset == 'IMDB-MULTI' or args.dataset == 'REDDIT-MULTI-12K': 41 | dataset = TUDataset('data/', name=args.dataset) 42 | max_degree = 0 43 | for g in dataset: 44 | if g.edge_index.size(1) > 0: 45 | max_degree = max(max_degree, int(degree(g.edge_index[0]).max().item())) 46 | dataset.transform = OneHotDegree(max_degree) 47 | args.num_classes = dataset.num_classes 48 | args.num_features = dataset.num_features 49 | else: 50 | dataset = TUDataset('data/', name=args.dataset, use_node_attr=True) 51 | args.num_classes = dataset.num_classes 52 | args.num_features = dataset.num_features 53 | 54 | print(args) 55 | 56 | num_training = int(len(dataset) * 0.8) 57 | num_val = int(len(dataset) * 0.1) 58 | num_test = len(dataset) - (num_training + num_val) 59 | training_set, validation_set, test_set = random_split(dataset, [num_training, num_val, num_test]) 60 | 61 | train_loader = DataLoader(training_set, batch_size=args.batch_size, shuffle=True) 62 | val_loader = DataLoader(validation_set, batch_size=args.batch_size, shuffle=False) 63 | test_loader = DataLoader(test_set, batch_size=args.batch_size, shuffle=False) 64 | 65 | model = GraphClassificationModel(args).to(args.device) 66 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) 67 | 68 | 69 | def train(): 70 | min_loss = 1e10 71 | patience_cnt = 0 72 | val_loss_values = [] 73 | best_epoch = 0 74 | 75 | t = time.time() 76 | model.train() 77 | for epoch in range(args.epochs): 78 | loss_train = 0.0 79 | correct = 0 80 | for i, data in enumerate(train_loader): 81 | optimizer.zero_grad() 82 | data = data.to(args.device) 83 | out = model(data) 84 | loss = F.nll_loss(out, data.y) 85 | loss.backward() 86 | optimizer.step() 87 | loss_train += loss.item() 88 | pred = out.max(dim=1)[1] 89 | correct += pred.eq(data.y).sum().item() 90 | acc_train = correct / len(train_loader.dataset) 91 | acc_val, loss_val = compute_test(val_loader) 92 | print('Epoch: {:04d}'.format(epoch + 1), 'loss_train: {:.6f}'.format(loss_train), 93 | 'acc_train: {:.6f}'.format(acc_train), 'loss_val: {:.6f}'.format(loss_val), 94 | 'acc_val: {:.6f}'.format(acc_val), 'time: {:.6f}s'.format(time.time() - t)) 95 | 96 | val_loss_values.append(loss_val) 97 | torch.save(model.state_dict(), '{}.pth'.format(epoch)) 98 | if val_loss_values[-1] < min_loss: 99 | min_loss = val_loss_values[-1] 100 | best_epoch = epoch 101 | patience_cnt = 0 102 | else: 103 | patience_cnt += 1 104 | 105 | if patience_cnt == args.patience: 106 | break 107 | 108 | files = glob.glob('*.pth') 109 | for f in files: 110 | epoch_nb = int(f.split('.')[0]) 111 | if epoch_nb < best_epoch: 112 | os.remove(f) 113 | 114 | files = glob.glob('*.pth') 115 | for f in files: 116 | epoch_nb = int(f.split('.')[0]) 117 | if epoch_nb > best_epoch: 118 | os.remove(f) 119 | print('Optimization Finished! Total time elapsed: {:.6f}'.format(time.time() - t)) 120 | 121 | return best_epoch 122 | 123 | 124 | def compute_test(loader): 125 | model.eval() 126 | correct = 0.0 127 | loss_test = 0.0 128 | for data in loader: 129 | data = data.to(args.device) 130 | out = model(data) 131 | pred = out.max(dim=1)[1] 132 | correct += pred.eq(data.y).sum().item() 133 | loss_test += F.nll_loss(out, data.y).item() 134 | return correct / len(loader.dataset), loss_test 135 | 136 | 137 | if __name__ == '__main__': 138 | # Model training 139 | best_model = train() 140 | # Restore best model for test set 141 | model.load_state_dict(torch.load('{}.pth'.format(best_model))) 142 | test_acc, test_loss = compute_test(test_loader) 143 | print('Test set results, loss = {:.6f}, accuracy = {:.6f}'.format(test_loss, test_acc)) 144 | 145 | 146 | -------------------------------------------------------------------------------- /main_node_classification.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import glob 4 | import argparse 5 | 6 | import torch 7 | import torch.nn.functional as F 8 | from models import NodeClassificationModel 9 | from torch_geometric.datasets import Coauthor, Planetoid 10 | from utils import index_to_mask, random_splits 11 | 12 | parser = argparse.ArgumentParser() 13 | 14 | parser.add_argument('--seed', type=int, default=777, help='random seed') 15 | parser.add_argument('--lr', type=float, default=0.01, help='learning rate') 16 | parser.add_argument('--weight_decay', type=float, default=0.01, help='weight decay') 17 | parser.add_argument('--nhid', type=int, default=64, help='hidden size') 18 | parser.add_argument('--depth', type=int, default=4, help='number of encoder layers') 19 | parser.add_argument('--sample_neighbor', type=bool, default=True, help='whether sample neighbors within h-hops') 20 | parser.add_argument('--sparse_attention', type=bool, default=True, help='whether use sparse attention') 21 | parser.add_argument('--structure_learning', type=bool, default=True, help='whether perform structure learning') 22 | parser.add_argument('--hop_connection', type=bool, default=False, help='whether directly connect node within h-hops') 23 | parser.add_argument('--hop', type=int, default=2, help='h-hops') 24 | parser.add_argument('--lamb', type=float, default=0.0, help='trade-off parameter') 25 | parser.add_argument('--dataset', type=str, default='CS', help='Cora/Citeseer/Pubmed/Physics') 26 | parser.add_argument('--device', type=str, default='cuda:0', help='specify cuda devices') 27 | parser.add_argument('--epochs', type=int, default=200, help='maximum number of epochs') 28 | parser.add_argument('--pool1', type=float, default=0.05, help='pool1 parameter') 29 | parser.add_argument('--pool2', type=float, default=0.5, help='pool2 parameter') 30 | parser.add_argument('--pool3', type=float, default=0.5, help='pool3 parameter') 31 | parser.add_argument('--pool4', type=float, default=0.5, help='pool4 parameter') 32 | parser.add_argument('--pool5', type=float, default=0.8, help='pool5 parameter') 33 | 34 | args = parser.parse_args() 35 | torch.manual_seed(args.seed) 36 | if torch.cuda.is_available(): 37 | torch.cuda.manual_seed(args.seed) 38 | 39 | if args.dataset == 'Physics' or args.dataset == 'CS': 40 | dataset = Coauthor(os.path.join('data', args.dataset), args.dataset) 41 | data = dataset.data 42 | data = random_splits(data, dataset.num_classes) 43 | else: 44 | dataset = Planetoid(os.path.join('data', args.dataset), args.dataset) 45 | data = dataset.data 46 | 47 | args.num_nodes = data.x.size(0) 48 | args.num_features = data.x.size(1) 49 | args.num_classes = dataset.num_classes 50 | 51 | print(args) 52 | 53 | model = NodeClassificationModel(args).to(args.device) 54 | data = data.to(args.device) 55 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) 56 | 57 | 58 | def train(): 59 | best_test_acc = 0 60 | val_acc_values = [] 61 | val_loss_values = [] 62 | best_epoch = 0 63 | min_loss = 1e10 64 | 65 | t = time.time() 66 | for epoch in range(args.epochs): 67 | loss_train = 0.0 68 | model.train() 69 | optimizer.zero_grad() 70 | out = model(data.x, data.edge_index) 71 | out = F.log_softmax(out, dim=1) 72 | loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask]) 73 | loss.backward() 74 | optimizer.step() 75 | loss_train += loss.item() 76 | pred = out[data.train_mask].max(dim=1)[1] 77 | correct = pred.eq(data.y[data.train_mask]).sum().item() 78 | acc_train = correct / data.train_mask.sum().item() 79 | acc_val, loss_val = compute_test(data.val_mask) 80 | acc_test, loss_test = compute_test(data.test_mask) 81 | 82 | if acc_test > best_test_acc: 83 | best_test_acc = acc_test 84 | 85 | print('Epoch: {:04d}'.format(epoch + 1), 'loss_train: {:.4f}'.format(loss_train), 86 | 'acc_train: {:.4f}'.format(acc_train), 'loss_val: {:.4f}'.format(loss_val), 87 | 'acc_val: {:.4f}'.format(acc_val), 'acc_test: {:.4f}'.format(best_test_acc), 88 | 'time: {:.4f}s'.format(time.time() - t)) 89 | 90 | val_acc_values.append(acc_val) 91 | val_loss_values.append(loss_val) 92 | torch.save(model.state_dict(), '{}.pth'.format(epoch)) 93 | 94 | if val_loss_values[-1] < min_loss: 95 | min_loss = val_loss_values[-1] 96 | best_epoch = epoch 97 | 98 | files = glob.glob('*.pth') 99 | for f in files: 100 | epoch_nb = int(f.split('.')[0]) 101 | if epoch_nb < best_epoch: 102 | os.remove(f) 103 | 104 | files = glob.glob('*.pth') 105 | for f in files: 106 | epoch_nb = int(f.split('.')[0]) 107 | if epoch_nb > best_epoch: 108 | os.remove(f) 109 | print('Optimization Finished! Total time elapsed: {:.4f}'.format(time.time() - t)) 110 | 111 | return best_epoch 112 | 113 | 114 | def compute_test(mask): 115 | model.eval() 116 | with torch.no_grad(): 117 | correct = 0.0 118 | loss_test = 0.0 119 | out = model(data.x, data.edge_index) 120 | out = F.log_softmax(out, dim=1) 121 | pred = out[mask].max(dim=1)[1] 122 | correct += pred.eq(data.y[mask]).sum().item() 123 | loss_test += F.nll_loss(out[mask], data.y[mask]).item() 124 | return correct / mask.sum().item(), loss_test 125 | 126 | 127 | def save_embedding(inputfile): 128 | model.eval() 129 | f = open(inputfile, 'w') 130 | with torch.no_grad(): 131 | embeddings = model.gen_embedding(data.x, data.edge_index) 132 | embeddings = embeddings.cpu().detach().numpy() 133 | gt = data.y.cpu().detach().numpy() 134 | num_nodes, num_dims = embeddings.shape 135 | for i in range(num_nodes): 136 | write_string = str(gt[i]) 137 | for j in range(num_dims): 138 | write_string += ' ' + str(embeddings[i, j]) 139 | write_string += '\n' 140 | f.writelines(write_string) 141 | f.close() 142 | 143 | 144 | if __name__ == '__main__': 145 | # Model training 146 | best_model = train() 147 | # Restore best model for test set 148 | model.load_state_dict(torch.load('{}.pth'.format(best_model))) 149 | test_acc, test_loss = compute_test(data.test_mask) 150 | print('Test set results, loss = {:.4f}, accuracy = {:.4f}'.format(test_loss, test_acc)) 151 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import networkx as nx 3 | import torch 4 | import torch.nn.functional as F 5 | from torch_geometric.nn import GCNConv 6 | from torch_geometric.nn import global_mean_pool as gap, global_max_pool as gmp 7 | 8 | from layers import GCN, MVPool 9 | 10 | 11 | class GraphClassificationModel(torch.nn.Module): 12 | def __init__(self, args): 13 | super(GraphClassificationModel, self).__init__() 14 | self.args = args 15 | self.num_features = args.num_features 16 | self.nhid = args.nhid 17 | self.num_classes = args.num_classes 18 | self.pooling_ratio = args.pooling_ratio 19 | self.dropout_ratio = args.dropout_ratio 20 | self.sample = args.sample_neighbor 21 | self.sparse = args.sparse_attention 22 | self.sl = args.structure_learning 23 | self.lamb = args.lamb 24 | 25 | self.conv1 = GCNConv(self.num_features, self.nhid) 26 | self.conv2 = GCN(self.nhid, self.nhid) 27 | self.conv3 = GCN(self.nhid, self.nhid) 28 | 29 | self.pool1 = MVPool(self.nhid, self.pooling_ratio, args) 30 | self.pool2 = MVPool(self.nhid, self.pooling_ratio, args) 31 | self.pool3 = MVPool(self.nhid, self.pooling_ratio, args) 32 | 33 | self.lin1 = torch.nn.Linear(self.nhid * 2, self.nhid) 34 | self.lin2 = torch.nn.Linear(self.nhid, self.nhid // 2) 35 | self.lin3 = torch.nn.Linear(self.nhid // 2, self.num_classes) 36 | 37 | def forward(self, data): 38 | x, edge_index, batch = data.x, data.edge_index, data.batch 39 | edge_attr = None 40 | 41 | x = F.relu(self.conv1(x, edge_index, edge_attr)) 42 | x, edge_index, edge_attr, batch, _ = self.pool1(x, edge_index, edge_attr, batch) 43 | x1 = torch.cat([gmp(x, batch), gap(x, batch)], dim=1) 44 | 45 | x = F.relu(self.conv2(x, edge_index, edge_attr)) 46 | x, edge_index, edge_attr, batch, _ = self.pool2(x, edge_index, edge_attr, batch) 47 | x2 = torch.cat([gmp(x, batch), gap(x, batch)], dim=1) 48 | 49 | x = F.relu(self.conv3(x, edge_index, edge_attr)) 50 | x, edge_index, edge_attr, batch, _ = self.pool3(x, edge_index, edge_attr, batch) 51 | x3 = torch.cat([gmp(x, batch), gap(x, batch)], dim=1) 52 | 53 | x = F.relu(x1) + F.relu(x2) + F.relu(x3) 54 | 55 | x = F.relu(self.lin1(x)) 56 | x = F.dropout(x, p=self.dropout_ratio, training=self.training) 57 | x = F.relu(self.lin2(x)) 58 | x = F.dropout(x, p=self.dropout_ratio, training=self.training) 59 | x = F.log_softmax(self.lin3(x), dim=-1) 60 | 61 | return x 62 | 63 | 64 | class NodeClassificationModel(torch.nn.Module): 65 | def __init__(self, args, sum_res=False, act=F.relu): 66 | super(NodeClassificationModel, self).__init__() 67 | assert args.depth >= 1 68 | self.in_channels = args.num_features 69 | self.hidden_channels = args.nhid 70 | self.out_channels = args.num_classes 71 | self.depth = args.depth 72 | self.pool_ratios = [args.pool1, args.pool2, args.pool3, args.pool4, args.pool5] 73 | self.act = act 74 | self.sum_res = sum_res 75 | 76 | channels = self.hidden_channels 77 | 78 | self.down_convs = torch.nn.ModuleList() 79 | self.pools = torch.nn.ModuleList() 80 | self.down_convs.append(GCNConv(self.in_channels, channels)) 81 | for i in range(self.depth): 82 | self.pools.append(MVPool(channels, self.pool_ratios[i], args)) 83 | self.down_convs.append(GCN(channels, channels)) 84 | 85 | in_channels = channels if sum_res else 2 * channels 86 | 87 | self.up_convs = torch.nn.ModuleList() 88 | for i in range(self.depth): 89 | self.up_convs.append(GCN(in_channels, channels)) 90 | self.up_convs.append(GCN(channels, self.out_channels)) 91 | 92 | def forward(self, x, edge_index, batch=None): 93 | if batch is None: 94 | batch = edge_index.new_zeros(x.size(0)) 95 | edge_weight = x.new_ones(edge_index.size(1)) 96 | 97 | x = F.dropout(x, p=0.92, training=self.training) 98 | x = self.down_convs[0](x, edge_index, edge_weight) 99 | x = self.act(x) 100 | 101 | xs = [x] 102 | edge_indices = [edge_index] 103 | edge_weights = [edge_weight] 104 | perms = [] 105 | 106 | for i in range(1, self.depth + 1): 107 | x, edge_index, edge_weight, batch, perm = self.pools[i - 1](x, edge_index, edge_weight, batch) 108 | x = self.down_convs[i](x, edge_index, edge_weight) 109 | x = self.act(x) 110 | 111 | if i < self.depth: 112 | xs += [x] 113 | edge_indices += [edge_index] 114 | edge_weights += [edge_weight] 115 | perms += [perm] 116 | 117 | for i in range(self.depth): 118 | j = self.depth - 1 - i 119 | 120 | res = xs[j] 121 | edge_index = edge_indices[j] 122 | edge_weight = edge_weights[j] 123 | perm = perms[j] 124 | 125 | up = torch.zeros_like(res) 126 | up[perm] = x 127 | x = res + up if self.sum_res else torch.cat((res, up), dim=-1) 128 | x = self.up_convs[i](x, edge_index, edge_weight) 129 | x = self.act(x) 130 | x = self.up_convs[-1](x, edge_index, edge_weight) 131 | 132 | return x 133 | -------------------------------------------------------------------------------- /sparse_softmax.py: -------------------------------------------------------------------------------- 1 | """ 2 | An original implementation of sparsemax (Martins & Astudillo, 2016) is available at 3 | https://github.com/OpenNMT/OpenNMT-py/blob/master/onmt/modules/sparse_activations.py. 4 | See `From Softmax to Sparsemax: A Sparse Model of Attention and Multi-Label Classification, ICML 2016` 5 | for detailed description. 6 | 7 | We make some modifications to make it work at scatter operation scenarios, e.g., calculate softmax according to batch 8 | indicators. 9 | 10 | Usage: 11 | >> x = torch.tensor([ 1.7301, 0.6792, -1.0565, 1.6614, -0.3196, -0.7790, -0.3877, -0.4943, 12 | 0.1831, -0.0061]) 13 | >> batch = torch.tensor([0, 0, 0, 0, 1, 1, 1, 1, 1, 1]) 14 | >> sparse_attention = Sparsemax() 15 | >> res = sparse_attention(x, batch) 16 | >> print(res) 17 | tensor([0.5343, 0.0000, 0.0000, 0.4657, 0.0612, 0.0000, 0.0000, 0.0000, 0.5640, 18 | 0.3748]) 19 | 20 | """ 21 | import torch 22 | import torch.nn as nn 23 | from torch.autograd import Function 24 | from torch_scatter import scatter_add, scatter_max 25 | 26 | 27 | def scatter_sort(x, batch, fill_value=-1e16): 28 | num_nodes = scatter_add(batch.new_ones(x.size(0)), batch, dim=0) 29 | batch_size, max_num_nodes = num_nodes.size(0), num_nodes.max().item() 30 | 31 | cum_num_nodes = torch.cat([num_nodes.new_zeros(1), num_nodes.cumsum(dim=0)[:-1]], dim=0) 32 | 33 | index = torch.arange(batch.size(0), dtype=torch.long, device=x.device) 34 | index = (index - cum_num_nodes[batch]) + (batch * max_num_nodes) 35 | 36 | dense_x = x.new_full((batch_size * max_num_nodes,), fill_value) 37 | dense_x[index] = x 38 | dense_x = dense_x.view(batch_size, max_num_nodes) 39 | 40 | sorted_x, _ = dense_x.sort(dim=-1, descending=True) 41 | cumsum_sorted_x = sorted_x.cumsum(dim=-1) 42 | cumsum_sorted_x = cumsum_sorted_x.view(-1) 43 | 44 | sorted_x = sorted_x.view(-1) 45 | filled_index = sorted_x != fill_value 46 | 47 | sorted_x = sorted_x[filled_index] 48 | cumsum_sorted_x = cumsum_sorted_x[filled_index] 49 | 50 | return sorted_x, cumsum_sorted_x 51 | 52 | 53 | def _make_ix_like(batch): 54 | num_nodes = scatter_add(batch.new_ones(batch.size(0)), batch, dim=0) 55 | idx = [torch.arange(1, i + 1, dtype=torch.long, device=batch.device) for i in num_nodes] 56 | idx = torch.cat(idx, dim=0) 57 | 58 | return idx 59 | 60 | 61 | def _threshold_and_support(x, batch): 62 | """Sparsemax building block: compute the threshold 63 | Args: 64 | x: input tensor to apply the sparsemax 65 | batch: group indicators 66 | Returns: 67 | the threshold value 68 | """ 69 | num_nodes = scatter_add(batch.new_ones(x.size(0)), batch, dim=0) 70 | cum_num_nodes = torch.cat([num_nodes.new_zeros(1), num_nodes.cumsum(dim=0)[:-1]], dim=0) 71 | 72 | sorted_input, input_cumsum = scatter_sort(x, batch) 73 | input_cumsum = input_cumsum - 1.0 74 | rhos = _make_ix_like(batch).to(x.dtype) 75 | support = rhos * sorted_input > input_cumsum 76 | 77 | support_size = scatter_add(support.to(batch.dtype), batch) 78 | # mask invalid index, for example, if batch is not start from 0 or not continuous, it may result in negative index 79 | idx = support_size + cum_num_nodes - 1 80 | mask = idx < 0 81 | idx[mask] = 0 82 | tau = input_cumsum.gather(0, idx) 83 | tau /= support_size.to(x.dtype) 84 | 85 | return tau, support_size 86 | 87 | 88 | class SparsemaxFunction(Function): 89 | 90 | @staticmethod 91 | def forward(ctx, x, batch): 92 | """sparsemax: normalizing sparse transform 93 | Parameters: 94 | ctx: context object 95 | x (Tensor): shape (N, ) 96 | batch: group indicator 97 | Returns: 98 | output (Tensor): same shape as input 99 | """ 100 | max_val, _ = scatter_max(x, batch) 101 | x -= max_val[batch] 102 | tau, supp_size = _threshold_and_support(x, batch) 103 | output = torch.clamp(x - tau[batch], min=0) 104 | ctx.save_for_backward(supp_size, output, batch) 105 | 106 | return output 107 | 108 | @staticmethod 109 | def backward(ctx, grad_output): 110 | supp_size, output, batch = ctx.saved_tensors 111 | grad_input = grad_output.clone() 112 | grad_input[output == 0] = 0 113 | 114 | v_hat = scatter_add(grad_input, batch) / supp_size.to(output.dtype) 115 | grad_input = torch.where(output != 0, grad_input - v_hat[batch], grad_input) 116 | 117 | return grad_input, None 118 | 119 | 120 | sparsemax = SparsemaxFunction.apply 121 | 122 | 123 | class Sparsemax(nn.Module): 124 | 125 | def __init__(self): 126 | super(Sparsemax, self).__init__() 127 | 128 | def forward(self, x, batch): 129 | return sparsemax(x, batch) 130 | 131 | 132 | if __name__ == '__main__': 133 | sparse_attention = Sparsemax() 134 | input_x = torch.tensor([1.7301, 0.6792, -1.0565, 1.6614, -0.3196, -0.7790, -0.3877, -0.4943, 0.1831, -0.0061]) 135 | input_batch = torch.cat([torch.zeros(4, dtype=torch.long), torch.ones(6, dtype=torch.long)], dim=0) 136 | res = sparse_attention(input_x, input_batch) 137 | print(res) 138 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def index_to_mask(index, size): 4 | mask = torch.zeros((size, ), dtype=torch.bool) 5 | mask[index] = 1 6 | return mask 7 | 8 | def random_splits(data, num_classes): 9 | # Set new random planetoid splits: 10 | # * 20 * num_classes labels for training 11 | # * 30 * num_classes labels for validation 12 | # * the rest for testing 13 | indices = [] 14 | for i in range(num_classes): 15 | index = (data.y == i).nonzero().view(-1) 16 | index = index[torch.randperm(index.size(0))] 17 | indices.append(index) 18 | 19 | train_index = torch.cat([i[:20] for i in indices], dim=0) 20 | val_index = torch.cat([i[20:50] for i in indices], dim=0) 21 | 22 | rest_index = torch.cat([i[50:] for i in indices], dim=0) 23 | rest_index = rest_index[torch.randperm(rest_index.size(0))] 24 | 25 | data.train_mask = index_to_mask(train_index, size=data.num_nodes) 26 | data.val_mask = index_to_mask(val_index, size=data.num_nodes) 27 | data.test_mask = index_to_mask(rest_index, size=data.num_nodes) 28 | 29 | return data 30 | --------------------------------------------------------------------------------