├── Framework.png
├── LICENSE
├── README.md
├── __init__.py
├── data
├── .DS_Store
├── ind.cora.allx
├── ind.cora.ally
├── ind.cora.graph
├── ind.cora.test.index
├── ind.cora.tx
├── ind.cora.ty
├── ind.cora.x
└── ind.cora.y
├── inits.py
├── layers.py
├── metrics.py
├── models.py
├── train.py
├── utils.py
└── weighting_func.py
/Framework.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SwiftieH/SpGAT/b9fdd1a326e28d4d4dcd922cdebaedd764783cf6/Framework.png
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 Heng Chang
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # SpGAT
2 |
3 | This is a TensorFlow implementation of Spectral Graph Attention Network with Fast Eigen-approximation (**SpGAT**).
4 |
5 | Heng Chang, Yu Rong, Tingyang Xu, Wenbing Huang, Somayeh Sojoudi, Junzhou Huang, Wenwu Zhu, [Spectral Graph Attention Network with Fast Eigen-approximation](https://dl.acm.org/doi/abs/10.1145/3459637.3482187), CIKM 2021.
6 |
7 |
8 |

9 |
10 |
11 | ## Requirements
12 | * python3
13 | * tensorflow (tested on 1.12.0)
14 | * networkx
15 | * numpy
16 | * scipy
17 | * sklearn
18 |
19 | Anaconda environment is recommended.
20 |
21 | ## Run the code
22 | To replicate the result of SpGAT on Cora:
23 | ```bash
24 | python train.py
25 | ```
26 | To replicate the result of SpGAT_Cheby on Cora:
27 | ```bash
28 | python train.py --model SpGAT_Cheby
29 | ```
30 |
31 | ## Acknowledgement
32 | This repo is modified from [GWNN](https://github.com/Eilene/GWNN), and we sincerely thank them for their contributions.
33 |
34 | ## Reference
35 | - If you find ``SpGAT`` useful in your research, please cite the following in your manuscript:
36 |
37 | ```
38 | @article{chang2020spectral,
39 | title={Spectral Graph Attention Network with Fast Eigen-approximation},
40 | author={Chang, Heng and Rong, Yu and Xu, Tingyang and Huang, Wenbing and Sojoudi, Somayeh and Huang, Junzhou and Zhu, Wenwu},
41 | journal={arXiv preprint arXiv:2003.07450},
42 | year={2020}
43 | }
44 | ```
45 |
46 |
--------------------------------------------------------------------------------
/__init__.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | from __future__ import division
3 |
--------------------------------------------------------------------------------
/data/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SwiftieH/SpGAT/b9fdd1a326e28d4d4dcd922cdebaedd764783cf6/data/.DS_Store
--------------------------------------------------------------------------------
/data/ind.cora.allx:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SwiftieH/SpGAT/b9fdd1a326e28d4d4dcd922cdebaedd764783cf6/data/ind.cora.allx
--------------------------------------------------------------------------------
/data/ind.cora.ally:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SwiftieH/SpGAT/b9fdd1a326e28d4d4dcd922cdebaedd764783cf6/data/ind.cora.ally
--------------------------------------------------------------------------------
/data/ind.cora.graph:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SwiftieH/SpGAT/b9fdd1a326e28d4d4dcd922cdebaedd764783cf6/data/ind.cora.graph
--------------------------------------------------------------------------------
/data/ind.cora.test.index:
--------------------------------------------------------------------------------
1 | 2692
2 | 2532
3 | 2050
4 | 1715
5 | 2362
6 | 2609
7 | 2622
8 | 1975
9 | 2081
10 | 1767
11 | 2263
12 | 1725
13 | 2588
14 | 2259
15 | 2357
16 | 1998
17 | 2574
18 | 2179
19 | 2291
20 | 2382
21 | 1812
22 | 1751
23 | 2422
24 | 1937
25 | 2631
26 | 2510
27 | 2378
28 | 2589
29 | 2345
30 | 1943
31 | 1850
32 | 2298
33 | 1825
34 | 2035
35 | 2507
36 | 2313
37 | 1906
38 | 1797
39 | 2023
40 | 2159
41 | 2495
42 | 1886
43 | 2122
44 | 2369
45 | 2461
46 | 1925
47 | 2565
48 | 1858
49 | 2234
50 | 2000
51 | 1846
52 | 2318
53 | 1723
54 | 2559
55 | 2258
56 | 1763
57 | 1991
58 | 1922
59 | 2003
60 | 2662
61 | 2250
62 | 2064
63 | 2529
64 | 1888
65 | 2499
66 | 2454
67 | 2320
68 | 2287
69 | 2203
70 | 2018
71 | 2002
72 | 2632
73 | 2554
74 | 2314
75 | 2537
76 | 1760
77 | 2088
78 | 2086
79 | 2218
80 | 2605
81 | 1953
82 | 2403
83 | 1920
84 | 2015
85 | 2335
86 | 2535
87 | 1837
88 | 2009
89 | 1905
90 | 2636
91 | 1942
92 | 2193
93 | 2576
94 | 2373
95 | 1873
96 | 2463
97 | 2509
98 | 1954
99 | 2656
100 | 2455
101 | 2494
102 | 2295
103 | 2114
104 | 2561
105 | 2176
106 | 2275
107 | 2635
108 | 2442
109 | 2704
110 | 2127
111 | 2085
112 | 2214
113 | 2487
114 | 1739
115 | 2543
116 | 1783
117 | 2485
118 | 2262
119 | 2472
120 | 2326
121 | 1738
122 | 2170
123 | 2100
124 | 2384
125 | 2152
126 | 2647
127 | 2693
128 | 2376
129 | 1775
130 | 1726
131 | 2476
132 | 2195
133 | 1773
134 | 1793
135 | 2194
136 | 2581
137 | 1854
138 | 2524
139 | 1945
140 | 1781
141 | 1987
142 | 2599
143 | 1744
144 | 2225
145 | 2300
146 | 1928
147 | 2042
148 | 2202
149 | 1958
150 | 1816
151 | 1916
152 | 2679
153 | 2190
154 | 1733
155 | 2034
156 | 2643
157 | 2177
158 | 1883
159 | 1917
160 | 1996
161 | 2491
162 | 2268
163 | 2231
164 | 2471
165 | 1919
166 | 1909
167 | 2012
168 | 2522
169 | 1865
170 | 2466
171 | 2469
172 | 2087
173 | 2584
174 | 2563
175 | 1924
176 | 2143
177 | 1736
178 | 1966
179 | 2533
180 | 2490
181 | 2630
182 | 1973
183 | 2568
184 | 1978
185 | 2664
186 | 2633
187 | 2312
188 | 2178
189 | 1754
190 | 2307
191 | 2480
192 | 1960
193 | 1742
194 | 1962
195 | 2160
196 | 2070
197 | 2553
198 | 2433
199 | 1768
200 | 2659
201 | 2379
202 | 2271
203 | 1776
204 | 2153
205 | 1877
206 | 2027
207 | 2028
208 | 2155
209 | 2196
210 | 2483
211 | 2026
212 | 2158
213 | 2407
214 | 1821
215 | 2131
216 | 2676
217 | 2277
218 | 2489
219 | 2424
220 | 1963
221 | 1808
222 | 1859
223 | 2597
224 | 2548
225 | 2368
226 | 1817
227 | 2405
228 | 2413
229 | 2603
230 | 2350
231 | 2118
232 | 2329
233 | 1969
234 | 2577
235 | 2475
236 | 2467
237 | 2425
238 | 1769
239 | 2092
240 | 2044
241 | 2586
242 | 2608
243 | 1983
244 | 2109
245 | 2649
246 | 1964
247 | 2144
248 | 1902
249 | 2411
250 | 2508
251 | 2360
252 | 1721
253 | 2005
254 | 2014
255 | 2308
256 | 2646
257 | 1949
258 | 1830
259 | 2212
260 | 2596
261 | 1832
262 | 1735
263 | 1866
264 | 2695
265 | 1941
266 | 2546
267 | 2498
268 | 2686
269 | 2665
270 | 1784
271 | 2613
272 | 1970
273 | 2021
274 | 2211
275 | 2516
276 | 2185
277 | 2479
278 | 2699
279 | 2150
280 | 1990
281 | 2063
282 | 2075
283 | 1979
284 | 2094
285 | 1787
286 | 2571
287 | 2690
288 | 1926
289 | 2341
290 | 2566
291 | 1957
292 | 1709
293 | 1955
294 | 2570
295 | 2387
296 | 1811
297 | 2025
298 | 2447
299 | 2696
300 | 2052
301 | 2366
302 | 1857
303 | 2273
304 | 2245
305 | 2672
306 | 2133
307 | 2421
308 | 1929
309 | 2125
310 | 2319
311 | 2641
312 | 2167
313 | 2418
314 | 1765
315 | 1761
316 | 1828
317 | 2188
318 | 1972
319 | 1997
320 | 2419
321 | 2289
322 | 2296
323 | 2587
324 | 2051
325 | 2440
326 | 2053
327 | 2191
328 | 1923
329 | 2164
330 | 1861
331 | 2339
332 | 2333
333 | 2523
334 | 2670
335 | 2121
336 | 1921
337 | 1724
338 | 2253
339 | 2374
340 | 1940
341 | 2545
342 | 2301
343 | 2244
344 | 2156
345 | 1849
346 | 2551
347 | 2011
348 | 2279
349 | 2572
350 | 1757
351 | 2400
352 | 2569
353 | 2072
354 | 2526
355 | 2173
356 | 2069
357 | 2036
358 | 1819
359 | 1734
360 | 1880
361 | 2137
362 | 2408
363 | 2226
364 | 2604
365 | 1771
366 | 2698
367 | 2187
368 | 2060
369 | 1756
370 | 2201
371 | 2066
372 | 2439
373 | 1844
374 | 1772
375 | 2383
376 | 2398
377 | 1708
378 | 1992
379 | 1959
380 | 1794
381 | 2426
382 | 2702
383 | 2444
384 | 1944
385 | 1829
386 | 2660
387 | 2497
388 | 2607
389 | 2343
390 | 1730
391 | 2624
392 | 1790
393 | 1935
394 | 1967
395 | 2401
396 | 2255
397 | 2355
398 | 2348
399 | 1931
400 | 2183
401 | 2161
402 | 2701
403 | 1948
404 | 2501
405 | 2192
406 | 2404
407 | 2209
408 | 2331
409 | 1810
410 | 2363
411 | 2334
412 | 1887
413 | 2393
414 | 2557
415 | 1719
416 | 1732
417 | 1986
418 | 2037
419 | 2056
420 | 1867
421 | 2126
422 | 1932
423 | 2117
424 | 1807
425 | 1801
426 | 1743
427 | 2041
428 | 1843
429 | 2388
430 | 2221
431 | 1833
432 | 2677
433 | 1778
434 | 2661
435 | 2306
436 | 2394
437 | 2106
438 | 2430
439 | 2371
440 | 2606
441 | 2353
442 | 2269
443 | 2317
444 | 2645
445 | 2372
446 | 2550
447 | 2043
448 | 1968
449 | 2165
450 | 2310
451 | 1985
452 | 2446
453 | 1982
454 | 2377
455 | 2207
456 | 1818
457 | 1913
458 | 1766
459 | 1722
460 | 1894
461 | 2020
462 | 1881
463 | 2621
464 | 2409
465 | 2261
466 | 2458
467 | 2096
468 | 1712
469 | 2594
470 | 2293
471 | 2048
472 | 2359
473 | 1839
474 | 2392
475 | 2254
476 | 1911
477 | 2101
478 | 2367
479 | 1889
480 | 1753
481 | 2555
482 | 2246
483 | 2264
484 | 2010
485 | 2336
486 | 2651
487 | 2017
488 | 2140
489 | 1842
490 | 2019
491 | 1890
492 | 2525
493 | 2134
494 | 2492
495 | 2652
496 | 2040
497 | 2145
498 | 2575
499 | 2166
500 | 1999
501 | 2434
502 | 1711
503 | 2276
504 | 2450
505 | 2389
506 | 2669
507 | 2595
508 | 1814
509 | 2039
510 | 2502
511 | 1896
512 | 2168
513 | 2344
514 | 2637
515 | 2031
516 | 1977
517 | 2380
518 | 1936
519 | 2047
520 | 2460
521 | 2102
522 | 1745
523 | 2650
524 | 2046
525 | 2514
526 | 1980
527 | 2352
528 | 2113
529 | 1713
530 | 2058
531 | 2558
532 | 1718
533 | 1864
534 | 1876
535 | 2338
536 | 1879
537 | 1891
538 | 2186
539 | 2451
540 | 2181
541 | 2638
542 | 2644
543 | 2103
544 | 2591
545 | 2266
546 | 2468
547 | 1869
548 | 2582
549 | 2674
550 | 2361
551 | 2462
552 | 1748
553 | 2215
554 | 2615
555 | 2236
556 | 2248
557 | 2493
558 | 2342
559 | 2449
560 | 2274
561 | 1824
562 | 1852
563 | 1870
564 | 2441
565 | 2356
566 | 1835
567 | 2694
568 | 2602
569 | 2685
570 | 1893
571 | 2544
572 | 2536
573 | 1994
574 | 1853
575 | 1838
576 | 1786
577 | 1930
578 | 2539
579 | 1892
580 | 2265
581 | 2618
582 | 2486
583 | 2583
584 | 2061
585 | 1796
586 | 1806
587 | 2084
588 | 1933
589 | 2095
590 | 2136
591 | 2078
592 | 1884
593 | 2438
594 | 2286
595 | 2138
596 | 1750
597 | 2184
598 | 1799
599 | 2278
600 | 2410
601 | 2642
602 | 2435
603 | 1956
604 | 2399
605 | 1774
606 | 2129
607 | 1898
608 | 1823
609 | 1938
610 | 2299
611 | 1862
612 | 2420
613 | 2673
614 | 1984
615 | 2204
616 | 1717
617 | 2074
618 | 2213
619 | 2436
620 | 2297
621 | 2592
622 | 2667
623 | 2703
624 | 2511
625 | 1779
626 | 1782
627 | 2625
628 | 2365
629 | 2315
630 | 2381
631 | 1788
632 | 1714
633 | 2302
634 | 1927
635 | 2325
636 | 2506
637 | 2169
638 | 2328
639 | 2629
640 | 2128
641 | 2655
642 | 2282
643 | 2073
644 | 2395
645 | 2247
646 | 2521
647 | 2260
648 | 1868
649 | 1988
650 | 2324
651 | 2705
652 | 2541
653 | 1731
654 | 2681
655 | 2707
656 | 2465
657 | 1785
658 | 2149
659 | 2045
660 | 2505
661 | 2611
662 | 2217
663 | 2180
664 | 1904
665 | 2453
666 | 2484
667 | 1871
668 | 2309
669 | 2349
670 | 2482
671 | 2004
672 | 1965
673 | 2406
674 | 2162
675 | 1805
676 | 2654
677 | 2007
678 | 1947
679 | 1981
680 | 2112
681 | 2141
682 | 1720
683 | 1758
684 | 2080
685 | 2330
686 | 2030
687 | 2432
688 | 2089
689 | 2547
690 | 1820
691 | 1815
692 | 2675
693 | 1840
694 | 2658
695 | 2370
696 | 2251
697 | 1908
698 | 2029
699 | 2068
700 | 2513
701 | 2549
702 | 2267
703 | 2580
704 | 2327
705 | 2351
706 | 2111
707 | 2022
708 | 2321
709 | 2614
710 | 2252
711 | 2104
712 | 1822
713 | 2552
714 | 2243
715 | 1798
716 | 2396
717 | 2663
718 | 2564
719 | 2148
720 | 2562
721 | 2684
722 | 2001
723 | 2151
724 | 2706
725 | 2240
726 | 2474
727 | 2303
728 | 2634
729 | 2680
730 | 2055
731 | 2090
732 | 2503
733 | 2347
734 | 2402
735 | 2238
736 | 1950
737 | 2054
738 | 2016
739 | 1872
740 | 2233
741 | 1710
742 | 2032
743 | 2540
744 | 2628
745 | 1795
746 | 2616
747 | 1903
748 | 2531
749 | 2567
750 | 1946
751 | 1897
752 | 2222
753 | 2227
754 | 2627
755 | 1856
756 | 2464
757 | 2241
758 | 2481
759 | 2130
760 | 2311
761 | 2083
762 | 2223
763 | 2284
764 | 2235
765 | 2097
766 | 1752
767 | 2515
768 | 2527
769 | 2385
770 | 2189
771 | 2283
772 | 2182
773 | 2079
774 | 2375
775 | 2174
776 | 2437
777 | 1993
778 | 2517
779 | 2443
780 | 2224
781 | 2648
782 | 2171
783 | 2290
784 | 2542
785 | 2038
786 | 1855
787 | 1831
788 | 1759
789 | 1848
790 | 2445
791 | 1827
792 | 2429
793 | 2205
794 | 2598
795 | 2657
796 | 1728
797 | 2065
798 | 1918
799 | 2427
800 | 2573
801 | 2620
802 | 2292
803 | 1777
804 | 2008
805 | 1875
806 | 2288
807 | 2256
808 | 2033
809 | 2470
810 | 2585
811 | 2610
812 | 2082
813 | 2230
814 | 1915
815 | 1847
816 | 2337
817 | 2512
818 | 2386
819 | 2006
820 | 2653
821 | 2346
822 | 1951
823 | 2110
824 | 2639
825 | 2520
826 | 1939
827 | 2683
828 | 2139
829 | 2220
830 | 1910
831 | 2237
832 | 1900
833 | 1836
834 | 2197
835 | 1716
836 | 1860
837 | 2077
838 | 2519
839 | 2538
840 | 2323
841 | 1914
842 | 1971
843 | 1845
844 | 2132
845 | 1802
846 | 1907
847 | 2640
848 | 2496
849 | 2281
850 | 2198
851 | 2416
852 | 2285
853 | 1755
854 | 2431
855 | 2071
856 | 2249
857 | 2123
858 | 1727
859 | 2459
860 | 2304
861 | 2199
862 | 1791
863 | 1809
864 | 1780
865 | 2210
866 | 2417
867 | 1874
868 | 1878
869 | 2116
870 | 1961
871 | 1863
872 | 2579
873 | 2477
874 | 2228
875 | 2332
876 | 2578
877 | 2457
878 | 2024
879 | 1934
880 | 2316
881 | 1841
882 | 1764
883 | 1737
884 | 2322
885 | 2239
886 | 2294
887 | 1729
888 | 2488
889 | 1974
890 | 2473
891 | 2098
892 | 2612
893 | 1834
894 | 2340
895 | 2423
896 | 2175
897 | 2280
898 | 2617
899 | 2208
900 | 2560
901 | 1741
902 | 2600
903 | 2059
904 | 1747
905 | 2242
906 | 2700
907 | 2232
908 | 2057
909 | 2147
910 | 2682
911 | 1792
912 | 1826
913 | 2120
914 | 1895
915 | 2364
916 | 2163
917 | 1851
918 | 2391
919 | 2414
920 | 2452
921 | 1803
922 | 1989
923 | 2623
924 | 2200
925 | 2528
926 | 2415
927 | 1804
928 | 2146
929 | 2619
930 | 2687
931 | 1762
932 | 2172
933 | 2270
934 | 2678
935 | 2593
936 | 2448
937 | 1882
938 | 2257
939 | 2500
940 | 1899
941 | 2478
942 | 2412
943 | 2107
944 | 1746
945 | 2428
946 | 2115
947 | 1800
948 | 1901
949 | 2397
950 | 2530
951 | 1912
952 | 2108
953 | 2206
954 | 2091
955 | 1740
956 | 2219
957 | 1976
958 | 2099
959 | 2142
960 | 2671
961 | 2668
962 | 2216
963 | 2272
964 | 2229
965 | 2666
966 | 2456
967 | 2534
968 | 2697
969 | 2688
970 | 2062
971 | 2691
972 | 2689
973 | 2154
974 | 2590
975 | 2626
976 | 2390
977 | 1813
978 | 2067
979 | 1952
980 | 2518
981 | 2358
982 | 1789
983 | 2076
984 | 2049
985 | 2119
986 | 2013
987 | 2124
988 | 2556
989 | 2105
990 | 2093
991 | 1885
992 | 2305
993 | 2354
994 | 2135
995 | 2601
996 | 1770
997 | 1995
998 | 2504
999 | 1749
1000 | 2157
1001 |
--------------------------------------------------------------------------------
/data/ind.cora.tx:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SwiftieH/SpGAT/b9fdd1a326e28d4d4dcd922cdebaedd764783cf6/data/ind.cora.tx
--------------------------------------------------------------------------------
/data/ind.cora.ty:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SwiftieH/SpGAT/b9fdd1a326e28d4d4dcd922cdebaedd764783cf6/data/ind.cora.ty
--------------------------------------------------------------------------------
/data/ind.cora.x:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SwiftieH/SpGAT/b9fdd1a326e28d4d4dcd922cdebaedd764783cf6/data/ind.cora.x
--------------------------------------------------------------------------------
/data/ind.cora.y:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SwiftieH/SpGAT/b9fdd1a326e28d4d4dcd922cdebaedd764783cf6/data/ind.cora.y
--------------------------------------------------------------------------------
/inits.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import numpy as np
3 |
4 |
5 | def uniform(shape, scale=1.0, name=None):
6 | """Uniform init."""
7 | initial = tf.random_uniform(shape, minval=0.0, maxval=scale, dtype=tf.float32)
8 | return tf.Variable(initial, name=name)
9 |
10 |
11 | def glorot(shape, name=None):
12 | """Glorot & Bengio (AISTATS 2010) init."""
13 | init_range = np.sqrt(6.0/(shape[0]+shape[1]))
14 | initial = tf.random_uniform(shape, minval=-init_range, maxval=init_range, dtype=tf.float32)
15 | return tf.Variable(initial, name=name)
16 |
17 |
18 | def zeros(shape, name=None):
19 | """All zeros."""
20 | initial = tf.zeros(shape, dtype=tf.float32)
21 | return tf.Variable(initial, name=name)
22 |
23 |
24 | def ones(shape, name=None):
25 | """All ones."""
26 | initial = tf.ones(shape, dtype=tf.float32)
27 | return tf.Variable(initial, name=name)
28 |
29 | def ones_fix(shape, name=None):
30 | """All ones."""
31 | initial = tf.ones(shape, dtype=tf.float32)
32 | return tf.Variable(initial, name=name, trainable=False)
33 |
34 |
35 |
36 |
37 |
--------------------------------------------------------------------------------
/layers.py:
--------------------------------------------------------------------------------
1 | from inits import *
2 | import numpy as np
3 | import tensorflow as tf
4 | from sklearn.preprocessing import normalize
5 | flags = tf.app.flags
6 | FLAGS = flags.FLAGS
7 |
8 | # global unique layer ID dictionary for layer name assignment
9 | _LAYER_UIDS = {}
10 |
11 |
12 | def get_layer_uid(layer_name=''):
13 | """Helper function, assigns unique layer IDs."""
14 | if layer_name not in _LAYER_UIDS:
15 | _LAYER_UIDS[layer_name] = 1
16 | return 1
17 | else:
18 | _LAYER_UIDS[layer_name] += 1
19 | return _LAYER_UIDS[layer_name]
20 |
21 | def sparse_dropout(x, keep_prob, noise_shape):
22 | """Dropout for sparse tensors."""
23 | random_tensor = keep_prob
24 | random_tensor += tf.random_uniform(noise_shape)
25 | dropout_mask = tf.cast(tf.floor(random_tensor), dtype=tf.bool)
26 | pre_out = tf.sparse_retain(x, dropout_mask)
27 | return pre_out * (1./keep_prob)
28 |
29 | def dot(x, y, sparse=False):
30 | """Wrapper for tf.matmul (sparse vs dense)."""
31 | if sparse:
32 | res = tf.sparse_tensor_dense_matmul(x, y)
33 | else:
34 | res = tf.matmul(x, y)
35 | return res
36 |
37 | class Layer(object):
38 | """Base layer class. Defines basic API for all layer objects.
39 | Implementation inspired by keras (http://keras.io).
40 |
41 | # Properties
42 | name: String, defines the variable scope of the layer.
43 | logging: Boolean, switches Tensorflow histogram logging on/off
44 |
45 | # Methods
46 | _call(inputs): Defines computation graph of layer
47 | (i.e. takes input, returns output)
48 | __call__(inputs): Wrapper for _call()
49 | _log_vars(): Log all variables
50 | """
51 |
52 | def __init__(self, **kwargs):
53 | allowed_kwargs = {'name', 'logging'}
54 | for kwarg in kwargs.keys():
55 | assert kwarg in allowed_kwargs, 'Invalid keyword argument: ' + kwarg
56 | name = kwargs.get('name')
57 | if not name:
58 | layer = self.__class__.__name__.lower()
59 | name = layer + '_' + str(get_layer_uid(layer))
60 | self.name = name
61 | self.vars = {}
62 | logging = kwargs.get('logging', False)
63 | self.logging = logging
64 | self.sparse_inputs = False
65 |
66 | def _call(self, inputs):
67 | return inputs
68 |
69 | def __call__(self, inputs):
70 | with tf.name_scope(self.name):
71 | if self.logging and not self.sparse_inputs:
72 | tf.summary.histogram(self.name + '/inputs', inputs)
73 | outputs = self._call(inputs)
74 | if self.logging:
75 | tf.summary.histogram(self.name + '/outputs', outputs)
76 | return outputs
77 |
78 | def _log_vars(self):
79 | for var in self.vars:
80 | tf.summary.histogram(self.name + '/vars/' + var, self.vars[var])
81 |
82 | class SpGAT_Conv(Layer):
83 | """Graph convolution layer."""
84 | def __init__(self, k_por, node_num,weight_normalize,input_dim, output_dim, placeholders, dropout=0.,
85 | sparse_inputs=False, act=tf.nn.relu, bias=False,
86 | featureless=False, **kwargs):
87 | super(SpGAT_Conv, self).__init__(**kwargs)
88 |
89 | if dropout:
90 | self.dropout = placeholders['dropout']
91 | else:
92 | self.dropout = 0.
93 |
94 | self.k_por = k_por
95 | self.node_num = node_num
96 | self.weight_normalize = weight_normalize
97 | self.act = act
98 | self.support = placeholders['support']
99 | self.sparse_inputs = sparse_inputs
100 | self.featureless = featureless
101 | self.bias = bias
102 |
103 | # helper variable for sparse dropout
104 | self.num_features_nonzero = placeholders['num_features_nonzero']
105 |
106 | with tf.variable_scope(self.name + '_vars'):
107 | self.vars['weights_' + str(0)] = glorot([input_dim, output_dim],
108 | name='weights_' + str(0))
109 | k_fre = int(self.k_por * self.node_num)
110 | init_alpha = np.array([1, 1], dtype='float32')
111 | self.alpha = tf.get_variable("tf_var_initialized_from_alpha", initializer = init_alpha, trainable=True)
112 | self.alpha = tf.nn.softmax(self.alpha)
113 | self.vars['low_w'] = self.alpha[0]
114 | self.vars['high_w'] = self.alpha[1]
115 |
116 |
117 | self.vars['kernel_low'] = ones_fix([k_fre], name='kernel_low')
118 | self.vars['kernel_high'] = ones_fix([self.node_num - k_fre], name='kernel_high')
119 | self.vars['kernel_low'] = self.vars['kernel_low'] * self.vars['low_w']
120 | self.vars['kernel_high'] = self.vars['kernel_high'] * self.vars['high_w']
121 |
122 |
123 | if self.bias:
124 | self.vars['bias'] = zeros([output_dim], name='bias')
125 |
126 | if self.logging:
127 | self._log_vars()
128 |
129 | def _call(self, inputs):
130 | x = inputs
131 |
132 | # dropout
133 | if self.sparse_inputs:
134 | x = sparse_dropout(x, 1-self.dropout, self.num_features_nonzero)
135 | else:
136 | x = tf.nn.dropout(x, 1-self.dropout)
137 |
138 | supports_low = tf.matmul(tf.sparse_tensor_to_dense(self.support[0]),tf.diag(self.vars['kernel_low']),a_is_sparse=True,b_is_sparse=True)
139 | supports_low = tf.matmul(supports_low,tf.sparse_tensor_to_dense(self.support[1]),a_is_sparse=True,b_is_sparse=True)
140 | pre_sup = dot(x, self.vars['weights_' + str(0)],sparse=self.sparse_inputs)
141 | output_low = dot(supports_low,pre_sup)
142 |
143 |
144 | supports_high = tf.matmul(tf.sparse_tensor_to_dense(self.support[2]),tf.diag(self.vars['kernel_high']),a_is_sparse=True,b_is_sparse=True)
145 | supports_high = tf.matmul(supports_high,tf.sparse_tensor_to_dense(self.support[3]),a_is_sparse=True,b_is_sparse=True)
146 | output_high = dot(supports_high,pre_sup)
147 |
148 | #Mean Pooling
149 | #output = output_low + output_high
150 | #Max Pooling
151 | output = tf.concat([tf.expand_dims(output_low, axis = 0), tf.expand_dims(output_high, axis = 0)], axis = 0)
152 | output = tf.reduce_max(output, axis = 0)
153 | #import pdb; pdb.set_trace()
154 | if self.bias:
155 | output += self.vars['bias']
156 |
157 | return self.act(output)
158 |
--------------------------------------------------------------------------------
/metrics.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 |
3 |
4 | def masked_softmax_cross_entropy(preds, labels, mask):
5 | """Softmax cross-entropy loss with masking."""
6 | loss = tf.nn.softmax_cross_entropy_with_logits(logits=preds, labels=labels)
7 | mask = tf.cast(mask, dtype=tf.float32)
8 | mask /= tf.reduce_mean(mask)
9 | loss *= mask
10 | return tf.reduce_mean(loss)
11 |
12 |
13 | def masked_accuracy(preds, labels, mask):
14 | """Accuracy with masking."""
15 | correct_prediction = tf.equal(tf.argmax(preds, 1), tf.argmax(labels, 1))
16 | accuracy_all = tf.cast(correct_prediction, tf.float32)
17 | mask = tf.cast(mask, dtype=tf.float32)
18 | mask /= tf.reduce_mean(mask)
19 | accuracy_all *= mask
20 | return tf.reduce_mean(accuracy_all)
21 |
--------------------------------------------------------------------------------
/models.py:
--------------------------------------------------------------------------------
1 | from layers import *
2 | from metrics import *
3 |
4 | flags = tf.app.flags
5 | FLAGS = flags.FLAGS
6 |
7 |
8 | class Model(object):
9 | def __init__(self, **kwargs):
10 | allowed_kwargs = {'name', 'logging'}
11 | for kwarg in kwargs.keys():
12 | assert kwarg in allowed_kwargs, 'Invalid keyword argument: ' + kwarg
13 | name = kwargs.get('name')
14 | if not name:
15 | name = self.__class__.__name__.lower()
16 | self.name = name
17 |
18 | logging = kwargs.get('logging', False)
19 | self.logging = logging
20 |
21 | self.vars = {}
22 | self.placeholders = {}
23 |
24 | self.layers = []
25 | self.activations = []
26 |
27 | self.inputs = None
28 | self.outputs = None
29 |
30 | self.loss = 0
31 | self.accuracy = 0
32 | self.optimizer = None
33 | self.opt_op = None
34 |
35 | def _build(self):
36 | raise NotImplementedError
37 |
38 | def build(self):
39 | """ Wrapper for _build() """
40 | with tf.variable_scope(self.name):
41 | self._build()
42 |
43 | # Build sequential layer model
44 | self.activations.append(self.inputs)
45 | for layer in self.layers:
46 | hidden = layer(self.activations[-1])
47 | self.activations.append(hidden)
48 | self.outputs = self.activations[-1]
49 |
50 | # Store model variables for easy access
51 | variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=self.name)
52 | self.vars = {var.name: var for var in variables}
53 |
54 | # Build metrics
55 | self._loss()
56 | self._accuracy()
57 |
58 | self.opt_op = self.optimizer.minimize(self.loss)
59 |
60 | def predict(self):
61 | pass
62 |
63 | def _loss(self):
64 | raise NotImplementedError
65 |
66 | def _accuracy(self):
67 | raise NotImplementedError
68 |
69 | def save(self, sess=None):
70 | if not sess:
71 | raise AttributeError("TensorFlow session not provided.")
72 | saver = tf.train.Saver(self.vars)
73 | save_path = saver.save(sess, "tmp/%s.ckpt" % self.name)
74 | print("Model saved in file: %s" % save_path)
75 |
76 | def load(self, sess=None):
77 | if not sess:
78 | raise AttributeError("TensorFlow session not provided.")
79 | saver = tf.train.Saver(self.vars)
80 | save_path = "tmp/%s.ckpt" % self.name
81 | saver.restore(sess, save_path)
82 | print("Model restored from file: %s" % save_path)
83 |
84 |
85 | class SpGAT(Model):
86 | def __init__(self, k_por, node_num,weight_normalize,placeholders, input_dim, **kwargs):
87 | super(SpGAT, self).__init__(**kwargs)
88 |
89 | self.weight_normalize = weight_normalize
90 | self.inputs = placeholders['features']
91 | self.k_por = k_por
92 | self.input_dim = input_dim
93 | self.node_num = node_num
94 | self.output_dim = placeholders['labels'].get_shape().as_list()[1]
95 | self.placeholders = placeholders
96 |
97 | self.optimizer = tf.train.AdamOptimizer(learning_rate=FLAGS.learning_rate)
98 |
99 | self.build()
100 |
101 | def _loss(self):
102 | # Weight decay loss
103 | for var in self.layers[0].vars.values():
104 | self.loss += FLAGS.weight_decay * tf.nn.l2_loss(var)
105 |
106 | # Cross entropy error
107 | self.loss += masked_softmax_cross_entropy(self.outputs, self.placeholders['labels'],
108 | self.placeholders['labels_mask'])
109 |
110 | def _accuracy(self):
111 | self.accuracy = masked_accuracy(self.outputs, self.placeholders['labels'],
112 | self.placeholders['labels_mask'])
113 |
114 | def _build(self):
115 |
116 | self.layers.append(SpGAT_Conv(k_por = self.k_por,
117 | node_num=self.node_num,
118 | weight_normalize = self.weight_normalize,
119 | input_dim=self.input_dim,
120 | output_dim=FLAGS.hidden1,
121 | placeholders=self.placeholders,
122 | act=tf.nn.relu,
123 | dropout=True,
124 | sparse_inputs=True,
125 | logging=self.logging))
126 |
127 |
128 | self.layers.append(SpGAT_Conv(k_por = self.k_por,
129 | node_num=self.node_num,
130 | weight_normalize = self.weight_normalize,
131 | input_dim=FLAGS.hidden1,
132 | output_dim=self.output_dim,
133 | placeholders=self.placeholders,
134 | act=lambda x: x,
135 | dropout=True,
136 | logging=self.logging))
137 |
138 | def predict(self):
139 | return tf.nn.softmax(self.outputs)
140 |
141 |
142 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | # -*- coding:UTF-8 -*-
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import tensorflow as tf
6 | import warnings
7 | warnings.filterwarnings("ignore")
8 |
9 | from utils import *
10 | from models import SpGAT
11 |
12 | import os
13 | os.environ['CUDA_VISIBLE_DEVICES']='-1'
14 |
15 | # Set random seed
16 | seed = 322
17 | np.random.seed(seed)
18 | tf.set_random_seed(seed)
19 |
20 | # Settings
21 | flags = tf.app.flags
22 | FLAGS = flags.FLAGS
23 | flags.DEFINE_string('dataset', 'cora', 'Dataset string.') # 'cora', 'citeseer', 'pubmed'
24 | flags.DEFINE_string('model', 'SpGAT', 'Model string.') # 'SpGAT', 'SpGAT_Cheby'
25 | flags.DEFINE_float('wavelet_s', 1.0, 'wavelet s .')
26 | flags.DEFINE_float('threshold', 1e-4, 'sparseness threshold .')
27 | flags.DEFINE_float('learning_rate', 0.01, 'Initial learning rate.')
28 | flags.DEFINE_bool('alldata', False, 'All data string.')
29 | flags.DEFINE_integer('epochs', 200, 'Number of epochs to train.')#1000
30 | flags.DEFINE_integer('hidden1', 64, 'Number of units in hidden layer 1.')
31 | flags.DEFINE_float('dropout', 0.5, 'Dropout rate (1 - keep probability).')
32 | flags.DEFINE_float('weight_decay', 5e-4, 'Weight for L2 loss on embedding matrix.')
33 | flags.DEFINE_integer('early_stopping', 200, 'Tolerance for early stopping (# of epochs).')
34 | flags.DEFINE_bool('mask', True, 'mask string.')
35 | flags.DEFINE_bool('laplacian_normalize', True, 'laplacian normalize string.')
36 | flags.DEFINE_bool('sparse_ness', True, 'wavelet sparse_ness string.')
37 | flags.DEFINE_bool('weight_normalize', False, 'weight normalize string.')
38 | flags.DEFINE_string('gpu', '-1', 'which gpu to use.')#1000
39 | flags.DEFINE_integer('repeating', 1, 'Number of repeating times')#1000
40 |
41 | os.environ['CUDA_VISIBLE_DEVICES']=FLAGS.gpu
42 |
43 |
44 | # Load data
45 | labels, adj, features, y_train, y_val, y_test, train_mask, val_mask, test_mask = load_data(FLAGS.dataset,alldata=FLAGS.alldata)
46 | # Some preprocessing, normalization
47 | features = preprocess_features(features)
48 | node_num = adj.shape[0]
49 |
50 | print("************Loading data finished, Begin constructing wavelet************")
51 |
52 | dataset = FLAGS.dataset
53 | s = FLAGS.wavelet_s
54 | laplacian_normalize = FLAGS.laplacian_normalize
55 | sparse_ness = FLAGS.sparse_ness
56 | threshold = FLAGS.threshold
57 | weight_normalize = FLAGS.weight_normalize
58 | if FLAGS.model == "SpGAT":
59 | support_t = wavelet_basis(dataset,adj, s, laplacian_normalize,sparse_ness,threshold,weight_normalize)
60 | elif FLAGS.model == "SpGAT_Cheby":
61 | s = 2.0
62 | support_t = wavelet_basis_appro(dataset,adj, s, laplacian_normalize,sparse_ness,threshold,weight_normalize)
63 | if dataset == 'cora':
64 | k_por = 0.05 # best $d$ for cora
65 | if dataset == 'pubmed':
66 | k_por = 0.10 # best $d$ for pubmed
67 | if dataset == 'citeseer':
68 | k_por = 0.15 # best $d$ for citeseer
69 | k_fre = int(k_por * node_num)
70 | support = [support_t[0][:,:k_fre], support_t[1][:k_fre,:], support_t[0][:,k_fre:], support_t[1][k_fre:,:]]
71 | sparse_to_tuple(support)
72 | num_supports = len(support)
73 | model_func = SpGAT
74 |
75 | # Define placeholders
76 | placeholders = {
77 | 'support': [tf.sparse_placeholder(tf.float32) for _ in range(num_supports)],
78 | 'features': tf.sparse_placeholder(tf.float32, shape=tf.constant(features[2], dtype=tf.int64)),
79 | 'labels': tf.placeholder(tf.float32, shape=(None, y_train.shape[1])),
80 | 'labels_mask': tf.placeholder(tf.int32),
81 | 'dropout': tf.placeholder_with_default(0., shape=()),
82 | 'num_features_nonzero': tf.placeholder(tf.int32) # helper variable for sparse dropout
83 | }
84 |
85 | # Define model evaluation function
86 | def evaluate(features, support, labels, mask, placeholders):
87 | feed_dict_val = construct_feed_dict(features, support, labels, mask, placeholders)
88 | outs_val = sess.run([model.outputs,model.loss, model.accuracy], feed_dict=feed_dict_val)
89 | return outs_val[0], outs_val[1], outs_val[2]
90 |
91 | test_acc_bestval = []
92 | test_acc_besttest = []
93 | val_acc = []
94 |
95 | for _ in range(FLAGS.repeating):
96 |
97 | #seed = np.random.randint(999)
98 | #np.random.seed(seed)
99 | #tf.set_random_seed(seed)
100 |
101 | # Create model
102 | weight_normalize = FLAGS.weight_normalize
103 | node_num = adj.shape[0]
104 | model = model_func(k_por, node_num,weight_normalize, placeholders, input_dim=features[2][1], logging=True)
105 | print("**************Constructing wavelet finished, Begin training**************")
106 | # Initialize session
107 | sess = tf.Session()
108 |
109 | # Init variables
110 | sess.run(tf.global_variables_initializer())
111 |
112 | # Train model
113 | cost_val = []
114 | best_val_acc = 0.0
115 | output_test_acc = 0.0
116 | best_test_acc = 0.0
117 |
118 | for epoch in range(FLAGS.epochs):
119 |
120 | # Construct feed dictionary
121 | feed_dict = construct_feed_dict(features, support, y_train, train_mask, placeholders)
122 | feed_dict.update({placeholders['dropout']: FLAGS.dropout})
123 |
124 | # Training step
125 | outs = sess.run([model.opt_op, model.loss, model.accuracy], feed_dict=feed_dict)
126 |
127 | # Validation
128 | val_output,cost, acc = evaluate(features, support, y_val, val_mask, placeholders)
129 | cost_val.append(cost)
130 | # Test
131 | test_output, test_cost, test_acc = evaluate(features, support, y_test, test_mask, placeholders)
132 |
133 | # best val acc
134 | if(best_val_acc <= acc):
135 | best_val_acc = acc
136 | output_test_acc = test_acc
137 | # best test acc
138 | if(best_test_acc <= test_acc):
139 | #import pdb; pdb.set_trace()
140 | best_test_acc = test_acc
141 |
142 |
143 | # Print results
144 | print("Epoch:", '%04d' % (epoch + 1), "train_loss=", "{:.5f}".format(outs[1]),
145 | "train_acc=", "{:.5f}".format(outs[2]), "val_loss=", "{:.5f}".format(cost),
146 | "val_acc=", "{:.5f}".format(acc), "test_loss=", "{:.5f}".format(test_cost), "test_acc=", "{:.5f}".format(test_acc))
147 |
148 | if epoch > FLAGS.early_stopping and cost_val[-1] > np.mean(cost_val[-(FLAGS.early_stopping+1):-1]):
149 | print("Early stopping...")
150 | break
151 |
152 | print("Optimization Finished!")
153 |
154 | print("dataset: ",FLAGS.dataset," model: ",FLAGS.model,",sparse_ness: ",FLAGS.sparse_ness,
155 | ",laplacian_normalize: ",FLAGS.laplacian_normalize,",threshold",FLAGS.threshold,",wavelet_s:",FLAGS.wavelet_s,",mask:",FLAGS.mask,
156 | ",weight_normalize:",FLAGS.weight_normalize,
157 | ",learning_rate:",FLAGS.learning_rate,",hidden1:",FLAGS.hidden1,",dropout:",FLAGS.dropout,",alldata:",FLAGS.alldata)
158 |
159 | print("Val accuracy:", best_val_acc, " Test accuracy: ",output_test_acc)
160 | test_acc_bestval.append(output_test_acc)
161 | test_acc_besttest.append(best_test_acc)
162 | val_acc.append(best_val_acc)
163 |
164 | print("********************************************************")
165 |
166 | result = []
167 | result.append(np.array(test_acc_bestval))
168 | result.append(np.array(test_acc_besttest))
169 | result.append(np.array(val_acc))
170 |
171 | alpha_1_low = sess.run(model.layers[0].vars['low_w'], feed_dict = feed_dict)
172 | alpha_1_high = sess.run(model.layers[0].vars['high_w'], feed_dict = feed_dict)
173 | alpha_2_low = sess.run(model.layers[1].vars['low_w'], feed_dict = feed_dict)
174 | alpha_2_high = sess.run(model.layers[1].vars['high_w'], feed_dict = feed_dict)
175 |
176 |
177 | r_half = int(FLAGS.repeating / 2)
178 | print("REPEAT\t{}".format(FLAGS.repeating))
179 | print("Model\t{}".format(FLAGS.model))
180 | print("Low frequency portion\t{} %".format(k_por * 100))
181 | print("{:<8}\t{:<8}\t{:<8}\t{:<8}\t{:<8}\t{:<8}\t{:<8}\t{:<8}\t{:<8}".format('DATASET', 'best_val_mean', 'best_val_std',
182 | 'best_test_mean', 'best_test_std', 'half_best_val_mean',
183 | 'half_best_val_std', 'half_best_test_mean', 'half_best_test_std'))
184 | print("{:<8}\t{:<8.6f}\t{:<8.6f}\t{:<8.6f}\t{:<8.6f}\t{:<8.6f}\t{:<8.6f}\t{:<8.6f}\t{:<8.6f}".format(
185 | FLAGS.dataset,
186 | result[0].mean(),
187 | result[0].std(),
188 | result[1].mean(),
189 | result[1].std(),
190 | result[0][np.argsort(result[0])[r_half:]].mean(),
191 | result[0][np.argsort(result[0])[r_half:]].std(),
192 | result[1][np.argsort(result[1])[r_half:]].mean(), #2 for validation
193 | result[1][np.argsort(result[1])[r_half:]].std()))
194 |
195 | print("{:<8}\t{:<8}\t{:<8}\t{:<8}".format('alpha_1_low', 'alpha_1_high', 'alpha_2_low', 'alpha_2_high'))
196 | print("{:<8.6f}\t{:<8.6f}\t{:<8.6f}\t{:<8.6f}".format(
197 | alpha_1_low,
198 | alpha_1_high,
199 | alpha_2_low,
200 | alpha_2_high))
201 |
202 |
203 |
204 |
205 |
206 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | from sklearn.preprocessing import normalize
2 | import numpy as np
3 | import pickle as pkl
4 | import networkx as nx
5 | import scipy.sparse as sp
6 | import scipy.special as ss
7 | from scipy.sparse.linalg.eigen.arpack import eigsh
8 | import sys
9 | import warnings
10 | warnings.filterwarnings("ignore")
11 | from weighting_func import laplacian,fourier,weight_wavelet,weight_wavelet_inverse
12 |
13 |
14 | def parse_index_file(filename):
15 | """Parse index file."""
16 | index = []
17 | for line in open(filename):
18 | index.append(int(line.strip()))
19 | return index
20 |
21 |
22 | def sample_mask(idx, l):
23 | """Create mask."""
24 | mask = np.zeros(l)
25 | mask[idx] = 1
26 | return np.array(mask, dtype=np.bool)
27 |
28 |
29 | def load_data(dataset_str,alldata = True):
30 | """
31 | Loads input data from gcn/data directory
32 |
33 | ind.dataset_str.x => the feature vectors of the training instances as scipy.sparse.csr.csr_matrix object;
34 | ind.dataset_str.tx => the feature vectors of the test instances as scipy.sparse.csr.csr_matrix object;
35 | ind.dataset_str.allx => the feature vectors of both labeled and unlabeled training instances
36 | (a superset of ind.dataset_str.x) as scipy.sparse.csr.csr_matrix object;
37 | ind.dataset_str.y => the one-hot labels of the labeled training instances as numpy.ndarray object;
38 | ind.dataset_str.ty => the one-hot labels of the test instances as numpy.ndarray object;
39 | ind.dataset_str.ally => the labels for instances in ind.dataset_str.allx as numpy.ndarray object;
40 | ind.dataset_str.graph => a dict in the format {index: [index_of_neighbor_nodes]} as collections.defaultdict
41 | object;
42 | ind.dataset_str.test.index => the indices of test instances in graph, for the inductive setting as list object.
43 |
44 | All objects above must be saved using python pickle module.
45 |
46 | :param dataset_str: Dataset name
47 | :return: All data input files loaded (as well the training/test data).
48 | """
49 | names = ['x', 'y', 'tx', 'ty', 'allx', 'ally', 'graph']
50 | objects = []
51 | for i in range(len(names)):
52 | with open("data/ind.{}.{}".format(dataset_str, names[i]), 'rb') as f:
53 | if sys.version_info > (3, 0):
54 | objects.append(pkl.load(f, encoding='latin1'))
55 | else:
56 | objects.append(pkl.load(f))
57 |
58 | x, y, tx, ty, allx, ally, graph = tuple(objects)
59 |
60 | test_idx_reorder = parse_index_file("data/ind.{}.test.index".format(dataset_str))
61 | test_idx_range = np.sort(test_idx_reorder)
62 |
63 | if dataset_str == 'citeseer':
64 | # Fix citeseer dataset (there are some isolated nodes in the graph)
65 | # Find isolated nodes, add them as zero-vecs into the right position
66 | test_idx_range_full = range(min(test_idx_reorder), max(test_idx_reorder)+1)
67 | tx_extended = sp.lil_matrix((len(test_idx_range_full), x.shape[1]))
68 | tx_extended[test_idx_range-min(test_idx_range), :] = tx
69 | tx = tx_extended
70 | ty_extended = np.zeros((len(test_idx_range_full), y.shape[1]))
71 | ty_extended[test_idx_range-min(test_idx_range), :] = ty
72 | ty = ty_extended
73 |
74 | features = sp.vstack((allx, tx)).tolil()
75 | features[test_idx_reorder, :] = features[test_idx_range, :]
76 | adj = nx.adjacency_matrix(nx.from_dict_of_lists(graph))
77 |
78 | labels = np.vstack((ally, ty))
79 | labels[test_idx_reorder, :] = labels[test_idx_range, :]
80 |
81 | idx_test = test_idx_range.tolist()
82 | idx_train = range(len(y))
83 | idx_val = range(len(y), len(y)+500)
84 |
85 | if(alldata == True):
86 | features = sp.vstack((allx, tx)).tolil()
87 | labels = np.vstack((ally,ty))
88 | num = labels.shape[0]
89 | idx_train = range(num/5*3)
90 | idx_val = range(num/5*3, num/5*4)
91 | idx_test = range(num/5*4, num)
92 |
93 | train_mask = sample_mask(idx_train, labels.shape[0])
94 | val_mask = sample_mask(idx_val, labels.shape[0])
95 | test_mask = sample_mask(idx_test, labels.shape[0])
96 |
97 | y_train = np.zeros(labels.shape)
98 | y_val = np.zeros(labels.shape)
99 | y_test = np.zeros(labels.shape)
100 | y_train[train_mask, :] = labels[train_mask, :]
101 | y_val[val_mask, :] = labels[val_mask, :]
102 | y_test[test_mask, :] = labels[test_mask, :]
103 | return labels,adj, features, y_train, y_val, y_test, train_mask, val_mask, test_mask
104 |
105 |
106 | def sparse_to_tuple(sparse_mx):
107 | """Convert sparse matrix to tuple representation."""
108 | def to_tuple(mx):
109 | if not sp.isspmatrix_coo(mx):
110 | mx = mx.tocoo()
111 | coords = np.vstack((mx.row, mx.col)).transpose()
112 | values = mx.data
113 | shape = mx.shape
114 | return coords, values, shape
115 |
116 | if isinstance(sparse_mx, list):
117 | for i in range(len(sparse_mx)):
118 | sparse_mx[i] = to_tuple(sparse_mx[i])
119 | else:
120 | sparse_mx = to_tuple(sparse_mx)
121 |
122 | return sparse_mx
123 |
124 |
125 | def preprocess_features(features):
126 | """Row-normalize feature matrix and convert to tuple representation"""
127 | rowsum = np.array(features.sum(1))
128 | # print rowsum
129 | r_inv = np.power(rowsum, -1).flatten()
130 | r_inv[np.isinf(r_inv)] = 0.
131 | r_mat_inv = sp.diags(r_inv,0)
132 | features = r_mat_inv.dot(features)
133 | return sparse_to_tuple(features)
134 |
135 |
136 | def normalize_adj(adj):
137 | """Symmetrically normalize adjacency matrix."""
138 | adj = sp.coo_matrix(adj)
139 | rowsum = np.array(adj.sum(1))
140 | d_inv_sqrt = np.power(rowsum, -0.5).flatten()
141 | d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0.
142 | d_mat_inv_sqrt = sp.diags(d_inv_sqrt,0)
143 | return adj.dot(d_mat_inv_sqrt).transpose().dot(d_mat_inv_sqrt).tocoo()
144 |
145 |
146 | def preprocess_adj(adj):
147 | """Preprocessing of adjacency matrix for simple GCN model and conversion to tuple representation."""
148 | adj_normalized = normalize_adj(adj + sp.eye(adj.shape[0]))
149 | # return adj_normalized
150 | return sparse_to_tuple(adj_normalized)
151 |
152 |
153 | def construct_feed_dict(features, support, labels, labels_mask, placeholders):
154 | """Construct feed dictionary."""
155 | feed_dict = dict()
156 | feed_dict.update({placeholders['labels']: labels})
157 | feed_dict.update({placeholders['labels_mask']: labels_mask})
158 | feed_dict.update({placeholders['features']: features})
159 | feed_dict.update({placeholders['support'][i]: support[i] for i in range(len(support))})
160 | feed_dict.update({placeholders['num_features_nonzero']: features[1].shape})
161 | return feed_dict
162 |
163 | def wavelet_basis(dataset,adj,s,laplacian_normalize,sparse_ness,threshold,weight_normalize):
164 |
165 | L = laplacian(adj,normalized=laplacian_normalize)
166 | lamb, U = fourier(dataset,L)
167 | #import pdb; pdb.set_trace()
168 | Weight = weight_wavelet(s,lamb,U)
169 | inverse_Weight = weight_wavelet_inverse(s,lamb,U)
170 | del U,lamb
171 |
172 | if (sparse_ness):
173 | Weight[Weight < threshold] = 0.0
174 | inverse_Weight[inverse_Weight < threshold] = 0.0
175 |
176 | if (weight_normalize == True):
177 | Weight = normalize(Weight, norm='l1', axis=1)
178 | inverse_Weight = normalize(inverse_Weight, norm='l1', axis=1)
179 |
180 | Weight = sp.csr_matrix(Weight)
181 | inverse_Weight = sp.csr_matrix(inverse_Weight)
182 | t_k = [inverse_Weight,Weight]
183 | return(t_k)
184 |
185 | def wavelet_basis_appro(dataset,adj,s,laplacian_normalize,sparse_ness,threshold,weight_normalize):
186 |
187 | L = laplacian(adj,normalized=laplacian_normalize)
188 | L = L - sp.eye(adj.shape[0])
189 | L = L.todense()
190 | # quick version for s = 2
191 | #Weight = 16.844 * sp.eye(adj.shape[0]) + 23.507 * L
192 | #inverse_Weight = 0.309 * sp.eye(adj.shape[0]) - 0.431 * L
193 | Weight = np.exp(s) * ss.iv(0,s) * np.eye(adj.shape[0]) + 2 * np.exp(s) * ss.iv(1,s) * L
194 | inverse_Weight = np.exp(-s) * ss.iv(0,-s) * np.eye(adj.shape[0]) + 2 * np.exp(-s) * ss.iv(1,-s) * L
195 |
196 | if (sparse_ness):
197 | Weight[Weight < threshold] = 0.0
198 | inverse_Weight[inverse_Weight < threshold] = 0.0
199 |
200 | if (weight_normalize == True):
201 | Weight = normalize(Weight, norm='l1', axis=1)
202 | inverse_Weight = normalize(inverse_Weight, norm='l1', axis=1)
203 |
204 | Weight = sp.csr_matrix(Weight)
205 | inverse_Weight = sp.csr_matrix(inverse_Weight)
206 | t_k = [inverse_Weight,Weight]
207 | return(t_k)
208 |
209 |
--------------------------------------------------------------------------------
/weighting_func.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pickle as pkl
3 | import networkx as nx
4 | import scipy.sparse
5 | import sys
6 | import math
7 | import warnings
8 | warnings.filterwarnings("ignore")
9 |
10 | def adj_matrix():
11 | names = [ 'graph']
12 | objects = []
13 | for i in range(len(names)):
14 | with open("data/ind.{}.{}".format("cora", names[i]), 'rb') as f:
15 | if sys.version_info > (3, 0):
16 | objects = pkl.load(f, encoding='latin1')
17 | else:
18 | objects = pkl.load(f)
19 | graph = objects
20 | adj = nx.adjacency_matrix(nx.from_dict_of_lists(graph))
21 | return adj
22 |
23 | def laplacian(W, normalized=False):
24 | """Return the Laplacian of the weight matrix."""
25 | # Degree matrix.
26 | d = W.sum(axis=0)
27 | # Laplacian matrix.
28 | if not normalized:
29 | D = scipy.sparse.diags(d.A.squeeze(), 0)
30 | L = D - W
31 | else:
32 | # d += np.spacing(np.array(0, W.dtype))
33 | d = 1 / np.sqrt(d)
34 | D = scipy.sparse.diags(d.A.squeeze(), 0)
35 | I = scipy.sparse.identity(d.size, dtype=W.dtype)
36 | L = I - D * W * D
37 |
38 | # assert np.abs(L - L.T).mean() < 1e-9
39 | assert type(L) is scipy.sparse.csr.csr_matrix
40 | return L
41 |
42 | def fourier(dataset,L, algo='eigh', k=100):
43 | """Return the Fourier basis, i.e. the EVD of the Laplacian."""
44 | # print "eigen decomposition:"
45 | def sort(lamb, U):
46 | idx = lamb.argsort()
47 | return lamb[idx], U[:, idx]
48 | # if(dataset == "pubmed"):
49 | # # print "loading pubmed U"
50 | # rfile = open("data/pubmed_U.pkl")
51 | # lamb, U = pkl.load(rfile)
52 | # rfile.close()
53 | # else:
54 | if algo is 'eig':
55 | lamb, U = np.linalg.eig(L.toarray())
56 | lamb, U = sort(lamb, U)
57 | elif algo is 'eigh':
58 | lamb, U = np.linalg.eigh(L.toarray())
59 | lamb, U = sort(lamb, U)
60 | elif algo is 'eigs':
61 | lamb, U = scipy.sparse.linalg.eigs(L, k=k, which='SM')
62 | lamb, U = sort(lamb, U)
63 | elif algo is 'eigsh':
64 | lamb, U = scipy.sparse.linalg.eigsh(L, k=k, which='SM')
65 | # print "end"
66 | # wfile = open("data/pubmed_U.pkl","w")
67 | # pkl.dump([lamb,U],wfile)
68 | # wfile.close()
69 | # print "pkl U end"
70 | return lamb, U
71 |
72 | def weight_wavelet(s,lamb,U):
73 | s = s
74 | for i in range(len(lamb)):
75 | lamb[i] = math.pow(math.e,-lamb[i]*s)
76 |
77 | Weight = np.dot(np.dot(U,np.diag(lamb)),np.transpose(U))
78 |
79 | return Weight
80 |
81 | def weight_wavelet_inverse(s,lamb,U):
82 | s = s
83 | for i in range(len(lamb)):
84 | lamb[i] = math.pow(math.e, lamb[i] * s)
85 |
86 | Weight = np.dot(np.dot(U, np.diag(lamb)), np.transpose(U))
87 |
88 | return Weight
89 |
90 |
91 |
92 |
93 |
94 |
--------------------------------------------------------------------------------