├── .gitignore ├── LICENSE ├── README.md ├── data ├── ind.cora.allx ├── ind.cora.ally ├── ind.cora.graph ├── ind.cora.test.index ├── ind.cora.tx ├── ind.cora.ty ├── ind.cora.x └── ind.cora.y ├── execute_cora.py ├── execute_cora_sparse.py ├── models ├── __init__.py ├── base_gattn.py ├── gat.py └── sp_gat.py ├── pre_trained └── cora │ ├── checkpoint │ ├── mod_cora.ckpt.data-00000-of-00001 │ ├── mod_cora.ckpt.index │ └── mod_cora.ckpt.meta └── utils ├── __init__.py ├── layers.py ├── process.py └── process_ppi.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | local_settings.py 56 | 57 | # Flask stuff: 58 | instance/ 59 | .webassets-cache 60 | 61 | # Scrapy stuff: 62 | .scrapy 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | # Jupyter Notebook 71 | .ipynb_checkpoints 72 | 73 | # pyenv 74 | .python-version 75 | 76 | # celery beat schedule file 77 | celerybeat-schedule 78 | 79 | # SageMath parsed files 80 | *.sage.py 81 | 82 | # dotenv 83 | .env 84 | 85 | # virtualenv 86 | .venv 87 | venv/ 88 | ENV/ 89 | 90 | # Spyder project settings 91 | .spyderproject 92 | .spyproject 93 | 94 | # Rope project settings 95 | .ropeproject 96 | 97 | # mkdocs documentation 98 | /site 99 | 100 | # mypy 101 | .mypy_cache/ 102 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Petar Veličković 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # GAT 2 | Graph Attention Networks (Veličković *et al.*, ICLR 2018): [https://arxiv.org/abs/1710.10903](https://arxiv.org/abs/1710.10903) 3 | 4 | GAT layer | t-SNE + Attention coefficients on Cora 5 | :-------------------------:|:-------------------------: 6 | ![](https://camo.githubusercontent.com/4fe1a90e67d17a2330d7cfcddc930d5f7501750c/68747470733a2f2f7777772e64726f70626f782e636f6d2f732f71327a703170366b37396a6a6431352f6761745f6c617965722e706e673f7261773d31) | ![](https://raw.githubusercontent.com/PetarV-/GAT/gh-pages/assets/t-sne.png) 7 | 8 | ## Overview 9 | Here we provide the implementation of a Graph Attention Network (GAT) layer in TensorFlow, along with a minimal execution example (on the Cora dataset). The repository is organised as follows: 10 | - `data/` contains the necessary dataset files for Cora; 11 | - `models/` contains the implementation of the GAT network (`gat.py`); 12 | - `pre_trained/` contains a pre-trained Cora model (achieving 84.4% accuracy on the test set); 13 | - `utils/` contains: 14 | * an implementation of an attention head, along with an experimental sparse version (`layers.py`); 15 | * preprocessing subroutines (`process.py`); 16 | * preprocessing utilities for the PPI benchmark (`process_ppi.py`). 17 | 18 | Finally, `execute_cora.py` puts all of the above together and may be used to execute a full training run on Cora. 19 | 20 | ## Sparse version 21 | An experimental sparse version is also available, working only when the batch size is equal to 1. 22 | The sparse model may be found at `models/sp_gat.py`. 23 | 24 | You may execute a full training run of the sparse model on Cora through `execute_cora_sparse.py`. 25 | 26 | ## Dependencies 27 | 28 | The script has been tested running under Python 3.5.2, with the following packages installed (along with their dependencies): 29 | 30 | - `numpy==1.14.1` 31 | - `scipy==1.0.0` 32 | - `networkx==2.1` 33 | - `tensorflow-gpu==1.6.0` 34 | 35 | In addition, CUDA 9.0 and cuDNN 7 have been used. 36 | 37 | ## Reference 38 | If you make advantage of the GAT model in your research, please cite the following in your manuscript: 39 | 40 | ``` 41 | @article{ 42 | velickovic2018graph, 43 | title="{Graph Attention Networks}", 44 | author={Veli{\v{c}}kovi{\'{c}}, Petar and Cucurull, Guillem and Casanova, Arantxa and Romero, Adriana and Li{\`{o}}, Pietro and Bengio, Yoshua}, 45 | journal={International Conference on Learning Representations}, 46 | year={2018}, 47 | url={https://openreview.net/forum?id=rJXMpikCZ}, 48 | note={accepted as poster}, 49 | } 50 | ``` 51 | 52 | For getting started with GATs, as well as graph representation learning in general, we **highly** recommend the [pytorch-GAT](https://github.com/gordicaleksa/pytorch-GAT) repository by [Aleksa Gordić](https://github.com/gordicaleksa). It ships with an inductive (PPI) example as well. 53 | 54 | GAT is a popular method for graph representation learning, with optimised implementations within virtually all standard GRL libraries: 55 | - \[PyTorch\] [PyTorch Geometric](https://pytorch-geometric.readthedocs.io/en/latest/) 56 | - \[PyTorch/TensorFlow\] [Deep Graph Library](https://www.dgl.ai/) 57 | - \[TensorFlow\] [Spektral](https://graphneural.network/) 58 | - \[JAX\] [jraph](https://github.com/deepmind/jraph) 59 | 60 | We recommend using either one of those (depending on your favoured framework), as their implementations have been more readily battle-tested. 61 | 62 | Early on post-release, two unofficial ports of the GAT model to various frameworks quickly surfaced. To honour the effort of their developers as early adopters of the GAT layer, we leave pointers to them here. 63 | - \[Keras\] [keras-gat](https://github.com/danielegrattarola/keras-gat), developed by [Daniele Grattarola](https://github.com/danielegrattarola); 64 | - \[PyTorch\] [pyGAT](https://github.com/Diego999/pyGAT), developed by [Diego Antognini](https://github.com/Diego999). 65 | 66 | ## License 67 | MIT 68 | -------------------------------------------------------------------------------- /data/ind.cora.allx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PetarV-/GAT/5af87e7fce2b90ae1cbd621cd58059036a3c7436/data/ind.cora.allx -------------------------------------------------------------------------------- /data/ind.cora.ally: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PetarV-/GAT/5af87e7fce2b90ae1cbd621cd58059036a3c7436/data/ind.cora.ally -------------------------------------------------------------------------------- /data/ind.cora.graph: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PetarV-/GAT/5af87e7fce2b90ae1cbd621cd58059036a3c7436/data/ind.cora.graph -------------------------------------------------------------------------------- /data/ind.cora.test.index: -------------------------------------------------------------------------------- 1 | 2692 2 | 2532 3 | 2050 4 | 1715 5 | 2362 6 | 2609 7 | 2622 8 | 1975 9 | 2081 10 | 1767 11 | 2263 12 | 1725 13 | 2588 14 | 2259 15 | 2357 16 | 1998 17 | 2574 18 | 2179 19 | 2291 20 | 2382 21 | 1812 22 | 1751 23 | 2422 24 | 1937 25 | 2631 26 | 2510 27 | 2378 28 | 2589 29 | 2345 30 | 1943 31 | 1850 32 | 2298 33 | 1825 34 | 2035 35 | 2507 36 | 2313 37 | 1906 38 | 1797 39 | 2023 40 | 2159 41 | 2495 42 | 1886 43 | 2122 44 | 2369 45 | 2461 46 | 1925 47 | 2565 48 | 1858 49 | 2234 50 | 2000 51 | 1846 52 | 2318 53 | 1723 54 | 2559 55 | 2258 56 | 1763 57 | 1991 58 | 1922 59 | 2003 60 | 2662 61 | 2250 62 | 2064 63 | 2529 64 | 1888 65 | 2499 66 | 2454 67 | 2320 68 | 2287 69 | 2203 70 | 2018 71 | 2002 72 | 2632 73 | 2554 74 | 2314 75 | 2537 76 | 1760 77 | 2088 78 | 2086 79 | 2218 80 | 2605 81 | 1953 82 | 2403 83 | 1920 84 | 2015 85 | 2335 86 | 2535 87 | 1837 88 | 2009 89 | 1905 90 | 2636 91 | 1942 92 | 2193 93 | 2576 94 | 2373 95 | 1873 96 | 2463 97 | 2509 98 | 1954 99 | 2656 100 | 2455 101 | 2494 102 | 2295 103 | 2114 104 | 2561 105 | 2176 106 | 2275 107 | 2635 108 | 2442 109 | 2704 110 | 2127 111 | 2085 112 | 2214 113 | 2487 114 | 1739 115 | 2543 116 | 1783 117 | 2485 118 | 2262 119 | 2472 120 | 2326 121 | 1738 122 | 2170 123 | 2100 124 | 2384 125 | 2152 126 | 2647 127 | 2693 128 | 2376 129 | 1775 130 | 1726 131 | 2476 132 | 2195 133 | 1773 134 | 1793 135 | 2194 136 | 2581 137 | 1854 138 | 2524 139 | 1945 140 | 1781 141 | 1987 142 | 2599 143 | 1744 144 | 2225 145 | 2300 146 | 1928 147 | 2042 148 | 2202 149 | 1958 150 | 1816 151 | 1916 152 | 2679 153 | 2190 154 | 1733 155 | 2034 156 | 2643 157 | 2177 158 | 1883 159 | 1917 160 | 1996 161 | 2491 162 | 2268 163 | 2231 164 | 2471 165 | 1919 166 | 1909 167 | 2012 168 | 2522 169 | 1865 170 | 2466 171 | 2469 172 | 2087 173 | 2584 174 | 2563 175 | 1924 176 | 2143 177 | 1736 178 | 1966 179 | 2533 180 | 2490 181 | 2630 182 | 1973 183 | 2568 184 | 1978 185 | 2664 186 | 2633 187 | 2312 188 | 2178 189 | 1754 190 | 2307 191 | 2480 192 | 1960 193 | 1742 194 | 1962 195 | 2160 196 | 2070 197 | 2553 198 | 2433 199 | 1768 200 | 2659 201 | 2379 202 | 2271 203 | 1776 204 | 2153 205 | 1877 206 | 2027 207 | 2028 208 | 2155 209 | 2196 210 | 2483 211 | 2026 212 | 2158 213 | 2407 214 | 1821 215 | 2131 216 | 2676 217 | 2277 218 | 2489 219 | 2424 220 | 1963 221 | 1808 222 | 1859 223 | 2597 224 | 2548 225 | 2368 226 | 1817 227 | 2405 228 | 2413 229 | 2603 230 | 2350 231 | 2118 232 | 2329 233 | 1969 234 | 2577 235 | 2475 236 | 2467 237 | 2425 238 | 1769 239 | 2092 240 | 2044 241 | 2586 242 | 2608 243 | 1983 244 | 2109 245 | 2649 246 | 1964 247 | 2144 248 | 1902 249 | 2411 250 | 2508 251 | 2360 252 | 1721 253 | 2005 254 | 2014 255 | 2308 256 | 2646 257 | 1949 258 | 1830 259 | 2212 260 | 2596 261 | 1832 262 | 1735 263 | 1866 264 | 2695 265 | 1941 266 | 2546 267 | 2498 268 | 2686 269 | 2665 270 | 1784 271 | 2613 272 | 1970 273 | 2021 274 | 2211 275 | 2516 276 | 2185 277 | 2479 278 | 2699 279 | 2150 280 | 1990 281 | 2063 282 | 2075 283 | 1979 284 | 2094 285 | 1787 286 | 2571 287 | 2690 288 | 1926 289 | 2341 290 | 2566 291 | 1957 292 | 1709 293 | 1955 294 | 2570 295 | 2387 296 | 1811 297 | 2025 298 | 2447 299 | 2696 300 | 2052 301 | 2366 302 | 1857 303 | 2273 304 | 2245 305 | 2672 306 | 2133 307 | 2421 308 | 1929 309 | 2125 310 | 2319 311 | 2641 312 | 2167 313 | 2418 314 | 1765 315 | 1761 316 | 1828 317 | 2188 318 | 1972 319 | 1997 320 | 2419 321 | 2289 322 | 2296 323 | 2587 324 | 2051 325 | 2440 326 | 2053 327 | 2191 328 | 1923 329 | 2164 330 | 1861 331 | 2339 332 | 2333 333 | 2523 334 | 2670 335 | 2121 336 | 1921 337 | 1724 338 | 2253 339 | 2374 340 | 1940 341 | 2545 342 | 2301 343 | 2244 344 | 2156 345 | 1849 346 | 2551 347 | 2011 348 | 2279 349 | 2572 350 | 1757 351 | 2400 352 | 2569 353 | 2072 354 | 2526 355 | 2173 356 | 2069 357 | 2036 358 | 1819 359 | 1734 360 | 1880 361 | 2137 362 | 2408 363 | 2226 364 | 2604 365 | 1771 366 | 2698 367 | 2187 368 | 2060 369 | 1756 370 | 2201 371 | 2066 372 | 2439 373 | 1844 374 | 1772 375 | 2383 376 | 2398 377 | 1708 378 | 1992 379 | 1959 380 | 1794 381 | 2426 382 | 2702 383 | 2444 384 | 1944 385 | 1829 386 | 2660 387 | 2497 388 | 2607 389 | 2343 390 | 1730 391 | 2624 392 | 1790 393 | 1935 394 | 1967 395 | 2401 396 | 2255 397 | 2355 398 | 2348 399 | 1931 400 | 2183 401 | 2161 402 | 2701 403 | 1948 404 | 2501 405 | 2192 406 | 2404 407 | 2209 408 | 2331 409 | 1810 410 | 2363 411 | 2334 412 | 1887 413 | 2393 414 | 2557 415 | 1719 416 | 1732 417 | 1986 418 | 2037 419 | 2056 420 | 1867 421 | 2126 422 | 1932 423 | 2117 424 | 1807 425 | 1801 426 | 1743 427 | 2041 428 | 1843 429 | 2388 430 | 2221 431 | 1833 432 | 2677 433 | 1778 434 | 2661 435 | 2306 436 | 2394 437 | 2106 438 | 2430 439 | 2371 440 | 2606 441 | 2353 442 | 2269 443 | 2317 444 | 2645 445 | 2372 446 | 2550 447 | 2043 448 | 1968 449 | 2165 450 | 2310 451 | 1985 452 | 2446 453 | 1982 454 | 2377 455 | 2207 456 | 1818 457 | 1913 458 | 1766 459 | 1722 460 | 1894 461 | 2020 462 | 1881 463 | 2621 464 | 2409 465 | 2261 466 | 2458 467 | 2096 468 | 1712 469 | 2594 470 | 2293 471 | 2048 472 | 2359 473 | 1839 474 | 2392 475 | 2254 476 | 1911 477 | 2101 478 | 2367 479 | 1889 480 | 1753 481 | 2555 482 | 2246 483 | 2264 484 | 2010 485 | 2336 486 | 2651 487 | 2017 488 | 2140 489 | 1842 490 | 2019 491 | 1890 492 | 2525 493 | 2134 494 | 2492 495 | 2652 496 | 2040 497 | 2145 498 | 2575 499 | 2166 500 | 1999 501 | 2434 502 | 1711 503 | 2276 504 | 2450 505 | 2389 506 | 2669 507 | 2595 508 | 1814 509 | 2039 510 | 2502 511 | 1896 512 | 2168 513 | 2344 514 | 2637 515 | 2031 516 | 1977 517 | 2380 518 | 1936 519 | 2047 520 | 2460 521 | 2102 522 | 1745 523 | 2650 524 | 2046 525 | 2514 526 | 1980 527 | 2352 528 | 2113 529 | 1713 530 | 2058 531 | 2558 532 | 1718 533 | 1864 534 | 1876 535 | 2338 536 | 1879 537 | 1891 538 | 2186 539 | 2451 540 | 2181 541 | 2638 542 | 2644 543 | 2103 544 | 2591 545 | 2266 546 | 2468 547 | 1869 548 | 2582 549 | 2674 550 | 2361 551 | 2462 552 | 1748 553 | 2215 554 | 2615 555 | 2236 556 | 2248 557 | 2493 558 | 2342 559 | 2449 560 | 2274 561 | 1824 562 | 1852 563 | 1870 564 | 2441 565 | 2356 566 | 1835 567 | 2694 568 | 2602 569 | 2685 570 | 1893 571 | 2544 572 | 2536 573 | 1994 574 | 1853 575 | 1838 576 | 1786 577 | 1930 578 | 2539 579 | 1892 580 | 2265 581 | 2618 582 | 2486 583 | 2583 584 | 2061 585 | 1796 586 | 1806 587 | 2084 588 | 1933 589 | 2095 590 | 2136 591 | 2078 592 | 1884 593 | 2438 594 | 2286 595 | 2138 596 | 1750 597 | 2184 598 | 1799 599 | 2278 600 | 2410 601 | 2642 602 | 2435 603 | 1956 604 | 2399 605 | 1774 606 | 2129 607 | 1898 608 | 1823 609 | 1938 610 | 2299 611 | 1862 612 | 2420 613 | 2673 614 | 1984 615 | 2204 616 | 1717 617 | 2074 618 | 2213 619 | 2436 620 | 2297 621 | 2592 622 | 2667 623 | 2703 624 | 2511 625 | 1779 626 | 1782 627 | 2625 628 | 2365 629 | 2315 630 | 2381 631 | 1788 632 | 1714 633 | 2302 634 | 1927 635 | 2325 636 | 2506 637 | 2169 638 | 2328 639 | 2629 640 | 2128 641 | 2655 642 | 2282 643 | 2073 644 | 2395 645 | 2247 646 | 2521 647 | 2260 648 | 1868 649 | 1988 650 | 2324 651 | 2705 652 | 2541 653 | 1731 654 | 2681 655 | 2707 656 | 2465 657 | 1785 658 | 2149 659 | 2045 660 | 2505 661 | 2611 662 | 2217 663 | 2180 664 | 1904 665 | 2453 666 | 2484 667 | 1871 668 | 2309 669 | 2349 670 | 2482 671 | 2004 672 | 1965 673 | 2406 674 | 2162 675 | 1805 676 | 2654 677 | 2007 678 | 1947 679 | 1981 680 | 2112 681 | 2141 682 | 1720 683 | 1758 684 | 2080 685 | 2330 686 | 2030 687 | 2432 688 | 2089 689 | 2547 690 | 1820 691 | 1815 692 | 2675 693 | 1840 694 | 2658 695 | 2370 696 | 2251 697 | 1908 698 | 2029 699 | 2068 700 | 2513 701 | 2549 702 | 2267 703 | 2580 704 | 2327 705 | 2351 706 | 2111 707 | 2022 708 | 2321 709 | 2614 710 | 2252 711 | 2104 712 | 1822 713 | 2552 714 | 2243 715 | 1798 716 | 2396 717 | 2663 718 | 2564 719 | 2148 720 | 2562 721 | 2684 722 | 2001 723 | 2151 724 | 2706 725 | 2240 726 | 2474 727 | 2303 728 | 2634 729 | 2680 730 | 2055 731 | 2090 732 | 2503 733 | 2347 734 | 2402 735 | 2238 736 | 1950 737 | 2054 738 | 2016 739 | 1872 740 | 2233 741 | 1710 742 | 2032 743 | 2540 744 | 2628 745 | 1795 746 | 2616 747 | 1903 748 | 2531 749 | 2567 750 | 1946 751 | 1897 752 | 2222 753 | 2227 754 | 2627 755 | 1856 756 | 2464 757 | 2241 758 | 2481 759 | 2130 760 | 2311 761 | 2083 762 | 2223 763 | 2284 764 | 2235 765 | 2097 766 | 1752 767 | 2515 768 | 2527 769 | 2385 770 | 2189 771 | 2283 772 | 2182 773 | 2079 774 | 2375 775 | 2174 776 | 2437 777 | 1993 778 | 2517 779 | 2443 780 | 2224 781 | 2648 782 | 2171 783 | 2290 784 | 2542 785 | 2038 786 | 1855 787 | 1831 788 | 1759 789 | 1848 790 | 2445 791 | 1827 792 | 2429 793 | 2205 794 | 2598 795 | 2657 796 | 1728 797 | 2065 798 | 1918 799 | 2427 800 | 2573 801 | 2620 802 | 2292 803 | 1777 804 | 2008 805 | 1875 806 | 2288 807 | 2256 808 | 2033 809 | 2470 810 | 2585 811 | 2610 812 | 2082 813 | 2230 814 | 1915 815 | 1847 816 | 2337 817 | 2512 818 | 2386 819 | 2006 820 | 2653 821 | 2346 822 | 1951 823 | 2110 824 | 2639 825 | 2520 826 | 1939 827 | 2683 828 | 2139 829 | 2220 830 | 1910 831 | 2237 832 | 1900 833 | 1836 834 | 2197 835 | 1716 836 | 1860 837 | 2077 838 | 2519 839 | 2538 840 | 2323 841 | 1914 842 | 1971 843 | 1845 844 | 2132 845 | 1802 846 | 1907 847 | 2640 848 | 2496 849 | 2281 850 | 2198 851 | 2416 852 | 2285 853 | 1755 854 | 2431 855 | 2071 856 | 2249 857 | 2123 858 | 1727 859 | 2459 860 | 2304 861 | 2199 862 | 1791 863 | 1809 864 | 1780 865 | 2210 866 | 2417 867 | 1874 868 | 1878 869 | 2116 870 | 1961 871 | 1863 872 | 2579 873 | 2477 874 | 2228 875 | 2332 876 | 2578 877 | 2457 878 | 2024 879 | 1934 880 | 2316 881 | 1841 882 | 1764 883 | 1737 884 | 2322 885 | 2239 886 | 2294 887 | 1729 888 | 2488 889 | 1974 890 | 2473 891 | 2098 892 | 2612 893 | 1834 894 | 2340 895 | 2423 896 | 2175 897 | 2280 898 | 2617 899 | 2208 900 | 2560 901 | 1741 902 | 2600 903 | 2059 904 | 1747 905 | 2242 906 | 2700 907 | 2232 908 | 2057 909 | 2147 910 | 2682 911 | 1792 912 | 1826 913 | 2120 914 | 1895 915 | 2364 916 | 2163 917 | 1851 918 | 2391 919 | 2414 920 | 2452 921 | 1803 922 | 1989 923 | 2623 924 | 2200 925 | 2528 926 | 2415 927 | 1804 928 | 2146 929 | 2619 930 | 2687 931 | 1762 932 | 2172 933 | 2270 934 | 2678 935 | 2593 936 | 2448 937 | 1882 938 | 2257 939 | 2500 940 | 1899 941 | 2478 942 | 2412 943 | 2107 944 | 1746 945 | 2428 946 | 2115 947 | 1800 948 | 1901 949 | 2397 950 | 2530 951 | 1912 952 | 2108 953 | 2206 954 | 2091 955 | 1740 956 | 2219 957 | 1976 958 | 2099 959 | 2142 960 | 2671 961 | 2668 962 | 2216 963 | 2272 964 | 2229 965 | 2666 966 | 2456 967 | 2534 968 | 2697 969 | 2688 970 | 2062 971 | 2691 972 | 2689 973 | 2154 974 | 2590 975 | 2626 976 | 2390 977 | 1813 978 | 2067 979 | 1952 980 | 2518 981 | 2358 982 | 1789 983 | 2076 984 | 2049 985 | 2119 986 | 2013 987 | 2124 988 | 2556 989 | 2105 990 | 2093 991 | 1885 992 | 2305 993 | 2354 994 | 2135 995 | 2601 996 | 1770 997 | 1995 998 | 2504 999 | 1749 1000 | 2157 1001 | -------------------------------------------------------------------------------- /data/ind.cora.tx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PetarV-/GAT/5af87e7fce2b90ae1cbd621cd58059036a3c7436/data/ind.cora.tx -------------------------------------------------------------------------------- /data/ind.cora.ty: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PetarV-/GAT/5af87e7fce2b90ae1cbd621cd58059036a3c7436/data/ind.cora.ty -------------------------------------------------------------------------------- /data/ind.cora.x: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PetarV-/GAT/5af87e7fce2b90ae1cbd621cd58059036a3c7436/data/ind.cora.x -------------------------------------------------------------------------------- /data/ind.cora.y: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PetarV-/GAT/5af87e7fce2b90ae1cbd621cd58059036a3c7436/data/ind.cora.y -------------------------------------------------------------------------------- /execute_cora.py: -------------------------------------------------------------------------------- 1 | import time 2 | import numpy as np 3 | import tensorflow as tf 4 | 5 | from models import GAT 6 | from utils import process 7 | 8 | checkpt_file = 'pre_trained/cora/mod_cora.ckpt' 9 | 10 | dataset = 'cora' 11 | 12 | # training params 13 | batch_size = 1 14 | nb_epochs = 100000 15 | patience = 100 16 | lr = 0.005 # learning rate 17 | l2_coef = 0.0005 # weight decay 18 | hid_units = [8] # numbers of hidden units per each attention head in each layer 19 | n_heads = [8, 1] # additional entry for the output layer 20 | residual = False 21 | nonlinearity = tf.nn.elu 22 | model = GAT 23 | 24 | print('Dataset: ' + dataset) 25 | print('----- Opt. hyperparams -----') 26 | print('lr: ' + str(lr)) 27 | print('l2_coef: ' + str(l2_coef)) 28 | print('----- Archi. hyperparams -----') 29 | print('nb. layers: ' + str(len(hid_units))) 30 | print('nb. units per layer: ' + str(hid_units)) 31 | print('nb. attention heads: ' + str(n_heads)) 32 | print('residual: ' + str(residual)) 33 | print('nonlinearity: ' + str(nonlinearity)) 34 | print('model: ' + str(model)) 35 | 36 | adj, features, y_train, y_val, y_test, train_mask, val_mask, test_mask = process.load_data(dataset) 37 | features, spars = process.preprocess_features(features) 38 | 39 | nb_nodes = features.shape[0] 40 | ft_size = features.shape[1] 41 | nb_classes = y_train.shape[1] 42 | 43 | adj = adj.todense() 44 | 45 | features = features[np.newaxis] 46 | adj = adj[np.newaxis] 47 | y_train = y_train[np.newaxis] 48 | y_val = y_val[np.newaxis] 49 | y_test = y_test[np.newaxis] 50 | train_mask = train_mask[np.newaxis] 51 | val_mask = val_mask[np.newaxis] 52 | test_mask = test_mask[np.newaxis] 53 | 54 | biases = process.adj_to_bias(adj, [nb_nodes], nhood=1) 55 | 56 | with tf.Graph().as_default(): 57 | with tf.name_scope('input'): 58 | ftr_in = tf.placeholder(dtype=tf.float32, shape=(batch_size, nb_nodes, ft_size)) 59 | bias_in = tf.placeholder(dtype=tf.float32, shape=(batch_size, nb_nodes, nb_nodes)) 60 | lbl_in = tf.placeholder(dtype=tf.int32, shape=(batch_size, nb_nodes, nb_classes)) 61 | msk_in = tf.placeholder(dtype=tf.int32, shape=(batch_size, nb_nodes)) 62 | attn_drop = tf.placeholder(dtype=tf.float32, shape=()) 63 | ffd_drop = tf.placeholder(dtype=tf.float32, shape=()) 64 | is_train = tf.placeholder(dtype=tf.bool, shape=()) 65 | 66 | logits = model.inference(ftr_in, nb_classes, nb_nodes, is_train, 67 | attn_drop, ffd_drop, 68 | bias_mat=bias_in, 69 | hid_units=hid_units, n_heads=n_heads, 70 | residual=residual, activation=nonlinearity) 71 | log_resh = tf.reshape(logits, [-1, nb_classes]) 72 | lab_resh = tf.reshape(lbl_in, [-1, nb_classes]) 73 | msk_resh = tf.reshape(msk_in, [-1]) 74 | loss = model.masked_softmax_cross_entropy(log_resh, lab_resh, msk_resh) 75 | accuracy = model.masked_accuracy(log_resh, lab_resh, msk_resh) 76 | 77 | train_op = model.training(loss, lr, l2_coef) 78 | 79 | saver = tf.train.Saver() 80 | 81 | init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) 82 | 83 | vlss_mn = np.inf 84 | vacc_mx = 0.0 85 | curr_step = 0 86 | 87 | with tf.Session() as sess: 88 | sess.run(init_op) 89 | 90 | train_loss_avg = 0 91 | train_acc_avg = 0 92 | val_loss_avg = 0 93 | val_acc_avg = 0 94 | 95 | for epoch in range(nb_epochs): 96 | tr_step = 0 97 | tr_size = features.shape[0] 98 | 99 | while tr_step * batch_size < tr_size: 100 | _, loss_value_tr, acc_tr = sess.run([train_op, loss, accuracy], 101 | feed_dict={ 102 | ftr_in: features[tr_step*batch_size:(tr_step+1)*batch_size], 103 | bias_in: biases[tr_step*batch_size:(tr_step+1)*batch_size], 104 | lbl_in: y_train[tr_step*batch_size:(tr_step+1)*batch_size], 105 | msk_in: train_mask[tr_step*batch_size:(tr_step+1)*batch_size], 106 | is_train: True, 107 | attn_drop: 0.6, ffd_drop: 0.6}) 108 | train_loss_avg += loss_value_tr 109 | train_acc_avg += acc_tr 110 | tr_step += 1 111 | 112 | vl_step = 0 113 | vl_size = features.shape[0] 114 | 115 | while vl_step * batch_size < vl_size: 116 | loss_value_vl, acc_vl = sess.run([loss, accuracy], 117 | feed_dict={ 118 | ftr_in: features[vl_step*batch_size:(vl_step+1)*batch_size], 119 | bias_in: biases[vl_step*batch_size:(vl_step+1)*batch_size], 120 | lbl_in: y_val[vl_step*batch_size:(vl_step+1)*batch_size], 121 | msk_in: val_mask[vl_step*batch_size:(vl_step+1)*batch_size], 122 | is_train: False, 123 | attn_drop: 0.0, ffd_drop: 0.0}) 124 | val_loss_avg += loss_value_vl 125 | val_acc_avg += acc_vl 126 | vl_step += 1 127 | 128 | print('Training: loss = %.5f, acc = %.5f | Val: loss = %.5f, acc = %.5f' % 129 | (train_loss_avg/tr_step, train_acc_avg/tr_step, 130 | val_loss_avg/vl_step, val_acc_avg/vl_step)) 131 | 132 | if val_acc_avg/vl_step >= vacc_mx or val_loss_avg/vl_step <= vlss_mn: 133 | if val_acc_avg/vl_step >= vacc_mx and val_loss_avg/vl_step <= vlss_mn: 134 | vacc_early_model = val_acc_avg/vl_step 135 | vlss_early_model = val_loss_avg/vl_step 136 | saver.save(sess, checkpt_file) 137 | vacc_mx = np.max((val_acc_avg/vl_step, vacc_mx)) 138 | vlss_mn = np.min((val_loss_avg/vl_step, vlss_mn)) 139 | curr_step = 0 140 | else: 141 | curr_step += 1 142 | if curr_step == patience: 143 | print('Early stop! Min loss: ', vlss_mn, ', Max accuracy: ', vacc_mx) 144 | print('Early stop model validation loss: ', vlss_early_model, ', accuracy: ', vacc_early_model) 145 | break 146 | 147 | train_loss_avg = 0 148 | train_acc_avg = 0 149 | val_loss_avg = 0 150 | val_acc_avg = 0 151 | 152 | saver.restore(sess, checkpt_file) 153 | 154 | ts_size = features.shape[0] 155 | ts_step = 0 156 | ts_loss = 0.0 157 | ts_acc = 0.0 158 | 159 | while ts_step * batch_size < ts_size: 160 | loss_value_ts, acc_ts = sess.run([loss, accuracy], 161 | feed_dict={ 162 | ftr_in: features[ts_step*batch_size:(ts_step+1)*batch_size], 163 | bias_in: biases[ts_step*batch_size:(ts_step+1)*batch_size], 164 | lbl_in: y_test[ts_step*batch_size:(ts_step+1)*batch_size], 165 | msk_in: test_mask[ts_step*batch_size:(ts_step+1)*batch_size], 166 | is_train: False, 167 | attn_drop: 0.0, ffd_drop: 0.0}) 168 | ts_loss += loss_value_ts 169 | ts_acc += acc_ts 170 | ts_step += 1 171 | 172 | print('Test loss:', ts_loss/ts_step, '; Test accuracy:', ts_acc/ts_step) 173 | 174 | sess.close() 175 | -------------------------------------------------------------------------------- /execute_cora_sparse.py: -------------------------------------------------------------------------------- 1 | import time 2 | import scipy.sparse as sp 3 | import numpy as np 4 | import tensorflow as tf 5 | import argparse 6 | 7 | from models import GAT 8 | from models import SpGAT 9 | from utils import process 10 | 11 | checkpt_file = 'pre_trained/cora/mod_cora.ckpt' 12 | 13 | dataset = 'cora' 14 | 15 | # training params 16 | batch_size = 1 17 | nb_epochs = 100000 18 | patience = 100 19 | lr = 0.005 # learning rate 20 | l2_coef = 0.0005 # weight decay 21 | hid_units = [8] # numbers of hidden units per each attention head in each layer 22 | n_heads = [8, 1] # additional entry for the output layer 23 | residual = False 24 | nonlinearity = tf.nn.elu 25 | # model = GAT 26 | model = SpGAT 27 | 28 | print('Dataset: ' + dataset) 29 | print('----- Opt. hyperparams -----') 30 | print('lr: ' + str(lr)) 31 | print('l2_coef: ' + str(l2_coef)) 32 | print('----- Archi. hyperparams -----') 33 | print('nb. layers: ' + str(len(hid_units))) 34 | print('nb. units per layer: ' + str(hid_units)) 35 | print('nb. attention heads: ' + str(n_heads)) 36 | print('residual: ' + str(residual)) 37 | print('nonlinearity: ' + str(nonlinearity)) 38 | print('model: ' + str(model)) 39 | 40 | sparse = True 41 | 42 | adj, features, y_train, y_val, y_test, train_mask, val_mask, test_mask = process.load_data(dataset) 43 | features, spars = process.preprocess_features(features) 44 | 45 | nb_nodes = features.shape[0] 46 | ft_size = features.shape[1] 47 | nb_classes = y_train.shape[1] 48 | 49 | features = features[np.newaxis] 50 | y_train = y_train[np.newaxis] 51 | y_val = y_val[np.newaxis] 52 | y_test = y_test[np.newaxis] 53 | train_mask = train_mask[np.newaxis] 54 | val_mask = val_mask[np.newaxis] 55 | test_mask = test_mask[np.newaxis] 56 | 57 | if sparse: 58 | biases = process.preprocess_adj_bias(adj) 59 | else: 60 | adj = adj.todense() 61 | adj = adj[np.newaxis] 62 | biases = process.adj_to_bias(adj, [nb_nodes], nhood=1) 63 | 64 | with tf.Graph().as_default(): 65 | with tf.name_scope('input'): 66 | ftr_in = tf.placeholder(dtype=tf.float32, shape=(batch_size, nb_nodes, ft_size)) 67 | if sparse: 68 | #bias_idx = tf.placeholder(tf.int64) 69 | #bias_val = tf.placeholder(tf.float32) 70 | #bias_shape = tf.placeholder(tf.int64) 71 | bias_in = tf.sparse_placeholder(dtype=tf.float32) 72 | else: 73 | bias_in = tf.placeholder(dtype=tf.float32, shape=(batch_size, nb_nodes, nb_nodes)) 74 | lbl_in = tf.placeholder(dtype=tf.int32, shape=(batch_size, nb_nodes, nb_classes)) 75 | msk_in = tf.placeholder(dtype=tf.int32, shape=(batch_size, nb_nodes)) 76 | attn_drop = tf.placeholder(dtype=tf.float32, shape=()) 77 | ffd_drop = tf.placeholder(dtype=tf.float32, shape=()) 78 | is_train = tf.placeholder(dtype=tf.bool, shape=()) 79 | 80 | logits = model.inference(ftr_in, nb_classes, nb_nodes, is_train, 81 | attn_drop, ffd_drop, 82 | bias_mat=bias_in, 83 | hid_units=hid_units, n_heads=n_heads, 84 | residual=residual, activation=nonlinearity) 85 | log_resh = tf.reshape(logits, [-1, nb_classes]) 86 | lab_resh = tf.reshape(lbl_in, [-1, nb_classes]) 87 | msk_resh = tf.reshape(msk_in, [-1]) 88 | loss = model.masked_softmax_cross_entropy(log_resh, lab_resh, msk_resh) 89 | accuracy = model.masked_accuracy(log_resh, lab_resh, msk_resh) 90 | 91 | train_op = model.training(loss, lr, l2_coef) 92 | 93 | saver = tf.train.Saver() 94 | 95 | init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) 96 | 97 | vlss_mn = np.inf 98 | vacc_mx = 0.0 99 | curr_step = 0 100 | 101 | with tf.Session() as sess: 102 | sess.run(init_op) 103 | 104 | train_loss_avg = 0 105 | train_acc_avg = 0 106 | val_loss_avg = 0 107 | val_acc_avg = 0 108 | 109 | for epoch in range(nb_epochs): 110 | tr_step = 0 111 | tr_size = features.shape[0] 112 | 113 | while tr_step * batch_size < tr_size: 114 | if sparse: 115 | bbias = biases 116 | else: 117 | bbias = biases[tr_step*batch_size:(tr_step+1)*batch_size] 118 | 119 | _, loss_value_tr, acc_tr = sess.run([train_op, loss, accuracy], 120 | feed_dict={ 121 | ftr_in: features[tr_step*batch_size:(tr_step+1)*batch_size], 122 | bias_in: bbias, 123 | lbl_in: y_train[tr_step*batch_size:(tr_step+1)*batch_size], 124 | msk_in: train_mask[tr_step*batch_size:(tr_step+1)*batch_size], 125 | is_train: True, 126 | attn_drop: 0.6, ffd_drop: 0.6}) 127 | train_loss_avg += loss_value_tr 128 | train_acc_avg += acc_tr 129 | tr_step += 1 130 | 131 | vl_step = 0 132 | vl_size = features.shape[0] 133 | 134 | while vl_step * batch_size < vl_size: 135 | if sparse: 136 | bbias = biases 137 | else: 138 | bbias = biases[vl_step*batch_size:(vl_step+1)*batch_size] 139 | loss_value_vl, acc_vl = sess.run([loss, accuracy], 140 | feed_dict={ 141 | ftr_in: features[vl_step*batch_size:(vl_step+1)*batch_size], 142 | bias_in: bbias, 143 | lbl_in: y_val[vl_step*batch_size:(vl_step+1)*batch_size], 144 | msk_in: val_mask[vl_step*batch_size:(vl_step+1)*batch_size], 145 | is_train: False, 146 | attn_drop: 0.0, ffd_drop: 0.0}) 147 | val_loss_avg += loss_value_vl 148 | val_acc_avg += acc_vl 149 | vl_step += 1 150 | 151 | print('Training: loss = %.5f, acc = %.5f | Val: loss = %.5f, acc = %.5f' % 152 | (train_loss_avg/tr_step, train_acc_avg/tr_step, 153 | val_loss_avg/vl_step, val_acc_avg/vl_step)) 154 | 155 | if val_acc_avg/vl_step >= vacc_mx or val_loss_avg/vl_step <= vlss_mn: 156 | if val_acc_avg/vl_step >= vacc_mx and val_loss_avg/vl_step <= vlss_mn: 157 | vacc_early_model = val_acc_avg/vl_step 158 | vlss_early_model = val_loss_avg/vl_step 159 | saver.save(sess, checkpt_file) 160 | vacc_mx = np.max((val_acc_avg/vl_step, vacc_mx)) 161 | vlss_mn = np.min((val_loss_avg/vl_step, vlss_mn)) 162 | curr_step = 0 163 | else: 164 | curr_step += 1 165 | if curr_step == patience: 166 | print('Early stop! Min loss: ', vlss_mn, ', Max accuracy: ', vacc_mx) 167 | print('Early stop model validation loss: ', vlss_early_model, ', accuracy: ', vacc_early_model) 168 | break 169 | 170 | train_loss_avg = 0 171 | train_acc_avg = 0 172 | val_loss_avg = 0 173 | val_acc_avg = 0 174 | 175 | saver.restore(sess, checkpt_file) 176 | 177 | ts_size = features.shape[0] 178 | ts_step = 0 179 | ts_loss = 0.0 180 | ts_acc = 0.0 181 | 182 | while ts_step * batch_size < ts_size: 183 | if sparse: 184 | bbias = biases 185 | else: 186 | bbias = biases[ts_step*batch_size:(ts_step+1)*batch_size] 187 | loss_value_ts, acc_ts = sess.run([loss, accuracy], 188 | feed_dict={ 189 | ftr_in: features[ts_step*batch_size:(ts_step+1)*batch_size], 190 | bias_in: bbias, 191 | lbl_in: y_test[ts_step*batch_size:(ts_step+1)*batch_size], 192 | msk_in: test_mask[ts_step*batch_size:(ts_step+1)*batch_size], 193 | is_train: False, 194 | attn_drop: 0.0, ffd_drop: 0.0}) 195 | ts_loss += loss_value_ts 196 | ts_acc += acc_ts 197 | ts_step += 1 198 | 199 | print('Test loss:', ts_loss/ts_step, '; Test accuracy:', ts_acc/ts_step) 200 | 201 | sess.close() 202 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .gat import GAT 2 | from .sp_gat import SpGAT 3 | -------------------------------------------------------------------------------- /models/base_gattn.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | class BaseGAttN: 4 | def loss(logits, labels, nb_classes, class_weights): 5 | sample_wts = tf.reduce_sum(tf.multiply(tf.one_hot(labels, nb_classes), class_weights), axis=-1) 6 | xentropy = tf.multiply(tf.nn.sparse_softmax_cross_entropy_with_logits( 7 | labels=labels, logits=logits), sample_wts) 8 | return tf.reduce_mean(xentropy, name='xentropy_mean') 9 | 10 | def training(loss, lr, l2_coef): 11 | # weight decay 12 | vars = tf.trainable_variables() 13 | lossL2 = tf.add_n([tf.nn.l2_loss(v) for v in vars if v.name not 14 | in ['bias', 'gamma', 'b', 'g', 'beta']]) * l2_coef 15 | 16 | # optimizer 17 | opt = tf.train.AdamOptimizer(learning_rate=lr) 18 | 19 | # training op 20 | train_op = opt.minimize(loss+lossL2) 21 | 22 | return train_op 23 | 24 | def preshape(logits, labels, nb_classes): 25 | new_sh_lab = [-1] 26 | new_sh_log = [-1, nb_classes] 27 | log_resh = tf.reshape(logits, new_sh_log) 28 | lab_resh = tf.reshape(labels, new_sh_lab) 29 | return log_resh, lab_resh 30 | 31 | def confmat(logits, labels): 32 | preds = tf.argmax(logits, axis=1) 33 | return tf.confusion_matrix(labels, preds) 34 | 35 | ########################## 36 | # Adapted from tkipf/gcn # 37 | ########################## 38 | 39 | def masked_softmax_cross_entropy(logits, labels, mask): 40 | """Softmax cross-entropy loss with masking.""" 41 | loss = tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=labels) 42 | mask = tf.cast(mask, dtype=tf.float32) 43 | mask /= tf.reduce_mean(mask) 44 | loss *= mask 45 | return tf.reduce_mean(loss) 46 | 47 | def masked_sigmoid_cross_entropy(logits, labels, mask): 48 | """Softmax cross-entropy loss with masking.""" 49 | labels = tf.cast(labels, dtype=tf.float32) 50 | loss = tf.nn.sigmoid_cross_entropy_with_logits(logits=logits, labels=labels) 51 | loss=tf.reduce_mean(loss,axis=1) 52 | mask = tf.cast(mask, dtype=tf.float32) 53 | mask /= tf.reduce_mean(mask) 54 | loss *= mask 55 | return tf.reduce_mean(loss) 56 | 57 | def masked_accuracy(logits, labels, mask): 58 | """Accuracy with masking.""" 59 | correct_prediction = tf.equal(tf.argmax(logits, 1), tf.argmax(labels, 1)) 60 | accuracy_all = tf.cast(correct_prediction, tf.float32) 61 | mask = tf.cast(mask, dtype=tf.float32) 62 | mask /= tf.reduce_mean(mask) 63 | accuracy_all *= mask 64 | return tf.reduce_mean(accuracy_all) 65 | 66 | def micro_f1(logits, labels, mask): 67 | """Accuracy with masking.""" 68 | predicted = tf.round(tf.nn.sigmoid(logits)) 69 | 70 | # Use integers to avoid any nasty FP behaviour 71 | predicted = tf.cast(predicted, dtype=tf.int32) 72 | labels = tf.cast(labels, dtype=tf.int32) 73 | mask = tf.cast(mask, dtype=tf.int32) 74 | 75 | # expand the mask so that broadcasting works ([nb_nodes, 1]) 76 | mask = tf.expand_dims(mask, -1) 77 | 78 | # Count true positives, true negatives, false positives and false negatives. 79 | tp = tf.count_nonzero(predicted * labels * mask) 80 | tn = tf.count_nonzero((predicted - 1) * (labels - 1) * mask) 81 | fp = tf.count_nonzero(predicted * (labels - 1) * mask) 82 | fn = tf.count_nonzero((predicted - 1) * labels * mask) 83 | 84 | # Calculate accuracy, precision, recall and F1 score. 85 | precision = tp / (tp + fp) 86 | recall = tp / (tp + fn) 87 | fmeasure = (2 * precision * recall) / (precision + recall) 88 | fmeasure = tf.cast(fmeasure, tf.float32) 89 | return fmeasure 90 | -------------------------------------------------------------------------------- /models/gat.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | 4 | from utils import layers 5 | from models.base_gattn import BaseGAttN 6 | 7 | class GAT(BaseGAttN): 8 | def inference(inputs, nb_classes, nb_nodes, training, attn_drop, ffd_drop, 9 | bias_mat, hid_units, n_heads, activation=tf.nn.elu, residual=False): 10 | attns = [] 11 | for _ in range(n_heads[0]): 12 | attns.append(layers.attn_head(inputs, bias_mat=bias_mat, 13 | out_sz=hid_units[0], activation=activation, 14 | in_drop=ffd_drop, coef_drop=attn_drop, residual=False)) 15 | h_1 = tf.concat(attns, axis=-1) 16 | for i in range(1, len(hid_units)): 17 | h_old = h_1 18 | attns = [] 19 | for _ in range(n_heads[i]): 20 | attns.append(layers.attn_head(h_1, bias_mat=bias_mat, 21 | out_sz=hid_units[i], activation=activation, 22 | in_drop=ffd_drop, coef_drop=attn_drop, residual=residual)) 23 | h_1 = tf.concat(attns, axis=-1) 24 | out = [] 25 | for i in range(n_heads[-1]): 26 | out.append(layers.attn_head(h_1, bias_mat=bias_mat, 27 | out_sz=nb_classes, activation=lambda x: x, 28 | in_drop=ffd_drop, coef_drop=attn_drop, residual=False)) 29 | logits = tf.add_n(out) / n_heads[-1] 30 | 31 | return logits 32 | -------------------------------------------------------------------------------- /models/sp_gat.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | 4 | from utils import layers 5 | from models.base_gattn import BaseGAttN 6 | 7 | class SpGAT(BaseGAttN): 8 | def inference(inputs, nb_classes, nb_nodes, training, attn_drop, ffd_drop, 9 | bias_mat, hid_units, n_heads, activation=tf.nn.elu, 10 | residual=False): 11 | attns = [] 12 | for _ in range(n_heads[0]): 13 | attns.append(layers.sp_attn_head(inputs, 14 | adj_mat=bias_mat, 15 | out_sz=hid_units[0], activation=activation, nb_nodes=nb_nodes, 16 | in_drop=ffd_drop, coef_drop=attn_drop, residual=False)) 17 | h_1 = tf.concat(attns, axis=-1) 18 | for i in range(1, len(hid_units)): 19 | h_old = h_1 20 | attns = [] 21 | for _ in range(n_heads[i]): 22 | attns.append(layers.sp_attn_head(h_1, 23 | adj_mat=bias_mat, 24 | out_sz=hid_units[i], activation=activation, nb_nodes=nb_nodes, 25 | in_drop=ffd_drop, coef_drop=attn_drop, residual=residual)) 26 | h_1 = tf.concat(attns, axis=-1) 27 | out = [] 28 | for i in range(n_heads[-1]): 29 | out.append(layers.sp_attn_head(h_1, adj_mat=bias_mat, 30 | out_sz=nb_classes, activation=lambda x: x, nb_nodes=nb_nodes, 31 | in_drop=ffd_drop, coef_drop=attn_drop, residual=False)) 32 | logits = tf.add_n(out) / n_heads[-1] 33 | 34 | return logits 35 | -------------------------------------------------------------------------------- /pre_trained/cora/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "mod_cora.ckpt" 2 | all_model_checkpoint_paths: "mod_cora.ckpt" 3 | -------------------------------------------------------------------------------- /pre_trained/cora/mod_cora.ckpt.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PetarV-/GAT/5af87e7fce2b90ae1cbd621cd58059036a3c7436/pre_trained/cora/mod_cora.ckpt.data-00000-of-00001 -------------------------------------------------------------------------------- /pre_trained/cora/mod_cora.ckpt.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PetarV-/GAT/5af87e7fce2b90ae1cbd621cd58059036a3c7436/pre_trained/cora/mod_cora.ckpt.index -------------------------------------------------------------------------------- /pre_trained/cora/mod_cora.ckpt.meta: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PetarV-/GAT/5af87e7fce2b90ae1cbd621cd58059036a3c7436/pre_trained/cora/mod_cora.ckpt.meta -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PetarV-/GAT/5af87e7fce2b90ae1cbd621cd58059036a3c7436/utils/__init__.py -------------------------------------------------------------------------------- /utils/layers.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | 4 | conv1d = tf.layers.conv1d 5 | 6 | def attn_head(seq, out_sz, bias_mat, activation, in_drop=0.0, coef_drop=0.0, residual=False): 7 | with tf.name_scope('my_attn'): 8 | if in_drop != 0.0: 9 | seq = tf.nn.dropout(seq, 1.0 - in_drop) 10 | 11 | seq_fts = tf.layers.conv1d(seq, out_sz, 1, use_bias=False) 12 | 13 | # simplest self-attention possible 14 | f_1 = tf.layers.conv1d(seq_fts, 1, 1) 15 | f_2 = tf.layers.conv1d(seq_fts, 1, 1) 16 | logits = f_1 + tf.transpose(f_2, [0, 2, 1]) 17 | coefs = tf.nn.softmax(tf.nn.leaky_relu(logits) + bias_mat) 18 | 19 | if coef_drop != 0.0: 20 | coefs = tf.nn.dropout(coefs, 1.0 - coef_drop) 21 | if in_drop != 0.0: 22 | seq_fts = tf.nn.dropout(seq_fts, 1.0 - in_drop) 23 | 24 | vals = tf.matmul(coefs, seq_fts) 25 | ret = tf.contrib.layers.bias_add(vals) 26 | 27 | # residual connection 28 | if residual: 29 | if seq.shape[-1] != ret.shape[-1]: 30 | ret = ret + conv1d(seq, ret.shape[-1], 1) # activation 31 | else: 32 | ret = ret + seq 33 | 34 | return activation(ret) # activation 35 | 36 | # Experimental sparse attention head (for running on datasets such as Pubmed) 37 | # N.B. Because of limitations of current TF implementation, will work _only_ if batch_size = 1! 38 | def sp_attn_head(seq, out_sz, adj_mat, activation, nb_nodes, in_drop=0.0, coef_drop=0.0, residual=False): 39 | with tf.name_scope('sp_attn'): 40 | if in_drop != 0.0: 41 | seq = tf.nn.dropout(seq, 1.0 - in_drop) 42 | 43 | seq_fts = tf.layers.conv1d(seq, out_sz, 1, use_bias=False) 44 | 45 | # simplest self-attention possible 46 | f_1 = tf.layers.conv1d(seq_fts, 1, 1) 47 | f_2 = tf.layers.conv1d(seq_fts, 1, 1) 48 | 49 | f_1 = tf.reshape(f_1, (nb_nodes, 1)) 50 | f_2 = tf.reshape(f_2, (nb_nodes, 1)) 51 | 52 | f_1 = adj_mat*f_1 53 | f_2 = adj_mat * tf.transpose(f_2, [1,0]) 54 | 55 | logits = tf.sparse_add(f_1, f_2) 56 | lrelu = tf.SparseTensor(indices=logits.indices, 57 | values=tf.nn.leaky_relu(logits.values), 58 | dense_shape=logits.dense_shape) 59 | coefs = tf.sparse_softmax(lrelu) 60 | 61 | if coef_drop != 0.0: 62 | coefs = tf.SparseTensor(indices=coefs.indices, 63 | values=tf.nn.dropout(coefs.values, 1.0 - coef_drop), 64 | dense_shape=coefs.dense_shape) 65 | if in_drop != 0.0: 66 | seq_fts = tf.nn.dropout(seq_fts, 1.0 - in_drop) 67 | 68 | # As tf.sparse_tensor_dense_matmul expects its arguments to have rank-2, 69 | # here we make an assumption that our input is of batch size 1, and reshape appropriately. 70 | # The method will fail in all other cases! 71 | coefs = tf.sparse_reshape(coefs, [nb_nodes, nb_nodes]) 72 | seq_fts = tf.squeeze(seq_fts) 73 | vals = tf.sparse_tensor_dense_matmul(coefs, seq_fts) 74 | vals = tf.expand_dims(vals, axis=0) 75 | vals.set_shape([1, nb_nodes, out_sz]) 76 | ret = tf.contrib.layers.bias_add(vals) 77 | 78 | # residual connection 79 | if residual: 80 | if seq.shape[-1] != ret.shape[-1]: 81 | ret = ret + conv1d(seq, ret.shape[-1], 1) # activation 82 | else: 83 | ret = ret + seq 84 | 85 | return activation(ret) # activation 86 | 87 | -------------------------------------------------------------------------------- /utils/process.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pickle as pkl 3 | import networkx as nx 4 | import scipy.sparse as sp 5 | from scipy.sparse.linalg.eigen.arpack import eigsh 6 | import sys 7 | 8 | """ 9 | Prepare adjacency matrix by expanding up to a given neighbourhood. 10 | This will insert loops on every node. 11 | Finally, the matrix is converted to bias vectors. 12 | Expected shape: [graph, nodes, nodes] 13 | """ 14 | def adj_to_bias(adj, sizes, nhood=1): 15 | nb_graphs = adj.shape[0] 16 | mt = np.empty(adj.shape) 17 | for g in range(nb_graphs): 18 | mt[g] = np.eye(adj.shape[1]) 19 | for _ in range(nhood): 20 | mt[g] = np.matmul(mt[g], (adj[g] + np.eye(adj.shape[1]))) 21 | for i in range(sizes[g]): 22 | for j in range(sizes[g]): 23 | if mt[g][i][j] > 0.0: 24 | mt[g][i][j] = 1.0 25 | return -1e9 * (1.0 - mt) 26 | 27 | 28 | ############################################### 29 | # This section of code adapted from tkipf/gcn # 30 | ############################################### 31 | 32 | def parse_index_file(filename): 33 | """Parse index file.""" 34 | index = [] 35 | for line in open(filename): 36 | index.append(int(line.strip())) 37 | return index 38 | 39 | def sample_mask(idx, l): 40 | """Create mask.""" 41 | mask = np.zeros(l) 42 | mask[idx] = 1 43 | return np.array(mask, dtype=np.bool) 44 | 45 | def load_data(dataset_str): # {'pubmed', 'citeseer', 'cora'} 46 | """Load data.""" 47 | names = ['x', 'y', 'tx', 'ty', 'allx', 'ally', 'graph'] 48 | objects = [] 49 | for i in range(len(names)): 50 | with open("data/ind.{}.{}".format(dataset_str, names[i]), 'rb') as f: 51 | if sys.version_info > (3, 0): 52 | objects.append(pkl.load(f, encoding='latin1')) 53 | else: 54 | objects.append(pkl.load(f)) 55 | 56 | x, y, tx, ty, allx, ally, graph = tuple(objects) 57 | test_idx_reorder = parse_index_file("data/ind.{}.test.index".format(dataset_str)) 58 | test_idx_range = np.sort(test_idx_reorder) 59 | 60 | if dataset_str == 'citeseer': 61 | # Fix citeseer dataset (there are some isolated nodes in the graph) 62 | # Find isolated nodes, add them as zero-vecs into the right position 63 | test_idx_range_full = range(min(test_idx_reorder), max(test_idx_reorder)+1) 64 | tx_extended = sp.lil_matrix((len(test_idx_range_full), x.shape[1])) 65 | tx_extended[test_idx_range-min(test_idx_range), :] = tx 66 | tx = tx_extended 67 | ty_extended = np.zeros((len(test_idx_range_full), y.shape[1])) 68 | ty_extended[test_idx_range-min(test_idx_range), :] = ty 69 | ty = ty_extended 70 | 71 | features = sp.vstack((allx, tx)).tolil() 72 | features[test_idx_reorder, :] = features[test_idx_range, :] 73 | adj = nx.adjacency_matrix(nx.from_dict_of_lists(graph)) 74 | 75 | labels = np.vstack((ally, ty)) 76 | labels[test_idx_reorder, :] = labels[test_idx_range, :] 77 | 78 | idx_test = test_idx_range.tolist() 79 | idx_train = range(len(y)) 80 | idx_val = range(len(y), len(y)+500) 81 | 82 | train_mask = sample_mask(idx_train, labels.shape[0]) 83 | val_mask = sample_mask(idx_val, labels.shape[0]) 84 | test_mask = sample_mask(idx_test, labels.shape[0]) 85 | 86 | y_train = np.zeros(labels.shape) 87 | y_val = np.zeros(labels.shape) 88 | y_test = np.zeros(labels.shape) 89 | y_train[train_mask, :] = labels[train_mask, :] 90 | y_val[val_mask, :] = labels[val_mask, :] 91 | y_test[test_mask, :] = labels[test_mask, :] 92 | 93 | print(adj.shape) 94 | print(features.shape) 95 | 96 | return adj, features, y_train, y_val, y_test, train_mask, val_mask, test_mask 97 | 98 | def load_random_data(size): 99 | 100 | adj = sp.random(size, size, density=0.002) # density similar to cora 101 | features = sp.random(size, 1000, density=0.015) 102 | int_labels = np.random.randint(7, size=(size)) 103 | labels = np.zeros((size, 7)) # Nx7 104 | labels[np.arange(size), int_labels] = 1 105 | 106 | train_mask = np.zeros((size,)).astype(bool) 107 | train_mask[np.arange(size)[0:int(size/2)]] = 1 108 | 109 | val_mask = np.zeros((size,)).astype(bool) 110 | val_mask[np.arange(size)[int(size/2):]] = 1 111 | 112 | test_mask = np.zeros((size,)).astype(bool) 113 | test_mask[np.arange(size)[int(size/2):]] = 1 114 | 115 | y_train = np.zeros(labels.shape) 116 | y_val = np.zeros(labels.shape) 117 | y_test = np.zeros(labels.shape) 118 | y_train[train_mask, :] = labels[train_mask, :] 119 | y_val[val_mask, :] = labels[val_mask, :] 120 | y_test[test_mask, :] = labels[test_mask, :] 121 | 122 | # sparse NxN, sparse NxF, norm NxC, ..., norm Nx1, ... 123 | return adj, features, y_train, y_val, y_test, train_mask, val_mask, test_mask 124 | 125 | def sparse_to_tuple(sparse_mx): 126 | """Convert sparse matrix to tuple representation.""" 127 | def to_tuple(mx): 128 | if not sp.isspmatrix_coo(mx): 129 | mx = mx.tocoo() 130 | coords = np.vstack((mx.row, mx.col)).transpose() 131 | values = mx.data 132 | shape = mx.shape 133 | return coords, values, shape 134 | 135 | if isinstance(sparse_mx, list): 136 | for i in range(len(sparse_mx)): 137 | sparse_mx[i] = to_tuple(sparse_mx[i]) 138 | else: 139 | sparse_mx = to_tuple(sparse_mx) 140 | 141 | return sparse_mx 142 | 143 | def standardize_data(f, train_mask): 144 | """Standardize feature matrix and convert to tuple representation""" 145 | # standardize data 146 | f = f.todense() 147 | mu = f[train_mask == True, :].mean(axis=0) 148 | sigma = f[train_mask == True, :].std(axis=0) 149 | f = f[:, np.squeeze(np.array(sigma > 0))] 150 | mu = f[train_mask == True, :].mean(axis=0) 151 | sigma = f[train_mask == True, :].std(axis=0) 152 | f = (f - mu) / sigma 153 | return f 154 | 155 | def preprocess_features(features): 156 | """Row-normalize feature matrix and convert to tuple representation""" 157 | rowsum = np.array(features.sum(1)) 158 | r_inv = np.power(rowsum, -1).flatten() 159 | r_inv[np.isinf(r_inv)] = 0. 160 | r_mat_inv = sp.diags(r_inv) 161 | features = r_mat_inv.dot(features) 162 | return features.todense(), sparse_to_tuple(features) 163 | 164 | def normalize_adj(adj): 165 | """Symmetrically normalize adjacency matrix.""" 166 | adj = sp.coo_matrix(adj) 167 | rowsum = np.array(adj.sum(1)) 168 | d_inv_sqrt = np.power(rowsum, -0.5).flatten() 169 | d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0. 170 | d_mat_inv_sqrt = sp.diags(d_inv_sqrt) 171 | return adj.dot(d_mat_inv_sqrt).transpose().dot(d_mat_inv_sqrt).tocoo() 172 | 173 | 174 | def preprocess_adj(adj): 175 | """Preprocessing of adjacency matrix for simple GCN model and conversion to tuple representation.""" 176 | adj_normalized = normalize_adj(adj + sp.eye(adj.shape[0])) 177 | return sparse_to_tuple(adj_normalized) 178 | 179 | def preprocess_adj_bias(adj): 180 | num_nodes = adj.shape[0] 181 | adj = adj + sp.eye(num_nodes) # self-loop 182 | adj[adj > 0.0] = 1.0 183 | if not sp.isspmatrix_coo(adj): 184 | adj = adj.tocoo() 185 | adj = adj.astype(np.float32) 186 | indices = np.vstack((adj.col, adj.row)).transpose() # This is where I made a mistake, I used (adj.row, adj.col) instead 187 | # return tf.SparseTensor(indices=indices, values=adj.data, dense_shape=adj.shape) 188 | return indices, adj.data, adj.shape 189 | -------------------------------------------------------------------------------- /utils/process_ppi.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import json 3 | import networkx as nx 4 | from networkx.readwrite import json_graph 5 | import scipy.sparse as sp 6 | import pdb 7 | import sys 8 | sys.setrecursionlimit(99999) 9 | 10 | 11 | def run_dfs(adj, msk, u, ind, nb_nodes): 12 | if msk[u] == -1: 13 | msk[u] = ind 14 | #for v in range(nb_nodes): 15 | for v in adj[u,:].nonzero()[1]: 16 | #if adj[u,v]== 1: 17 | run_dfs(adj, msk, v, ind, nb_nodes) 18 | 19 | # Use depth-first search to split a graph into subgraphs 20 | def dfs_split(adj): 21 | # Assume adj is of shape [nb_nodes, nb_nodes] 22 | nb_nodes = adj.shape[0] 23 | ret = np.full(nb_nodes, -1, dtype=np.int32) 24 | 25 | graph_id = 0 26 | 27 | for i in range(nb_nodes): 28 | if ret[i] == -1: 29 | run_dfs(adj, ret, i, graph_id, nb_nodes) 30 | graph_id += 1 31 | 32 | return ret 33 | 34 | def test(adj, mapping): 35 | nb_nodes = adj.shape[0] 36 | for i in range(nb_nodes): 37 | #for j in range(nb_nodes): 38 | for j in adj[i, :].nonzero()[1]: 39 | if mapping[i] != mapping[j]: 40 | # if adj[i,j] == 1: 41 | return False 42 | return True 43 | 44 | 45 | 46 | def find_split(adj, mapping, ds_label): 47 | nb_nodes = adj.shape[0] 48 | dict_splits={} 49 | for i in range(nb_nodes): 50 | #for j in range(nb_nodes): 51 | for j in adj[i, :].nonzero()[1]: 52 | if mapping[i]==0 or mapping[j]==0: 53 | dict_splits[0]=None 54 | elif mapping[i] == mapping[j]: 55 | if ds_label[i]['val'] == ds_label[j]['val'] and ds_label[i]['test'] == ds_label[j]['test']: 56 | 57 | if mapping[i] not in dict_splits.keys(): 58 | if ds_label[i]['val']: 59 | dict_splits[mapping[i]] = 'val' 60 | 61 | elif ds_label[i]['test']: 62 | dict_splits[mapping[i]]='test' 63 | 64 | else: 65 | dict_splits[mapping[i]] = 'train' 66 | 67 | else: 68 | if ds_label[i]['test']: 69 | ind_label='test' 70 | elif ds_label[i]['val']: 71 | ind_label='val' 72 | else: 73 | ind_label='train' 74 | if dict_splits[mapping[i]]!= ind_label: 75 | print ('inconsistent labels within a graph exiting!!!') 76 | return None 77 | else: 78 | print ('label of both nodes different, exiting!!') 79 | return None 80 | return dict_splits 81 | 82 | 83 | 84 | 85 | def process_p2p(): 86 | 87 | 88 | print ('Loading G...') 89 | with open('p2p_dataset/ppi-G.json') as jsonfile: 90 | g_data = json.load(jsonfile) 91 | print (len(g_data)) 92 | G = json_graph.node_link_graph(g_data) 93 | 94 | #Extracting adjacency matrix 95 | adj=nx.adjacency_matrix(G) 96 | 97 | prev_key='' 98 | for key, value in g_data.items(): 99 | if prev_key!=key: 100 | print (key) 101 | prev_key=key 102 | 103 | print ('Loading id_map...') 104 | with open('p2p_dataset/ppi-id_map.json') as jsonfile: 105 | id_map = json.load(jsonfile) 106 | print (len(id_map)) 107 | 108 | id_map = {int(k):int(v) for k,v in id_map.items()} 109 | for key, value in id_map.items(): 110 | id_map[key]=[value] 111 | print (len(id_map)) 112 | 113 | print ('Loading features...') 114 | features_=np.load('p2p_dataset/ppi-feats.npy') 115 | print (features_.shape) 116 | 117 | #standarizing features 118 | from sklearn.preprocessing import StandardScaler 119 | 120 | train_ids = np.array([id_map[n] for n in G.nodes() if not G.node[n]['val'] and not G.node[n]['test']]) 121 | train_feats = features_[train_ids[:,0]] 122 | scaler = StandardScaler() 123 | scaler.fit(train_feats) 124 | features_ = scaler.transform(features_) 125 | 126 | features = sp.csr_matrix(features_).tolil() 127 | 128 | 129 | print ('Loading class_map...') 130 | class_map = {} 131 | with open('p2p_dataset/ppi-class_map.json') as jsonfile: 132 | class_map = json.load(jsonfile) 133 | print (len(class_map)) 134 | 135 | #pdb.set_trace() 136 | #Split graph into sub-graphs 137 | print ('Splitting graph...') 138 | splits=dfs_split(adj) 139 | 140 | #Rearrange sub-graph index and append sub-graphs with 1 or 2 nodes to bigger sub-graphs 141 | print ('Re-arranging sub-graph IDs...') 142 | list_splits=splits.tolist() 143 | group_inc=1 144 | 145 | for i in range(np.max(list_splits)+1): 146 | if list_splits.count(i)>=3: 147 | splits[np.array(list_splits) == i] =group_inc 148 | group_inc+=1 149 | else: 150 | #splits[np.array(list_splits) == i] = 0 151 | ind_nodes=np.argwhere(np.array(list_splits) == i) 152 | ind_nodes=ind_nodes[:,0].tolist() 153 | split=None 154 | 155 | for ind_node in ind_nodes: 156 | if g_data['nodes'][ind_node]['val']: 157 | if split is None or split=='val': 158 | splits[np.array(list_splits) == i] = 21 159 | split='val' 160 | else: 161 | raise ValueError('new node is VAL but previously was {}'.format(split)) 162 | elif g_data['nodes'][ind_node]['test']: 163 | if split is None or split=='test': 164 | splits[np.array(list_splits) == i] = 23 165 | split='test' 166 | else: 167 | raise ValueError('new node is TEST but previously was {}'.format(split)) 168 | else: 169 | if split is None or split == 'train': 170 | splits[np.array(list_splits) == i] = 1 171 | split='train' 172 | else: 173 | pdb.set_trace() 174 | raise ValueError('new node is TRAIN but previously was {}'.format(split)) 175 | 176 | #counting number of nodes per sub-graph 177 | list_splits=splits.tolist() 178 | nodes_per_graph=[] 179 | for i in range(1,np.max(list_splits) + 1): 180 | nodes_per_graph.append(list_splits.count(i)) 181 | 182 | #Splitting adj matrix into sub-graphs 183 | subgraph_nodes=np.max(nodes_per_graph) 184 | adj_sub=np.empty((len(nodes_per_graph), subgraph_nodes, subgraph_nodes)) 185 | feat_sub = np.empty((len(nodes_per_graph), subgraph_nodes, features.shape[1])) 186 | labels_sub = np.empty((len(nodes_per_graph), subgraph_nodes, 121)) 187 | 188 | for i in range(1, np.max(list_splits) + 1): 189 | #Creating same size sub-graphs 190 | indexes = np.where(splits == i)[0] 191 | subgraph_=adj[indexes,:][:,indexes] 192 | 193 | if subgraph_.shape[0]