├── .gitignore ├── LICENSE ├── README.md ├── gae ├── __init__.py ├── data │ ├── ind.citeseer.allx │ ├── ind.citeseer.graph │ ├── ind.citeseer.test.index │ ├── ind.citeseer.tx │ ├── ind.citeseer.x │ ├── ind.cora.allx │ ├── ind.cora.graph │ ├── ind.cora.test.index │ ├── ind.cora.tx │ └── ind.cora.x ├── layers.py ├── model.py ├── optimizer.py ├── train.py └── utils.py └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | .hypothesis/ 50 | .pytest_cache/ 51 | 52 | # Translations 53 | *.mo 54 | *.pot 55 | 56 | # Django stuff: 57 | *.log 58 | local_settings.py 59 | db.sqlite3 60 | 61 | # Flask stuff: 62 | instance/ 63 | .webassets-cache 64 | 65 | # Scrapy stuff: 66 | .scrapy 67 | 68 | # Sphinx documentation 69 | docs/_build/ 70 | 71 | # PyBuilder 72 | target/ 73 | 74 | # Jupyter Notebook 75 | .ipynb_checkpoints 76 | 77 | # pyenv 78 | .python-version 79 | 80 | # celery beat schedule file 81 | celerybeat-schedule 82 | 83 | # SageMath parsed files 84 | *.sage.py 85 | 86 | # Environments 87 | .env 88 | .venv 89 | env/ 90 | venv/ 91 | ENV/ 92 | env.bak/ 93 | venv.bak/ 94 | 95 | # Spyder project settings 96 | .spyderproject 97 | .spyproject 98 | 99 | # Rope project settings 100 | .ropeproject 101 | 102 | # mkdocs documentation 103 | /site 104 | 105 | # mypy 106 | .mypy_cache/ 107 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 zfjsail 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # gae-pytorch 2 | Graph Auto-Encoder in PyTorch 3 | 4 | This is a PyTorch implementation of the Variational Graph Auto-Encoder model described in the paper: 5 | 6 | T. N. Kipf, M. Welling, [Variational Graph Auto-Encoders](https://arxiv.org/abs/1611.07308), NIPS Workshop on Bayesian Deep Learning (2016) 7 | 8 | The code in this repo is based on or refers to https://github.com/tkipf/gae, https://github.com/tkipf/pygcn and https://github.com/vmasrani/gae_in_pytorch. 9 | 10 | ### Requirements 11 | - Python 3 12 | - PyTorch 0.4 13 | - install requirements via ``` 14 | pip install -r requirements.txt``` 15 | 16 | ### How to run 17 | ```bash 18 | python gae/train.py 19 | ``` 20 | -------------------------------------------------------------------------------- /gae/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /gae/data/ind.citeseer.allx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zfjsail/gae-pytorch/c0b95cac8eb2928d0b5d6d65fee938fe97f60262/gae/data/ind.citeseer.allx -------------------------------------------------------------------------------- /gae/data/ind.citeseer.graph: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zfjsail/gae-pytorch/c0b95cac8eb2928d0b5d6d65fee938fe97f60262/gae/data/ind.citeseer.graph -------------------------------------------------------------------------------- /gae/data/ind.citeseer.test.index: -------------------------------------------------------------------------------- 1 | 2488 2 | 2644 3 | 3261 4 | 2804 5 | 3176 6 | 2432 7 | 3310 8 | 2410 9 | 2812 10 | 2520 11 | 2994 12 | 3282 13 | 2680 14 | 2848 15 | 2670 16 | 3005 17 | 2977 18 | 2592 19 | 2967 20 | 2461 21 | 3184 22 | 2852 23 | 2768 24 | 2905 25 | 2851 26 | 3129 27 | 3164 28 | 2438 29 | 2793 30 | 2763 31 | 2528 32 | 2954 33 | 2347 34 | 2640 35 | 3265 36 | 2874 37 | 2446 38 | 2856 39 | 3149 40 | 2374 41 | 3097 42 | 3301 43 | 2664 44 | 2418 45 | 2655 46 | 2464 47 | 2596 48 | 3262 49 | 3278 50 | 2320 51 | 2612 52 | 2614 53 | 2550 54 | 2626 55 | 2772 56 | 3007 57 | 2733 58 | 2516 59 | 2476 60 | 2798 61 | 2561 62 | 2839 63 | 2685 64 | 2391 65 | 2705 66 | 3098 67 | 2754 68 | 3251 69 | 2767 70 | 2630 71 | 2727 72 | 2513 73 | 2701 74 | 3264 75 | 2792 76 | 2821 77 | 3260 78 | 2462 79 | 3307 80 | 2639 81 | 2900 82 | 3060 83 | 2672 84 | 3116 85 | 2731 86 | 3316 87 | 2386 88 | 2425 89 | 2518 90 | 3151 91 | 2586 92 | 2797 93 | 2479 94 | 3117 95 | 2580 96 | 3182 97 | 2459 98 | 2508 99 | 3052 100 | 3230 101 | 3215 102 | 2803 103 | 2969 104 | 2562 105 | 2398 106 | 3325 107 | 2343 108 | 3030 109 | 2414 110 | 2776 111 | 2383 112 | 3173 113 | 2850 114 | 2499 115 | 3312 116 | 2648 117 | 2784 118 | 2898 119 | 3056 120 | 2484 121 | 3179 122 | 3132 123 | 2577 124 | 2563 125 | 2867 126 | 3317 127 | 2355 128 | 3207 129 | 3178 130 | 2968 131 | 3319 132 | 2358 133 | 2764 134 | 3001 135 | 2683 136 | 3271 137 | 2321 138 | 2567 139 | 2502 140 | 3246 141 | 2715 142 | 3066 143 | 2390 144 | 2381 145 | 3162 146 | 2741 147 | 2498 148 | 2790 149 | 3038 150 | 3321 151 | 2481 152 | 3050 153 | 3161 154 | 3122 155 | 2801 156 | 2957 157 | 3177 158 | 2965 159 | 2621 160 | 3208 161 | 2921 162 | 2802 163 | 2357 164 | 2677 165 | 2519 166 | 2860 167 | 2696 168 | 2368 169 | 3241 170 | 2858 171 | 2419 172 | 2762 173 | 2875 174 | 3222 175 | 3064 176 | 2827 177 | 3044 178 | 2471 179 | 3062 180 | 2982 181 | 2736 182 | 2322 183 | 2709 184 | 2766 185 | 2424 186 | 2602 187 | 2970 188 | 2675 189 | 3299 190 | 2554 191 | 2964 192 | 2597 193 | 2753 194 | 2979 195 | 2523 196 | 2912 197 | 2896 198 | 2317 199 | 3167 200 | 2813 201 | 2482 202 | 2557 203 | 3043 204 | 3244 205 | 2985 206 | 2460 207 | 2363 208 | 3272 209 | 3045 210 | 3192 211 | 2453 212 | 2656 213 | 2834 214 | 2443 215 | 3202 216 | 2926 217 | 2711 218 | 2633 219 | 2384 220 | 2752 221 | 3285 222 | 2817 223 | 2483 224 | 2919 225 | 2924 226 | 2661 227 | 2698 228 | 2361 229 | 2662 230 | 2819 231 | 3143 232 | 2316 233 | 3196 234 | 2739 235 | 2345 236 | 2578 237 | 2822 238 | 3229 239 | 2908 240 | 2917 241 | 2692 242 | 3200 243 | 2324 244 | 2522 245 | 3322 246 | 2697 247 | 3163 248 | 3093 249 | 3233 250 | 2774 251 | 2371 252 | 2835 253 | 2652 254 | 2539 255 | 2843 256 | 3231 257 | 2976 258 | 2429 259 | 2367 260 | 3144 261 | 2564 262 | 3283 263 | 3217 264 | 3035 265 | 2962 266 | 2433 267 | 2415 268 | 2387 269 | 3021 270 | 2595 271 | 2517 272 | 2468 273 | 3061 274 | 2673 275 | 2348 276 | 3027 277 | 2467 278 | 3318 279 | 2959 280 | 3273 281 | 2392 282 | 2779 283 | 2678 284 | 3004 285 | 2634 286 | 2974 287 | 3198 288 | 2342 289 | 2376 290 | 3249 291 | 2868 292 | 2952 293 | 2710 294 | 2838 295 | 2335 296 | 2524 297 | 2650 298 | 3186 299 | 2743 300 | 2545 301 | 2841 302 | 2515 303 | 2505 304 | 3181 305 | 2945 306 | 2738 307 | 2933 308 | 3303 309 | 2611 310 | 3090 311 | 2328 312 | 3010 313 | 3016 314 | 2504 315 | 2936 316 | 3266 317 | 3253 318 | 2840 319 | 3034 320 | 2581 321 | 2344 322 | 2452 323 | 2654 324 | 3199 325 | 3137 326 | 2514 327 | 2394 328 | 2544 329 | 2641 330 | 2613 331 | 2618 332 | 2558 333 | 2593 334 | 2532 335 | 2512 336 | 2975 337 | 3267 338 | 2566 339 | 2951 340 | 3300 341 | 2869 342 | 2629 343 | 2747 344 | 3055 345 | 2831 346 | 3105 347 | 3168 348 | 3100 349 | 2431 350 | 2828 351 | 2684 352 | 3269 353 | 2910 354 | 2865 355 | 2693 356 | 2884 357 | 3228 358 | 2783 359 | 3247 360 | 2770 361 | 3157 362 | 2421 363 | 2382 364 | 2331 365 | 3203 366 | 3240 367 | 2351 368 | 3114 369 | 2986 370 | 2688 371 | 2439 372 | 2996 373 | 3079 374 | 3103 375 | 3296 376 | 2349 377 | 2372 378 | 3096 379 | 2422 380 | 2551 381 | 3069 382 | 2737 383 | 3084 384 | 3304 385 | 3022 386 | 2542 387 | 3204 388 | 2949 389 | 2318 390 | 2450 391 | 3140 392 | 2734 393 | 2881 394 | 2576 395 | 3054 396 | 3089 397 | 3125 398 | 2761 399 | 3136 400 | 3111 401 | 2427 402 | 2466 403 | 3101 404 | 3104 405 | 3259 406 | 2534 407 | 2961 408 | 3191 409 | 3000 410 | 3036 411 | 2356 412 | 2800 413 | 3155 414 | 3224 415 | 2646 416 | 2735 417 | 3020 418 | 2866 419 | 2426 420 | 2448 421 | 3226 422 | 3219 423 | 2749 424 | 3183 425 | 2906 426 | 2360 427 | 2440 428 | 2946 429 | 2313 430 | 2859 431 | 2340 432 | 3008 433 | 2719 434 | 3058 435 | 2653 436 | 3023 437 | 2888 438 | 3243 439 | 2913 440 | 3242 441 | 3067 442 | 2409 443 | 3227 444 | 2380 445 | 2353 446 | 2686 447 | 2971 448 | 2847 449 | 2947 450 | 2857 451 | 3263 452 | 3218 453 | 2861 454 | 3323 455 | 2635 456 | 2966 457 | 2604 458 | 2456 459 | 2832 460 | 2694 461 | 3245 462 | 3119 463 | 2942 464 | 3153 465 | 2894 466 | 2555 467 | 3128 468 | 2703 469 | 2323 470 | 2631 471 | 2732 472 | 2699 473 | 2314 474 | 2590 475 | 3127 476 | 2891 477 | 2873 478 | 2814 479 | 2326 480 | 3026 481 | 3288 482 | 3095 483 | 2706 484 | 2457 485 | 2377 486 | 2620 487 | 2526 488 | 2674 489 | 3190 490 | 2923 491 | 3032 492 | 2334 493 | 3254 494 | 2991 495 | 3277 496 | 2973 497 | 2599 498 | 2658 499 | 2636 500 | 2826 501 | 3148 502 | 2958 503 | 3258 504 | 2990 505 | 3180 506 | 2538 507 | 2748 508 | 2625 509 | 2565 510 | 3011 511 | 3057 512 | 2354 513 | 3158 514 | 2622 515 | 3308 516 | 2983 517 | 2560 518 | 3169 519 | 3059 520 | 2480 521 | 3194 522 | 3291 523 | 3216 524 | 2643 525 | 3172 526 | 2352 527 | 2724 528 | 2485 529 | 2411 530 | 2948 531 | 2445 532 | 2362 533 | 2668 534 | 3275 535 | 3107 536 | 2496 537 | 2529 538 | 2700 539 | 2541 540 | 3028 541 | 2879 542 | 2660 543 | 3324 544 | 2755 545 | 2436 546 | 3048 547 | 2623 548 | 2920 549 | 3040 550 | 2568 551 | 3221 552 | 3003 553 | 3295 554 | 2473 555 | 3232 556 | 3213 557 | 2823 558 | 2897 559 | 2573 560 | 2645 561 | 3018 562 | 3326 563 | 2795 564 | 2915 565 | 3109 566 | 3086 567 | 2463 568 | 3118 569 | 2671 570 | 2909 571 | 2393 572 | 2325 573 | 3029 574 | 2972 575 | 3110 576 | 2870 577 | 3284 578 | 2816 579 | 2647 580 | 2667 581 | 2955 582 | 2333 583 | 2960 584 | 2864 585 | 2893 586 | 2458 587 | 2441 588 | 2359 589 | 2327 590 | 3256 591 | 3099 592 | 3073 593 | 3138 594 | 2511 595 | 2666 596 | 2548 597 | 2364 598 | 2451 599 | 2911 600 | 3237 601 | 3206 602 | 3080 603 | 3279 604 | 2934 605 | 2981 606 | 2878 607 | 3130 608 | 2830 609 | 3091 610 | 2659 611 | 2449 612 | 3152 613 | 2413 614 | 2722 615 | 2796 616 | 3220 617 | 2751 618 | 2935 619 | 3238 620 | 2491 621 | 2730 622 | 2842 623 | 3223 624 | 2492 625 | 3074 626 | 3094 627 | 2833 628 | 2521 629 | 2883 630 | 3315 631 | 2845 632 | 2907 633 | 3083 634 | 2572 635 | 3092 636 | 2903 637 | 2918 638 | 3039 639 | 3286 640 | 2587 641 | 3068 642 | 2338 643 | 3166 644 | 3134 645 | 2455 646 | 2497 647 | 2992 648 | 2775 649 | 2681 650 | 2430 651 | 2932 652 | 2931 653 | 2434 654 | 3154 655 | 3046 656 | 2598 657 | 2366 658 | 3015 659 | 3147 660 | 2944 661 | 2582 662 | 3274 663 | 2987 664 | 2642 665 | 2547 666 | 2420 667 | 2930 668 | 2750 669 | 2417 670 | 2808 671 | 3141 672 | 2997 673 | 2995 674 | 2584 675 | 2312 676 | 3033 677 | 3070 678 | 3065 679 | 2509 680 | 3314 681 | 2396 682 | 2543 683 | 2423 684 | 3170 685 | 2389 686 | 3289 687 | 2728 688 | 2540 689 | 2437 690 | 2486 691 | 2895 692 | 3017 693 | 2853 694 | 2406 695 | 2346 696 | 2877 697 | 2472 698 | 3210 699 | 2637 700 | 2927 701 | 2789 702 | 2330 703 | 3088 704 | 3102 705 | 2616 706 | 3081 707 | 2902 708 | 3205 709 | 3320 710 | 3165 711 | 2984 712 | 3185 713 | 2707 714 | 3255 715 | 2583 716 | 2773 717 | 2742 718 | 3024 719 | 2402 720 | 2718 721 | 2882 722 | 2575 723 | 3281 724 | 2786 725 | 2855 726 | 3014 727 | 2401 728 | 2535 729 | 2687 730 | 2495 731 | 3113 732 | 2609 733 | 2559 734 | 2665 735 | 2530 736 | 3293 737 | 2399 738 | 2605 739 | 2690 740 | 3133 741 | 2799 742 | 2533 743 | 2695 744 | 2713 745 | 2886 746 | 2691 747 | 2549 748 | 3077 749 | 3002 750 | 3049 751 | 3051 752 | 3087 753 | 2444 754 | 3085 755 | 3135 756 | 2702 757 | 3211 758 | 3108 759 | 2501 760 | 2769 761 | 3290 762 | 2465 763 | 3025 764 | 3019 765 | 2385 766 | 2940 767 | 2657 768 | 2610 769 | 2525 770 | 2941 771 | 3078 772 | 2341 773 | 2916 774 | 2956 775 | 2375 776 | 2880 777 | 3009 778 | 2780 779 | 2370 780 | 2925 781 | 2332 782 | 3146 783 | 2315 784 | 2809 785 | 3145 786 | 3106 787 | 2782 788 | 2760 789 | 2493 790 | 2765 791 | 2556 792 | 2890 793 | 2400 794 | 2339 795 | 3201 796 | 2818 797 | 3248 798 | 3280 799 | 2570 800 | 2569 801 | 2937 802 | 3174 803 | 2836 804 | 2708 805 | 2820 806 | 3195 807 | 2617 808 | 3197 809 | 2319 810 | 2744 811 | 2615 812 | 2825 813 | 2603 814 | 2914 815 | 2531 816 | 3193 817 | 2624 818 | 2365 819 | 2810 820 | 3239 821 | 3159 822 | 2537 823 | 2844 824 | 2758 825 | 2938 826 | 3037 827 | 2503 828 | 3297 829 | 2885 830 | 2608 831 | 2494 832 | 2712 833 | 2408 834 | 2901 835 | 2704 836 | 2536 837 | 2373 838 | 2478 839 | 2723 840 | 3076 841 | 2627 842 | 2369 843 | 2669 844 | 3006 845 | 2628 846 | 2788 847 | 3276 848 | 2435 849 | 3139 850 | 3235 851 | 2527 852 | 2571 853 | 2815 854 | 2442 855 | 2892 856 | 2978 857 | 2746 858 | 3150 859 | 2574 860 | 2725 861 | 3188 862 | 2601 863 | 2378 864 | 3075 865 | 2632 866 | 2794 867 | 3270 868 | 3071 869 | 2506 870 | 3126 871 | 3236 872 | 3257 873 | 2824 874 | 2989 875 | 2950 876 | 2428 877 | 2405 878 | 3156 879 | 2447 880 | 2787 881 | 2805 882 | 2720 883 | 2403 884 | 2811 885 | 2329 886 | 2474 887 | 2785 888 | 2350 889 | 2507 890 | 2416 891 | 3112 892 | 2475 893 | 2876 894 | 2585 895 | 2487 896 | 3072 897 | 3082 898 | 2943 899 | 2757 900 | 2388 901 | 2600 902 | 3294 903 | 2756 904 | 3142 905 | 3041 906 | 2594 907 | 2998 908 | 3047 909 | 2379 910 | 2980 911 | 2454 912 | 2862 913 | 3175 914 | 2588 915 | 3031 916 | 3012 917 | 2889 918 | 2500 919 | 2791 920 | 2854 921 | 2619 922 | 2395 923 | 2807 924 | 2740 925 | 2412 926 | 3131 927 | 3013 928 | 2939 929 | 2651 930 | 2490 931 | 2988 932 | 2863 933 | 3225 934 | 2745 935 | 2714 936 | 3160 937 | 3124 938 | 2849 939 | 2676 940 | 2872 941 | 3287 942 | 3189 943 | 2716 944 | 3115 945 | 2928 946 | 2871 947 | 2591 948 | 2717 949 | 2546 950 | 2777 951 | 3298 952 | 2397 953 | 3187 954 | 2726 955 | 2336 956 | 3268 957 | 2477 958 | 2904 959 | 2846 960 | 3121 961 | 2899 962 | 2510 963 | 2806 964 | 2963 965 | 3313 966 | 2679 967 | 3302 968 | 2663 969 | 3053 970 | 2469 971 | 2999 972 | 3311 973 | 2470 974 | 2638 975 | 3120 976 | 3171 977 | 2689 978 | 2922 979 | 2607 980 | 2721 981 | 2993 982 | 2887 983 | 2837 984 | 2929 985 | 2829 986 | 3234 987 | 2649 988 | 2337 989 | 2759 990 | 2778 991 | 2771 992 | 2404 993 | 2589 994 | 3123 995 | 3209 996 | 2729 997 | 3252 998 | 2606 999 | 2579 1000 | 2552 1001 | -------------------------------------------------------------------------------- /gae/data/ind.citeseer.tx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zfjsail/gae-pytorch/c0b95cac8eb2928d0b5d6d65fee938fe97f60262/gae/data/ind.citeseer.tx -------------------------------------------------------------------------------- /gae/data/ind.citeseer.x: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zfjsail/gae-pytorch/c0b95cac8eb2928d0b5d6d65fee938fe97f60262/gae/data/ind.citeseer.x -------------------------------------------------------------------------------- /gae/data/ind.cora.allx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zfjsail/gae-pytorch/c0b95cac8eb2928d0b5d6d65fee938fe97f60262/gae/data/ind.cora.allx -------------------------------------------------------------------------------- /gae/data/ind.cora.graph: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zfjsail/gae-pytorch/c0b95cac8eb2928d0b5d6d65fee938fe97f60262/gae/data/ind.cora.graph -------------------------------------------------------------------------------- /gae/data/ind.cora.test.index: -------------------------------------------------------------------------------- 1 | 2692 2 | 2532 3 | 2050 4 | 1715 5 | 2362 6 | 2609 7 | 2622 8 | 1975 9 | 2081 10 | 1767 11 | 2263 12 | 1725 13 | 2588 14 | 2259 15 | 2357 16 | 1998 17 | 2574 18 | 2179 19 | 2291 20 | 2382 21 | 1812 22 | 1751 23 | 2422 24 | 1937 25 | 2631 26 | 2510 27 | 2378 28 | 2589 29 | 2345 30 | 1943 31 | 1850 32 | 2298 33 | 1825 34 | 2035 35 | 2507 36 | 2313 37 | 1906 38 | 1797 39 | 2023 40 | 2159 41 | 2495 42 | 1886 43 | 2122 44 | 2369 45 | 2461 46 | 1925 47 | 2565 48 | 1858 49 | 2234 50 | 2000 51 | 1846 52 | 2318 53 | 1723 54 | 2559 55 | 2258 56 | 1763 57 | 1991 58 | 1922 59 | 2003 60 | 2662 61 | 2250 62 | 2064 63 | 2529 64 | 1888 65 | 2499 66 | 2454 67 | 2320 68 | 2287 69 | 2203 70 | 2018 71 | 2002 72 | 2632 73 | 2554 74 | 2314 75 | 2537 76 | 1760 77 | 2088 78 | 2086 79 | 2218 80 | 2605 81 | 1953 82 | 2403 83 | 1920 84 | 2015 85 | 2335 86 | 2535 87 | 1837 88 | 2009 89 | 1905 90 | 2636 91 | 1942 92 | 2193 93 | 2576 94 | 2373 95 | 1873 96 | 2463 97 | 2509 98 | 1954 99 | 2656 100 | 2455 101 | 2494 102 | 2295 103 | 2114 104 | 2561 105 | 2176 106 | 2275 107 | 2635 108 | 2442 109 | 2704 110 | 2127 111 | 2085 112 | 2214 113 | 2487 114 | 1739 115 | 2543 116 | 1783 117 | 2485 118 | 2262 119 | 2472 120 | 2326 121 | 1738 122 | 2170 123 | 2100 124 | 2384 125 | 2152 126 | 2647 127 | 2693 128 | 2376 129 | 1775 130 | 1726 131 | 2476 132 | 2195 133 | 1773 134 | 1793 135 | 2194 136 | 2581 137 | 1854 138 | 2524 139 | 1945 140 | 1781 141 | 1987 142 | 2599 143 | 1744 144 | 2225 145 | 2300 146 | 1928 147 | 2042 148 | 2202 149 | 1958 150 | 1816 151 | 1916 152 | 2679 153 | 2190 154 | 1733 155 | 2034 156 | 2643 157 | 2177 158 | 1883 159 | 1917 160 | 1996 161 | 2491 162 | 2268 163 | 2231 164 | 2471 165 | 1919 166 | 1909 167 | 2012 168 | 2522 169 | 1865 170 | 2466 171 | 2469 172 | 2087 173 | 2584 174 | 2563 175 | 1924 176 | 2143 177 | 1736 178 | 1966 179 | 2533 180 | 2490 181 | 2630 182 | 1973 183 | 2568 184 | 1978 185 | 2664 186 | 2633 187 | 2312 188 | 2178 189 | 1754 190 | 2307 191 | 2480 192 | 1960 193 | 1742 194 | 1962 195 | 2160 196 | 2070 197 | 2553 198 | 2433 199 | 1768 200 | 2659 201 | 2379 202 | 2271 203 | 1776 204 | 2153 205 | 1877 206 | 2027 207 | 2028 208 | 2155 209 | 2196 210 | 2483 211 | 2026 212 | 2158 213 | 2407 214 | 1821 215 | 2131 216 | 2676 217 | 2277 218 | 2489 219 | 2424 220 | 1963 221 | 1808 222 | 1859 223 | 2597 224 | 2548 225 | 2368 226 | 1817 227 | 2405 228 | 2413 229 | 2603 230 | 2350 231 | 2118 232 | 2329 233 | 1969 234 | 2577 235 | 2475 236 | 2467 237 | 2425 238 | 1769 239 | 2092 240 | 2044 241 | 2586 242 | 2608 243 | 1983 244 | 2109 245 | 2649 246 | 1964 247 | 2144 248 | 1902 249 | 2411 250 | 2508 251 | 2360 252 | 1721 253 | 2005 254 | 2014 255 | 2308 256 | 2646 257 | 1949 258 | 1830 259 | 2212 260 | 2596 261 | 1832 262 | 1735 263 | 1866 264 | 2695 265 | 1941 266 | 2546 267 | 2498 268 | 2686 269 | 2665 270 | 1784 271 | 2613 272 | 1970 273 | 2021 274 | 2211 275 | 2516 276 | 2185 277 | 2479 278 | 2699 279 | 2150 280 | 1990 281 | 2063 282 | 2075 283 | 1979 284 | 2094 285 | 1787 286 | 2571 287 | 2690 288 | 1926 289 | 2341 290 | 2566 291 | 1957 292 | 1709 293 | 1955 294 | 2570 295 | 2387 296 | 1811 297 | 2025 298 | 2447 299 | 2696 300 | 2052 301 | 2366 302 | 1857 303 | 2273 304 | 2245 305 | 2672 306 | 2133 307 | 2421 308 | 1929 309 | 2125 310 | 2319 311 | 2641 312 | 2167 313 | 2418 314 | 1765 315 | 1761 316 | 1828 317 | 2188 318 | 1972 319 | 1997 320 | 2419 321 | 2289 322 | 2296 323 | 2587 324 | 2051 325 | 2440 326 | 2053 327 | 2191 328 | 1923 329 | 2164 330 | 1861 331 | 2339 332 | 2333 333 | 2523 334 | 2670 335 | 2121 336 | 1921 337 | 1724 338 | 2253 339 | 2374 340 | 1940 341 | 2545 342 | 2301 343 | 2244 344 | 2156 345 | 1849 346 | 2551 347 | 2011 348 | 2279 349 | 2572 350 | 1757 351 | 2400 352 | 2569 353 | 2072 354 | 2526 355 | 2173 356 | 2069 357 | 2036 358 | 1819 359 | 1734 360 | 1880 361 | 2137 362 | 2408 363 | 2226 364 | 2604 365 | 1771 366 | 2698 367 | 2187 368 | 2060 369 | 1756 370 | 2201 371 | 2066 372 | 2439 373 | 1844 374 | 1772 375 | 2383 376 | 2398 377 | 1708 378 | 1992 379 | 1959 380 | 1794 381 | 2426 382 | 2702 383 | 2444 384 | 1944 385 | 1829 386 | 2660 387 | 2497 388 | 2607 389 | 2343 390 | 1730 391 | 2624 392 | 1790 393 | 1935 394 | 1967 395 | 2401 396 | 2255 397 | 2355 398 | 2348 399 | 1931 400 | 2183 401 | 2161 402 | 2701 403 | 1948 404 | 2501 405 | 2192 406 | 2404 407 | 2209 408 | 2331 409 | 1810 410 | 2363 411 | 2334 412 | 1887 413 | 2393 414 | 2557 415 | 1719 416 | 1732 417 | 1986 418 | 2037 419 | 2056 420 | 1867 421 | 2126 422 | 1932 423 | 2117 424 | 1807 425 | 1801 426 | 1743 427 | 2041 428 | 1843 429 | 2388 430 | 2221 431 | 1833 432 | 2677 433 | 1778 434 | 2661 435 | 2306 436 | 2394 437 | 2106 438 | 2430 439 | 2371 440 | 2606 441 | 2353 442 | 2269 443 | 2317 444 | 2645 445 | 2372 446 | 2550 447 | 2043 448 | 1968 449 | 2165 450 | 2310 451 | 1985 452 | 2446 453 | 1982 454 | 2377 455 | 2207 456 | 1818 457 | 1913 458 | 1766 459 | 1722 460 | 1894 461 | 2020 462 | 1881 463 | 2621 464 | 2409 465 | 2261 466 | 2458 467 | 2096 468 | 1712 469 | 2594 470 | 2293 471 | 2048 472 | 2359 473 | 1839 474 | 2392 475 | 2254 476 | 1911 477 | 2101 478 | 2367 479 | 1889 480 | 1753 481 | 2555 482 | 2246 483 | 2264 484 | 2010 485 | 2336 486 | 2651 487 | 2017 488 | 2140 489 | 1842 490 | 2019 491 | 1890 492 | 2525 493 | 2134 494 | 2492 495 | 2652 496 | 2040 497 | 2145 498 | 2575 499 | 2166 500 | 1999 501 | 2434 502 | 1711 503 | 2276 504 | 2450 505 | 2389 506 | 2669 507 | 2595 508 | 1814 509 | 2039 510 | 2502 511 | 1896 512 | 2168 513 | 2344 514 | 2637 515 | 2031 516 | 1977 517 | 2380 518 | 1936 519 | 2047 520 | 2460 521 | 2102 522 | 1745 523 | 2650 524 | 2046 525 | 2514 526 | 1980 527 | 2352 528 | 2113 529 | 1713 530 | 2058 531 | 2558 532 | 1718 533 | 1864 534 | 1876 535 | 2338 536 | 1879 537 | 1891 538 | 2186 539 | 2451 540 | 2181 541 | 2638 542 | 2644 543 | 2103 544 | 2591 545 | 2266 546 | 2468 547 | 1869 548 | 2582 549 | 2674 550 | 2361 551 | 2462 552 | 1748 553 | 2215 554 | 2615 555 | 2236 556 | 2248 557 | 2493 558 | 2342 559 | 2449 560 | 2274 561 | 1824 562 | 1852 563 | 1870 564 | 2441 565 | 2356 566 | 1835 567 | 2694 568 | 2602 569 | 2685 570 | 1893 571 | 2544 572 | 2536 573 | 1994 574 | 1853 575 | 1838 576 | 1786 577 | 1930 578 | 2539 579 | 1892 580 | 2265 581 | 2618 582 | 2486 583 | 2583 584 | 2061 585 | 1796 586 | 1806 587 | 2084 588 | 1933 589 | 2095 590 | 2136 591 | 2078 592 | 1884 593 | 2438 594 | 2286 595 | 2138 596 | 1750 597 | 2184 598 | 1799 599 | 2278 600 | 2410 601 | 2642 602 | 2435 603 | 1956 604 | 2399 605 | 1774 606 | 2129 607 | 1898 608 | 1823 609 | 1938 610 | 2299 611 | 1862 612 | 2420 613 | 2673 614 | 1984 615 | 2204 616 | 1717 617 | 2074 618 | 2213 619 | 2436 620 | 2297 621 | 2592 622 | 2667 623 | 2703 624 | 2511 625 | 1779 626 | 1782 627 | 2625 628 | 2365 629 | 2315 630 | 2381 631 | 1788 632 | 1714 633 | 2302 634 | 1927 635 | 2325 636 | 2506 637 | 2169 638 | 2328 639 | 2629 640 | 2128 641 | 2655 642 | 2282 643 | 2073 644 | 2395 645 | 2247 646 | 2521 647 | 2260 648 | 1868 649 | 1988 650 | 2324 651 | 2705 652 | 2541 653 | 1731 654 | 2681 655 | 2707 656 | 2465 657 | 1785 658 | 2149 659 | 2045 660 | 2505 661 | 2611 662 | 2217 663 | 2180 664 | 1904 665 | 2453 666 | 2484 667 | 1871 668 | 2309 669 | 2349 670 | 2482 671 | 2004 672 | 1965 673 | 2406 674 | 2162 675 | 1805 676 | 2654 677 | 2007 678 | 1947 679 | 1981 680 | 2112 681 | 2141 682 | 1720 683 | 1758 684 | 2080 685 | 2330 686 | 2030 687 | 2432 688 | 2089 689 | 2547 690 | 1820 691 | 1815 692 | 2675 693 | 1840 694 | 2658 695 | 2370 696 | 2251 697 | 1908 698 | 2029 699 | 2068 700 | 2513 701 | 2549 702 | 2267 703 | 2580 704 | 2327 705 | 2351 706 | 2111 707 | 2022 708 | 2321 709 | 2614 710 | 2252 711 | 2104 712 | 1822 713 | 2552 714 | 2243 715 | 1798 716 | 2396 717 | 2663 718 | 2564 719 | 2148 720 | 2562 721 | 2684 722 | 2001 723 | 2151 724 | 2706 725 | 2240 726 | 2474 727 | 2303 728 | 2634 729 | 2680 730 | 2055 731 | 2090 732 | 2503 733 | 2347 734 | 2402 735 | 2238 736 | 1950 737 | 2054 738 | 2016 739 | 1872 740 | 2233 741 | 1710 742 | 2032 743 | 2540 744 | 2628 745 | 1795 746 | 2616 747 | 1903 748 | 2531 749 | 2567 750 | 1946 751 | 1897 752 | 2222 753 | 2227 754 | 2627 755 | 1856 756 | 2464 757 | 2241 758 | 2481 759 | 2130 760 | 2311 761 | 2083 762 | 2223 763 | 2284 764 | 2235 765 | 2097 766 | 1752 767 | 2515 768 | 2527 769 | 2385 770 | 2189 771 | 2283 772 | 2182 773 | 2079 774 | 2375 775 | 2174 776 | 2437 777 | 1993 778 | 2517 779 | 2443 780 | 2224 781 | 2648 782 | 2171 783 | 2290 784 | 2542 785 | 2038 786 | 1855 787 | 1831 788 | 1759 789 | 1848 790 | 2445 791 | 1827 792 | 2429 793 | 2205 794 | 2598 795 | 2657 796 | 1728 797 | 2065 798 | 1918 799 | 2427 800 | 2573 801 | 2620 802 | 2292 803 | 1777 804 | 2008 805 | 1875 806 | 2288 807 | 2256 808 | 2033 809 | 2470 810 | 2585 811 | 2610 812 | 2082 813 | 2230 814 | 1915 815 | 1847 816 | 2337 817 | 2512 818 | 2386 819 | 2006 820 | 2653 821 | 2346 822 | 1951 823 | 2110 824 | 2639 825 | 2520 826 | 1939 827 | 2683 828 | 2139 829 | 2220 830 | 1910 831 | 2237 832 | 1900 833 | 1836 834 | 2197 835 | 1716 836 | 1860 837 | 2077 838 | 2519 839 | 2538 840 | 2323 841 | 1914 842 | 1971 843 | 1845 844 | 2132 845 | 1802 846 | 1907 847 | 2640 848 | 2496 849 | 2281 850 | 2198 851 | 2416 852 | 2285 853 | 1755 854 | 2431 855 | 2071 856 | 2249 857 | 2123 858 | 1727 859 | 2459 860 | 2304 861 | 2199 862 | 1791 863 | 1809 864 | 1780 865 | 2210 866 | 2417 867 | 1874 868 | 1878 869 | 2116 870 | 1961 871 | 1863 872 | 2579 873 | 2477 874 | 2228 875 | 2332 876 | 2578 877 | 2457 878 | 2024 879 | 1934 880 | 2316 881 | 1841 882 | 1764 883 | 1737 884 | 2322 885 | 2239 886 | 2294 887 | 1729 888 | 2488 889 | 1974 890 | 2473 891 | 2098 892 | 2612 893 | 1834 894 | 2340 895 | 2423 896 | 2175 897 | 2280 898 | 2617 899 | 2208 900 | 2560 901 | 1741 902 | 2600 903 | 2059 904 | 1747 905 | 2242 906 | 2700 907 | 2232 908 | 2057 909 | 2147 910 | 2682 911 | 1792 912 | 1826 913 | 2120 914 | 1895 915 | 2364 916 | 2163 917 | 1851 918 | 2391 919 | 2414 920 | 2452 921 | 1803 922 | 1989 923 | 2623 924 | 2200 925 | 2528 926 | 2415 927 | 1804 928 | 2146 929 | 2619 930 | 2687 931 | 1762 932 | 2172 933 | 2270 934 | 2678 935 | 2593 936 | 2448 937 | 1882 938 | 2257 939 | 2500 940 | 1899 941 | 2478 942 | 2412 943 | 2107 944 | 1746 945 | 2428 946 | 2115 947 | 1800 948 | 1901 949 | 2397 950 | 2530 951 | 1912 952 | 2108 953 | 2206 954 | 2091 955 | 1740 956 | 2219 957 | 1976 958 | 2099 959 | 2142 960 | 2671 961 | 2668 962 | 2216 963 | 2272 964 | 2229 965 | 2666 966 | 2456 967 | 2534 968 | 2697 969 | 2688 970 | 2062 971 | 2691 972 | 2689 973 | 2154 974 | 2590 975 | 2626 976 | 2390 977 | 1813 978 | 2067 979 | 1952 980 | 2518 981 | 2358 982 | 1789 983 | 2076 984 | 2049 985 | 2119 986 | 2013 987 | 2124 988 | 2556 989 | 2105 990 | 2093 991 | 1885 992 | 2305 993 | 2354 994 | 2135 995 | 2601 996 | 1770 997 | 1995 998 | 2504 999 | 1749 1000 | 2157 1001 | -------------------------------------------------------------------------------- /gae/data/ind.cora.tx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zfjsail/gae-pytorch/c0b95cac8eb2928d0b5d6d65fee938fe97f60262/gae/data/ind.cora.tx -------------------------------------------------------------------------------- /gae/data/ind.cora.x: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zfjsail/gae-pytorch/c0b95cac8eb2928d0b5d6d65fee938fe97f60262/gae/data/ind.cora.x -------------------------------------------------------------------------------- /gae/layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch.nn.modules.module import Module 4 | from torch.nn.parameter import Parameter 5 | 6 | 7 | class GraphConvolution(Module): 8 | """ 9 | Simple GCN layer, similar to https://arxiv.org/abs/1609.02907 10 | """ 11 | 12 | def __init__(self, in_features, out_features, dropout=0., act=F.relu): 13 | super(GraphConvolution, self).__init__() 14 | self.in_features = in_features 15 | self.out_features = out_features 16 | self.dropout = dropout 17 | self.act = act 18 | self.weight = Parameter(torch.FloatTensor(in_features, out_features)) 19 | self.reset_parameters() 20 | 21 | def reset_parameters(self): 22 | torch.nn.init.xavier_uniform_(self.weight) 23 | 24 | def forward(self, input, adj): 25 | input = F.dropout(input, self.dropout, self.training) 26 | support = torch.mm(input, self.weight) 27 | output = torch.spmm(adj, support) 28 | output = self.act(output) 29 | return output 30 | 31 | def __repr__(self): 32 | return self.__class__.__name__ + ' (' \ 33 | + str(self.in_features) + ' -> ' \ 34 | + str(self.out_features) + ')' 35 | -------------------------------------------------------------------------------- /gae/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from gae.layers import GraphConvolution 6 | 7 | 8 | class GCNModelVAE(nn.Module): 9 | def __init__(self, input_feat_dim, hidden_dim1, hidden_dim2, dropout): 10 | super(GCNModelVAE, self).__init__() 11 | self.gc1 = GraphConvolution(input_feat_dim, hidden_dim1, dropout, act=F.relu) 12 | self.gc2 = GraphConvolution(hidden_dim1, hidden_dim2, dropout, act=lambda x: x) 13 | self.gc3 = GraphConvolution(hidden_dim1, hidden_dim2, dropout, act=lambda x: x) 14 | self.dc = InnerProductDecoder(dropout, act=lambda x: x) 15 | 16 | def encode(self, x, adj): 17 | hidden1 = self.gc1(x, adj) 18 | return self.gc2(hidden1, adj), self.gc3(hidden1, adj) 19 | 20 | def reparameterize(self, mu, logvar): 21 | if self.training: 22 | std = torch.exp(logvar) 23 | eps = torch.randn_like(std) 24 | return eps.mul(std).add_(mu) 25 | else: 26 | return mu 27 | 28 | def forward(self, x, adj): 29 | mu, logvar = self.encode(x, adj) 30 | z = self.reparameterize(mu, logvar) 31 | return self.dc(z), mu, logvar 32 | 33 | 34 | class InnerProductDecoder(nn.Module): 35 | """Decoder for using inner product for prediction.""" 36 | 37 | def __init__(self, dropout, act=torch.sigmoid): 38 | super(InnerProductDecoder, self).__init__() 39 | self.dropout = dropout 40 | self.act = act 41 | 42 | def forward(self, z): 43 | z = F.dropout(z, self.dropout, training=self.training) 44 | adj = self.act(torch.mm(z, z.t())) 45 | return adj 46 | -------------------------------------------------------------------------------- /gae/optimizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.modules.loss 3 | import torch.nn.functional as F 4 | 5 | 6 | def loss_function(preds, labels, mu, logvar, n_nodes, norm, pos_weight): 7 | cost = norm * F.binary_cross_entropy_with_logits(preds, labels, pos_weight=pos_weight) 8 | 9 | # see Appendix B from VAE paper: 10 | # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014 11 | # https://arxiv.org/abs/1312.6114 12 | # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2) 13 | KLD = -0.5 / n_nodes * torch.mean(torch.sum( 14 | 1 + 2 * logvar - mu.pow(2) - logvar.exp().pow(2), 1)) 15 | return cost + KLD 16 | -------------------------------------------------------------------------------- /gae/train.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from __future__ import print_function 3 | 4 | import argparse 5 | import time 6 | 7 | import numpy as np 8 | import scipy.sparse as sp 9 | import torch 10 | from torch import optim 11 | 12 | from gae.model import GCNModelVAE 13 | from gae.optimizer import loss_function 14 | from gae.utils import load_data, mask_test_edges, preprocess_graph, get_roc_score 15 | 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument('--model', type=str, default='gcn_vae', help="models used") 18 | parser.add_argument('--seed', type=int, default=42, help='Random seed.') 19 | parser.add_argument('--epochs', type=int, default=200, help='Number of epochs to train.') 20 | parser.add_argument('--hidden1', type=int, default=32, help='Number of units in hidden layer 1.') 21 | parser.add_argument('--hidden2', type=int, default=16, help='Number of units in hidden layer 2.') 22 | parser.add_argument('--lr', type=float, default=0.01, help='Initial learning rate.') 23 | parser.add_argument('--dropout', type=float, default=0., help='Dropout rate (1 - keep probability).') 24 | parser.add_argument('--dataset-str', type=str, default='cora', help='type of dataset.') 25 | 26 | args = parser.parse_args() 27 | 28 | 29 | def gae_for(args): 30 | print("Using {} dataset".format(args.dataset_str)) 31 | adj, features = load_data(args.dataset_str) 32 | n_nodes, feat_dim = features.shape 33 | 34 | # Store original adjacency matrix (without diagonal entries) for later 35 | adj_orig = adj 36 | adj_orig = adj_orig - sp.dia_matrix((adj_orig.diagonal()[np.newaxis, :], [0]), shape=adj_orig.shape) 37 | adj_orig.eliminate_zeros() 38 | 39 | adj_train, train_edges, val_edges, val_edges_false, test_edges, test_edges_false = mask_test_edges(adj) 40 | adj = adj_train 41 | 42 | # Some preprocessing 43 | adj_norm = preprocess_graph(adj) 44 | adj_label = adj_train + sp.eye(adj_train.shape[0]) 45 | # adj_label = sparse_to_tuple(adj_label) 46 | adj_label = torch.FloatTensor(adj_label.toarray()) 47 | 48 | pos_weight = float(adj.shape[0] * adj.shape[0] - adj.sum()) / adj.sum() 49 | norm = adj.shape[0] * adj.shape[0] / float((adj.shape[0] * adj.shape[0] - adj.sum()) * 2) 50 | 51 | model = GCNModelVAE(feat_dim, args.hidden1, args.hidden2, args.dropout) 52 | optimizer = optim.Adam(model.parameters(), lr=args.lr) 53 | 54 | hidden_emb = None 55 | for epoch in range(args.epochs): 56 | t = time.time() 57 | model.train() 58 | optimizer.zero_grad() 59 | recovered, mu, logvar = model(features, adj_norm) 60 | loss = loss_function(preds=recovered, labels=adj_label, 61 | mu=mu, logvar=logvar, n_nodes=n_nodes, 62 | norm=norm, pos_weight=pos_weight) 63 | loss.backward() 64 | cur_loss = loss.item() 65 | optimizer.step() 66 | 67 | hidden_emb = mu.data.numpy() 68 | roc_curr, ap_curr = get_roc_score(hidden_emb, adj_orig, val_edges, val_edges_false) 69 | 70 | print("Epoch:", '%04d' % (epoch + 1), "train_loss=", "{:.5f}".format(cur_loss), 71 | "val_ap=", "{:.5f}".format(ap_curr), 72 | "time=", "{:.5f}".format(time.time() - t) 73 | ) 74 | 75 | print("Optimization Finished!") 76 | 77 | roc_score, ap_score = get_roc_score(hidden_emb, adj_orig, test_edges, test_edges_false) 78 | print('Test ROC score: ' + str(roc_score)) 79 | print('Test AP score: ' + str(ap_score)) 80 | 81 | 82 | if __name__ == '__main__': 83 | gae_for(args) 84 | -------------------------------------------------------------------------------- /gae/utils.py: -------------------------------------------------------------------------------- 1 | import pickle as pkl 2 | 3 | import networkx as nx 4 | import numpy as np 5 | import scipy.sparse as sp 6 | import torch 7 | from sklearn.metrics import roc_auc_score, average_precision_score 8 | 9 | 10 | def load_data(dataset): 11 | # load the data: x, tx, allx, graph 12 | names = ['x', 'tx', 'allx', 'graph'] 13 | objects = [] 14 | for i in range(len(names)): 15 | ''' 16 | fix Pickle incompatibility of numpy arrays between Python 2 and 3 17 | https://stackoverflow.com/questions/11305790/pickle-incompatibility-of-numpy-arrays-between-python-2-and-3 18 | ''' 19 | with open("data/ind.{}.{}".format(dataset, names[i]), 'rb') as rf: 20 | u = pkl._Unpickler(rf) 21 | u.encoding = 'latin1' 22 | cur_data = u.load() 23 | objects.append(cur_data) 24 | # objects.append( 25 | # pkl.load(open("data/ind.{}.{}".format(dataset, names[i]), 'rb'))) 26 | x, tx, allx, graph = tuple(objects) 27 | test_idx_reorder = parse_index_file( 28 | "data/ind.{}.test.index".format(dataset)) 29 | test_idx_range = np.sort(test_idx_reorder) 30 | 31 | if dataset == 'citeseer': 32 | # Fix citeseer dataset (there are some isolated nodes in the graph) 33 | # Find isolated nodes, add them as zero-vecs into the right position 34 | test_idx_range_full = range( 35 | min(test_idx_reorder), max(test_idx_reorder) + 1) 36 | tx_extended = sp.lil_matrix((len(test_idx_range_full), x.shape[1])) 37 | tx_extended[test_idx_range - min(test_idx_range), :] = tx 38 | tx = tx_extended 39 | 40 | features = sp.vstack((allx, tx)).tolil() 41 | features[test_idx_reorder, :] = features[test_idx_range, :] 42 | features = torch.FloatTensor(np.array(features.todense())) 43 | adj = nx.adjacency_matrix(nx.from_dict_of_lists(graph)) 44 | 45 | return adj, features 46 | 47 | 48 | def parse_index_file(filename): 49 | index = [] 50 | for line in open(filename): 51 | index.append(int(line.strip())) 52 | return index 53 | 54 | 55 | def sparse_to_tuple(sparse_mx): 56 | if not sp.isspmatrix_coo(sparse_mx): 57 | sparse_mx = sparse_mx.tocoo() 58 | coords = np.vstack((sparse_mx.row, sparse_mx.col)).transpose() 59 | values = sparse_mx.data 60 | shape = sparse_mx.shape 61 | return coords, values, shape 62 | 63 | 64 | def mask_test_edges(adj): 65 | # Function to build test set with 10% positive links 66 | # NOTE: Splits are randomized and results might slightly deviate from reported numbers in the paper. 67 | # TODO: Clean up. 68 | 69 | # Remove diagonal elements 70 | adj = adj - sp.dia_matrix((adj.diagonal()[np.newaxis, :], [0]), shape=adj.shape) 71 | adj.eliminate_zeros() 72 | # Check that diag is zero: 73 | assert np.diag(adj.todense()).sum() == 0 74 | 75 | adj_triu = sp.triu(adj) 76 | adj_tuple = sparse_to_tuple(adj_triu) 77 | edges = adj_tuple[0] 78 | edges_all = sparse_to_tuple(adj)[0] 79 | num_test = int(np.floor(edges.shape[0] / 10.)) 80 | num_val = int(np.floor(edges.shape[0] / 20.)) 81 | 82 | all_edge_idx = list(range(edges.shape[0])) 83 | np.random.shuffle(all_edge_idx) 84 | val_edge_idx = all_edge_idx[:num_val] 85 | test_edge_idx = all_edge_idx[num_val:(num_val + num_test)] 86 | test_edges = edges[test_edge_idx] 87 | val_edges = edges[val_edge_idx] 88 | train_edges = np.delete(edges, np.hstack([test_edge_idx, val_edge_idx]), axis=0) 89 | 90 | def ismember(a, b, tol=5): 91 | rows_close = np.all(np.round(a - b[:, None], tol) == 0, axis=-1) 92 | return np.any(rows_close) 93 | 94 | test_edges_false = [] 95 | while len(test_edges_false) < len(test_edges): 96 | idx_i = np.random.randint(0, adj.shape[0]) 97 | idx_j = np.random.randint(0, adj.shape[0]) 98 | if idx_i == idx_j: 99 | continue 100 | if ismember([idx_i, idx_j], edges_all): 101 | continue 102 | if test_edges_false: 103 | if ismember([idx_j, idx_i], np.array(test_edges_false)): 104 | continue 105 | if ismember([idx_i, idx_j], np.array(test_edges_false)): 106 | continue 107 | test_edges_false.append([idx_i, idx_j]) 108 | 109 | val_edges_false = [] 110 | while len(val_edges_false) < len(val_edges): 111 | idx_i = np.random.randint(0, adj.shape[0]) 112 | idx_j = np.random.randint(0, adj.shape[0]) 113 | if idx_i == idx_j: 114 | continue 115 | if ismember([idx_i, idx_j], train_edges): 116 | continue 117 | if ismember([idx_j, idx_i], train_edges): 118 | continue 119 | if ismember([idx_i, idx_j], val_edges): 120 | continue 121 | if ismember([idx_j, idx_i], val_edges): 122 | continue 123 | if val_edges_false: 124 | if ismember([idx_j, idx_i], np.array(val_edges_false)): 125 | continue 126 | if ismember([idx_i, idx_j], np.array(val_edges_false)): 127 | continue 128 | val_edges_false.append([idx_i, idx_j]) 129 | 130 | assert ~ismember(test_edges_false, edges_all) 131 | assert ~ismember(val_edges_false, edges_all) 132 | assert ~ismember(val_edges, train_edges) 133 | assert ~ismember(test_edges, train_edges) 134 | assert ~ismember(val_edges, test_edges) 135 | 136 | data = np.ones(train_edges.shape[0]) 137 | 138 | # Re-build adj matrix 139 | adj_train = sp.csr_matrix((data, (train_edges[:, 0], train_edges[:, 1])), shape=adj.shape) 140 | adj_train = adj_train + adj_train.T 141 | 142 | # NOTE: these edge lists only contain single direction of edge! 143 | return adj_train, train_edges, val_edges, val_edges_false, test_edges, test_edges_false 144 | 145 | 146 | def preprocess_graph(adj): 147 | adj = sp.coo_matrix(adj) 148 | adj_ = adj + sp.eye(adj.shape[0]) 149 | rowsum = np.array(adj_.sum(1)) 150 | degree_mat_inv_sqrt = sp.diags(np.power(rowsum, -0.5).flatten()) 151 | adj_normalized = adj_.dot(degree_mat_inv_sqrt).transpose().dot(degree_mat_inv_sqrt).tocoo() 152 | # return sparse_to_tuple(adj_normalized) 153 | return sparse_mx_to_torch_sparse_tensor(adj_normalized) 154 | 155 | 156 | def sparse_mx_to_torch_sparse_tensor(sparse_mx): 157 | """Convert a scipy sparse matrix to a torch sparse tensor.""" 158 | sparse_mx = sparse_mx.tocoo().astype(np.float32) 159 | indices = torch.from_numpy( 160 | np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64)) 161 | values = torch.from_numpy(sparse_mx.data) 162 | shape = torch.Size(sparse_mx.shape) 163 | return torch.sparse.FloatTensor(indices, values, shape) 164 | 165 | 166 | def get_roc_score(emb, adj_orig, edges_pos, edges_neg): 167 | def sigmoid(x): 168 | return 1 / (1 + np.exp(-x)) 169 | 170 | # Predict on test set of edges 171 | adj_rec = np.dot(emb, emb.T) 172 | preds = [] 173 | pos = [] 174 | for e in edges_pos: 175 | preds.append(sigmoid(adj_rec[e[0], e[1]])) 176 | pos.append(adj_orig[e[0], e[1]]) 177 | 178 | preds_neg = [] 179 | neg = [] 180 | for e in edges_neg: 181 | preds_neg.append(sigmoid(adj_rec[e[0], e[1]])) 182 | neg.append(adj_orig[e[0], e[1]]) 183 | 184 | preds_all = np.hstack([preds, preds_neg]) 185 | labels_all = np.hstack([np.ones(len(preds)), np.zeros(len(preds_neg))]) 186 | roc_score = roc_auc_score(labels_all, preds_all) 187 | ap_score = average_precision_score(labels_all, preds_all) 188 | 189 | return roc_score, ap_score 190 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | scipy==1.0.0 2 | numpy==1.14.0 3 | torch==0.4.1 4 | networkx==2.1 5 | scikit_learn==0.19.2 6 | --------------------------------------------------------------------------------