├── Framework.png ├── LICENSE ├── README.md ├── __init__.py ├── data ├── .DS_Store ├── ind.cora.allx ├── ind.cora.ally ├── ind.cora.graph ├── ind.cora.test.index ├── ind.cora.tx ├── ind.cora.ty ├── ind.cora.x └── ind.cora.y ├── inits.py ├── layers.py ├── metrics.py ├── models.py ├── train.py ├── utils.py └── weighting_func.py /Framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SwiftieH/SpGAT/b9fdd1a326e28d4d4dcd922cdebaedd764783cf6/Framework.png -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Heng Chang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SpGAT 2 | 3 | This is a TensorFlow implementation of Spectral Graph Attention Network with Fast Eigen-approximation (**SpGAT**). 4 | 5 | Heng Chang, Yu Rong, Tingyang Xu, Wenbing Huang, Somayeh Sojoudi, Junzhou Huang, Wenwu Zhu, [Spectral Graph Attention Network with Fast Eigen-approximation](https://dl.acm.org/doi/abs/10.1145/3459637.3482187), CIKM 2021. 6 | 7 |
8 | framework 9 |
10 | 11 | ## Requirements 12 | * python3 13 | * tensorflow (tested on 1.12.0) 14 | * networkx 15 | * numpy 16 | * scipy 17 | * sklearn 18 | 19 | Anaconda environment is recommended. 20 | 21 | ## Run the code 22 | To replicate the result of SpGAT on Cora: 23 | ```bash 24 | python train.py 25 | ``` 26 | To replicate the result of SpGAT_Cheby on Cora: 27 | ```bash 28 | python train.py --model SpGAT_Cheby 29 | ``` 30 | 31 | ## Acknowledgement 32 | This repo is modified from [GWNN](https://github.com/Eilene/GWNN), and we sincerely thank them for their contributions. 33 | 34 | ## Reference 35 | - If you find ``SpGAT`` useful in your research, please cite the following in your manuscript: 36 | 37 | ``` 38 | @article{chang2020spectral, 39 | title={Spectral Graph Attention Network with Fast Eigen-approximation}, 40 | author={Chang, Heng and Rong, Yu and Xu, Tingyang and Huang, Wenbing and Sojoudi, Somayeh and Huang, Junzhou and Zhu, Wenwu}, 41 | journal={arXiv preprint arXiv:2003.07450}, 42 | year={2020} 43 | } 44 | ``` 45 | 46 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from __future__ import division 3 | -------------------------------------------------------------------------------- /data/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SwiftieH/SpGAT/b9fdd1a326e28d4d4dcd922cdebaedd764783cf6/data/.DS_Store -------------------------------------------------------------------------------- /data/ind.cora.allx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SwiftieH/SpGAT/b9fdd1a326e28d4d4dcd922cdebaedd764783cf6/data/ind.cora.allx -------------------------------------------------------------------------------- /data/ind.cora.ally: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SwiftieH/SpGAT/b9fdd1a326e28d4d4dcd922cdebaedd764783cf6/data/ind.cora.ally -------------------------------------------------------------------------------- /data/ind.cora.graph: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SwiftieH/SpGAT/b9fdd1a326e28d4d4dcd922cdebaedd764783cf6/data/ind.cora.graph -------------------------------------------------------------------------------- /data/ind.cora.test.index: -------------------------------------------------------------------------------- 1 | 2692 2 | 2532 3 | 2050 4 | 1715 5 | 2362 6 | 2609 7 | 2622 8 | 1975 9 | 2081 10 | 1767 11 | 2263 12 | 1725 13 | 2588 14 | 2259 15 | 2357 16 | 1998 17 | 2574 18 | 2179 19 | 2291 20 | 2382 21 | 1812 22 | 1751 23 | 2422 24 | 1937 25 | 2631 26 | 2510 27 | 2378 28 | 2589 29 | 2345 30 | 1943 31 | 1850 32 | 2298 33 | 1825 34 | 2035 35 | 2507 36 | 2313 37 | 1906 38 | 1797 39 | 2023 40 | 2159 41 | 2495 42 | 1886 43 | 2122 44 | 2369 45 | 2461 46 | 1925 47 | 2565 48 | 1858 49 | 2234 50 | 2000 51 | 1846 52 | 2318 53 | 1723 54 | 2559 55 | 2258 56 | 1763 57 | 1991 58 | 1922 59 | 2003 60 | 2662 61 | 2250 62 | 2064 63 | 2529 64 | 1888 65 | 2499 66 | 2454 67 | 2320 68 | 2287 69 | 2203 70 | 2018 71 | 2002 72 | 2632 73 | 2554 74 | 2314 75 | 2537 76 | 1760 77 | 2088 78 | 2086 79 | 2218 80 | 2605 81 | 1953 82 | 2403 83 | 1920 84 | 2015 85 | 2335 86 | 2535 87 | 1837 88 | 2009 89 | 1905 90 | 2636 91 | 1942 92 | 2193 93 | 2576 94 | 2373 95 | 1873 96 | 2463 97 | 2509 98 | 1954 99 | 2656 100 | 2455 101 | 2494 102 | 2295 103 | 2114 104 | 2561 105 | 2176 106 | 2275 107 | 2635 108 | 2442 109 | 2704 110 | 2127 111 | 2085 112 | 2214 113 | 2487 114 | 1739 115 | 2543 116 | 1783 117 | 2485 118 | 2262 119 | 2472 120 | 2326 121 | 1738 122 | 2170 123 | 2100 124 | 2384 125 | 2152 126 | 2647 127 | 2693 128 | 2376 129 | 1775 130 | 1726 131 | 2476 132 | 2195 133 | 1773 134 | 1793 135 | 2194 136 | 2581 137 | 1854 138 | 2524 139 | 1945 140 | 1781 141 | 1987 142 | 2599 143 | 1744 144 | 2225 145 | 2300 146 | 1928 147 | 2042 148 | 2202 149 | 1958 150 | 1816 151 | 1916 152 | 2679 153 | 2190 154 | 1733 155 | 2034 156 | 2643 157 | 2177 158 | 1883 159 | 1917 160 | 1996 161 | 2491 162 | 2268 163 | 2231 164 | 2471 165 | 1919 166 | 1909 167 | 2012 168 | 2522 169 | 1865 170 | 2466 171 | 2469 172 | 2087 173 | 2584 174 | 2563 175 | 1924 176 | 2143 177 | 1736 178 | 1966 179 | 2533 180 | 2490 181 | 2630 182 | 1973 183 | 2568 184 | 1978 185 | 2664 186 | 2633 187 | 2312 188 | 2178 189 | 1754 190 | 2307 191 | 2480 192 | 1960 193 | 1742 194 | 1962 195 | 2160 196 | 2070 197 | 2553 198 | 2433 199 | 1768 200 | 2659 201 | 2379 202 | 2271 203 | 1776 204 | 2153 205 | 1877 206 | 2027 207 | 2028 208 | 2155 209 | 2196 210 | 2483 211 | 2026 212 | 2158 213 | 2407 214 | 1821 215 | 2131 216 | 2676 217 | 2277 218 | 2489 219 | 2424 220 | 1963 221 | 1808 222 | 1859 223 | 2597 224 | 2548 225 | 2368 226 | 1817 227 | 2405 228 | 2413 229 | 2603 230 | 2350 231 | 2118 232 | 2329 233 | 1969 234 | 2577 235 | 2475 236 | 2467 237 | 2425 238 | 1769 239 | 2092 240 | 2044 241 | 2586 242 | 2608 243 | 1983 244 | 2109 245 | 2649 246 | 1964 247 | 2144 248 | 1902 249 | 2411 250 | 2508 251 | 2360 252 | 1721 253 | 2005 254 | 2014 255 | 2308 256 | 2646 257 | 1949 258 | 1830 259 | 2212 260 | 2596 261 | 1832 262 | 1735 263 | 1866 264 | 2695 265 | 1941 266 | 2546 267 | 2498 268 | 2686 269 | 2665 270 | 1784 271 | 2613 272 | 1970 273 | 2021 274 | 2211 275 | 2516 276 | 2185 277 | 2479 278 | 2699 279 | 2150 280 | 1990 281 | 2063 282 | 2075 283 | 1979 284 | 2094 285 | 1787 286 | 2571 287 | 2690 288 | 1926 289 | 2341 290 | 2566 291 | 1957 292 | 1709 293 | 1955 294 | 2570 295 | 2387 296 | 1811 297 | 2025 298 | 2447 299 | 2696 300 | 2052 301 | 2366 302 | 1857 303 | 2273 304 | 2245 305 | 2672 306 | 2133 307 | 2421 308 | 1929 309 | 2125 310 | 2319 311 | 2641 312 | 2167 313 | 2418 314 | 1765 315 | 1761 316 | 1828 317 | 2188 318 | 1972 319 | 1997 320 | 2419 321 | 2289 322 | 2296 323 | 2587 324 | 2051 325 | 2440 326 | 2053 327 | 2191 328 | 1923 329 | 2164 330 | 1861 331 | 2339 332 | 2333 333 | 2523 334 | 2670 335 | 2121 336 | 1921 337 | 1724 338 | 2253 339 | 2374 340 | 1940 341 | 2545 342 | 2301 343 | 2244 344 | 2156 345 | 1849 346 | 2551 347 | 2011 348 | 2279 349 | 2572 350 | 1757 351 | 2400 352 | 2569 353 | 2072 354 | 2526 355 | 2173 356 | 2069 357 | 2036 358 | 1819 359 | 1734 360 | 1880 361 | 2137 362 | 2408 363 | 2226 364 | 2604 365 | 1771 366 | 2698 367 | 2187 368 | 2060 369 | 1756 370 | 2201 371 | 2066 372 | 2439 373 | 1844 374 | 1772 375 | 2383 376 | 2398 377 | 1708 378 | 1992 379 | 1959 380 | 1794 381 | 2426 382 | 2702 383 | 2444 384 | 1944 385 | 1829 386 | 2660 387 | 2497 388 | 2607 389 | 2343 390 | 1730 391 | 2624 392 | 1790 393 | 1935 394 | 1967 395 | 2401 396 | 2255 397 | 2355 398 | 2348 399 | 1931 400 | 2183 401 | 2161 402 | 2701 403 | 1948 404 | 2501 405 | 2192 406 | 2404 407 | 2209 408 | 2331 409 | 1810 410 | 2363 411 | 2334 412 | 1887 413 | 2393 414 | 2557 415 | 1719 416 | 1732 417 | 1986 418 | 2037 419 | 2056 420 | 1867 421 | 2126 422 | 1932 423 | 2117 424 | 1807 425 | 1801 426 | 1743 427 | 2041 428 | 1843 429 | 2388 430 | 2221 431 | 1833 432 | 2677 433 | 1778 434 | 2661 435 | 2306 436 | 2394 437 | 2106 438 | 2430 439 | 2371 440 | 2606 441 | 2353 442 | 2269 443 | 2317 444 | 2645 445 | 2372 446 | 2550 447 | 2043 448 | 1968 449 | 2165 450 | 2310 451 | 1985 452 | 2446 453 | 1982 454 | 2377 455 | 2207 456 | 1818 457 | 1913 458 | 1766 459 | 1722 460 | 1894 461 | 2020 462 | 1881 463 | 2621 464 | 2409 465 | 2261 466 | 2458 467 | 2096 468 | 1712 469 | 2594 470 | 2293 471 | 2048 472 | 2359 473 | 1839 474 | 2392 475 | 2254 476 | 1911 477 | 2101 478 | 2367 479 | 1889 480 | 1753 481 | 2555 482 | 2246 483 | 2264 484 | 2010 485 | 2336 486 | 2651 487 | 2017 488 | 2140 489 | 1842 490 | 2019 491 | 1890 492 | 2525 493 | 2134 494 | 2492 495 | 2652 496 | 2040 497 | 2145 498 | 2575 499 | 2166 500 | 1999 501 | 2434 502 | 1711 503 | 2276 504 | 2450 505 | 2389 506 | 2669 507 | 2595 508 | 1814 509 | 2039 510 | 2502 511 | 1896 512 | 2168 513 | 2344 514 | 2637 515 | 2031 516 | 1977 517 | 2380 518 | 1936 519 | 2047 520 | 2460 521 | 2102 522 | 1745 523 | 2650 524 | 2046 525 | 2514 526 | 1980 527 | 2352 528 | 2113 529 | 1713 530 | 2058 531 | 2558 532 | 1718 533 | 1864 534 | 1876 535 | 2338 536 | 1879 537 | 1891 538 | 2186 539 | 2451 540 | 2181 541 | 2638 542 | 2644 543 | 2103 544 | 2591 545 | 2266 546 | 2468 547 | 1869 548 | 2582 549 | 2674 550 | 2361 551 | 2462 552 | 1748 553 | 2215 554 | 2615 555 | 2236 556 | 2248 557 | 2493 558 | 2342 559 | 2449 560 | 2274 561 | 1824 562 | 1852 563 | 1870 564 | 2441 565 | 2356 566 | 1835 567 | 2694 568 | 2602 569 | 2685 570 | 1893 571 | 2544 572 | 2536 573 | 1994 574 | 1853 575 | 1838 576 | 1786 577 | 1930 578 | 2539 579 | 1892 580 | 2265 581 | 2618 582 | 2486 583 | 2583 584 | 2061 585 | 1796 586 | 1806 587 | 2084 588 | 1933 589 | 2095 590 | 2136 591 | 2078 592 | 1884 593 | 2438 594 | 2286 595 | 2138 596 | 1750 597 | 2184 598 | 1799 599 | 2278 600 | 2410 601 | 2642 602 | 2435 603 | 1956 604 | 2399 605 | 1774 606 | 2129 607 | 1898 608 | 1823 609 | 1938 610 | 2299 611 | 1862 612 | 2420 613 | 2673 614 | 1984 615 | 2204 616 | 1717 617 | 2074 618 | 2213 619 | 2436 620 | 2297 621 | 2592 622 | 2667 623 | 2703 624 | 2511 625 | 1779 626 | 1782 627 | 2625 628 | 2365 629 | 2315 630 | 2381 631 | 1788 632 | 1714 633 | 2302 634 | 1927 635 | 2325 636 | 2506 637 | 2169 638 | 2328 639 | 2629 640 | 2128 641 | 2655 642 | 2282 643 | 2073 644 | 2395 645 | 2247 646 | 2521 647 | 2260 648 | 1868 649 | 1988 650 | 2324 651 | 2705 652 | 2541 653 | 1731 654 | 2681 655 | 2707 656 | 2465 657 | 1785 658 | 2149 659 | 2045 660 | 2505 661 | 2611 662 | 2217 663 | 2180 664 | 1904 665 | 2453 666 | 2484 667 | 1871 668 | 2309 669 | 2349 670 | 2482 671 | 2004 672 | 1965 673 | 2406 674 | 2162 675 | 1805 676 | 2654 677 | 2007 678 | 1947 679 | 1981 680 | 2112 681 | 2141 682 | 1720 683 | 1758 684 | 2080 685 | 2330 686 | 2030 687 | 2432 688 | 2089 689 | 2547 690 | 1820 691 | 1815 692 | 2675 693 | 1840 694 | 2658 695 | 2370 696 | 2251 697 | 1908 698 | 2029 699 | 2068 700 | 2513 701 | 2549 702 | 2267 703 | 2580 704 | 2327 705 | 2351 706 | 2111 707 | 2022 708 | 2321 709 | 2614 710 | 2252 711 | 2104 712 | 1822 713 | 2552 714 | 2243 715 | 1798 716 | 2396 717 | 2663 718 | 2564 719 | 2148 720 | 2562 721 | 2684 722 | 2001 723 | 2151 724 | 2706 725 | 2240 726 | 2474 727 | 2303 728 | 2634 729 | 2680 730 | 2055 731 | 2090 732 | 2503 733 | 2347 734 | 2402 735 | 2238 736 | 1950 737 | 2054 738 | 2016 739 | 1872 740 | 2233 741 | 1710 742 | 2032 743 | 2540 744 | 2628 745 | 1795 746 | 2616 747 | 1903 748 | 2531 749 | 2567 750 | 1946 751 | 1897 752 | 2222 753 | 2227 754 | 2627 755 | 1856 756 | 2464 757 | 2241 758 | 2481 759 | 2130 760 | 2311 761 | 2083 762 | 2223 763 | 2284 764 | 2235 765 | 2097 766 | 1752 767 | 2515 768 | 2527 769 | 2385 770 | 2189 771 | 2283 772 | 2182 773 | 2079 774 | 2375 775 | 2174 776 | 2437 777 | 1993 778 | 2517 779 | 2443 780 | 2224 781 | 2648 782 | 2171 783 | 2290 784 | 2542 785 | 2038 786 | 1855 787 | 1831 788 | 1759 789 | 1848 790 | 2445 791 | 1827 792 | 2429 793 | 2205 794 | 2598 795 | 2657 796 | 1728 797 | 2065 798 | 1918 799 | 2427 800 | 2573 801 | 2620 802 | 2292 803 | 1777 804 | 2008 805 | 1875 806 | 2288 807 | 2256 808 | 2033 809 | 2470 810 | 2585 811 | 2610 812 | 2082 813 | 2230 814 | 1915 815 | 1847 816 | 2337 817 | 2512 818 | 2386 819 | 2006 820 | 2653 821 | 2346 822 | 1951 823 | 2110 824 | 2639 825 | 2520 826 | 1939 827 | 2683 828 | 2139 829 | 2220 830 | 1910 831 | 2237 832 | 1900 833 | 1836 834 | 2197 835 | 1716 836 | 1860 837 | 2077 838 | 2519 839 | 2538 840 | 2323 841 | 1914 842 | 1971 843 | 1845 844 | 2132 845 | 1802 846 | 1907 847 | 2640 848 | 2496 849 | 2281 850 | 2198 851 | 2416 852 | 2285 853 | 1755 854 | 2431 855 | 2071 856 | 2249 857 | 2123 858 | 1727 859 | 2459 860 | 2304 861 | 2199 862 | 1791 863 | 1809 864 | 1780 865 | 2210 866 | 2417 867 | 1874 868 | 1878 869 | 2116 870 | 1961 871 | 1863 872 | 2579 873 | 2477 874 | 2228 875 | 2332 876 | 2578 877 | 2457 878 | 2024 879 | 1934 880 | 2316 881 | 1841 882 | 1764 883 | 1737 884 | 2322 885 | 2239 886 | 2294 887 | 1729 888 | 2488 889 | 1974 890 | 2473 891 | 2098 892 | 2612 893 | 1834 894 | 2340 895 | 2423 896 | 2175 897 | 2280 898 | 2617 899 | 2208 900 | 2560 901 | 1741 902 | 2600 903 | 2059 904 | 1747 905 | 2242 906 | 2700 907 | 2232 908 | 2057 909 | 2147 910 | 2682 911 | 1792 912 | 1826 913 | 2120 914 | 1895 915 | 2364 916 | 2163 917 | 1851 918 | 2391 919 | 2414 920 | 2452 921 | 1803 922 | 1989 923 | 2623 924 | 2200 925 | 2528 926 | 2415 927 | 1804 928 | 2146 929 | 2619 930 | 2687 931 | 1762 932 | 2172 933 | 2270 934 | 2678 935 | 2593 936 | 2448 937 | 1882 938 | 2257 939 | 2500 940 | 1899 941 | 2478 942 | 2412 943 | 2107 944 | 1746 945 | 2428 946 | 2115 947 | 1800 948 | 1901 949 | 2397 950 | 2530 951 | 1912 952 | 2108 953 | 2206 954 | 2091 955 | 1740 956 | 2219 957 | 1976 958 | 2099 959 | 2142 960 | 2671 961 | 2668 962 | 2216 963 | 2272 964 | 2229 965 | 2666 966 | 2456 967 | 2534 968 | 2697 969 | 2688 970 | 2062 971 | 2691 972 | 2689 973 | 2154 974 | 2590 975 | 2626 976 | 2390 977 | 1813 978 | 2067 979 | 1952 980 | 2518 981 | 2358 982 | 1789 983 | 2076 984 | 2049 985 | 2119 986 | 2013 987 | 2124 988 | 2556 989 | 2105 990 | 2093 991 | 1885 992 | 2305 993 | 2354 994 | 2135 995 | 2601 996 | 1770 997 | 1995 998 | 2504 999 | 1749 1000 | 2157 1001 | -------------------------------------------------------------------------------- /data/ind.cora.tx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SwiftieH/SpGAT/b9fdd1a326e28d4d4dcd922cdebaedd764783cf6/data/ind.cora.tx -------------------------------------------------------------------------------- /data/ind.cora.ty: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SwiftieH/SpGAT/b9fdd1a326e28d4d4dcd922cdebaedd764783cf6/data/ind.cora.ty -------------------------------------------------------------------------------- /data/ind.cora.x: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SwiftieH/SpGAT/b9fdd1a326e28d4d4dcd922cdebaedd764783cf6/data/ind.cora.x -------------------------------------------------------------------------------- /data/ind.cora.y: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SwiftieH/SpGAT/b9fdd1a326e28d4d4dcd922cdebaedd764783cf6/data/ind.cora.y -------------------------------------------------------------------------------- /inits.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | 4 | 5 | def uniform(shape, scale=1.0, name=None): 6 | """Uniform init.""" 7 | initial = tf.random_uniform(shape, minval=0.0, maxval=scale, dtype=tf.float32) 8 | return tf.Variable(initial, name=name) 9 | 10 | 11 | def glorot(shape, name=None): 12 | """Glorot & Bengio (AISTATS 2010) init.""" 13 | init_range = np.sqrt(6.0/(shape[0]+shape[1])) 14 | initial = tf.random_uniform(shape, minval=-init_range, maxval=init_range, dtype=tf.float32) 15 | return tf.Variable(initial, name=name) 16 | 17 | 18 | def zeros(shape, name=None): 19 | """All zeros.""" 20 | initial = tf.zeros(shape, dtype=tf.float32) 21 | return tf.Variable(initial, name=name) 22 | 23 | 24 | def ones(shape, name=None): 25 | """All ones.""" 26 | initial = tf.ones(shape, dtype=tf.float32) 27 | return tf.Variable(initial, name=name) 28 | 29 | def ones_fix(shape, name=None): 30 | """All ones.""" 31 | initial = tf.ones(shape, dtype=tf.float32) 32 | return tf.Variable(initial, name=name, trainable=False) 33 | 34 | 35 | 36 | 37 | -------------------------------------------------------------------------------- /layers.py: -------------------------------------------------------------------------------- 1 | from inits import * 2 | import numpy as np 3 | import tensorflow as tf 4 | from sklearn.preprocessing import normalize 5 | flags = tf.app.flags 6 | FLAGS = flags.FLAGS 7 | 8 | # global unique layer ID dictionary for layer name assignment 9 | _LAYER_UIDS = {} 10 | 11 | 12 | def get_layer_uid(layer_name=''): 13 | """Helper function, assigns unique layer IDs.""" 14 | if layer_name not in _LAYER_UIDS: 15 | _LAYER_UIDS[layer_name] = 1 16 | return 1 17 | else: 18 | _LAYER_UIDS[layer_name] += 1 19 | return _LAYER_UIDS[layer_name] 20 | 21 | def sparse_dropout(x, keep_prob, noise_shape): 22 | """Dropout for sparse tensors.""" 23 | random_tensor = keep_prob 24 | random_tensor += tf.random_uniform(noise_shape) 25 | dropout_mask = tf.cast(tf.floor(random_tensor), dtype=tf.bool) 26 | pre_out = tf.sparse_retain(x, dropout_mask) 27 | return pre_out * (1./keep_prob) 28 | 29 | def dot(x, y, sparse=False): 30 | """Wrapper for tf.matmul (sparse vs dense).""" 31 | if sparse: 32 | res = tf.sparse_tensor_dense_matmul(x, y) 33 | else: 34 | res = tf.matmul(x, y) 35 | return res 36 | 37 | class Layer(object): 38 | """Base layer class. Defines basic API for all layer objects. 39 | Implementation inspired by keras (http://keras.io). 40 | 41 | # Properties 42 | name: String, defines the variable scope of the layer. 43 | logging: Boolean, switches Tensorflow histogram logging on/off 44 | 45 | # Methods 46 | _call(inputs): Defines computation graph of layer 47 | (i.e. takes input, returns output) 48 | __call__(inputs): Wrapper for _call() 49 | _log_vars(): Log all variables 50 | """ 51 | 52 | def __init__(self, **kwargs): 53 | allowed_kwargs = {'name', 'logging'} 54 | for kwarg in kwargs.keys(): 55 | assert kwarg in allowed_kwargs, 'Invalid keyword argument: ' + kwarg 56 | name = kwargs.get('name') 57 | if not name: 58 | layer = self.__class__.__name__.lower() 59 | name = layer + '_' + str(get_layer_uid(layer)) 60 | self.name = name 61 | self.vars = {} 62 | logging = kwargs.get('logging', False) 63 | self.logging = logging 64 | self.sparse_inputs = False 65 | 66 | def _call(self, inputs): 67 | return inputs 68 | 69 | def __call__(self, inputs): 70 | with tf.name_scope(self.name): 71 | if self.logging and not self.sparse_inputs: 72 | tf.summary.histogram(self.name + '/inputs', inputs) 73 | outputs = self._call(inputs) 74 | if self.logging: 75 | tf.summary.histogram(self.name + '/outputs', outputs) 76 | return outputs 77 | 78 | def _log_vars(self): 79 | for var in self.vars: 80 | tf.summary.histogram(self.name + '/vars/' + var, self.vars[var]) 81 | 82 | class SpGAT_Conv(Layer): 83 | """Graph convolution layer.""" 84 | def __init__(self, k_por, node_num,weight_normalize,input_dim, output_dim, placeholders, dropout=0., 85 | sparse_inputs=False, act=tf.nn.relu, bias=False, 86 | featureless=False, **kwargs): 87 | super(SpGAT_Conv, self).__init__(**kwargs) 88 | 89 | if dropout: 90 | self.dropout = placeholders['dropout'] 91 | else: 92 | self.dropout = 0. 93 | 94 | self.k_por = k_por 95 | self.node_num = node_num 96 | self.weight_normalize = weight_normalize 97 | self.act = act 98 | self.support = placeholders['support'] 99 | self.sparse_inputs = sparse_inputs 100 | self.featureless = featureless 101 | self.bias = bias 102 | 103 | # helper variable for sparse dropout 104 | self.num_features_nonzero = placeholders['num_features_nonzero'] 105 | 106 | with tf.variable_scope(self.name + '_vars'): 107 | self.vars['weights_' + str(0)] = glorot([input_dim, output_dim], 108 | name='weights_' + str(0)) 109 | k_fre = int(self.k_por * self.node_num) 110 | init_alpha = np.array([1, 1], dtype='float32') 111 | self.alpha = tf.get_variable("tf_var_initialized_from_alpha", initializer = init_alpha, trainable=True) 112 | self.alpha = tf.nn.softmax(self.alpha) 113 | self.vars['low_w'] = self.alpha[0] 114 | self.vars['high_w'] = self.alpha[1] 115 | 116 | 117 | self.vars['kernel_low'] = ones_fix([k_fre], name='kernel_low') 118 | self.vars['kernel_high'] = ones_fix([self.node_num - k_fre], name='kernel_high') 119 | self.vars['kernel_low'] = self.vars['kernel_low'] * self.vars['low_w'] 120 | self.vars['kernel_high'] = self.vars['kernel_high'] * self.vars['high_w'] 121 | 122 | 123 | if self.bias: 124 | self.vars['bias'] = zeros([output_dim], name='bias') 125 | 126 | if self.logging: 127 | self._log_vars() 128 | 129 | def _call(self, inputs): 130 | x = inputs 131 | 132 | # dropout 133 | if self.sparse_inputs: 134 | x = sparse_dropout(x, 1-self.dropout, self.num_features_nonzero) 135 | else: 136 | x = tf.nn.dropout(x, 1-self.dropout) 137 | 138 | supports_low = tf.matmul(tf.sparse_tensor_to_dense(self.support[0]),tf.diag(self.vars['kernel_low']),a_is_sparse=True,b_is_sparse=True) 139 | supports_low = tf.matmul(supports_low,tf.sparse_tensor_to_dense(self.support[1]),a_is_sparse=True,b_is_sparse=True) 140 | pre_sup = dot(x, self.vars['weights_' + str(0)],sparse=self.sparse_inputs) 141 | output_low = dot(supports_low,pre_sup) 142 | 143 | 144 | supports_high = tf.matmul(tf.sparse_tensor_to_dense(self.support[2]),tf.diag(self.vars['kernel_high']),a_is_sparse=True,b_is_sparse=True) 145 | supports_high = tf.matmul(supports_high,tf.sparse_tensor_to_dense(self.support[3]),a_is_sparse=True,b_is_sparse=True) 146 | output_high = dot(supports_high,pre_sup) 147 | 148 | #Mean Pooling 149 | #output = output_low + output_high 150 | #Max Pooling 151 | output = tf.concat([tf.expand_dims(output_low, axis = 0), tf.expand_dims(output_high, axis = 0)], axis = 0) 152 | output = tf.reduce_max(output, axis = 0) 153 | #import pdb; pdb.set_trace() 154 | if self.bias: 155 | output += self.vars['bias'] 156 | 157 | return self.act(output) 158 | -------------------------------------------------------------------------------- /metrics.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | def masked_softmax_cross_entropy(preds, labels, mask): 5 | """Softmax cross-entropy loss with masking.""" 6 | loss = tf.nn.softmax_cross_entropy_with_logits(logits=preds, labels=labels) 7 | mask = tf.cast(mask, dtype=tf.float32) 8 | mask /= tf.reduce_mean(mask) 9 | loss *= mask 10 | return tf.reduce_mean(loss) 11 | 12 | 13 | def masked_accuracy(preds, labels, mask): 14 | """Accuracy with masking.""" 15 | correct_prediction = tf.equal(tf.argmax(preds, 1), tf.argmax(labels, 1)) 16 | accuracy_all = tf.cast(correct_prediction, tf.float32) 17 | mask = tf.cast(mask, dtype=tf.float32) 18 | mask /= tf.reduce_mean(mask) 19 | accuracy_all *= mask 20 | return tf.reduce_mean(accuracy_all) 21 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | from layers import * 2 | from metrics import * 3 | 4 | flags = tf.app.flags 5 | FLAGS = flags.FLAGS 6 | 7 | 8 | class Model(object): 9 | def __init__(self, **kwargs): 10 | allowed_kwargs = {'name', 'logging'} 11 | for kwarg in kwargs.keys(): 12 | assert kwarg in allowed_kwargs, 'Invalid keyword argument: ' + kwarg 13 | name = kwargs.get('name') 14 | if not name: 15 | name = self.__class__.__name__.lower() 16 | self.name = name 17 | 18 | logging = kwargs.get('logging', False) 19 | self.logging = logging 20 | 21 | self.vars = {} 22 | self.placeholders = {} 23 | 24 | self.layers = [] 25 | self.activations = [] 26 | 27 | self.inputs = None 28 | self.outputs = None 29 | 30 | self.loss = 0 31 | self.accuracy = 0 32 | self.optimizer = None 33 | self.opt_op = None 34 | 35 | def _build(self): 36 | raise NotImplementedError 37 | 38 | def build(self): 39 | """ Wrapper for _build() """ 40 | with tf.variable_scope(self.name): 41 | self._build() 42 | 43 | # Build sequential layer model 44 | self.activations.append(self.inputs) 45 | for layer in self.layers: 46 | hidden = layer(self.activations[-1]) 47 | self.activations.append(hidden) 48 | self.outputs = self.activations[-1] 49 | 50 | # Store model variables for easy access 51 | variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=self.name) 52 | self.vars = {var.name: var for var in variables} 53 | 54 | # Build metrics 55 | self._loss() 56 | self._accuracy() 57 | 58 | self.opt_op = self.optimizer.minimize(self.loss) 59 | 60 | def predict(self): 61 | pass 62 | 63 | def _loss(self): 64 | raise NotImplementedError 65 | 66 | def _accuracy(self): 67 | raise NotImplementedError 68 | 69 | def save(self, sess=None): 70 | if not sess: 71 | raise AttributeError("TensorFlow session not provided.") 72 | saver = tf.train.Saver(self.vars) 73 | save_path = saver.save(sess, "tmp/%s.ckpt" % self.name) 74 | print("Model saved in file: %s" % save_path) 75 | 76 | def load(self, sess=None): 77 | if not sess: 78 | raise AttributeError("TensorFlow session not provided.") 79 | saver = tf.train.Saver(self.vars) 80 | save_path = "tmp/%s.ckpt" % self.name 81 | saver.restore(sess, save_path) 82 | print("Model restored from file: %s" % save_path) 83 | 84 | 85 | class SpGAT(Model): 86 | def __init__(self, k_por, node_num,weight_normalize,placeholders, input_dim, **kwargs): 87 | super(SpGAT, self).__init__(**kwargs) 88 | 89 | self.weight_normalize = weight_normalize 90 | self.inputs = placeholders['features'] 91 | self.k_por = k_por 92 | self.input_dim = input_dim 93 | self.node_num = node_num 94 | self.output_dim = placeholders['labels'].get_shape().as_list()[1] 95 | self.placeholders = placeholders 96 | 97 | self.optimizer = tf.train.AdamOptimizer(learning_rate=FLAGS.learning_rate) 98 | 99 | self.build() 100 | 101 | def _loss(self): 102 | # Weight decay loss 103 | for var in self.layers[0].vars.values(): 104 | self.loss += FLAGS.weight_decay * tf.nn.l2_loss(var) 105 | 106 | # Cross entropy error 107 | self.loss += masked_softmax_cross_entropy(self.outputs, self.placeholders['labels'], 108 | self.placeholders['labels_mask']) 109 | 110 | def _accuracy(self): 111 | self.accuracy = masked_accuracy(self.outputs, self.placeholders['labels'], 112 | self.placeholders['labels_mask']) 113 | 114 | def _build(self): 115 | 116 | self.layers.append(SpGAT_Conv(k_por = self.k_por, 117 | node_num=self.node_num, 118 | weight_normalize = self.weight_normalize, 119 | input_dim=self.input_dim, 120 | output_dim=FLAGS.hidden1, 121 | placeholders=self.placeholders, 122 | act=tf.nn.relu, 123 | dropout=True, 124 | sparse_inputs=True, 125 | logging=self.logging)) 126 | 127 | 128 | self.layers.append(SpGAT_Conv(k_por = self.k_por, 129 | node_num=self.node_num, 130 | weight_normalize = self.weight_normalize, 131 | input_dim=FLAGS.hidden1, 132 | output_dim=self.output_dim, 133 | placeholders=self.placeholders, 134 | act=lambda x: x, 135 | dropout=True, 136 | logging=self.logging)) 137 | 138 | def predict(self): 139 | return tf.nn.softmax(self.outputs) 140 | 141 | 142 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # -*- coding:UTF-8 -*- 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import tensorflow as tf 6 | import warnings 7 | warnings.filterwarnings("ignore") 8 | 9 | from utils import * 10 | from models import SpGAT 11 | 12 | import os 13 | os.environ['CUDA_VISIBLE_DEVICES']='-1' 14 | 15 | # Set random seed 16 | seed = 322 17 | np.random.seed(seed) 18 | tf.set_random_seed(seed) 19 | 20 | # Settings 21 | flags = tf.app.flags 22 | FLAGS = flags.FLAGS 23 | flags.DEFINE_string('dataset', 'cora', 'Dataset string.') # 'cora', 'citeseer', 'pubmed' 24 | flags.DEFINE_string('model', 'SpGAT', 'Model string.') # 'SpGAT', 'SpGAT_Cheby' 25 | flags.DEFINE_float('wavelet_s', 1.0, 'wavelet s .') 26 | flags.DEFINE_float('threshold', 1e-4, 'sparseness threshold .') 27 | flags.DEFINE_float('learning_rate', 0.01, 'Initial learning rate.') 28 | flags.DEFINE_bool('alldata', False, 'All data string.') 29 | flags.DEFINE_integer('epochs', 200, 'Number of epochs to train.')#1000 30 | flags.DEFINE_integer('hidden1', 64, 'Number of units in hidden layer 1.') 31 | flags.DEFINE_float('dropout', 0.5, 'Dropout rate (1 - keep probability).') 32 | flags.DEFINE_float('weight_decay', 5e-4, 'Weight for L2 loss on embedding matrix.') 33 | flags.DEFINE_integer('early_stopping', 200, 'Tolerance for early stopping (# of epochs).') 34 | flags.DEFINE_bool('mask', True, 'mask string.') 35 | flags.DEFINE_bool('laplacian_normalize', True, 'laplacian normalize string.') 36 | flags.DEFINE_bool('sparse_ness', True, 'wavelet sparse_ness string.') 37 | flags.DEFINE_bool('weight_normalize', False, 'weight normalize string.') 38 | flags.DEFINE_string('gpu', '-1', 'which gpu to use.')#1000 39 | flags.DEFINE_integer('repeating', 1, 'Number of repeating times')#1000 40 | 41 | os.environ['CUDA_VISIBLE_DEVICES']=FLAGS.gpu 42 | 43 | 44 | # Load data 45 | labels, adj, features, y_train, y_val, y_test, train_mask, val_mask, test_mask = load_data(FLAGS.dataset,alldata=FLAGS.alldata) 46 | # Some preprocessing, normalization 47 | features = preprocess_features(features) 48 | node_num = adj.shape[0] 49 | 50 | print("************Loading data finished, Begin constructing wavelet************") 51 | 52 | dataset = FLAGS.dataset 53 | s = FLAGS.wavelet_s 54 | laplacian_normalize = FLAGS.laplacian_normalize 55 | sparse_ness = FLAGS.sparse_ness 56 | threshold = FLAGS.threshold 57 | weight_normalize = FLAGS.weight_normalize 58 | if FLAGS.model == "SpGAT": 59 | support_t = wavelet_basis(dataset,adj, s, laplacian_normalize,sparse_ness,threshold,weight_normalize) 60 | elif FLAGS.model == "SpGAT_Cheby": 61 | s = 2.0 62 | support_t = wavelet_basis_appro(dataset,adj, s, laplacian_normalize,sparse_ness,threshold,weight_normalize) 63 | if dataset == 'cora': 64 | k_por = 0.05 # best $d$ for cora 65 | if dataset == 'pubmed': 66 | k_por = 0.10 # best $d$ for pubmed 67 | if dataset == 'citeseer': 68 | k_por = 0.15 # best $d$ for citeseer 69 | k_fre = int(k_por * node_num) 70 | support = [support_t[0][:,:k_fre], support_t[1][:k_fre,:], support_t[0][:,k_fre:], support_t[1][k_fre:,:]] 71 | sparse_to_tuple(support) 72 | num_supports = len(support) 73 | model_func = SpGAT 74 | 75 | # Define placeholders 76 | placeholders = { 77 | 'support': [tf.sparse_placeholder(tf.float32) for _ in range(num_supports)], 78 | 'features': tf.sparse_placeholder(tf.float32, shape=tf.constant(features[2], dtype=tf.int64)), 79 | 'labels': tf.placeholder(tf.float32, shape=(None, y_train.shape[1])), 80 | 'labels_mask': tf.placeholder(tf.int32), 81 | 'dropout': tf.placeholder_with_default(0., shape=()), 82 | 'num_features_nonzero': tf.placeholder(tf.int32) # helper variable for sparse dropout 83 | } 84 | 85 | # Define model evaluation function 86 | def evaluate(features, support, labels, mask, placeholders): 87 | feed_dict_val = construct_feed_dict(features, support, labels, mask, placeholders) 88 | outs_val = sess.run([model.outputs,model.loss, model.accuracy], feed_dict=feed_dict_val) 89 | return outs_val[0], outs_val[1], outs_val[2] 90 | 91 | test_acc_bestval = [] 92 | test_acc_besttest = [] 93 | val_acc = [] 94 | 95 | for _ in range(FLAGS.repeating): 96 | 97 | #seed = np.random.randint(999) 98 | #np.random.seed(seed) 99 | #tf.set_random_seed(seed) 100 | 101 | # Create model 102 | weight_normalize = FLAGS.weight_normalize 103 | node_num = adj.shape[0] 104 | model = model_func(k_por, node_num,weight_normalize, placeholders, input_dim=features[2][1], logging=True) 105 | print("**************Constructing wavelet finished, Begin training**************") 106 | # Initialize session 107 | sess = tf.Session() 108 | 109 | # Init variables 110 | sess.run(tf.global_variables_initializer()) 111 | 112 | # Train model 113 | cost_val = [] 114 | best_val_acc = 0.0 115 | output_test_acc = 0.0 116 | best_test_acc = 0.0 117 | 118 | for epoch in range(FLAGS.epochs): 119 | 120 | # Construct feed dictionary 121 | feed_dict = construct_feed_dict(features, support, y_train, train_mask, placeholders) 122 | feed_dict.update({placeholders['dropout']: FLAGS.dropout}) 123 | 124 | # Training step 125 | outs = sess.run([model.opt_op, model.loss, model.accuracy], feed_dict=feed_dict) 126 | 127 | # Validation 128 | val_output,cost, acc = evaluate(features, support, y_val, val_mask, placeholders) 129 | cost_val.append(cost) 130 | # Test 131 | test_output, test_cost, test_acc = evaluate(features, support, y_test, test_mask, placeholders) 132 | 133 | # best val acc 134 | if(best_val_acc <= acc): 135 | best_val_acc = acc 136 | output_test_acc = test_acc 137 | # best test acc 138 | if(best_test_acc <= test_acc): 139 | #import pdb; pdb.set_trace() 140 | best_test_acc = test_acc 141 | 142 | 143 | # Print results 144 | print("Epoch:", '%04d' % (epoch + 1), "train_loss=", "{:.5f}".format(outs[1]), 145 | "train_acc=", "{:.5f}".format(outs[2]), "val_loss=", "{:.5f}".format(cost), 146 | "val_acc=", "{:.5f}".format(acc), "test_loss=", "{:.5f}".format(test_cost), "test_acc=", "{:.5f}".format(test_acc)) 147 | 148 | if epoch > FLAGS.early_stopping and cost_val[-1] > np.mean(cost_val[-(FLAGS.early_stopping+1):-1]): 149 | print("Early stopping...") 150 | break 151 | 152 | print("Optimization Finished!") 153 | 154 | print("dataset: ",FLAGS.dataset," model: ",FLAGS.model,",sparse_ness: ",FLAGS.sparse_ness, 155 | ",laplacian_normalize: ",FLAGS.laplacian_normalize,",threshold",FLAGS.threshold,",wavelet_s:",FLAGS.wavelet_s,",mask:",FLAGS.mask, 156 | ",weight_normalize:",FLAGS.weight_normalize, 157 | ",learning_rate:",FLAGS.learning_rate,",hidden1:",FLAGS.hidden1,",dropout:",FLAGS.dropout,",alldata:",FLAGS.alldata) 158 | 159 | print("Val accuracy:", best_val_acc, " Test accuracy: ",output_test_acc) 160 | test_acc_bestval.append(output_test_acc) 161 | test_acc_besttest.append(best_test_acc) 162 | val_acc.append(best_val_acc) 163 | 164 | print("********************************************************") 165 | 166 | result = [] 167 | result.append(np.array(test_acc_bestval)) 168 | result.append(np.array(test_acc_besttest)) 169 | result.append(np.array(val_acc)) 170 | 171 | alpha_1_low = sess.run(model.layers[0].vars['low_w'], feed_dict = feed_dict) 172 | alpha_1_high = sess.run(model.layers[0].vars['high_w'], feed_dict = feed_dict) 173 | alpha_2_low = sess.run(model.layers[1].vars['low_w'], feed_dict = feed_dict) 174 | alpha_2_high = sess.run(model.layers[1].vars['high_w'], feed_dict = feed_dict) 175 | 176 | 177 | r_half = int(FLAGS.repeating / 2) 178 | print("REPEAT\t{}".format(FLAGS.repeating)) 179 | print("Model\t{}".format(FLAGS.model)) 180 | print("Low frequency portion\t{} %".format(k_por * 100)) 181 | print("{:<8}\t{:<8}\t{:<8}\t{:<8}\t{:<8}\t{:<8}\t{:<8}\t{:<8}\t{:<8}".format('DATASET', 'best_val_mean', 'best_val_std', 182 | 'best_test_mean', 'best_test_std', 'half_best_val_mean', 183 | 'half_best_val_std', 'half_best_test_mean', 'half_best_test_std')) 184 | print("{:<8}\t{:<8.6f}\t{:<8.6f}\t{:<8.6f}\t{:<8.6f}\t{:<8.6f}\t{:<8.6f}\t{:<8.6f}\t{:<8.6f}".format( 185 | FLAGS.dataset, 186 | result[0].mean(), 187 | result[0].std(), 188 | result[1].mean(), 189 | result[1].std(), 190 | result[0][np.argsort(result[0])[r_half:]].mean(), 191 | result[0][np.argsort(result[0])[r_half:]].std(), 192 | result[1][np.argsort(result[1])[r_half:]].mean(), #2 for validation 193 | result[1][np.argsort(result[1])[r_half:]].std())) 194 | 195 | print("{:<8}\t{:<8}\t{:<8}\t{:<8}".format('alpha_1_low', 'alpha_1_high', 'alpha_2_low', 'alpha_2_high')) 196 | print("{:<8.6f}\t{:<8.6f}\t{:<8.6f}\t{:<8.6f}".format( 197 | alpha_1_low, 198 | alpha_1_high, 199 | alpha_2_low, 200 | alpha_2_high)) 201 | 202 | 203 | 204 | 205 | 206 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | from sklearn.preprocessing import normalize 2 | import numpy as np 3 | import pickle as pkl 4 | import networkx as nx 5 | import scipy.sparse as sp 6 | import scipy.special as ss 7 | from scipy.sparse.linalg.eigen.arpack import eigsh 8 | import sys 9 | import warnings 10 | warnings.filterwarnings("ignore") 11 | from weighting_func import laplacian,fourier,weight_wavelet,weight_wavelet_inverse 12 | 13 | 14 | def parse_index_file(filename): 15 | """Parse index file.""" 16 | index = [] 17 | for line in open(filename): 18 | index.append(int(line.strip())) 19 | return index 20 | 21 | 22 | def sample_mask(idx, l): 23 | """Create mask.""" 24 | mask = np.zeros(l) 25 | mask[idx] = 1 26 | return np.array(mask, dtype=np.bool) 27 | 28 | 29 | def load_data(dataset_str,alldata = True): 30 | """ 31 | Loads input data from gcn/data directory 32 | 33 | ind.dataset_str.x => the feature vectors of the training instances as scipy.sparse.csr.csr_matrix object; 34 | ind.dataset_str.tx => the feature vectors of the test instances as scipy.sparse.csr.csr_matrix object; 35 | ind.dataset_str.allx => the feature vectors of both labeled and unlabeled training instances 36 | (a superset of ind.dataset_str.x) as scipy.sparse.csr.csr_matrix object; 37 | ind.dataset_str.y => the one-hot labels of the labeled training instances as numpy.ndarray object; 38 | ind.dataset_str.ty => the one-hot labels of the test instances as numpy.ndarray object; 39 | ind.dataset_str.ally => the labels for instances in ind.dataset_str.allx as numpy.ndarray object; 40 | ind.dataset_str.graph => a dict in the format {index: [index_of_neighbor_nodes]} as collections.defaultdict 41 | object; 42 | ind.dataset_str.test.index => the indices of test instances in graph, for the inductive setting as list object. 43 | 44 | All objects above must be saved using python pickle module. 45 | 46 | :param dataset_str: Dataset name 47 | :return: All data input files loaded (as well the training/test data). 48 | """ 49 | names = ['x', 'y', 'tx', 'ty', 'allx', 'ally', 'graph'] 50 | objects = [] 51 | for i in range(len(names)): 52 | with open("data/ind.{}.{}".format(dataset_str, names[i]), 'rb') as f: 53 | if sys.version_info > (3, 0): 54 | objects.append(pkl.load(f, encoding='latin1')) 55 | else: 56 | objects.append(pkl.load(f)) 57 | 58 | x, y, tx, ty, allx, ally, graph = tuple(objects) 59 | 60 | test_idx_reorder = parse_index_file("data/ind.{}.test.index".format(dataset_str)) 61 | test_idx_range = np.sort(test_idx_reorder) 62 | 63 | if dataset_str == 'citeseer': 64 | # Fix citeseer dataset (there are some isolated nodes in the graph) 65 | # Find isolated nodes, add them as zero-vecs into the right position 66 | test_idx_range_full = range(min(test_idx_reorder), max(test_idx_reorder)+1) 67 | tx_extended = sp.lil_matrix((len(test_idx_range_full), x.shape[1])) 68 | tx_extended[test_idx_range-min(test_idx_range), :] = tx 69 | tx = tx_extended 70 | ty_extended = np.zeros((len(test_idx_range_full), y.shape[1])) 71 | ty_extended[test_idx_range-min(test_idx_range), :] = ty 72 | ty = ty_extended 73 | 74 | features = sp.vstack((allx, tx)).tolil() 75 | features[test_idx_reorder, :] = features[test_idx_range, :] 76 | adj = nx.adjacency_matrix(nx.from_dict_of_lists(graph)) 77 | 78 | labels = np.vstack((ally, ty)) 79 | labels[test_idx_reorder, :] = labels[test_idx_range, :] 80 | 81 | idx_test = test_idx_range.tolist() 82 | idx_train = range(len(y)) 83 | idx_val = range(len(y), len(y)+500) 84 | 85 | if(alldata == True): 86 | features = sp.vstack((allx, tx)).tolil() 87 | labels = np.vstack((ally,ty)) 88 | num = labels.shape[0] 89 | idx_train = range(num/5*3) 90 | idx_val = range(num/5*3, num/5*4) 91 | idx_test = range(num/5*4, num) 92 | 93 | train_mask = sample_mask(idx_train, labels.shape[0]) 94 | val_mask = sample_mask(idx_val, labels.shape[0]) 95 | test_mask = sample_mask(idx_test, labels.shape[0]) 96 | 97 | y_train = np.zeros(labels.shape) 98 | y_val = np.zeros(labels.shape) 99 | y_test = np.zeros(labels.shape) 100 | y_train[train_mask, :] = labels[train_mask, :] 101 | y_val[val_mask, :] = labels[val_mask, :] 102 | y_test[test_mask, :] = labels[test_mask, :] 103 | return labels,adj, features, y_train, y_val, y_test, train_mask, val_mask, test_mask 104 | 105 | 106 | def sparse_to_tuple(sparse_mx): 107 | """Convert sparse matrix to tuple representation.""" 108 | def to_tuple(mx): 109 | if not sp.isspmatrix_coo(mx): 110 | mx = mx.tocoo() 111 | coords = np.vstack((mx.row, mx.col)).transpose() 112 | values = mx.data 113 | shape = mx.shape 114 | return coords, values, shape 115 | 116 | if isinstance(sparse_mx, list): 117 | for i in range(len(sparse_mx)): 118 | sparse_mx[i] = to_tuple(sparse_mx[i]) 119 | else: 120 | sparse_mx = to_tuple(sparse_mx) 121 | 122 | return sparse_mx 123 | 124 | 125 | def preprocess_features(features): 126 | """Row-normalize feature matrix and convert to tuple representation""" 127 | rowsum = np.array(features.sum(1)) 128 | # print rowsum 129 | r_inv = np.power(rowsum, -1).flatten() 130 | r_inv[np.isinf(r_inv)] = 0. 131 | r_mat_inv = sp.diags(r_inv,0) 132 | features = r_mat_inv.dot(features) 133 | return sparse_to_tuple(features) 134 | 135 | 136 | def normalize_adj(adj): 137 | """Symmetrically normalize adjacency matrix.""" 138 | adj = sp.coo_matrix(adj) 139 | rowsum = np.array(adj.sum(1)) 140 | d_inv_sqrt = np.power(rowsum, -0.5).flatten() 141 | d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0. 142 | d_mat_inv_sqrt = sp.diags(d_inv_sqrt,0) 143 | return adj.dot(d_mat_inv_sqrt).transpose().dot(d_mat_inv_sqrt).tocoo() 144 | 145 | 146 | def preprocess_adj(adj): 147 | """Preprocessing of adjacency matrix for simple GCN model and conversion to tuple representation.""" 148 | adj_normalized = normalize_adj(adj + sp.eye(adj.shape[0])) 149 | # return adj_normalized 150 | return sparse_to_tuple(adj_normalized) 151 | 152 | 153 | def construct_feed_dict(features, support, labels, labels_mask, placeholders): 154 | """Construct feed dictionary.""" 155 | feed_dict = dict() 156 | feed_dict.update({placeholders['labels']: labels}) 157 | feed_dict.update({placeholders['labels_mask']: labels_mask}) 158 | feed_dict.update({placeholders['features']: features}) 159 | feed_dict.update({placeholders['support'][i]: support[i] for i in range(len(support))}) 160 | feed_dict.update({placeholders['num_features_nonzero']: features[1].shape}) 161 | return feed_dict 162 | 163 | def wavelet_basis(dataset,adj,s,laplacian_normalize,sparse_ness,threshold,weight_normalize): 164 | 165 | L = laplacian(adj,normalized=laplacian_normalize) 166 | lamb, U = fourier(dataset,L) 167 | #import pdb; pdb.set_trace() 168 | Weight = weight_wavelet(s,lamb,U) 169 | inverse_Weight = weight_wavelet_inverse(s,lamb,U) 170 | del U,lamb 171 | 172 | if (sparse_ness): 173 | Weight[Weight < threshold] = 0.0 174 | inverse_Weight[inverse_Weight < threshold] = 0.0 175 | 176 | if (weight_normalize == True): 177 | Weight = normalize(Weight, norm='l1', axis=1) 178 | inverse_Weight = normalize(inverse_Weight, norm='l1', axis=1) 179 | 180 | Weight = sp.csr_matrix(Weight) 181 | inverse_Weight = sp.csr_matrix(inverse_Weight) 182 | t_k = [inverse_Weight,Weight] 183 | return(t_k) 184 | 185 | def wavelet_basis_appro(dataset,adj,s,laplacian_normalize,sparse_ness,threshold,weight_normalize): 186 | 187 | L = laplacian(adj,normalized=laplacian_normalize) 188 | L = L - sp.eye(adj.shape[0]) 189 | L = L.todense() 190 | # quick version for s = 2 191 | #Weight = 16.844 * sp.eye(adj.shape[0]) + 23.507 * L 192 | #inverse_Weight = 0.309 * sp.eye(adj.shape[0]) - 0.431 * L 193 | Weight = np.exp(s) * ss.iv(0,s) * np.eye(adj.shape[0]) + 2 * np.exp(s) * ss.iv(1,s) * L 194 | inverse_Weight = np.exp(-s) * ss.iv(0,-s) * np.eye(adj.shape[0]) + 2 * np.exp(-s) * ss.iv(1,-s) * L 195 | 196 | if (sparse_ness): 197 | Weight[Weight < threshold] = 0.0 198 | inverse_Weight[inverse_Weight < threshold] = 0.0 199 | 200 | if (weight_normalize == True): 201 | Weight = normalize(Weight, norm='l1', axis=1) 202 | inverse_Weight = normalize(inverse_Weight, norm='l1', axis=1) 203 | 204 | Weight = sp.csr_matrix(Weight) 205 | inverse_Weight = sp.csr_matrix(inverse_Weight) 206 | t_k = [inverse_Weight,Weight] 207 | return(t_k) 208 | 209 | -------------------------------------------------------------------------------- /weighting_func.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pickle as pkl 3 | import networkx as nx 4 | import scipy.sparse 5 | import sys 6 | import math 7 | import warnings 8 | warnings.filterwarnings("ignore") 9 | 10 | def adj_matrix(): 11 | names = [ 'graph'] 12 | objects = [] 13 | for i in range(len(names)): 14 | with open("data/ind.{}.{}".format("cora", names[i]), 'rb') as f: 15 | if sys.version_info > (3, 0): 16 | objects = pkl.load(f, encoding='latin1') 17 | else: 18 | objects = pkl.load(f) 19 | graph = objects 20 | adj = nx.adjacency_matrix(nx.from_dict_of_lists(graph)) 21 | return adj 22 | 23 | def laplacian(W, normalized=False): 24 | """Return the Laplacian of the weight matrix.""" 25 | # Degree matrix. 26 | d = W.sum(axis=0) 27 | # Laplacian matrix. 28 | if not normalized: 29 | D = scipy.sparse.diags(d.A.squeeze(), 0) 30 | L = D - W 31 | else: 32 | # d += np.spacing(np.array(0, W.dtype)) 33 | d = 1 / np.sqrt(d) 34 | D = scipy.sparse.diags(d.A.squeeze(), 0) 35 | I = scipy.sparse.identity(d.size, dtype=W.dtype) 36 | L = I - D * W * D 37 | 38 | # assert np.abs(L - L.T).mean() < 1e-9 39 | assert type(L) is scipy.sparse.csr.csr_matrix 40 | return L 41 | 42 | def fourier(dataset,L, algo='eigh', k=100): 43 | """Return the Fourier basis, i.e. the EVD of the Laplacian.""" 44 | # print "eigen decomposition:" 45 | def sort(lamb, U): 46 | idx = lamb.argsort() 47 | return lamb[idx], U[:, idx] 48 | # if(dataset == "pubmed"): 49 | # # print "loading pubmed U" 50 | # rfile = open("data/pubmed_U.pkl") 51 | # lamb, U = pkl.load(rfile) 52 | # rfile.close() 53 | # else: 54 | if algo is 'eig': 55 | lamb, U = np.linalg.eig(L.toarray()) 56 | lamb, U = sort(lamb, U) 57 | elif algo is 'eigh': 58 | lamb, U = np.linalg.eigh(L.toarray()) 59 | lamb, U = sort(lamb, U) 60 | elif algo is 'eigs': 61 | lamb, U = scipy.sparse.linalg.eigs(L, k=k, which='SM') 62 | lamb, U = sort(lamb, U) 63 | elif algo is 'eigsh': 64 | lamb, U = scipy.sparse.linalg.eigsh(L, k=k, which='SM') 65 | # print "end" 66 | # wfile = open("data/pubmed_U.pkl","w") 67 | # pkl.dump([lamb,U],wfile) 68 | # wfile.close() 69 | # print "pkl U end" 70 | return lamb, U 71 | 72 | def weight_wavelet(s,lamb,U): 73 | s = s 74 | for i in range(len(lamb)): 75 | lamb[i] = math.pow(math.e,-lamb[i]*s) 76 | 77 | Weight = np.dot(np.dot(U,np.diag(lamb)),np.transpose(U)) 78 | 79 | return Weight 80 | 81 | def weight_wavelet_inverse(s,lamb,U): 82 | s = s 83 | for i in range(len(lamb)): 84 | lamb[i] = math.pow(math.e, lamb[i] * s) 85 | 86 | Weight = np.dot(np.dot(U, np.diag(lamb)), np.transpose(U)) 87 | 88 | return Weight 89 | 90 | 91 | 92 | 93 | 94 | --------------------------------------------------------------------------------