├── ICASSP_GGCN_preprint.pdf
├── Illustration.png
├── README.md
├── data
├── citeseer_features.npz
├── citeseer_graph.npz
├── cora_features.npz
├── cora_graph.npz
├── pubmed_features.npz
├── pubmed_graph.npz
└── test
│ ├── ind.citeseer.allx
│ ├── ind.citeseer.graph
│ ├── ind.citeseer.test.index
│ ├── ind.citeseer.tx
│ ├── ind.citeseer.x
│ ├── ind.cora.allx
│ ├── ind.cora.graph
│ ├── ind.cora.test.index
│ ├── ind.cora.tx
│ ├── ind.cora.x
│ ├── ind.pubmed.allx
│ ├── ind.pubmed.graph
│ ├── ind.pubmed.test.index
│ ├── ind.pubmed.tx
│ └── ind.pubmed.x
└── src
├── __init__.py
├── graph_data.py
├── layers.py
├── loss.py
├── models.py
├── run_exist_nodes.py
├── run_iso_nodes.py
├── train_citation_gae.py
├── utils.py
├── validate_gae_implementation.py
└── validate_implementation.py
/ICASSP_GGCN_preprint.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/StatsDLMathsRecomSys/Generative-Graph-Convolutional-Network-for-Growing-Graphs/5ba45db6726584cd3c547f8e2429bbea10ea90cf/ICASSP_GGCN_preprint.pdf
--------------------------------------------------------------------------------
/Illustration.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/StatsDLMathsRecomSys/Generative-Graph-Convolutional-Network-for-Growing-Graphs/5ba45db6726584cd3c547f8e2429bbea10ea90cf/Illustration.png
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## GENERATIVE GRAPH CONVOLUTIONAL NETWORK FOR GROWING GRAPHS (ICASSP 2019)
2 |
3 | #### Authors: Da Xu*, Chuanwei Ruan*, Kamiya Motwani, Sushant Kumar, Evren Korpeoglu, Kannan Achan
4 |
5 | #### Please contact Da.Xu@walmartlabs.com or Chuanwei.Ruan@walmartlabs.com for questions.
6 |
7 | ### Introduction
8 | Modeling generative process of growing graphs has wide applications in social networks and recommendation systems. Despite the emerging literature in learning graph representation and graph generation, most of them can not handle isolated new nodes without nontrivial modifications. The challenge arises due to the fact that learning to generate representations for nodes in observed graph relies heavily on topological features, whereas for new nodes only node attributes are available.
9 |
10 | Here we propose a unified generative graph convolutional network that learns node representations for all nodes adaptively in a generative model framework, by sampling graph generation sequences constructed from observed graph data. We optimize over a variational lower bound that consists of a graph reconstruction term and an adaptive Kullback-Leibler divergence regularization term.
11 |
12 | 
13 |
14 |
15 | ### Inductive representation learning on temporal graphs
16 | In our [Inductive Representation Learning on Temporal Graphs (ICLR 2020)](https://openreview.net/pdf?id=rJeW1yHYwH) paper, we discuss how to learning node embeddings inductively on temporal graphs, which is an extension to the growing graphs discussed in this paper. The implementation is also avaiable at the [github page](https://github.com/StatsDLMathsRecomSys/Inductive-representation-learning-on-temporal-graphs).
17 |
18 | ### Datasets
19 | The public dataset of Cora, Citeseer and Pubmed are provided in the data repository. The raw-formated data are in the data/test folder, and the .npz data files in the data repository have been preprocessed into sparisity format.
20 |
21 | ### Running experiments
22 |
23 | * For link prediction tasks on isolated new nodes:
24 | ```bash
25 | python ./scr/run_iso_nodes.py --data_set [cora, citeseer, pubmed]
26 | ```
27 |
28 | * For link prediction on existing nodes:
29 | ```bash
30 | python ./src/run_exist_nodes.py --data_set [cora, citeseer, pubmed]
31 | ```
32 |
33 | ### Cite
34 |
35 | ```
36 | @INPROCEEDINGS{8682360,
37 | author={D. {Xu} and C. {Ruan} and K. {Motwani} and E. {Korpeoglu} and S. {Kumar} and K. {Achan}},
38 | booktitle={ICASSP 2019 - 2019 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)},
39 | title={Generative Graph Convolutional Network for Growing Graphs},
40 | year={2019},
41 | volume={},
42 | number={},
43 | pages={3167-3171},
44 | keywords={Task analysis;Encoding;Decoding;Standards;Adaptation models;Training;Social networking (online);Graph representation learning;sequential generative model;variational autoencoder;growing graph},
45 | doi={10.1109/ICASSP.2019.8682360},
46 | ISSN={2379-190X},
47 | month={May},}
48 | ```
49 |
--------------------------------------------------------------------------------
/data/citeseer_features.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/StatsDLMathsRecomSys/Generative-Graph-Convolutional-Network-for-Growing-Graphs/5ba45db6726584cd3c547f8e2429bbea10ea90cf/data/citeseer_features.npz
--------------------------------------------------------------------------------
/data/citeseer_graph.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/StatsDLMathsRecomSys/Generative-Graph-Convolutional-Network-for-Growing-Graphs/5ba45db6726584cd3c547f8e2429bbea10ea90cf/data/citeseer_graph.npz
--------------------------------------------------------------------------------
/data/cora_features.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/StatsDLMathsRecomSys/Generative-Graph-Convolutional-Network-for-Growing-Graphs/5ba45db6726584cd3c547f8e2429bbea10ea90cf/data/cora_features.npz
--------------------------------------------------------------------------------
/data/cora_graph.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/StatsDLMathsRecomSys/Generative-Graph-Convolutional-Network-for-Growing-Graphs/5ba45db6726584cd3c547f8e2429bbea10ea90cf/data/cora_graph.npz
--------------------------------------------------------------------------------
/data/pubmed_features.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/StatsDLMathsRecomSys/Generative-Graph-Convolutional-Network-for-Growing-Graphs/5ba45db6726584cd3c547f8e2429bbea10ea90cf/data/pubmed_features.npz
--------------------------------------------------------------------------------
/data/pubmed_graph.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/StatsDLMathsRecomSys/Generative-Graph-Convolutional-Network-for-Growing-Graphs/5ba45db6726584cd3c547f8e2429bbea10ea90cf/data/pubmed_graph.npz
--------------------------------------------------------------------------------
/data/test/ind.citeseer.allx:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/StatsDLMathsRecomSys/Generative-Graph-Convolutional-Network-for-Growing-Graphs/5ba45db6726584cd3c547f8e2429bbea10ea90cf/data/test/ind.citeseer.allx
--------------------------------------------------------------------------------
/data/test/ind.citeseer.graph:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/StatsDLMathsRecomSys/Generative-Graph-Convolutional-Network-for-Growing-Graphs/5ba45db6726584cd3c547f8e2429bbea10ea90cf/data/test/ind.citeseer.graph
--------------------------------------------------------------------------------
/data/test/ind.citeseer.test.index:
--------------------------------------------------------------------------------
1 | 2488
2 | 2644
3 | 3261
4 | 2804
5 | 3176
6 | 2432
7 | 3310
8 | 2410
9 | 2812
10 | 2520
11 | 2994
12 | 3282
13 | 2680
14 | 2848
15 | 2670
16 | 3005
17 | 2977
18 | 2592
19 | 2967
20 | 2461
21 | 3184
22 | 2852
23 | 2768
24 | 2905
25 | 2851
26 | 3129
27 | 3164
28 | 2438
29 | 2793
30 | 2763
31 | 2528
32 | 2954
33 | 2347
34 | 2640
35 | 3265
36 | 2874
37 | 2446
38 | 2856
39 | 3149
40 | 2374
41 | 3097
42 | 3301
43 | 2664
44 | 2418
45 | 2655
46 | 2464
47 | 2596
48 | 3262
49 | 3278
50 | 2320
51 | 2612
52 | 2614
53 | 2550
54 | 2626
55 | 2772
56 | 3007
57 | 2733
58 | 2516
59 | 2476
60 | 2798
61 | 2561
62 | 2839
63 | 2685
64 | 2391
65 | 2705
66 | 3098
67 | 2754
68 | 3251
69 | 2767
70 | 2630
71 | 2727
72 | 2513
73 | 2701
74 | 3264
75 | 2792
76 | 2821
77 | 3260
78 | 2462
79 | 3307
80 | 2639
81 | 2900
82 | 3060
83 | 2672
84 | 3116
85 | 2731
86 | 3316
87 | 2386
88 | 2425
89 | 2518
90 | 3151
91 | 2586
92 | 2797
93 | 2479
94 | 3117
95 | 2580
96 | 3182
97 | 2459
98 | 2508
99 | 3052
100 | 3230
101 | 3215
102 | 2803
103 | 2969
104 | 2562
105 | 2398
106 | 3325
107 | 2343
108 | 3030
109 | 2414
110 | 2776
111 | 2383
112 | 3173
113 | 2850
114 | 2499
115 | 3312
116 | 2648
117 | 2784
118 | 2898
119 | 3056
120 | 2484
121 | 3179
122 | 3132
123 | 2577
124 | 2563
125 | 2867
126 | 3317
127 | 2355
128 | 3207
129 | 3178
130 | 2968
131 | 3319
132 | 2358
133 | 2764
134 | 3001
135 | 2683
136 | 3271
137 | 2321
138 | 2567
139 | 2502
140 | 3246
141 | 2715
142 | 3066
143 | 2390
144 | 2381
145 | 3162
146 | 2741
147 | 2498
148 | 2790
149 | 3038
150 | 3321
151 | 2481
152 | 3050
153 | 3161
154 | 3122
155 | 2801
156 | 2957
157 | 3177
158 | 2965
159 | 2621
160 | 3208
161 | 2921
162 | 2802
163 | 2357
164 | 2677
165 | 2519
166 | 2860
167 | 2696
168 | 2368
169 | 3241
170 | 2858
171 | 2419
172 | 2762
173 | 2875
174 | 3222
175 | 3064
176 | 2827
177 | 3044
178 | 2471
179 | 3062
180 | 2982
181 | 2736
182 | 2322
183 | 2709
184 | 2766
185 | 2424
186 | 2602
187 | 2970
188 | 2675
189 | 3299
190 | 2554
191 | 2964
192 | 2597
193 | 2753
194 | 2979
195 | 2523
196 | 2912
197 | 2896
198 | 2317
199 | 3167
200 | 2813
201 | 2482
202 | 2557
203 | 3043
204 | 3244
205 | 2985
206 | 2460
207 | 2363
208 | 3272
209 | 3045
210 | 3192
211 | 2453
212 | 2656
213 | 2834
214 | 2443
215 | 3202
216 | 2926
217 | 2711
218 | 2633
219 | 2384
220 | 2752
221 | 3285
222 | 2817
223 | 2483
224 | 2919
225 | 2924
226 | 2661
227 | 2698
228 | 2361
229 | 2662
230 | 2819
231 | 3143
232 | 2316
233 | 3196
234 | 2739
235 | 2345
236 | 2578
237 | 2822
238 | 3229
239 | 2908
240 | 2917
241 | 2692
242 | 3200
243 | 2324
244 | 2522
245 | 3322
246 | 2697
247 | 3163
248 | 3093
249 | 3233
250 | 2774
251 | 2371
252 | 2835
253 | 2652
254 | 2539
255 | 2843
256 | 3231
257 | 2976
258 | 2429
259 | 2367
260 | 3144
261 | 2564
262 | 3283
263 | 3217
264 | 3035
265 | 2962
266 | 2433
267 | 2415
268 | 2387
269 | 3021
270 | 2595
271 | 2517
272 | 2468
273 | 3061
274 | 2673
275 | 2348
276 | 3027
277 | 2467
278 | 3318
279 | 2959
280 | 3273
281 | 2392
282 | 2779
283 | 2678
284 | 3004
285 | 2634
286 | 2974
287 | 3198
288 | 2342
289 | 2376
290 | 3249
291 | 2868
292 | 2952
293 | 2710
294 | 2838
295 | 2335
296 | 2524
297 | 2650
298 | 3186
299 | 2743
300 | 2545
301 | 2841
302 | 2515
303 | 2505
304 | 3181
305 | 2945
306 | 2738
307 | 2933
308 | 3303
309 | 2611
310 | 3090
311 | 2328
312 | 3010
313 | 3016
314 | 2504
315 | 2936
316 | 3266
317 | 3253
318 | 2840
319 | 3034
320 | 2581
321 | 2344
322 | 2452
323 | 2654
324 | 3199
325 | 3137
326 | 2514
327 | 2394
328 | 2544
329 | 2641
330 | 2613
331 | 2618
332 | 2558
333 | 2593
334 | 2532
335 | 2512
336 | 2975
337 | 3267
338 | 2566
339 | 2951
340 | 3300
341 | 2869
342 | 2629
343 | 2747
344 | 3055
345 | 2831
346 | 3105
347 | 3168
348 | 3100
349 | 2431
350 | 2828
351 | 2684
352 | 3269
353 | 2910
354 | 2865
355 | 2693
356 | 2884
357 | 3228
358 | 2783
359 | 3247
360 | 2770
361 | 3157
362 | 2421
363 | 2382
364 | 2331
365 | 3203
366 | 3240
367 | 2351
368 | 3114
369 | 2986
370 | 2688
371 | 2439
372 | 2996
373 | 3079
374 | 3103
375 | 3296
376 | 2349
377 | 2372
378 | 3096
379 | 2422
380 | 2551
381 | 3069
382 | 2737
383 | 3084
384 | 3304
385 | 3022
386 | 2542
387 | 3204
388 | 2949
389 | 2318
390 | 2450
391 | 3140
392 | 2734
393 | 2881
394 | 2576
395 | 3054
396 | 3089
397 | 3125
398 | 2761
399 | 3136
400 | 3111
401 | 2427
402 | 2466
403 | 3101
404 | 3104
405 | 3259
406 | 2534
407 | 2961
408 | 3191
409 | 3000
410 | 3036
411 | 2356
412 | 2800
413 | 3155
414 | 3224
415 | 2646
416 | 2735
417 | 3020
418 | 2866
419 | 2426
420 | 2448
421 | 3226
422 | 3219
423 | 2749
424 | 3183
425 | 2906
426 | 2360
427 | 2440
428 | 2946
429 | 2313
430 | 2859
431 | 2340
432 | 3008
433 | 2719
434 | 3058
435 | 2653
436 | 3023
437 | 2888
438 | 3243
439 | 2913
440 | 3242
441 | 3067
442 | 2409
443 | 3227
444 | 2380
445 | 2353
446 | 2686
447 | 2971
448 | 2847
449 | 2947
450 | 2857
451 | 3263
452 | 3218
453 | 2861
454 | 3323
455 | 2635
456 | 2966
457 | 2604
458 | 2456
459 | 2832
460 | 2694
461 | 3245
462 | 3119
463 | 2942
464 | 3153
465 | 2894
466 | 2555
467 | 3128
468 | 2703
469 | 2323
470 | 2631
471 | 2732
472 | 2699
473 | 2314
474 | 2590
475 | 3127
476 | 2891
477 | 2873
478 | 2814
479 | 2326
480 | 3026
481 | 3288
482 | 3095
483 | 2706
484 | 2457
485 | 2377
486 | 2620
487 | 2526
488 | 2674
489 | 3190
490 | 2923
491 | 3032
492 | 2334
493 | 3254
494 | 2991
495 | 3277
496 | 2973
497 | 2599
498 | 2658
499 | 2636
500 | 2826
501 | 3148
502 | 2958
503 | 3258
504 | 2990
505 | 3180
506 | 2538
507 | 2748
508 | 2625
509 | 2565
510 | 3011
511 | 3057
512 | 2354
513 | 3158
514 | 2622
515 | 3308
516 | 2983
517 | 2560
518 | 3169
519 | 3059
520 | 2480
521 | 3194
522 | 3291
523 | 3216
524 | 2643
525 | 3172
526 | 2352
527 | 2724
528 | 2485
529 | 2411
530 | 2948
531 | 2445
532 | 2362
533 | 2668
534 | 3275
535 | 3107
536 | 2496
537 | 2529
538 | 2700
539 | 2541
540 | 3028
541 | 2879
542 | 2660
543 | 3324
544 | 2755
545 | 2436
546 | 3048
547 | 2623
548 | 2920
549 | 3040
550 | 2568
551 | 3221
552 | 3003
553 | 3295
554 | 2473
555 | 3232
556 | 3213
557 | 2823
558 | 2897
559 | 2573
560 | 2645
561 | 3018
562 | 3326
563 | 2795
564 | 2915
565 | 3109
566 | 3086
567 | 2463
568 | 3118
569 | 2671
570 | 2909
571 | 2393
572 | 2325
573 | 3029
574 | 2972
575 | 3110
576 | 2870
577 | 3284
578 | 2816
579 | 2647
580 | 2667
581 | 2955
582 | 2333
583 | 2960
584 | 2864
585 | 2893
586 | 2458
587 | 2441
588 | 2359
589 | 2327
590 | 3256
591 | 3099
592 | 3073
593 | 3138
594 | 2511
595 | 2666
596 | 2548
597 | 2364
598 | 2451
599 | 2911
600 | 3237
601 | 3206
602 | 3080
603 | 3279
604 | 2934
605 | 2981
606 | 2878
607 | 3130
608 | 2830
609 | 3091
610 | 2659
611 | 2449
612 | 3152
613 | 2413
614 | 2722
615 | 2796
616 | 3220
617 | 2751
618 | 2935
619 | 3238
620 | 2491
621 | 2730
622 | 2842
623 | 3223
624 | 2492
625 | 3074
626 | 3094
627 | 2833
628 | 2521
629 | 2883
630 | 3315
631 | 2845
632 | 2907
633 | 3083
634 | 2572
635 | 3092
636 | 2903
637 | 2918
638 | 3039
639 | 3286
640 | 2587
641 | 3068
642 | 2338
643 | 3166
644 | 3134
645 | 2455
646 | 2497
647 | 2992
648 | 2775
649 | 2681
650 | 2430
651 | 2932
652 | 2931
653 | 2434
654 | 3154
655 | 3046
656 | 2598
657 | 2366
658 | 3015
659 | 3147
660 | 2944
661 | 2582
662 | 3274
663 | 2987
664 | 2642
665 | 2547
666 | 2420
667 | 2930
668 | 2750
669 | 2417
670 | 2808
671 | 3141
672 | 2997
673 | 2995
674 | 2584
675 | 2312
676 | 3033
677 | 3070
678 | 3065
679 | 2509
680 | 3314
681 | 2396
682 | 2543
683 | 2423
684 | 3170
685 | 2389
686 | 3289
687 | 2728
688 | 2540
689 | 2437
690 | 2486
691 | 2895
692 | 3017
693 | 2853
694 | 2406
695 | 2346
696 | 2877
697 | 2472
698 | 3210
699 | 2637
700 | 2927
701 | 2789
702 | 2330
703 | 3088
704 | 3102
705 | 2616
706 | 3081
707 | 2902
708 | 3205
709 | 3320
710 | 3165
711 | 2984
712 | 3185
713 | 2707
714 | 3255
715 | 2583
716 | 2773
717 | 2742
718 | 3024
719 | 2402
720 | 2718
721 | 2882
722 | 2575
723 | 3281
724 | 2786
725 | 2855
726 | 3014
727 | 2401
728 | 2535
729 | 2687
730 | 2495
731 | 3113
732 | 2609
733 | 2559
734 | 2665
735 | 2530
736 | 3293
737 | 2399
738 | 2605
739 | 2690
740 | 3133
741 | 2799
742 | 2533
743 | 2695
744 | 2713
745 | 2886
746 | 2691
747 | 2549
748 | 3077
749 | 3002
750 | 3049
751 | 3051
752 | 3087
753 | 2444
754 | 3085
755 | 3135
756 | 2702
757 | 3211
758 | 3108
759 | 2501
760 | 2769
761 | 3290
762 | 2465
763 | 3025
764 | 3019
765 | 2385
766 | 2940
767 | 2657
768 | 2610
769 | 2525
770 | 2941
771 | 3078
772 | 2341
773 | 2916
774 | 2956
775 | 2375
776 | 2880
777 | 3009
778 | 2780
779 | 2370
780 | 2925
781 | 2332
782 | 3146
783 | 2315
784 | 2809
785 | 3145
786 | 3106
787 | 2782
788 | 2760
789 | 2493
790 | 2765
791 | 2556
792 | 2890
793 | 2400
794 | 2339
795 | 3201
796 | 2818
797 | 3248
798 | 3280
799 | 2570
800 | 2569
801 | 2937
802 | 3174
803 | 2836
804 | 2708
805 | 2820
806 | 3195
807 | 2617
808 | 3197
809 | 2319
810 | 2744
811 | 2615
812 | 2825
813 | 2603
814 | 2914
815 | 2531
816 | 3193
817 | 2624
818 | 2365
819 | 2810
820 | 3239
821 | 3159
822 | 2537
823 | 2844
824 | 2758
825 | 2938
826 | 3037
827 | 2503
828 | 3297
829 | 2885
830 | 2608
831 | 2494
832 | 2712
833 | 2408
834 | 2901
835 | 2704
836 | 2536
837 | 2373
838 | 2478
839 | 2723
840 | 3076
841 | 2627
842 | 2369
843 | 2669
844 | 3006
845 | 2628
846 | 2788
847 | 3276
848 | 2435
849 | 3139
850 | 3235
851 | 2527
852 | 2571
853 | 2815
854 | 2442
855 | 2892
856 | 2978
857 | 2746
858 | 3150
859 | 2574
860 | 2725
861 | 3188
862 | 2601
863 | 2378
864 | 3075
865 | 2632
866 | 2794
867 | 3270
868 | 3071
869 | 2506
870 | 3126
871 | 3236
872 | 3257
873 | 2824
874 | 2989
875 | 2950
876 | 2428
877 | 2405
878 | 3156
879 | 2447
880 | 2787
881 | 2805
882 | 2720
883 | 2403
884 | 2811
885 | 2329
886 | 2474
887 | 2785
888 | 2350
889 | 2507
890 | 2416
891 | 3112
892 | 2475
893 | 2876
894 | 2585
895 | 2487
896 | 3072
897 | 3082
898 | 2943
899 | 2757
900 | 2388
901 | 2600
902 | 3294
903 | 2756
904 | 3142
905 | 3041
906 | 2594
907 | 2998
908 | 3047
909 | 2379
910 | 2980
911 | 2454
912 | 2862
913 | 3175
914 | 2588
915 | 3031
916 | 3012
917 | 2889
918 | 2500
919 | 2791
920 | 2854
921 | 2619
922 | 2395
923 | 2807
924 | 2740
925 | 2412
926 | 3131
927 | 3013
928 | 2939
929 | 2651
930 | 2490
931 | 2988
932 | 2863
933 | 3225
934 | 2745
935 | 2714
936 | 3160
937 | 3124
938 | 2849
939 | 2676
940 | 2872
941 | 3287
942 | 3189
943 | 2716
944 | 3115
945 | 2928
946 | 2871
947 | 2591
948 | 2717
949 | 2546
950 | 2777
951 | 3298
952 | 2397
953 | 3187
954 | 2726
955 | 2336
956 | 3268
957 | 2477
958 | 2904
959 | 2846
960 | 3121
961 | 2899
962 | 2510
963 | 2806
964 | 2963
965 | 3313
966 | 2679
967 | 3302
968 | 2663
969 | 3053
970 | 2469
971 | 2999
972 | 3311
973 | 2470
974 | 2638
975 | 3120
976 | 3171
977 | 2689
978 | 2922
979 | 2607
980 | 2721
981 | 2993
982 | 2887
983 | 2837
984 | 2929
985 | 2829
986 | 3234
987 | 2649
988 | 2337
989 | 2759
990 | 2778
991 | 2771
992 | 2404
993 | 2589
994 | 3123
995 | 3209
996 | 2729
997 | 3252
998 | 2606
999 | 2579
1000 | 2552
1001 |
--------------------------------------------------------------------------------
/data/test/ind.citeseer.tx:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/StatsDLMathsRecomSys/Generative-Graph-Convolutional-Network-for-Growing-Graphs/5ba45db6726584cd3c547f8e2429bbea10ea90cf/data/test/ind.citeseer.tx
--------------------------------------------------------------------------------
/data/test/ind.citeseer.x:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/StatsDLMathsRecomSys/Generative-Graph-Convolutional-Network-for-Growing-Graphs/5ba45db6726584cd3c547f8e2429bbea10ea90cf/data/test/ind.citeseer.x
--------------------------------------------------------------------------------
/data/test/ind.cora.allx:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/StatsDLMathsRecomSys/Generative-Graph-Convolutional-Network-for-Growing-Graphs/5ba45db6726584cd3c547f8e2429bbea10ea90cf/data/test/ind.cora.allx
--------------------------------------------------------------------------------
/data/test/ind.cora.graph:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/StatsDLMathsRecomSys/Generative-Graph-Convolutional-Network-for-Growing-Graphs/5ba45db6726584cd3c547f8e2429bbea10ea90cf/data/test/ind.cora.graph
--------------------------------------------------------------------------------
/data/test/ind.cora.test.index:
--------------------------------------------------------------------------------
1 | 2692
2 | 2532
3 | 2050
4 | 1715
5 | 2362
6 | 2609
7 | 2622
8 | 1975
9 | 2081
10 | 1767
11 | 2263
12 | 1725
13 | 2588
14 | 2259
15 | 2357
16 | 1998
17 | 2574
18 | 2179
19 | 2291
20 | 2382
21 | 1812
22 | 1751
23 | 2422
24 | 1937
25 | 2631
26 | 2510
27 | 2378
28 | 2589
29 | 2345
30 | 1943
31 | 1850
32 | 2298
33 | 1825
34 | 2035
35 | 2507
36 | 2313
37 | 1906
38 | 1797
39 | 2023
40 | 2159
41 | 2495
42 | 1886
43 | 2122
44 | 2369
45 | 2461
46 | 1925
47 | 2565
48 | 1858
49 | 2234
50 | 2000
51 | 1846
52 | 2318
53 | 1723
54 | 2559
55 | 2258
56 | 1763
57 | 1991
58 | 1922
59 | 2003
60 | 2662
61 | 2250
62 | 2064
63 | 2529
64 | 1888
65 | 2499
66 | 2454
67 | 2320
68 | 2287
69 | 2203
70 | 2018
71 | 2002
72 | 2632
73 | 2554
74 | 2314
75 | 2537
76 | 1760
77 | 2088
78 | 2086
79 | 2218
80 | 2605
81 | 1953
82 | 2403
83 | 1920
84 | 2015
85 | 2335
86 | 2535
87 | 1837
88 | 2009
89 | 1905
90 | 2636
91 | 1942
92 | 2193
93 | 2576
94 | 2373
95 | 1873
96 | 2463
97 | 2509
98 | 1954
99 | 2656
100 | 2455
101 | 2494
102 | 2295
103 | 2114
104 | 2561
105 | 2176
106 | 2275
107 | 2635
108 | 2442
109 | 2704
110 | 2127
111 | 2085
112 | 2214
113 | 2487
114 | 1739
115 | 2543
116 | 1783
117 | 2485
118 | 2262
119 | 2472
120 | 2326
121 | 1738
122 | 2170
123 | 2100
124 | 2384
125 | 2152
126 | 2647
127 | 2693
128 | 2376
129 | 1775
130 | 1726
131 | 2476
132 | 2195
133 | 1773
134 | 1793
135 | 2194
136 | 2581
137 | 1854
138 | 2524
139 | 1945
140 | 1781
141 | 1987
142 | 2599
143 | 1744
144 | 2225
145 | 2300
146 | 1928
147 | 2042
148 | 2202
149 | 1958
150 | 1816
151 | 1916
152 | 2679
153 | 2190
154 | 1733
155 | 2034
156 | 2643
157 | 2177
158 | 1883
159 | 1917
160 | 1996
161 | 2491
162 | 2268
163 | 2231
164 | 2471
165 | 1919
166 | 1909
167 | 2012
168 | 2522
169 | 1865
170 | 2466
171 | 2469
172 | 2087
173 | 2584
174 | 2563
175 | 1924
176 | 2143
177 | 1736
178 | 1966
179 | 2533
180 | 2490
181 | 2630
182 | 1973
183 | 2568
184 | 1978
185 | 2664
186 | 2633
187 | 2312
188 | 2178
189 | 1754
190 | 2307
191 | 2480
192 | 1960
193 | 1742
194 | 1962
195 | 2160
196 | 2070
197 | 2553
198 | 2433
199 | 1768
200 | 2659
201 | 2379
202 | 2271
203 | 1776
204 | 2153
205 | 1877
206 | 2027
207 | 2028
208 | 2155
209 | 2196
210 | 2483
211 | 2026
212 | 2158
213 | 2407
214 | 1821
215 | 2131
216 | 2676
217 | 2277
218 | 2489
219 | 2424
220 | 1963
221 | 1808
222 | 1859
223 | 2597
224 | 2548
225 | 2368
226 | 1817
227 | 2405
228 | 2413
229 | 2603
230 | 2350
231 | 2118
232 | 2329
233 | 1969
234 | 2577
235 | 2475
236 | 2467
237 | 2425
238 | 1769
239 | 2092
240 | 2044
241 | 2586
242 | 2608
243 | 1983
244 | 2109
245 | 2649
246 | 1964
247 | 2144
248 | 1902
249 | 2411
250 | 2508
251 | 2360
252 | 1721
253 | 2005
254 | 2014
255 | 2308
256 | 2646
257 | 1949
258 | 1830
259 | 2212
260 | 2596
261 | 1832
262 | 1735
263 | 1866
264 | 2695
265 | 1941
266 | 2546
267 | 2498
268 | 2686
269 | 2665
270 | 1784
271 | 2613
272 | 1970
273 | 2021
274 | 2211
275 | 2516
276 | 2185
277 | 2479
278 | 2699
279 | 2150
280 | 1990
281 | 2063
282 | 2075
283 | 1979
284 | 2094
285 | 1787
286 | 2571
287 | 2690
288 | 1926
289 | 2341
290 | 2566
291 | 1957
292 | 1709
293 | 1955
294 | 2570
295 | 2387
296 | 1811
297 | 2025
298 | 2447
299 | 2696
300 | 2052
301 | 2366
302 | 1857
303 | 2273
304 | 2245
305 | 2672
306 | 2133
307 | 2421
308 | 1929
309 | 2125
310 | 2319
311 | 2641
312 | 2167
313 | 2418
314 | 1765
315 | 1761
316 | 1828
317 | 2188
318 | 1972
319 | 1997
320 | 2419
321 | 2289
322 | 2296
323 | 2587
324 | 2051
325 | 2440
326 | 2053
327 | 2191
328 | 1923
329 | 2164
330 | 1861
331 | 2339
332 | 2333
333 | 2523
334 | 2670
335 | 2121
336 | 1921
337 | 1724
338 | 2253
339 | 2374
340 | 1940
341 | 2545
342 | 2301
343 | 2244
344 | 2156
345 | 1849
346 | 2551
347 | 2011
348 | 2279
349 | 2572
350 | 1757
351 | 2400
352 | 2569
353 | 2072
354 | 2526
355 | 2173
356 | 2069
357 | 2036
358 | 1819
359 | 1734
360 | 1880
361 | 2137
362 | 2408
363 | 2226
364 | 2604
365 | 1771
366 | 2698
367 | 2187
368 | 2060
369 | 1756
370 | 2201
371 | 2066
372 | 2439
373 | 1844
374 | 1772
375 | 2383
376 | 2398
377 | 1708
378 | 1992
379 | 1959
380 | 1794
381 | 2426
382 | 2702
383 | 2444
384 | 1944
385 | 1829
386 | 2660
387 | 2497
388 | 2607
389 | 2343
390 | 1730
391 | 2624
392 | 1790
393 | 1935
394 | 1967
395 | 2401
396 | 2255
397 | 2355
398 | 2348
399 | 1931
400 | 2183
401 | 2161
402 | 2701
403 | 1948
404 | 2501
405 | 2192
406 | 2404
407 | 2209
408 | 2331
409 | 1810
410 | 2363
411 | 2334
412 | 1887
413 | 2393
414 | 2557
415 | 1719
416 | 1732
417 | 1986
418 | 2037
419 | 2056
420 | 1867
421 | 2126
422 | 1932
423 | 2117
424 | 1807
425 | 1801
426 | 1743
427 | 2041
428 | 1843
429 | 2388
430 | 2221
431 | 1833
432 | 2677
433 | 1778
434 | 2661
435 | 2306
436 | 2394
437 | 2106
438 | 2430
439 | 2371
440 | 2606
441 | 2353
442 | 2269
443 | 2317
444 | 2645
445 | 2372
446 | 2550
447 | 2043
448 | 1968
449 | 2165
450 | 2310
451 | 1985
452 | 2446
453 | 1982
454 | 2377
455 | 2207
456 | 1818
457 | 1913
458 | 1766
459 | 1722
460 | 1894
461 | 2020
462 | 1881
463 | 2621
464 | 2409
465 | 2261
466 | 2458
467 | 2096
468 | 1712
469 | 2594
470 | 2293
471 | 2048
472 | 2359
473 | 1839
474 | 2392
475 | 2254
476 | 1911
477 | 2101
478 | 2367
479 | 1889
480 | 1753
481 | 2555
482 | 2246
483 | 2264
484 | 2010
485 | 2336
486 | 2651
487 | 2017
488 | 2140
489 | 1842
490 | 2019
491 | 1890
492 | 2525
493 | 2134
494 | 2492
495 | 2652
496 | 2040
497 | 2145
498 | 2575
499 | 2166
500 | 1999
501 | 2434
502 | 1711
503 | 2276
504 | 2450
505 | 2389
506 | 2669
507 | 2595
508 | 1814
509 | 2039
510 | 2502
511 | 1896
512 | 2168
513 | 2344
514 | 2637
515 | 2031
516 | 1977
517 | 2380
518 | 1936
519 | 2047
520 | 2460
521 | 2102
522 | 1745
523 | 2650
524 | 2046
525 | 2514
526 | 1980
527 | 2352
528 | 2113
529 | 1713
530 | 2058
531 | 2558
532 | 1718
533 | 1864
534 | 1876
535 | 2338
536 | 1879
537 | 1891
538 | 2186
539 | 2451
540 | 2181
541 | 2638
542 | 2644
543 | 2103
544 | 2591
545 | 2266
546 | 2468
547 | 1869
548 | 2582
549 | 2674
550 | 2361
551 | 2462
552 | 1748
553 | 2215
554 | 2615
555 | 2236
556 | 2248
557 | 2493
558 | 2342
559 | 2449
560 | 2274
561 | 1824
562 | 1852
563 | 1870
564 | 2441
565 | 2356
566 | 1835
567 | 2694
568 | 2602
569 | 2685
570 | 1893
571 | 2544
572 | 2536
573 | 1994
574 | 1853
575 | 1838
576 | 1786
577 | 1930
578 | 2539
579 | 1892
580 | 2265
581 | 2618
582 | 2486
583 | 2583
584 | 2061
585 | 1796
586 | 1806
587 | 2084
588 | 1933
589 | 2095
590 | 2136
591 | 2078
592 | 1884
593 | 2438
594 | 2286
595 | 2138
596 | 1750
597 | 2184
598 | 1799
599 | 2278
600 | 2410
601 | 2642
602 | 2435
603 | 1956
604 | 2399
605 | 1774
606 | 2129
607 | 1898
608 | 1823
609 | 1938
610 | 2299
611 | 1862
612 | 2420
613 | 2673
614 | 1984
615 | 2204
616 | 1717
617 | 2074
618 | 2213
619 | 2436
620 | 2297
621 | 2592
622 | 2667
623 | 2703
624 | 2511
625 | 1779
626 | 1782
627 | 2625
628 | 2365
629 | 2315
630 | 2381
631 | 1788
632 | 1714
633 | 2302
634 | 1927
635 | 2325
636 | 2506
637 | 2169
638 | 2328
639 | 2629
640 | 2128
641 | 2655
642 | 2282
643 | 2073
644 | 2395
645 | 2247
646 | 2521
647 | 2260
648 | 1868
649 | 1988
650 | 2324
651 | 2705
652 | 2541
653 | 1731
654 | 2681
655 | 2707
656 | 2465
657 | 1785
658 | 2149
659 | 2045
660 | 2505
661 | 2611
662 | 2217
663 | 2180
664 | 1904
665 | 2453
666 | 2484
667 | 1871
668 | 2309
669 | 2349
670 | 2482
671 | 2004
672 | 1965
673 | 2406
674 | 2162
675 | 1805
676 | 2654
677 | 2007
678 | 1947
679 | 1981
680 | 2112
681 | 2141
682 | 1720
683 | 1758
684 | 2080
685 | 2330
686 | 2030
687 | 2432
688 | 2089
689 | 2547
690 | 1820
691 | 1815
692 | 2675
693 | 1840
694 | 2658
695 | 2370
696 | 2251
697 | 1908
698 | 2029
699 | 2068
700 | 2513
701 | 2549
702 | 2267
703 | 2580
704 | 2327
705 | 2351
706 | 2111
707 | 2022
708 | 2321
709 | 2614
710 | 2252
711 | 2104
712 | 1822
713 | 2552
714 | 2243
715 | 1798
716 | 2396
717 | 2663
718 | 2564
719 | 2148
720 | 2562
721 | 2684
722 | 2001
723 | 2151
724 | 2706
725 | 2240
726 | 2474
727 | 2303
728 | 2634
729 | 2680
730 | 2055
731 | 2090
732 | 2503
733 | 2347
734 | 2402
735 | 2238
736 | 1950
737 | 2054
738 | 2016
739 | 1872
740 | 2233
741 | 1710
742 | 2032
743 | 2540
744 | 2628
745 | 1795
746 | 2616
747 | 1903
748 | 2531
749 | 2567
750 | 1946
751 | 1897
752 | 2222
753 | 2227
754 | 2627
755 | 1856
756 | 2464
757 | 2241
758 | 2481
759 | 2130
760 | 2311
761 | 2083
762 | 2223
763 | 2284
764 | 2235
765 | 2097
766 | 1752
767 | 2515
768 | 2527
769 | 2385
770 | 2189
771 | 2283
772 | 2182
773 | 2079
774 | 2375
775 | 2174
776 | 2437
777 | 1993
778 | 2517
779 | 2443
780 | 2224
781 | 2648
782 | 2171
783 | 2290
784 | 2542
785 | 2038
786 | 1855
787 | 1831
788 | 1759
789 | 1848
790 | 2445
791 | 1827
792 | 2429
793 | 2205
794 | 2598
795 | 2657
796 | 1728
797 | 2065
798 | 1918
799 | 2427
800 | 2573
801 | 2620
802 | 2292
803 | 1777
804 | 2008
805 | 1875
806 | 2288
807 | 2256
808 | 2033
809 | 2470
810 | 2585
811 | 2610
812 | 2082
813 | 2230
814 | 1915
815 | 1847
816 | 2337
817 | 2512
818 | 2386
819 | 2006
820 | 2653
821 | 2346
822 | 1951
823 | 2110
824 | 2639
825 | 2520
826 | 1939
827 | 2683
828 | 2139
829 | 2220
830 | 1910
831 | 2237
832 | 1900
833 | 1836
834 | 2197
835 | 1716
836 | 1860
837 | 2077
838 | 2519
839 | 2538
840 | 2323
841 | 1914
842 | 1971
843 | 1845
844 | 2132
845 | 1802
846 | 1907
847 | 2640
848 | 2496
849 | 2281
850 | 2198
851 | 2416
852 | 2285
853 | 1755
854 | 2431
855 | 2071
856 | 2249
857 | 2123
858 | 1727
859 | 2459
860 | 2304
861 | 2199
862 | 1791
863 | 1809
864 | 1780
865 | 2210
866 | 2417
867 | 1874
868 | 1878
869 | 2116
870 | 1961
871 | 1863
872 | 2579
873 | 2477
874 | 2228
875 | 2332
876 | 2578
877 | 2457
878 | 2024
879 | 1934
880 | 2316
881 | 1841
882 | 1764
883 | 1737
884 | 2322
885 | 2239
886 | 2294
887 | 1729
888 | 2488
889 | 1974
890 | 2473
891 | 2098
892 | 2612
893 | 1834
894 | 2340
895 | 2423
896 | 2175
897 | 2280
898 | 2617
899 | 2208
900 | 2560
901 | 1741
902 | 2600
903 | 2059
904 | 1747
905 | 2242
906 | 2700
907 | 2232
908 | 2057
909 | 2147
910 | 2682
911 | 1792
912 | 1826
913 | 2120
914 | 1895
915 | 2364
916 | 2163
917 | 1851
918 | 2391
919 | 2414
920 | 2452
921 | 1803
922 | 1989
923 | 2623
924 | 2200
925 | 2528
926 | 2415
927 | 1804
928 | 2146
929 | 2619
930 | 2687
931 | 1762
932 | 2172
933 | 2270
934 | 2678
935 | 2593
936 | 2448
937 | 1882
938 | 2257
939 | 2500
940 | 1899
941 | 2478
942 | 2412
943 | 2107
944 | 1746
945 | 2428
946 | 2115
947 | 1800
948 | 1901
949 | 2397
950 | 2530
951 | 1912
952 | 2108
953 | 2206
954 | 2091
955 | 1740
956 | 2219
957 | 1976
958 | 2099
959 | 2142
960 | 2671
961 | 2668
962 | 2216
963 | 2272
964 | 2229
965 | 2666
966 | 2456
967 | 2534
968 | 2697
969 | 2688
970 | 2062
971 | 2691
972 | 2689
973 | 2154
974 | 2590
975 | 2626
976 | 2390
977 | 1813
978 | 2067
979 | 1952
980 | 2518
981 | 2358
982 | 1789
983 | 2076
984 | 2049
985 | 2119
986 | 2013
987 | 2124
988 | 2556
989 | 2105
990 | 2093
991 | 1885
992 | 2305
993 | 2354
994 | 2135
995 | 2601
996 | 1770
997 | 1995
998 | 2504
999 | 1749
1000 | 2157
1001 |
--------------------------------------------------------------------------------
/data/test/ind.cora.tx:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/StatsDLMathsRecomSys/Generative-Graph-Convolutional-Network-for-Growing-Graphs/5ba45db6726584cd3c547f8e2429bbea10ea90cf/data/test/ind.cora.tx
--------------------------------------------------------------------------------
/data/test/ind.cora.x:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/StatsDLMathsRecomSys/Generative-Graph-Convolutional-Network-for-Growing-Graphs/5ba45db6726584cd3c547f8e2429bbea10ea90cf/data/test/ind.cora.x
--------------------------------------------------------------------------------
/data/test/ind.pubmed.allx:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/StatsDLMathsRecomSys/Generative-Graph-Convolutional-Network-for-Growing-Graphs/5ba45db6726584cd3c547f8e2429bbea10ea90cf/data/test/ind.pubmed.allx
--------------------------------------------------------------------------------
/data/test/ind.pubmed.graph:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/StatsDLMathsRecomSys/Generative-Graph-Convolutional-Network-for-Growing-Graphs/5ba45db6726584cd3c547f8e2429bbea10ea90cf/data/test/ind.pubmed.graph
--------------------------------------------------------------------------------
/data/test/ind.pubmed.test.index:
--------------------------------------------------------------------------------
1 | 18747
2 | 19392
3 | 19181
4 | 18843
5 | 19221
6 | 18962
7 | 19560
8 | 19097
9 | 18966
10 | 19014
11 | 18756
12 | 19313
13 | 19000
14 | 19569
15 | 19359
16 | 18854
17 | 18970
18 | 19073
19 | 19661
20 | 19180
21 | 19377
22 | 18750
23 | 19401
24 | 18788
25 | 19224
26 | 19447
27 | 19017
28 | 19241
29 | 18890
30 | 18908
31 | 18965
32 | 19001
33 | 18849
34 | 19641
35 | 18852
36 | 19222
37 | 19172
38 | 18762
39 | 19156
40 | 19162
41 | 18856
42 | 18763
43 | 19318
44 | 18826
45 | 19712
46 | 19192
47 | 19695
48 | 19030
49 | 19523
50 | 19249
51 | 19079
52 | 19232
53 | 19455
54 | 18743
55 | 18800
56 | 19071
57 | 18885
58 | 19593
59 | 19394
60 | 19390
61 | 18832
62 | 19445
63 | 18838
64 | 19632
65 | 19548
66 | 19546
67 | 18825
68 | 19498
69 | 19266
70 | 19117
71 | 19595
72 | 19252
73 | 18730
74 | 18913
75 | 18809
76 | 19452
77 | 19520
78 | 19274
79 | 19555
80 | 19388
81 | 18919
82 | 19099
83 | 19637
84 | 19403
85 | 18720
86 | 19526
87 | 18905
88 | 19451
89 | 19408
90 | 18923
91 | 18794
92 | 19322
93 | 19431
94 | 18912
95 | 18841
96 | 19239
97 | 19125
98 | 19258
99 | 19565
100 | 18898
101 | 19482
102 | 19029
103 | 18778
104 | 19096
105 | 19684
106 | 19552
107 | 18765
108 | 19361
109 | 19171
110 | 19367
111 | 19623
112 | 19402
113 | 19327
114 | 19118
115 | 18888
116 | 18726
117 | 19510
118 | 18831
119 | 19490
120 | 19576
121 | 19050
122 | 18729
123 | 18896
124 | 19246
125 | 19012
126 | 18862
127 | 18873
128 | 19193
129 | 19693
130 | 19474
131 | 18953
132 | 19115
133 | 19182
134 | 19269
135 | 19116
136 | 18837
137 | 18872
138 | 19007
139 | 19212
140 | 18798
141 | 19102
142 | 18772
143 | 19660
144 | 19511
145 | 18914
146 | 18886
147 | 19672
148 | 19360
149 | 19213
150 | 18810
151 | 19420
152 | 19512
153 | 18719
154 | 19432
155 | 19350
156 | 19127
157 | 18782
158 | 19587
159 | 18924
160 | 19488
161 | 18781
162 | 19340
163 | 19190
164 | 19383
165 | 19094
166 | 18835
167 | 19487
168 | 19230
169 | 18791
170 | 18882
171 | 18937
172 | 18928
173 | 18755
174 | 18802
175 | 19516
176 | 18795
177 | 18786
178 | 19273
179 | 19349
180 | 19398
181 | 19626
182 | 19130
183 | 19351
184 | 19489
185 | 19446
186 | 18959
187 | 19025
188 | 18792
189 | 18878
190 | 19304
191 | 19629
192 | 19061
193 | 18785
194 | 19194
195 | 19179
196 | 19210
197 | 19417
198 | 19583
199 | 19415
200 | 19443
201 | 18739
202 | 19662
203 | 18904
204 | 18910
205 | 18901
206 | 18960
207 | 18722
208 | 18827
209 | 19290
210 | 18842
211 | 19389
212 | 19344
213 | 18961
214 | 19098
215 | 19147
216 | 19334
217 | 19358
218 | 18829
219 | 18984
220 | 18931
221 | 18742
222 | 19320
223 | 19111
224 | 19196
225 | 18887
226 | 18991
227 | 19469
228 | 18990
229 | 18876
230 | 19261
231 | 19270
232 | 19522
233 | 19088
234 | 19284
235 | 19646
236 | 19493
237 | 19225
238 | 19615
239 | 19449
240 | 19043
241 | 19674
242 | 19391
243 | 18918
244 | 19155
245 | 19110
246 | 18815
247 | 19131
248 | 18834
249 | 19715
250 | 19603
251 | 19688
252 | 19133
253 | 19053
254 | 19166
255 | 19066
256 | 18893
257 | 18757
258 | 19582
259 | 19282
260 | 19257
261 | 18869
262 | 19467
263 | 18954
264 | 19371
265 | 19151
266 | 19462
267 | 19598
268 | 19653
269 | 19187
270 | 19624
271 | 19564
272 | 19534
273 | 19581
274 | 19478
275 | 18985
276 | 18746
277 | 19342
278 | 18777
279 | 19696
280 | 18824
281 | 19138
282 | 18728
283 | 19643
284 | 19199
285 | 18731
286 | 19168
287 | 18948
288 | 19216
289 | 19697
290 | 19347
291 | 18808
292 | 18725
293 | 19134
294 | 18847
295 | 18828
296 | 18996
297 | 19106
298 | 19485
299 | 18917
300 | 18911
301 | 18776
302 | 19203
303 | 19158
304 | 18895
305 | 19165
306 | 19382
307 | 18780
308 | 18836
309 | 19373
310 | 19659
311 | 18947
312 | 19375
313 | 19299
314 | 18761
315 | 19366
316 | 18754
317 | 19248
318 | 19416
319 | 19658
320 | 19638
321 | 19034
322 | 19281
323 | 18844
324 | 18922
325 | 19491
326 | 19272
327 | 19341
328 | 19068
329 | 19332
330 | 19559
331 | 19293
332 | 18804
333 | 18933
334 | 18935
335 | 19405
336 | 18936
337 | 18945
338 | 18943
339 | 18818
340 | 18797
341 | 19570
342 | 19464
343 | 19428
344 | 19093
345 | 19433
346 | 18986
347 | 19161
348 | 19255
349 | 19157
350 | 19046
351 | 19292
352 | 19434
353 | 19298
354 | 18724
355 | 19410
356 | 19694
357 | 19214
358 | 19640
359 | 19189
360 | 18963
361 | 19218
362 | 19585
363 | 19041
364 | 19550
365 | 19123
366 | 19620
367 | 19376
368 | 19561
369 | 18944
370 | 19706
371 | 19056
372 | 19283
373 | 18741
374 | 19319
375 | 19144
376 | 19542
377 | 18821
378 | 19404
379 | 19080
380 | 19303
381 | 18793
382 | 19306
383 | 19678
384 | 19435
385 | 19519
386 | 19566
387 | 19278
388 | 18946
389 | 19536
390 | 19020
391 | 19057
392 | 19198
393 | 19333
394 | 19649
395 | 19699
396 | 19399
397 | 19654
398 | 19136
399 | 19465
400 | 19321
401 | 19577
402 | 18907
403 | 19665
404 | 19386
405 | 19596
406 | 19247
407 | 19473
408 | 19568
409 | 19355
410 | 18925
411 | 19586
412 | 18982
413 | 19616
414 | 19495
415 | 19612
416 | 19023
417 | 19438
418 | 18817
419 | 19692
420 | 19295
421 | 19414
422 | 19676
423 | 19472
424 | 19107
425 | 19062
426 | 19035
427 | 18883
428 | 19409
429 | 19052
430 | 19606
431 | 19091
432 | 19651
433 | 19475
434 | 19413
435 | 18796
436 | 19369
437 | 19639
438 | 19701
439 | 19461
440 | 19645
441 | 19251
442 | 19063
443 | 19679
444 | 19545
445 | 19081
446 | 19363
447 | 18995
448 | 19549
449 | 18790
450 | 18855
451 | 18833
452 | 18899
453 | 19395
454 | 18717
455 | 19647
456 | 18768
457 | 19103
458 | 19245
459 | 18819
460 | 18779
461 | 19656
462 | 19076
463 | 18745
464 | 18971
465 | 19197
466 | 19711
467 | 19074
468 | 19128
469 | 19466
470 | 19139
471 | 19309
472 | 19324
473 | 18814
474 | 19092
475 | 19627
476 | 19060
477 | 18806
478 | 18929
479 | 18737
480 | 18942
481 | 18906
482 | 18858
483 | 19456
484 | 19253
485 | 19716
486 | 19104
487 | 19667
488 | 19574
489 | 18903
490 | 19237
491 | 18864
492 | 19556
493 | 19364
494 | 18952
495 | 19008
496 | 19323
497 | 19700
498 | 19170
499 | 19267
500 | 19345
501 | 19238
502 | 18909
503 | 18892
504 | 19109
505 | 19704
506 | 18902
507 | 19275
508 | 19680
509 | 18723
510 | 19242
511 | 19112
512 | 19169
513 | 18956
514 | 19343
515 | 19650
516 | 19541
517 | 19698
518 | 19521
519 | 19087
520 | 18976
521 | 19038
522 | 18775
523 | 18968
524 | 19671
525 | 19412
526 | 19407
527 | 19573
528 | 19027
529 | 18813
530 | 19357
531 | 19460
532 | 19673
533 | 19481
534 | 19036
535 | 19614
536 | 18787
537 | 19195
538 | 18732
539 | 18884
540 | 19613
541 | 19657
542 | 19575
543 | 19226
544 | 19589
545 | 19234
546 | 19617
547 | 19707
548 | 19484
549 | 18740
550 | 19424
551 | 18784
552 | 19419
553 | 19159
554 | 18865
555 | 19105
556 | 19315
557 | 19480
558 | 19664
559 | 19378
560 | 18803
561 | 19605
562 | 18870
563 | 19042
564 | 19426
565 | 18848
566 | 19223
567 | 19509
568 | 19532
569 | 18752
570 | 19691
571 | 18718
572 | 19209
573 | 19362
574 | 19090
575 | 19492
576 | 19567
577 | 19687
578 | 19018
579 | 18830
580 | 19530
581 | 19554
582 | 19119
583 | 19442
584 | 19558
585 | 19527
586 | 19427
587 | 19291
588 | 19543
589 | 19422
590 | 19142
591 | 18897
592 | 18950
593 | 19425
594 | 19002
595 | 19588
596 | 18978
597 | 19551
598 | 18930
599 | 18736
600 | 19101
601 | 19215
602 | 19150
603 | 19263
604 | 18949
605 | 18974
606 | 18759
607 | 19335
608 | 19200
609 | 19129
610 | 19328
611 | 19437
612 | 18988
613 | 19429
614 | 19368
615 | 19406
616 | 19049
617 | 18811
618 | 19296
619 | 19256
620 | 19385
621 | 19602
622 | 18770
623 | 19337
624 | 19580
625 | 19476
626 | 19045
627 | 19132
628 | 19089
629 | 19120
630 | 19265
631 | 19483
632 | 18767
633 | 19227
634 | 18934
635 | 19069
636 | 18820
637 | 19006
638 | 19459
639 | 18927
640 | 19037
641 | 19280
642 | 19441
643 | 18823
644 | 19015
645 | 19114
646 | 19618
647 | 18957
648 | 19176
649 | 18853
650 | 19648
651 | 19201
652 | 19444
653 | 19279
654 | 18751
655 | 19302
656 | 19505
657 | 18733
658 | 19601
659 | 19533
660 | 18863
661 | 19708
662 | 19387
663 | 19346
664 | 19152
665 | 19206
666 | 18851
667 | 19338
668 | 19681
669 | 19380
670 | 19055
671 | 18766
672 | 19085
673 | 19591
674 | 19547
675 | 18958
676 | 19146
677 | 18840
678 | 19051
679 | 19021
680 | 19207
681 | 19235
682 | 19086
683 | 18979
684 | 19300
685 | 18939
686 | 19100
687 | 19619
688 | 19287
689 | 18980
690 | 19277
691 | 19326
692 | 19108
693 | 18920
694 | 19625
695 | 19374
696 | 19078
697 | 18734
698 | 19634
699 | 19339
700 | 18877
701 | 19423
702 | 19652
703 | 19683
704 | 19044
705 | 18983
706 | 19330
707 | 19529
708 | 19714
709 | 19468
710 | 19075
711 | 19540
712 | 18839
713 | 19022
714 | 19286
715 | 19537
716 | 19175
717 | 19463
718 | 19167
719 | 19705
720 | 19562
721 | 19244
722 | 19486
723 | 19611
724 | 18801
725 | 19178
726 | 19590
727 | 18846
728 | 19450
729 | 19205
730 | 19381
731 | 18941
732 | 19670
733 | 19185
734 | 19504
735 | 19633
736 | 18997
737 | 19113
738 | 19397
739 | 19636
740 | 19709
741 | 19289
742 | 19264
743 | 19353
744 | 19584
745 | 19126
746 | 18938
747 | 19669
748 | 18964
749 | 19276
750 | 18774
751 | 19173
752 | 19231
753 | 18973
754 | 18769
755 | 19064
756 | 19040
757 | 19668
758 | 18738
759 | 19082
760 | 19655
761 | 19236
762 | 19352
763 | 19609
764 | 19628
765 | 18951
766 | 19384
767 | 19122
768 | 18875
769 | 18992
770 | 18753
771 | 19379
772 | 19254
773 | 19301
774 | 19506
775 | 19135
776 | 19010
777 | 19682
778 | 19400
779 | 19579
780 | 19316
781 | 19553
782 | 19208
783 | 19635
784 | 19644
785 | 18891
786 | 19024
787 | 18989
788 | 19250
789 | 18850
790 | 19317
791 | 18915
792 | 19607
793 | 18799
794 | 18881
795 | 19479
796 | 19031
797 | 19365
798 | 19164
799 | 18744
800 | 18760
801 | 19502
802 | 19058
803 | 19517
804 | 18735
805 | 19448
806 | 19243
807 | 19453
808 | 19285
809 | 18857
810 | 19439
811 | 19016
812 | 18975
813 | 19503
814 | 18998
815 | 18981
816 | 19186
817 | 18994
818 | 19240
819 | 19631
820 | 19070
821 | 19174
822 | 18900
823 | 19065
824 | 19220
825 | 19229
826 | 18880
827 | 19308
828 | 19372
829 | 19496
830 | 18771
831 | 19325
832 | 19538
833 | 19033
834 | 18874
835 | 19077
836 | 19211
837 | 18764
838 | 19458
839 | 19571
840 | 19121
841 | 19019
842 | 19059
843 | 19497
844 | 18969
845 | 19666
846 | 19297
847 | 19219
848 | 19622
849 | 19184
850 | 18977
851 | 19702
852 | 19539
853 | 19329
854 | 19095
855 | 19675
856 | 18972
857 | 19514
858 | 19703
859 | 19188
860 | 18866
861 | 18812
862 | 19314
863 | 18822
864 | 18845
865 | 19494
866 | 19411
867 | 18916
868 | 19686
869 | 18967
870 | 19294
871 | 19143
872 | 19204
873 | 18805
874 | 19689
875 | 19233
876 | 18758
877 | 18748
878 | 19011
879 | 19685
880 | 19336
881 | 19608
882 | 19454
883 | 19124
884 | 18868
885 | 18807
886 | 19544
887 | 19621
888 | 19228
889 | 19154
890 | 19141
891 | 19145
892 | 19153
893 | 18860
894 | 19163
895 | 19393
896 | 19268
897 | 19160
898 | 19305
899 | 19259
900 | 19471
901 | 19524
902 | 18783
903 | 19396
904 | 18894
905 | 19430
906 | 19690
907 | 19348
908 | 19597
909 | 19592
910 | 19677
911 | 18889
912 | 19331
913 | 18773
914 | 19137
915 | 19009
916 | 18932
917 | 19599
918 | 18816
919 | 19054
920 | 19067
921 | 19477
922 | 19191
923 | 18921
924 | 18940
925 | 19578
926 | 19183
927 | 19004
928 | 19072
929 | 19710
930 | 19005
931 | 19610
932 | 18955
933 | 19457
934 | 19148
935 | 18859
936 | 18993
937 | 19642
938 | 19047
939 | 19418
940 | 19535
941 | 19600
942 | 19312
943 | 19039
944 | 19028
945 | 18879
946 | 19003
947 | 19026
948 | 19013
949 | 19149
950 | 19177
951 | 19217
952 | 18987
953 | 19354
954 | 19525
955 | 19202
956 | 19084
957 | 19032
958 | 18749
959 | 18867
960 | 19048
961 | 18999
962 | 19260
963 | 19630
964 | 18727
965 | 19356
966 | 19083
967 | 18926
968 | 18789
969 | 19370
970 | 18861
971 | 19311
972 | 19557
973 | 19531
974 | 19436
975 | 19140
976 | 19310
977 | 19501
978 | 18721
979 | 19604
980 | 19713
981 | 19262
982 | 19563
983 | 19507
984 | 19440
985 | 19572
986 | 19513
987 | 19515
988 | 19518
989 | 19421
990 | 19470
991 | 19499
992 | 19663
993 | 19508
994 | 18871
995 | 19528
996 | 19500
997 | 19307
998 | 19288
999 | 19594
1000 | 19271
1001 |
--------------------------------------------------------------------------------
/data/test/ind.pubmed.tx:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/StatsDLMathsRecomSys/Generative-Graph-Convolutional-Network-for-Growing-Graphs/5ba45db6726584cd3c547f8e2429bbea10ea90cf/data/test/ind.pubmed.tx
--------------------------------------------------------------------------------
/data/test/ind.pubmed.x:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/StatsDLMathsRecomSys/Generative-Graph-Convolutional-Network-for-Growing-Graphs/5ba45db6726584cd3c547f8e2429bbea10ea90cf/data/test/ind.pubmed.x
--------------------------------------------------------------------------------
/src/__init__.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function, division
2 |
--------------------------------------------------------------------------------
/src/graph_data.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import networkx as nx
4 | import scipy.sparse as sp
5 | from torch.utils.data import Dataset
6 |
7 | def generate_random_graph_data(n_nodes=100, x_dim = 10, seed=999, noise=True, with_test=False, test_ratio=0.3, partial_ratio=None):
8 | np.random.seed(seed)
9 | X = np.random.rand(n_nodes * x_dim).reshape(n_nodes, x_dim) - 0.5
10 | X = X.astype(np.float32)
11 |
12 | gram = X.dot(X.T)
13 | if noise:
14 | gram += np.random.normal(size=gram.shape) * 0.01
15 | adj = gram > 0.5
16 | if partial_ratio is not None:
17 | X = X[:, :int(partial_ratio * X.shape[1])]
18 |
19 | # G = nx.from_numpy_matrix(adj)
20 | # largest_cc_node_list = list(max(nx.connected_components(G), key=len))
21 | # X_select = X[largest_cc_node_list, :]
22 | # adj_select = nx.to_numpy_matrix(G, nodelist=largest_cc_node_list)
23 | adj_select = adj
24 | X_select = X
25 | if with_test:
26 | cut_idx = int(adj_select.shape[0] * (1 - test_ratio))
27 |
28 | adj_train = adj_select[:cut_idx, :cut_idx]
29 | X_train = X_select[:cut_idx, :]
30 | return adj_train, X_train, adj_select, X_select
31 | else:
32 | return adj_select, X_select
33 |
34 | def read_ice_cream(with_test=False, permute=False, test_ratio=0.3):
35 | features = sp.load_npz('../data/feat_sparse.npz')
36 | adj_orig = sp.load_npz('../data/P_sparse.npz')
37 |
38 | adj_orig = adj_orig + sp.eye(adj_orig.shape[0])
39 | adj_orig = adj_orig > 0.8
40 |
41 | X_select = features.todense().astype(np.float32)
42 | adj_select = adj_orig.todense().astype(np.float32)
43 |
44 | if permute:
45 | x_idx = np.random.permutation(adj_select.shape[0])
46 | adj_select = adj_select[np.ix_(x_idx, x_idx)]
47 | X_select = X_select[x_idx, :]
48 |
49 | if with_test:
50 | cut_idx = int(adj_select.shape[0] * (1 - test_ratio))
51 |
52 | adj_train = adj_select[:cut_idx, :cut_idx]
53 | X_train = X_select[:cut_idx, :]
54 | return adj_train, X_train, adj_select, X_select
55 | else:
56 | return adj_select, X_select
57 |
58 | def bfs_seq(G, start_id):
59 | """
60 | get a bfs node sequence
61 | :param G:
62 | :param start_id:
63 | :return:
64 | """
65 | dictionary = dict(nx.bfs_successors(G, start_id))
66 | start = [start_id]
67 | output = [start_id]
68 | while len(start) > 0:
69 | next = []
70 | while len(start) > 0:
71 | current = start.pop(0)
72 | neighbor = dictionary.get(current)
73 | if neighbor is not None:
74 | next = next + neighbor
75 | output = output + next
76 | start = next
77 | return output
78 |
79 | def preprocess_graph_torch(adj):
80 | with torch.no_grad():
81 | rowsum = adj.sum(1)
82 | degree_mat_inv_sqrt = torch.diag(rowsum ** -.5)
83 | adj_normalized = adj.mm(degree_mat_inv_sqrt).t().mm(degree_mat_inv_sqrt)
84 | return adj_normalized
85 |
86 |
87 |
88 | def preprocess_graph(adj):
89 | # adj = sp.coo_matrix(adj)
90 | # adj_ = adj + sp.eye(adj.shape[0])
91 | rowsum = np.array(adj.sum(1))
92 | degree_mat_inv_sqrt = np.diag(np.power(rowsum, -0.5).flatten())
93 | adj_normalized = adj.dot(degree_mat_inv_sqrt).transpose().dot(degree_mat_inv_sqrt)
94 | return adj_normalized.astype(np.float32)
95 |
96 |
97 | class GraphSequenceBfsRandSampler(Dataset):
98 | def __init__(self, adj, X, num_permutation=10000, seed=None, fix=False):
99 |
100 | # self.adj = nx.to_numpy_matrix(G)
101 | self.adj = adj
102 | self.len = adj.shape[0]
103 | self.X = X
104 | self.num_permutation = num_permutation
105 | self.fix = fix
106 |
107 | def __len__(self):
108 | return self.num_permutation
109 |
110 | def __getitem__(self, idx):
111 | adj_copy = self.adj.copy()
112 | X_copy = self.X.copy()
113 |
114 | # initial permutation
115 | len_batch = adj_copy.shape[0]
116 | if not self.fix:
117 | x_idx = np.random.permutation(adj_copy.shape[0])
118 | adj_copy = adj_copy[np.ix_(x_idx, x_idx)]
119 | X_copy = X_copy[x_idx, :]
120 |
121 |
122 | adj_copy = adj_copy.astype(np.float32)
123 | X_copy = X_copy.astype(np.float32)
124 |
125 | return torch.from_numpy(adj_copy), torch.from_numpy(X_copy)
126 |
--------------------------------------------------------------------------------
/src/layers.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import torch
4 | import torch.nn.functional as F
5 | from torch.nn.parameter import Parameter
6 | from torch.nn.modules.module import Module
7 |
8 | from graph_data import preprocess_graph, preprocess_graph_torch
9 |
10 |
11 | class WeightTransitionLinearUnit(Module):
12 | def __init__(self, orig_dim, map_dim):
13 | super(WeightTransitionLinearUnit, self).__init__()
14 | self.linear_1 = torch.nn.Linear(map_dim, orig_dim)
15 | self.linear_2 = torch.nn.Linear(map_dim, map_dim)
16 | self.linear_3 = torch.nn.Linear(map_dim, map_dim)
17 | self.orig_dim = orig_dim
18 | self.map_dim = map_dim
19 |
20 | # self.bias = Parameter(torch.FloatTensor(map_dim))
21 | # self.weight = Parameter(torch.FloatTensor(orig_dim, map_dim))
22 | self.reset_parameters()
23 |
24 | def reset_parameters(self):
25 | # stdv = 1. / math.sqrt(self.weight.size(1))
26 | # torch.nn.init.xavier_normal_(self.weight)
27 | # self.bias.data.uniform_(-stdv, stdv)
28 | pass
29 |
30 | def forward(self, last_w, z_cov):
31 | z_cov = (z_cov - z_cov.mean())/ z_cov.std()
32 | hidden = F.relu(self.linear_1(z_cov).t())
33 | w_update = self.linear_2(hidden)
34 | update_gate = torch.sigmoid(self.linear_3(hidden))
35 | # w_update = torch.mm(self.weight, z_cov) + self.bias
36 | #w_update = (w_update - w_update.min()) / w_update.max() * 0.001
37 |
38 | # update_gate = torch.clamp(update_gate, min= 0, max= 1)
39 | #print(bound)
40 | w_update = torch.clamp(w_update, min= -0.1, max= 0.1)
41 |
42 | # print(last_w)
43 | # print(w_update)
44 | return (1 - update_gate) * last_w + w_update * update_gate
45 |
46 |
47 |
48 | class RecursiveGraphConvolutionalNetwork(Module):
49 |
50 |
51 | def __init__(self, in_features, hidden_dim, out_features, bias=True, dropout=0.3):
52 | super(RecursiveGraphConvolutionalNetwork, self).__init__()
53 |
54 | self.dropout = dropout
55 |
56 | self.init_hidden_weight = Parameter(torch.FloatTensor(in_features, hidden_dim))
57 | self.init_hidden_bias = Parameter(torch.FloatTensor(hidden_dim))
58 |
59 | self.init_mean_weight = Parameter(torch.FloatTensor(hidden_dim, out_features))
60 | self.init_mean_bias = Parameter(torch.FloatTensor(out_features))
61 |
62 | self.init_log_std_weight = Parameter(torch.FloatTensor(hidden_dim, out_features))
63 | self.init_log_std_bias = Parameter(torch.FloatTensor(out_features))
64 |
65 | self.hidden_w_transition = WeightTransitionLinearUnit(in_features, hidden_dim)
66 | self.mean_w_transition = WeightTransitionLinearUnit(hidden_dim, out_features)
67 | self.log_std_w_transition = WeightTransitionLinearUnit(hidden_dim, out_features)
68 |
69 | self.reset_parameters()
70 |
71 | def init_all_weights(self):
72 | self.hidden_weight = self.init_hidden_weight + 0.0
73 | self.hidden_bias = self.init_hidden_bias + 0.0
74 |
75 | self.mean_weight = self.init_mean_weight + 0.0
76 | self.mean_bias = self.init_mean_bias + 0.0
77 |
78 | self.log_std_weight = self.init_log_std_weight + 0.0
79 | self.log_std_bias = self.init_log_std_bias + 0.0
80 |
81 |
82 | def convo_ops(self, input, adj):
83 | input = F.dropout(input, self.dropout, training=self.training)
84 | support = torch.mm(input, self.hidden_weight)
85 | hidden = F.relu(torch.spmm(adj, support) + self.hidden_bias)
86 |
87 | hidden = F.dropout(hidden, self.dropout, training=self.training)
88 | support_mean = torch.mm(hidden, self.mean_weight)
89 | mean = torch.spmm(adj, support_mean)
90 | support_std = torch.mm(hidden, self.log_std_weight)
91 | log_std = torch.spmm(adj, support_std)
92 |
93 | return mean, log_std
94 |
95 | def weight_transition(self, last_z, current_z):
96 | # compute the 'covariance' matrix for the difference of z
97 | z_diff = last_z - current_z
98 | z_cov = torch.mm(torch.t(z_diff), z_diff)
99 | # self.hidden_weight = self.hidden_w_transition(self.hidden_weight, z_cov) * 1.0
100 | self.mean_weight = self.mean_w_transition(self.mean_weight, z_cov) * 1.0
101 | self.log_std_weight = self.log_std_w_transition(self.log_std_weight, z_cov) * 1.0
102 |
103 |
104 | def forward(self, adj, input, update_size, input_new=None):
105 | # print(adj.size(0))
106 | # print(update_size)
107 | if adj.size(0) < update_size:
108 | raise ValueError('adj must be no less than update size!')
109 |
110 | self.init_all_weights()
111 |
112 | normal = torch.distributions.Normal(0, 1)
113 |
114 | # print(input.size())
115 |
116 | adj_h = torch.eye(update_size)
117 | last_z_mean, last_z_log_std = self.convo_ops(input[:update_size], adj_h)
118 |
119 | num_step = int(math.ceil(adj.size()[0] / update_size))
120 | # adj_frame = torch.eye(adj.size(0))
121 |
122 | z_prior, z_post = [], []
123 | z_prior.append((last_z_mean, last_z_log_std))
124 | for step in range(num_step - 1):
125 | start_idx = step * update_size
126 | end_idx = min(adj.size()[0], start_idx + update_size)
127 | adj_feed_norm = preprocess_graph_torch(adj[:end_idx, :end_idx])
128 | curr_z_mean, curr_z_log_std = self.convo_ops(input[:end_idx], adj_feed_norm)
129 |
130 | # cache the z for loss computation later
131 | z_post.append((curr_z_mean, curr_z_log_std))
132 |
133 | # sample z to approximiate the posterior
134 | current_z = curr_z_mean + normal.sample(curr_z_mean.size()) * torch.exp(curr_z_log_std)
135 | last_z = last_z_mean + normal.sample(last_z_mean.size()) * torch.exp(last_z_log_std)
136 |
137 | # update w_(t-1) to w_(t) based on difference between z_(t-1) and z_(t)
138 |
139 | self.weight_transition(last_z, current_z)
140 |
141 | # update the hypothetic z based on adj_h
142 | adj_ext = torch.eye(adj_feed_norm.size(0) + update_size)
143 | adj_ext[:end_idx, :end_idx] = adj_feed_norm
144 | next_end_idx = end_idx + update_size
145 | last_z_mean, last_z_log_std = self.convo_ops(input[:next_end_idx], adj_ext)
146 | z_prior.append((last_z_mean, last_z_log_std))
147 |
148 | if self.training:
149 | return z_prior, z_post
150 | else:
151 | adj_h_all = torch.eye(input.size(0) + input_new.size(0))
152 | adj_h_all[:adj.size(0), :adj.size(0)] = adj
153 | input_all = torch.cat((input, input_new))
154 | z_out_mean, z_out_log_std = self.convo_ops(input_all, adj_h_all)
155 | return z_out_mean, z_out_log_std
156 |
157 |
158 |
159 |
160 | def reset_parameters(self):
161 | stdv = 1. / math.sqrt(self.init_hidden_weight.size(1))
162 | self.init_hidden_weight.data.uniform_(-stdv, stdv)
163 | self.init_hidden_bias.data.uniform_(-stdv, stdv)
164 | self.init_mean_weight.data.uniform_(-stdv, stdv)
165 | self.init_mean_bias.data.uniform_(-stdv, stdv)
166 | self.init_log_std_weight.data.uniform_(-stdv, stdv)
167 | self.init_log_std_bias.data.uniform_(-stdv, stdv)
168 |
169 |
170 |
171 | def __repr__(self):
172 | return self.__class__.__name__ + ' (' \
173 | + str(self.in_features) + ' -> ' \
174 | + str(self.out_features) + ')'
175 |
176 |
177 | class GraphConvolution(Module):
178 | """
179 | Simple GCN layer, similar to https://arxiv.org/abs/1609.02907
180 | """
181 |
182 | def __init__(self, in_features, out_features, bias=False, act=lambda x: x, dropout=0.0):
183 | super(GraphConvolution, self).__init__()
184 | self.in_features = in_features
185 | self.out_features = out_features
186 | self.dropout = dropout
187 | self.act = act
188 | self.weight = Parameter(torch.FloatTensor(in_features, out_features))
189 | if bias:
190 | self.bias = Parameter(torch.FloatTensor(out_features))
191 | else:
192 | self.register_parameter('bias', None)
193 | self.reset_parameters()
194 |
195 | def reset_parameters(self):
196 | stdv = 1. / math.sqrt(self.weight.size(1))
197 | #self.weight.data.uniform_(-stdv, stdv)
198 | torch.nn.init.xavier_uniform_(self.weight)
199 | if self.bias is not None:
200 | self.bias.data.uniform_(-stdv, stdv)
201 |
202 | def forward(self, input, adj):
203 | input = F.dropout(input, self.dropout, training = self.training)
204 | support = torch.mm(input, self.weight)
205 | output = torch.spmm(adj, support)
206 | if self.bias is not None:
207 | output = output + self.bias
208 | return self.act(output)
209 |
210 | def __repr__(self):
211 | return self.__class__.__name__ + ' (' \
212 | + str(self.in_features) + ' -> ' \
213 | + str(self.out_features) + ')'
214 |
215 |
216 | class GraphVae(Module):
217 | def __init__(self, features_dim, hidden_dim, out_dim, bias=False, dropout=0.3):
218 | super(GraphVae, self).__init__()
219 | self.features_dim = features_dim
220 | self.out_dim = out_dim
221 | self.dropout = dropout
222 |
223 | self.gc1 = GraphConvolution(features_dim, hidden_dim, bias=bias, dropout=dropout, act=F.relu)
224 | self.gc_mean = GraphConvolution(hidden_dim, out_dim, bias=bias, dropout=dropout)
225 | self.gc_log_std = GraphConvolution(hidden_dim, out_dim, bias=bias, dropout=dropout)
226 |
227 | def forward(self, adj, input):
228 | hidden = self.gc1(input, adj)
229 | z_mean = self.gc_mean(hidden, adj)
230 | z_log_std = self.gc_log_std(hidden, adj)
231 | return z_mean, z_log_std
232 |
233 | def __repr__(self):
234 | return self.__class__.__name__ + ' (' \
235 | + str(self.in_features) + ' -> ' \
236 | + str(self.out_features) + ')'
237 |
238 | class GraphAE(Module):
239 | def __init__(self, features_dim, hidden_dim, out_dim, bias=False, dropout=0.3):
240 | super(GraphAE, self).__init__()
241 | self.features_dim = features_dim
242 | self.out_dim = out_dim
243 | self.dropout = dropout
244 |
245 | self.gc1 = GraphConvolution(features_dim, hidden_dim, bias=bias, dropout=dropout, act=F.relu)
246 | self.gc_z = GraphConvolution(hidden_dim, out_dim, bias=bias, dropout=dropout)
247 |
248 | def forward(self, adj, input):
249 | hidden = self.gc1(input, adj)
250 | z = self.gc_z(hidden, adj)
251 | return z
252 |
253 | def __repr__(self):
254 | return self.__class__.__name__ + ' (' \
255 | + str(self.in_features) + ' -> ' \
256 | + str(self.out_features) + ')'
257 |
258 | class MLP(Module):
259 | def __init__(self, features_dim, hidden_dim, out_dim, bias=True, dropout=0.3):
260 | super(MLP, self).__init__()
261 | self.features_dim = features_dim
262 | self.out_dim = out_dim
263 | self.dropout = dropout
264 |
265 | self.linear = torch.nn.Linear(features_dim, hidden_dim)
266 | self.z_mean = torch.nn.Linear(hidden_dim, out_dim)
267 | self.z_log_std = torch.nn.Linear(hidden_dim, out_dim)
268 |
269 | def forward(self, input):
270 | hidden = F.relu(self.linear(input))
271 | z_mean = F.dropout(self.z_mean(hidden), self.dropout, training=self.training)
272 | z_log_std = F.dropout(self.z_log_std(hidden), self.dropout, training=self.training)
273 | return z_mean, z_log_std
274 |
275 | def __repr__(self):
276 | return self.__class__.__name__ + ' (' \
277 | + str(self.in_features) + ' -> ' \
278 | + str(self.out_features) + ')'
279 |
280 | class RecursiveGraphConvolutionStep(Module):
281 | """
282 | Given A, X and X_new, sample Z_new from P(Z_new|A_hat, X_new)
283 | """
284 |
285 | def __init__(self, features_dim, hidden_dim, out_dim, bias=True, dropout=0.3):
286 | super(RecursiveGraphConvolutionStep, self).__init__()
287 | self.features_dim = features_dim
288 | self.out_dim = out_dim
289 | self.dropout = dropout
290 |
291 | self.gc1 = GraphConvolution(features_dim, hidden_dim, dropout=dropout, act=F.relu)
292 | self.gc_mean = GraphConvolution(hidden_dim, out_dim, dropout=dropout)
293 | self.gc_log_std = GraphConvolution(hidden_dim, out_dim, dropout=dropout)
294 |
295 | def forward(self, adj, input, input_new=None):
296 | hidden_old = self.gc1(input, adj)
297 | z_mean_old = self.gc_mean(hidden_old, adj)
298 | z_log_std_old = self.gc_log_std(hidden_old, adj)
299 |
300 | if input_new is not None:
301 | adj_new = torch.eye(input_new.size()[0])
302 | hidden_new = self.gc1(input_new, adj_new)
303 | z_mean_new = self.gc_mean(hidden_new, adj_new)
304 | z_log_std_new = self.gc_log_std(hidden_new, adj_new)
305 | return z_mean_old, z_log_std_old, z_mean_new, z_log_std_new
306 | else:
307 | return z_mean_old, z_log_std_old
308 |
309 | def __repr__(self):
310 | return self.__class__.__name__ + ' (' \
311 | + str(self.in_features) + ' -> ' \
312 | + str(self.out_features) + ')'
313 |
314 |
315 | class RecursiveGraphConvolutionStepAddOn(Module):
316 | """
317 | Given A, X and X_new, sample Z_new from P(Z_new|A_hat, X_new)
318 | """
319 |
320 | def __init__(self, features_dim, hidden_dim, out_dim, random_add=False, bias=False, dropout=0.3):
321 | super(RecursiveGraphConvolutionStepAddOn, self).__init__()
322 | self.features_dim = features_dim
323 | self.out_dim = out_dim
324 | self.dropout = dropout
325 | self.random_add = random_add
326 |
327 | self.gc1 = GraphConvolution(features_dim, hidden_dim, bias=bias, dropout=dropout, act=F.relu)
328 | self.gc_mean = GraphConvolution(hidden_dim, out_dim, bias=bias, dropout=dropout)
329 | self.gc_log_std = GraphConvolution(hidden_dim, out_dim, bias=bias, dropout=dropout)
330 |
331 | def forward(self, adj, input, input_new=None):
332 | if input_new is None:
333 | adj = adj.numpy()
334 | adj_norm = torch.from_numpy(preprocess_graph(adj))
335 | hidden = self.gc1(input, adj_norm)
336 | z_mean =self.gc_mean(hidden, adj_norm)
337 | z_log_std = self.gc_log_std(hidden, adj_norm)
338 | return z_mean, z_log_std
339 |
340 | num_total_nodes = input_new.size()[0] + input.size()[0]
341 | if self.training:
342 | with torch.no_grad():
343 | adj_new = torch.zeros(num_total_nodes, num_total_nodes)
344 | if self.random_add:
345 | num_edges = float(((adj > 0).sum() - adj.shape[0]) / 2)
346 | p0 = num_edges / (num_total_nodes ** 2)
347 | adj_new.bernoulli_(p0)
348 | adj_new = adj_new - adj_new.tril()
349 | adj_new = ((adj_new + adj_new.t()) > 0).float()
350 |
351 | adj_new += torch.eye(num_total_nodes)
352 | adj_new[:input.size()[0], :input.size()[0]] = adj
353 | adj_new = adj_new.numpy()
354 | adj_norm = torch.from_numpy(preprocess_graph(adj_new))
355 | input_all = torch.cat((input, input_new))
356 |
357 | hidden = self.gc1(input_all, adj_norm)
358 | z_mean = self.gc_mean(hidden, adj_norm)
359 | z_log_std = self.gc_log_std(hidden, adj_norm)
360 |
361 | z_mean_old = z_mean[:input.size()[0], :]
362 | z_log_std_old = z_log_std[:input.size()[0], :]
363 |
364 | z_mean_new = z_mean[input.size()[0]:, :]
365 | z_log_std_new = z_log_std[input.size()[0]:, :]
366 |
367 | return z_mean_old, z_log_std_old, z_mean_new, z_log_std_new
368 |
369 | else:
370 | adj = adj.numpy()
371 | adj_norm = torch.from_numpy(preprocess_graph(adj))
372 | hidden_old = self.gc1(input, adj_norm)
373 | z_mean_old = self.gc_mean(hidden_old, adj_norm)
374 | z_log_std_old = self.gc_log_std(hidden_old, adj_norm)
375 |
376 | adj_new = torch.eye(input_new.size()[0])
377 | hidden_new = self.gc1(input_new, adj_new)
378 | z_mean_new = self.gc_mean(hidden_new, adj_new)
379 | z_log_std_new = self.gc_log_std(hidden_new, adj_new)
380 | return z_mean_old, z_log_std_old, z_mean_new, z_log_std_new
381 |
382 | def __repr__(self):
383 | return self.__class__.__name__ + ' (' \
384 | + str(self.in_features) + ' -> ' \
385 | + str(self.out_features) + ')'
386 |
387 |
388 | class GraphFuse(Module):
389 | def __init__(self, features_dim, hidden_dim, out_dim, bias=True, dropout=0.3):
390 | super(GraphFuse, self).__init__()
391 | self.features_dim = features_dim
392 | self.out_dim = out_dim
393 | self.dropout = dropout
394 |
395 | self.dropout = dropout
396 |
397 | self.mixture_weight = Parameter(torch.FloatTensor(1))
398 |
399 | self.hidden_weight = Parameter(torch.FloatTensor(features_dim, hidden_dim))
400 | self.hidden_bias = Parameter(torch.FloatTensor(hidden_dim))
401 |
402 | self.mean_weight = Parameter(torch.FloatTensor(hidden_dim, out_dim))
403 | self.mean_bias = Parameter(torch.FloatTensor(out_dim))
404 |
405 | self.log_std_weight = Parameter(torch.FloatTensor(hidden_dim, out_dim))
406 | self.log_std_bias = Parameter(torch.FloatTensor(out_dim))
407 |
408 | self.reset_parameters()
409 |
410 | def reset_parameters(self):
411 | stdv = 1. / math.sqrt(self.hidden_weight.size(1))
412 | self.mixture_weight.data.uniform_(-stdv, stdv)
413 | self.hidden_weight.data.uniform_(-stdv, stdv)
414 | self.hidden_bias.data.uniform_(-stdv, stdv)
415 | self.mean_weight.data.uniform_(-stdv, stdv)
416 | self.mean_bias.data.uniform_(-stdv, stdv)
417 | self.log_std_weight.data.uniform_(-stdv, stdv)
418 | self.log_std_bias.data.uniform_(-stdv, stdv)
419 |
420 |
421 | def convo_ops(self, input, adj):
422 | input = F.dropout(input, self.dropout, training=self.training)
423 | support = torch.mm(input, self.hidden_weight)
424 | hidden = F.relu(torch.spmm(adj, support) + self.hidden_bias)
425 |
426 | hidden = F.dropout(hidden, self.dropout, training=self.training)
427 | support_mean = torch.mm(hidden, self.mean_weight)
428 | mean = torch.spmm(adj, support_mean)
429 | support_std = torch.mm(hidden, self.log_std_weight)
430 | log_std = torch.spmm(adj, support_std)
431 |
432 | return mean, log_std
433 |
434 | def mlp_ops(self, input):
435 | input = F.dropout(input, self.dropout, training=self.training)
436 | hidden = torch.mm(input, self.hidden_weight)
437 | hidden = F.relu(hidden + self.hidden_bias)
438 | hidden = F.dropout(hidden, self.dropout, training=self.training)
439 |
440 | mean = torch.mm(hidden, self.mean_weight) + self.mean_bias
441 | log_std = torch.mm(hidden, self.log_std_weight) + self.log_std_bias
442 |
443 |
444 | return mean, log_std
445 |
446 | def forward(self, input, adj=None):
447 | mixture_ratio = torch.sigmoid(self.mixture_weight)
448 | if adj is None:
449 | return self.mlp_ops(input)
450 | else:
451 | z_mean_gcn, z_log_std_gcn = self.convo_ops(input, adj)
452 | z_mean_mlp, z_log_std_mlp = self.mlp_ops(input)
453 | z_mean = z_mean_gcn * self.mixture_weight + z_mean_mlp * (1 - self.mixture_weight)
454 | z_log_std = z_log_std_gcn * mixture_ratio + z_log_std_mlp * (1 - mixture_ratio)
455 | return z_mean, z_log_std
456 |
457 | def __repr__(self):
458 | return self.__class__.__name__ + ' (' \
459 | + str(self.in_features) + ' -> ' \
460 | + str(self.out_features) + ')'
461 |
462 |
463 | class GraphFuseSimple(Module):
464 | def __init__(self, n_nodes, features_dim, hidden_dim, out_dim, bias=True, dropout=0.3):
465 | super(GraphFuseV2, self).__init__()
466 | self.features_dim = features_dim
467 | self.out_dim = out_dim
468 | self.dropout = dropout
469 |
470 | self.dropout = dropout
471 |
472 | self.mixture_weight = Parameter(torch.FloatTensor(1))
473 |
474 | self.hidden_weight = Parameter(torch.FloatTensor(features_dim, hidden_dim))
475 | self.hidden_bias = Parameter(torch.FloatTensor(hidden_dim))
476 |
477 | self.gcn_hidden_weight = Parameter(torch.FloatTensor(n_nodes, hidden_dim))
478 | self.gcn_hidden_bias = Parameter(torch.FloatTensor(hidden_dim))
479 |
480 | self.mean_weight = Parameter(torch.FloatTensor(hidden_dim, out_dim))
481 | self.mean_bias = Parameter(torch.FloatTensor(out_dim))
482 |
483 | self.log_std_weight = Parameter(torch.FloatTensor(hidden_dim, out_dim))
484 | self.log_std_bias = Parameter(torch.FloatTensor(out_dim))
485 |
486 | self.reset_parameters()
487 |
488 | def reset_parameters(self):
489 | stdv = 1. / math.sqrt(self.hidden_weight.size(1))
490 | self.mixture_weight.data.uniform_(-stdv, stdv)
491 | self.gcn_hidden_weight.data.uniform_(-stdv, stdv)
492 | self.gcn_hidden_bias.data.uniform_(-stdv, stdv)
493 | self.hidden_weight.data.uniform_(-stdv, stdv)
494 | self.hidden_bias.data.uniform_(-stdv, stdv)
495 | self.mean_weight.data.uniform_(-stdv, stdv)
496 | self.mean_bias.data.uniform_(-stdv, stdv)
497 | self.log_std_weight.data.uniform_(-stdv, stdv)
498 | self.log_std_bias.data.uniform_(-stdv, stdv)
499 |
500 |
501 | def convo_ops(self, input, adj):
502 | input = F.dropout(input, self.dropout, training=self.training)
503 | support = torch.mm(input, self.gcn_hidden_weight)
504 | hidden = F.relu(torch.spmm(adj, support) + self.gcn_hidden_bias)
505 |
506 | hidden = F.dropout(hidden, self.dropout, training=self.training)
507 | support_mean = torch.mm(hidden, self.mean_weight)
508 | mean = torch.spmm(adj, support_mean)
509 | support_std = torch.mm(hidden, self.log_std_weight)
510 | log_std = torch.spmm(adj, support_std)
511 |
512 | return mean, log_std
513 |
514 | def mlp_ops(self, input):
515 | input = F.dropout(input, self.dropout, training=self.training)
516 | hidden = torch.mm(input, self.hidden_weight)
517 | hidden = F.relu(hidden + self.hidden_bias)
518 | hidden = F.dropout(hidden, self.dropout, training=self.training)
519 |
520 | mean = torch.mm(hidden, self.mean_weight) + self.mean_bias
521 | log_std = torch.mm(hidden, self.log_std_weight) + self.log_std_bias
522 |
523 | return mean, log_std
524 |
525 | def forward(self, input, adj=None):
526 | mixture_ratio = torch.sigmoid(self.mixture_weight)
527 | if adj is None:
528 | return self.mlp_ops(input)
529 | else:
530 | gcn_input = torch.eye(adj.size(0))
531 | z_mean_gcn, z_log_std_gcn = self.convo_ops(gcn_input, adj)
532 | z_mean_mlp, z_log_std_mlp = self.mlp_ops(input)
533 | z_mean = z_mean_gcn * self.mixture_weight + z_mean_mlp * (1 - self.mixture_weight)
534 | z_log_std = z_log_std_gcn * mixture_ratio + z_log_std_mlp * (1 - mixture_ratio)
535 | return z_mean, z_log_std
536 |
537 | def __repr__(self):
538 | return self.__class__.__name__ + ' (' \
539 | + str(self.in_features) + ' -> ' \
540 | + str(self.out_features) + ')'
541 |
--------------------------------------------------------------------------------
/src/loss.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function, division
2 | import math
3 | import numpy as np
4 | import torch
5 | from graph_data import preprocess_graph
6 | from utils import sample_reconstruction
7 |
8 | def weighted_cross_entropy_with_logits(logits, targets, pos_weight):
9 | """
10 | see: https://www.tensorflow.org/api_docs/python/tf/nn/weighted_cross_entropy_with_logits
11 | """
12 | #logits = torch.clamp(logits, min=-10, max=10)
13 |
14 | x = logits
15 | z = targets
16 | l = 1 + (pos_weight - 1) * targets
17 |
18 | loss = (1 - z) * x + l * (torch.log(1 + torch.exp(-torch.abs(x))) + torch.clamp(-x, min=0))
19 | return loss
20 |
21 | # return targets * -torch.log(torch.sigmoid(logits)) * pos_weight + (1 - targets) * -torch.log(1 - torch.sigmoid(logits))
22 |
23 | def KL_normal(z_mean_1, z_std_1, z_mean_2, z_std_2):
24 |
25 | kl = torch.log(z_std_2 / z_std_1) + ((z_std_1 ** 2) + (z_mean_1 - z_mean_2) ** 2) / (2 * z_std_2 ** 2) - 0.5
26 | return kl.sum(1).mean()
27 | #return torch.mean(kl)
28 |
29 | def reconstruction_loss(adj, adj_h,mask=None, test=False, fixed_norm=None):
30 | if not test:
31 | norm = adj.shape[0] ** 2 / float((adj.shape[0] ** 2 - adj.sum()) * 2)
32 | pos_weight = float(adj.shape[0] * adj.shape[0] - adj.sum()) / adj.sum()
33 | else:
34 | norm = 1.0
35 | pos_weight = 1.0
36 |
37 | if fixed_norm is not None:
38 | norm = fixed_norm
39 | pos_weight = 1.0
40 |
41 |
42 | element_loss = weighted_cross_entropy_with_logits(adj_h, adj, pos_weight)
43 | if mask is not None:
44 | element_loss = element_loss[mask]
45 | neg_log_lik = norm * torch.mean(element_loss)
46 | return neg_log_lik
47 |
48 | def vae_loss(z_mean, z_log_std, adj, fixed_norm=None):
49 | adj_h = sample_reconstruction(z_mean, z_log_std)
50 | neg_log_lik = reconstruction_loss(adj, adj_h, fixed_norm=fixed_norm)
51 | z_std = torch.exp(z_log_std)
52 | kl = KL_normal(z_mean, z_std, 0.0, 1.0)
53 | # kl = torch.mean(torch.log(1 / z_std) + (z_std ** 2 + (z_mean - 0) ** 2) * 0.5)
54 | return neg_log_lik + kl / z_mean.size()[0]
55 |
56 |
57 | def r_vae_loss(z_mean_old, z_log_std_old, z_mean_new, z_log_std_new, adj, fixed_norm=None):
58 |
59 | z_mean = torch.cat((z_mean_old, z_mean_new))
60 | z_log_std = torch.cat((z_log_std_old, z_log_std_new))
61 |
62 | adj_h = sample_reconstruction(z_mean, z_log_std)
63 | loss = reconstruction_loss(adj, adj_h, fixed_norm=fixed_norm)
64 | z_std_new = torch.exp(z_log_std_new)
65 | kl = KL_normal(z_mean_new, z_std_new, 0.0, 1.0)
66 | # kl = torch.mean(torch.log(1 / z_std_new) + (z_std_new ** 2 + (z_mean_new - 0) ** 2) * 0.5)
67 | loss += kl * (z_mean_new.size()[0] / z_mean.size()[0] ** 2)
68 |
69 | return loss
70 |
71 | def r_vae_loss_addon(last_z_mean, last_z_log_std, z_mean_old, z_log_std_old, z_mean_new, z_log_std_new, adj, fixed_norm=None):
72 |
73 | z_mean = torch.cat((z_mean_old, z_mean_new))
74 | z_log_std = torch.cat((z_log_std_old, z_log_std_new))
75 |
76 | adj_h = sample_reconstruction(z_mean, z_log_std)
77 | loss = reconstruction_loss(adj, adj_h, fixed_norm=fixed_norm)
78 |
79 | last_z_std = torch.exp(last_z_log_std)
80 | z_std_old = torch.exp(z_log_std_old)
81 |
82 |
83 | kl_last = KL_normal(z_mean_old, z_std_old, last_z_mean, last_z_std)
84 | #kl_last = torch.mean(torch.log(last_z_std / z_std_old) + (z_std_old ** 2 + (z_mean_old - last_z_mean) ** 2) * 0.5)
85 | kl_last *= (z_mean_old.size()[0] / z_mean.size()[0] ** 2)
86 |
87 | z_std_new = torch.exp(z_log_std_new)
88 | kl_new = KL_normal(z_mean_new, z_std_new, 0.0, 1.0)
89 | #kl_new = torch.mean(torch.log(1 / z_std_new) + (z_std_new ** 2 + (z_mean_new - 0) ** 2) * 0.5)
90 | kl_new *= (1.0 / z_mean.size()[0] ** 2)
91 | loss += kl_last
92 | loss += kl_new
93 | return loss
94 |
95 | def recursive_loss(gcn_step, adj, feat, size_update, fixed_norm=1.2):
96 | num_step = int(math.ceil(1.0 * adj.size()[0] / size_update))
97 |
98 | # print("num step: {}".format(num_step))
99 |
100 | loss = torch.tensor([0.0])
101 | for step in range(num_step):
102 |
103 | if step == 0:
104 | adj_feed = torch.eye(size_update)
105 | feat_feed = feat[:size_update, :]
106 | z_mean, z_log_std = gcn_step(adj_feed, feat_feed)
107 |
108 | adj_truth = adj[0:size_update, 0:size_update]
109 | loss += vae_loss(z_mean, z_log_std, adj_truth, fixed_norm=fixed_norm)
110 |
111 | continue
112 |
113 | start_idx = step * size_update
114 | end_idx = min(adj.size()[0], start_idx + size_update)
115 | adj_feed = adj[:start_idx, :start_idx].numpy()
116 | adj_feed_norm = preprocess_graph(adj_feed)
117 | adj_feed_norm = torch.from_numpy(adj_feed_norm)
118 |
119 | feat_feed = feat[:start_idx, :]
120 | fead_new = feat[start_idx:end_idx, :]
121 | z_mean_old, z_log_std_old, z_mean_new, z_log_std_new = gcn_step(adj_feed_norm, feat_feed, fead_new)
122 | adj_truth = adj[:end_idx, :end_idx]
123 | curr_loss = r_vae_loss(z_mean_old, z_log_std_old, z_mean_new, z_log_std_new, adj_truth, fixed_norm=fixed_norm)
124 |
125 | loss += curr_loss * end_idx ** 2
126 | return loss / num_step
127 |
128 |
129 | def recursive_loss_with_noise(gcn_step, adj, feat, size_update, fixed_norm=1.2):
130 | num_step = int(math.ceil(1.0 * adj.size()[0] / size_update))
131 |
132 | #print("num step: {}".format(num_step))
133 |
134 | last_z_mean = None
135 | last_z_log_std = None
136 | for step in range(num_step):
137 |
138 | if step == 0:
139 | #adj_feed = torch.eye(size_update)
140 | adj_feed = adj[:size_update, :size_update]
141 | feat_feed = feat[:size_update, :]
142 | z_mean, z_log_std = gcn_step(adj_feed, feat_feed)
143 |
144 | adj_truth = adj[0:size_update, 0:size_update]
145 | loss = vae_loss(z_mean, z_log_std, adj_truth, fixed_norm=fixed_norm)
146 | last_z_mean, last_z_log_std = z_mean, z_log_std
147 | continue
148 |
149 | start_idx = step * size_update
150 | end_idx = min(adj.size()[0], start_idx + size_update)
151 | adj_feed = adj[:start_idx, :start_idx]
152 |
153 | feat_feed = feat[:start_idx, :]
154 | fead_new = feat[start_idx:end_idx, :]
155 | z_mean_old, z_log_std_old, z_mean_new, z_log_std_new = gcn_step(adj_feed, feat_feed, fead_new)
156 | adj_truth = adj[:end_idx, :end_idx]
157 |
158 | curr_loss = r_vae_loss_addon(last_z_mean, last_z_log_std, z_mean_old, z_log_std_old, z_mean_new, z_log_std_new, adj_truth, fixed_norm=fixed_norm)
159 | loss += curr_loss * end_idx ** 2
160 |
161 | # update hidden latent spaces
162 | #last_z_mean = torch.cat((z_mean_old, z_mean_new))
163 | # last_z_log_std = torch.cat((z_log_std_old, z_log_std_new))
164 | adj_feed = adj[:end_idx, :end_idx]
165 | feat_feed = feat[:end_idx, :]
166 | last_z_mean, last_z_log_std = gcn_step(adj_feed, feat_feed)
167 |
168 |
169 | return loss / num_step
170 |
171 |
172 | def recursive_loss_with_noise_supervised(gcn_step, adj, label, feat, size_update, fixed_norm=1.2):
173 | num_step = int(math.ceil(adj.size()[0] / size_update))
174 |
175 | # print("num step: {}".format(num_step))
176 |
177 | last_z_mean = None
178 | last_z_log_std = None
179 | for step in range(num_step - 1):
180 |
181 | if step == 0:
182 | adj_feed = torch.eye(size_update)
183 | feat_feed = feat[:size_update, :]
184 | z_mean, z_log_std = gcn_step(adj_feed, feat_feed)
185 |
186 | label_truth = label[0:size_update, 0:size_update]
187 | loss = vae_loss(z_mean, z_log_std, label_truth, fixed_norm=fixed_norm)
188 | last_z_mean, last_z_log_std = z_mean, z_log_std
189 | continue
190 |
191 | start_idx = step * size_update
192 | end_idx = min(adj.size()[0], start_idx + size_update)
193 | adj_feed = adj[:start_idx, :start_idx]
194 |
195 | feat_feed = feat[:start_idx, :]
196 | fead_new = feat[start_idx:end_idx, :]
197 | z_mean_old, z_log_std_old, z_mean_new, z_log_std_new = gcn_step(adj_feed, feat_feed, fead_new)
198 | label_truth = label[:end_idx, :end_idx]
199 |
200 | curr_loss = r_vae_loss_addon(last_z_mean, last_z_log_std, z_mean_old, z_log_std_old, z_mean_new, z_log_std_new, label_truth, fixed_norm=fixed_norm)
201 | loss += curr_loss * (end_idx + 1) ** 2
202 |
203 | # update hidden latent spaces
204 | last_z_mean = torch.cat((z_mean_old, z_mean_new))
205 | last_z_log_std = torch.cat((z_log_std_old, z_log_std_new))
206 | return loss / num_step
207 |
208 | def rgcn_loss(z_prior, z_post, adj):
209 | kl = torch.tensor(0.0)
210 | num_const1 = 0
211 | num_const2 = 0
212 | for i in range(len(z_post)):
213 | num_const2 += z_post[i][0].size(0) ** 2
214 | num_const1 += z_post[i][0].size(0)
215 | for i in range(len(z_post)):
216 | z_post_mean = z_post[i][0]
217 | z_post_std = torch.exp(z_post[i][1])
218 | z_prior_mean = z_prior[i][0]
219 | z_prior_std = torch.exp(z_prior[i][1])
220 | curr_kl = KL_normal(z_post_mean, z_post_std, z_prior_mean, z_prior_std) * z_post_mean.size(0) / num_const1
221 | #print(curr_kl)
222 | kl += curr_kl
223 | neg_log_lik = torch.tensor(0.0)
224 | for z_mean, z_log_std in z_post:
225 | adj_h = sample_reconstruction(z_mean, z_log_std)
226 | idx = z_mean.size(0)
227 | neg_log_lik += reconstruction_loss(adj[:idx, :idx], adj_h) * idx ** 2 / num_const2
228 | print(neg_log_lik)
229 | print(kl)
230 | return neg_log_lik + kl
231 |
232 | def rgcn_loss1(z_prior, z_post, adj):
233 | kl = torch.tensor(0.0)
234 | num_const = 0
235 | for i in range(len(z_post)):
236 | num_const += z_post[i][0].size(0) ** 2
237 | for i in range(len(z_post)):
238 | z_post_mean = z_post[i][0]
239 | z_post_std = torch.exp(z_post[i][1])
240 | z_prior_mean = z_prior[i][0]
241 | z_prior_std = torch.exp(z_prior[i][1])
242 | curr_kl = KL_normal(z_post_mean, z_post_std, z_prior_mean, z_prior_std) * z_post_mean.size(0) / num_const
243 | #print(curr_kl)
244 | kl += curr_kl
245 | neg_log_lik = torch.tensor(0.0)
246 | z_mean = z_post[-1][0]
247 | z_log_std = z_post[-1][1]
248 | adj_h = sample_reconstruction(z_mean, z_log_std)
249 | print(adj_h.shape)
250 | print(adj.shape)
251 | neg_log_lik = reconstruction_loss(adj[:adj_h.shape[0],:adj_h.shape[0]], adj_h)
252 | print(neg_log_lik)
253 | print(kl)
254 | return neg_log_lik + kl
255 |
--------------------------------------------------------------------------------
/src/models.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/StatsDLMathsRecomSys/Generative-Graph-Convolutional-Network-for-Growing-Graphs/5ba45db6726584cd3c547f8e2429bbea10ea90cf/src/models.py
--------------------------------------------------------------------------------
/src/run_exist_nodes.py:
--------------------------------------------------------------------------------
1 | import math
2 | import argparse
3 | import networkx as nx
4 | import numpy as np
5 | import scipy.sparse as sp
6 | import torch
7 | import pickle as pkl
8 |
9 | from sklearn.metrics import roc_auc_score
10 | from sklearn.metrics import average_precision_score
11 |
12 | from layers import GraphVae, MLP, RecursiveGraphConvolutionStepAddOn
13 | from loss import reconstruction_loss, vae_loss, recursive_loss_with_noise
14 |
15 | #meta
16 | parser = argparse.ArgumentParser()
17 | parser.add_argument('--hidden_dim', type=int, default=400)
18 | parser.add_argument('--out_dim', type=int, default=200)
19 | parser.add_argument('--num_iters', type=int, default=200)
20 | parser.add_argument('--data_set', type=str, default='cora', choices = ['cora', 'citeseer', 'pubmed'])
21 | parser.add_argument('--seed', type=int, default=111)
22 | parser.add_argument('--size_update_ratio', type=float, default=0.33)
23 | parser.add_argument('--weight_decay', type=float, default=0.0)
24 | args = parser.parse_args()
25 |
26 | size_update_ratio = args.size_update_ratio
27 | hidden_dim = args.hidden_dim
28 | out_dim = args.out_dim
29 | cite_data = args.data_set
30 | weight_decay = args.weight_decay
31 | dataset_str = args.data_set
32 | norm=None
33 | num_iters = args.num_iters
34 | seed = args.seed
35 | np.random.seed(seed)
36 |
37 | ############## utility functions ##############
38 | # def read_citation_dat(dataset):
39 | # '''
40 | # dataset: {'cora', 'citeseer', 'pubmed'}
41 | # '''
42 | #
43 | # feat_fname = '../data/' + dataset + '_features.npz'
44 | # adj_fname = '../data/' + dataset + '_graph.npz'
45 | # features = sp.load_npz(feat_fname)
46 | # adj_orig = sp.load_npz(adj_fname)
47 | # adj_orig = adj_orig + sp.eye(adj_orig.shape[0])
48 | # return adj_orig, features
49 | def parse_index_file(filename):
50 | index = []
51 | for line in open(filename):
52 | index.append(int(line.strip()))
53 | return index
54 |
55 | def load_data(dataset):
56 | # load the data: x, tx, allx, graph
57 | names = ['x', 'tx', 'allx', 'graph']
58 | objects = []
59 | for i in range(len(names)):
60 | objects.append(pkl.load(open("../data/test/ind.{}.{}".format(dataset, names[i]))))
61 | x, tx, allx, graph = tuple(objects)
62 | test_idx_reorder = parse_index_file("../data/test/ind.{}.test.index".format(dataset))
63 | test_idx_range = np.sort(test_idx_reorder)
64 |
65 | if dataset == 'citeseer':
66 | # Fix citeseer dataset (there are some isolated nodes in the graph)
67 | # Find isolated nodes, add them as zero-vecs into the right position
68 | test_idx_range_full = range(min(test_idx_reorder), max(test_idx_reorder)+1)
69 | tx_extended = sp.lil_matrix((len(test_idx_range_full), x.shape[1]))
70 | tx_extended[test_idx_range-min(test_idx_range), :] = tx
71 | tx = tx_extended
72 |
73 | features = sp.vstack((allx, tx)).tolil()
74 | features[test_idx_reorder, :] = features[test_idx_range, :]
75 | adj = nx.adjacency_matrix(nx.from_dict_of_lists(graph))
76 |
77 | return adj, features
78 |
79 | def sparse_to_tuple(sparse_mx):
80 | if not sp.isspmatrix_coo(sparse_mx):
81 | sparse_mx = sparse_mx.tocoo()
82 | coords = np.vstack((sparse_mx.row, sparse_mx.col)).transpose()
83 | values = sparse_mx.data
84 | shape = sparse_mx.shape
85 | return coords, values, shape
86 |
87 | def mask_test_edges(adj):
88 | # Function to build test set with 10% positive links
89 | # NOTE: Splits are randomized and results might slightly deviate from reported numbers in the paper.
90 | # TODO: Clean up.
91 |
92 | # Remove diagonal elements
93 | adj = adj - sp.dia_matrix((adj.diagonal()[np.newaxis, :], [0]), shape=adj.shape)
94 | adj.eliminate_zeros()
95 | # Check that diag is zero:
96 | assert np.diag(adj.todense()).sum() == 0
97 |
98 | adj_triu = sp.triu(adj)
99 | adj_tuple = sparse_to_tuple(adj_triu)
100 | edges = adj_tuple[0]
101 | edges_all = sparse_to_tuple(adj)[0]
102 | num_test = int(np.floor(edges.shape[0] / 10.))
103 | num_val = int(np.floor(edges.shape[0] / 20.))
104 |
105 | all_edge_idx = range(edges.shape[0])
106 | np.random.shuffle(all_edge_idx)
107 | val_edge_idx = all_edge_idx[:num_val]
108 | test_edge_idx = all_edge_idx[num_val:(num_val + num_test)]
109 | test_edges = edges[test_edge_idx]
110 | val_edges = edges[val_edge_idx]
111 | train_edges = np.delete(edges, np.hstack([test_edge_idx, val_edge_idx]), axis=0)
112 |
113 | def ismember(a, b, tol=5):
114 | rows_close = np.all(np.round(a - b[:, None], tol) == 0, axis=-1)
115 | return np.any(rows_close)
116 |
117 | test_edges_false = []
118 | while len(test_edges_false) < len(test_edges):
119 | idx_i = np.random.randint(0, adj.shape[0])
120 | idx_j = np.random.randint(0, adj.shape[0])
121 | if idx_i == idx_j:
122 | continue
123 | if ismember([idx_i, idx_j], edges_all):
124 | continue
125 | if test_edges_false:
126 | if ismember([idx_j, idx_i], np.array(test_edges_false)):
127 | continue
128 | if ismember([idx_i, idx_j], np.array(test_edges_false)):
129 | continue
130 | test_edges_false.append([idx_i, idx_j])
131 |
132 | val_edges_false = []
133 | while len(val_edges_false) < len(val_edges):
134 | idx_i = np.random.randint(0, adj.shape[0])
135 | idx_j = np.random.randint(0, adj.shape[0])
136 | if idx_i == idx_j:
137 | continue
138 | if ismember([idx_i, idx_j], train_edges):
139 | continue
140 | if ismember([idx_j, idx_i], train_edges):
141 | continue
142 | if ismember([idx_i, idx_j], val_edges):
143 | continue
144 | if ismember([idx_j, idx_i], val_edges):
145 | continue
146 | if val_edges_false:
147 | if ismember([idx_j, idx_i], np.array(val_edges_false)):
148 | continue
149 | if ismember([idx_i, idx_j], np.array(val_edges_false)):
150 | continue
151 | val_edges_false.append([idx_i, idx_j])
152 |
153 | assert ~ismember(test_edges_false, edges_all)
154 | assert ~ismember(val_edges_false, edges_all)
155 | assert ~ismember(val_edges, train_edges)
156 | assert ~ismember(test_edges, train_edges)
157 | assert ~ismember(val_edges, test_edges)
158 |
159 | data = np.ones(train_edges.shape[0])
160 |
161 | # Re-build adj matrix
162 | adj_train = sp.csr_matrix((data, (train_edges[:, 0], train_edges[:, 1])), shape=adj.shape)
163 | adj_train = adj_train + adj_train.T
164 |
165 | # NOTE: these edge lists only contain single direction of edge!
166 | return adj_train, train_edges, val_edges, val_edges_false, test_edges, test_edges_false
167 |
168 | def preprocess_graph(adj):
169 | adj = sp.coo_matrix(adj)
170 | adj_ = adj + sp.eye(adj.shape[0])
171 | rowsum = np.array(adj_.sum(1))
172 | degree_mat_inv_sqrt = np.diag(np.power(rowsum, -0.5).flatten())
173 | adj_normalized = adj_.dot(degree_mat_inv_sqrt).transpose().dot(degree_mat_inv_sqrt)
174 |
175 | #print(adj_normalized[:20, :].sum(1))
176 |
177 | return adj_normalized.astype(np.float32)
178 |
179 | def get_roc_score(edges_pos, edges_neg, emb):
180 |
181 | def sigmoid(x):
182 | return 1 / (1 + np.exp(-x))
183 |
184 | # Predict on test set of edges
185 | adj_rec = np.dot(emb, emb.T)
186 | preds = []
187 | pos = []
188 | for e in edges_pos:
189 | preds.append(sigmoid(adj_rec[e[0], e[1]]))
190 | pos.append(adj_orig[e[0], e[1]])
191 |
192 | preds_neg = []
193 | neg = []
194 | for e in edges_neg:
195 | preds_neg.append(sigmoid(adj_rec[e[0], e[1]]))
196 | neg.append(adj_orig[e[0], e[1]])
197 |
198 | preds_all = np.hstack([preds, preds_neg])
199 | labels_all = np.hstack([np.ones(len(preds)), np.zeros(len(preds))])
200 | roc_score = roc_auc_score(labels_all, preds_all)
201 | ap_score = average_precision_score(labels_all, preds_all)
202 |
203 | return roc_score, ap_score
204 |
205 |
206 | ############ prepare data ##############
207 | adj, feat = load_data(dataset_str)
208 |
209 | features_dim = feat.shape[1]
210 | size_update = int(feat.shape[0] * size_update_ratio)
211 | print("size_update: {}".format(size_update))
212 |
213 | # Store original adjacency matrix (without diagonal entries) for later
214 | adj_orig = adj
215 | adj_orig = adj_orig - sp.dia_matrix((adj_orig.diagonal()[np.newaxis, :], [0]), shape=adj_orig.shape)
216 | adj_orig.eliminate_zeros()
217 |
218 | adj_train, train_edges, val_edges, val_edges_false, test_edges, test_edges_false = mask_test_edges(adj)
219 | adj = adj_train
220 |
221 | adj_label = adj_train + sp.eye(adj_train.shape[0])
222 |
223 | adj_norm = torch.from_numpy(preprocess_graph(adj))
224 | adj_label = torch.from_numpy(adj_label.todense().astype(np.float32))
225 | feat = torch.from_numpy(feat.todense().astype(np.float32))
226 |
227 | ############## init model ##############
228 | gcn_vae = GraphVae(features_dim, hidden_dim, out_dim, bias=False, dropout=0.0)
229 | optimizer_vae = torch.optim.Adam(gcn_vae.parameters(), lr=1e-3)
230 |
231 |
232 | gcn_step = RecursiveGraphConvolutionStepAddOn(features_dim, hidden_dim, out_dim, dropout=0.0)
233 | optimizer = torch.optim.Adam(gcn_step.parameters(), lr=1e-3, weight_decay=weight_decay)
234 | # mlp = MLP(features_dim, hidden_dim, out_dim, dropout=0.0)
235 | # optimizer_mlp = torch.optim.Adam(mlp.parameters(), lr=1e-2)
236 |
237 | for batch_idx in range(num_iters):
238 | # # train GCN
239 | # optimizer_vae.zero_grad()
240 | # gcn_vae.train()
241 | # z_mean, z_log_std = gcn_vae(adj_norm, feat)
242 | # vae_train_loss = vae_loss(z_mean, z_log_std, adj_label)
243 | # vae_train_loss.backward()
244 | # optimizer_vae.step()
245 |
246 | # train r-gcn
247 | adj_copy = adj_label.numpy().copy()
248 | X_copy = feat.numpy().copy()
249 |
250 | # initial permutation
251 | len_batch = adj_copy.shape[0]
252 | x_idx = np.random.permutation(adj_copy.shape[0])
253 | adj_copy = adj_copy[np.ix_(x_idx, x_idx)]
254 | X_copy = X_copy[x_idx, :]
255 |
256 | adj_copy = torch.from_numpy(adj_copy.astype(np.float32))
257 | X_copy = torch.from_numpy(X_copy.astype(np.float32))
258 | optimizer.zero_grad()
259 | gcn_step.train()
260 | loss = recursive_loss_with_noise(gcn_step, adj_copy, X_copy, size_update, norm)
261 | loss.backward()
262 | optimizer.step()
263 |
264 | # train mlp
265 | # optimizer_mlp.zero_grad()
266 | # mlp.train()
267 | # z_mean, z_log_std = mlp(feat)
268 | # mlp_train_loss = vae_loss(z_mean, z_log_std, adj_label)
269 | # mlp_train_loss.backward()
270 | # optimizer_mlp.step()
271 |
272 | # print('GCN [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(batch_idx, num_iters,
273 | # 100. * batch_idx / num_iters,
274 | # vae_train_loss.item()))
275 |
276 | print('R-GCN [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(batch_idx, num_iters,
277 | 100. * batch_idx / num_iters,
278 | loss.item()))
279 |
280 | if batch_idx % 10 == 0:
281 | # print('GCN [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(batch_idx, num_iters,
282 | # 100. * batch_idx / num_iters,
283 | # vae_train_loss.item()))
284 | # print('MLP [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(batch_idx, num_iters,
285 | # 100. * batch_idx / num_iters,
286 | # mlp_train_loss.item()))
287 |
288 | with torch.no_grad():
289 |
290 | # # test original gcn
291 | # gcn_vae.eval()
292 | # z_mean, z_log_std = gcn_vae(adj_norm, feat)
293 | #
294 | # normal = torch.distributions.Normal(0, 1)
295 | # z = normal.sample(z_mean.size())
296 | # z = z * torch.exp(z_log_std) + z_mean
297 | #
298 | # roc, ap = get_roc_score(val_edges, val_edges_false, z.numpy())
299 | # print('GCN val AP: {:.6f}'.format(ap))
300 | # print('GCN val AUC: {:.6f}'.format(roc))
301 |
302 |
303 | # test r-gcn
304 | gcn_step.eval()
305 | z_mean, z_log_std = gcn_step(adj_label, feat)
306 | normal = torch.distributions.Normal(0, 1)
307 | z = normal.sample(z_mean.size())
308 | z = z * torch.exp(z_log_std) + z_mean
309 |
310 | roc, ap = get_roc_score(val_edges, val_edges_false, z.numpy())
311 | print('R-GCN val AP: {:.6f}'.format(ap))
312 | print('R-GCN val AUC: {:.6f}'.format(roc))
313 |
314 |
315 | # mlp.eval()
316 | # z_mean, z_log_std = mlp(feat)
317 | # normal = torch.distributions.Normal(0, 1)
318 | # z = normal.sample(z_mean.size())
319 | # z = z * torch.exp(z_log_std) + z_mean
320 | # roc, ap = get_roc_score(val_edges, val_edges_false, z.numpy())
321 | # print('MLP val AP: {:.6f}'.format(ap))
322 | # print('MLP val AUC: {:.6f}'.format(roc))
323 |
324 |
325 | with torch.no_grad():
326 | # gcn_vae.eval()
327 | # z_mean, z_log_std = gcn_vae(adj_norm, feat)
328 | # normal = torch.distributions.Normal(0, 1)
329 | # z = normal.sample(z_mean.size())
330 | # z = z * torch.exp(z_log_std) + z_mean
331 | # roc, ap = get_roc_score(test_edges, test_edges_false, z.numpy())
332 | # print('GCN test AP: {:.6f}'.format(ap))
333 | # print('GCN test AUC: {:.6f}'.format(roc))
334 |
335 | gcn_step.eval()
336 | z_mean, z_log_std = gcn_step(adj_label, feat)
337 | normal = torch.distributions.Normal(0, 1)
338 | z = normal.sample(z_mean.size())
339 | z = z * torch.exp(z_log_std) + z_mean
340 |
341 | roc, ap = get_roc_score(test_edges, test_edges_false, z.numpy())
342 | print('R-GCN test AP: {:.6f}'.format(ap))
343 | print('R-GCN test AUC: {:.6f}'.format(roc))
344 |
--------------------------------------------------------------------------------
/src/run_iso_nodes.py:
--------------------------------------------------------------------------------
1 | #!/Users/d0x00ar/anaconda3/bin/python3.6
2 | import math
3 | import argparse
4 | import os
5 | import networkx as nx
6 | import numpy as np
7 | import torch
8 | import scipy.sparse as sp
9 | from layers import RecursiveGraphConvolutionStep, RecursiveGraphConvolutionStepAddOn, GraphVae, MLP
10 | from graph_data import generate_random_graph_data, GraphSequenceBfsRandSampler, preprocess_graph
11 | from loss import recursive_loss, reconstruction_loss, vae_loss, recursive_loss_with_noise
12 | from utils import link_split_mask, sample_reconstruction, get_roc_auc_score, get_average_precision_score, get_equal_mask
13 |
14 | #os.chdir("/Users/d0x00ar/Documents/GitHub/R-GraphVAE/src")
15 | import logging
16 |
17 | #meta
18 | def read_citation_dat(dataset, with_test=False, permute=False, test_ratio=0.3):
19 | '''
20 | dataset: {'cora', 'citeseer', 'pubmed'}
21 | '''
22 |
23 | feat_fname = '../data/' + dataset + '_features.npz'
24 | adj_fname = '../data/' + dataset + '_graph.npz'
25 | features = sp.load_npz(feat_fname)
26 | adj_orig = sp.load_npz(adj_fname)
27 |
28 | adj_orig = adj_orig + sp.eye(adj_orig.shape[0])
29 |
30 | X_select = features.todense().astype(np.float32)
31 | adj_select = adj_orig.todense().astype(np.float32)
32 |
33 | if permute:
34 | x_idx = np.random.permutation(adj_copy.shape[0])
35 | adj_select = adj_select[np.ix_(x_idx, x_idx)]
36 | X_select = X_select[x_idx, :]
37 |
38 | if with_test:
39 | cut_idx = int(adj_select.shape[0] * (1 - test_ratio))
40 |
41 | adj_train = adj_select[:cut_idx, :cut_idx]
42 | X_train = X_select[:cut_idx, :]
43 | return adj_train, X_train, adj_select, X_select
44 | else:
45 | return adj_select, X_select
46 |
47 | parser = argparse.ArgumentParser()
48 | parser.add_argument('--hidden_dim', type=int, default=400)
49 | parser.add_argument('--out_dim', type=int, default=200)
50 | parser.add_argument('--update_ratio', type=float, default=0.33)
51 | parser.add_argument('--data_set', type=str, default='cora', choices = ['cora', 'citeseer', 'pubmed'])
52 | parser.add_argument('--seed', default=None)
53 | parser.add_argument('--refit', type=int, default=0)
54 | parser.add_argument('--permute', type=int, default=1)
55 | args = parser.parse_args()
56 |
57 |
58 | hidden_dim = args.hidden_dim
59 | out_dim = args.out_dim
60 | cite_data = args.data_set
61 |
62 | g_adj, X, g_adj_all, X_all = read_citation_dat(cite_data, with_test=True, permute=False)
63 | num_nodes = g_adj_all.shape[0]
64 | num_edges = ((g_adj_all > 0).sum() - num_nodes) / 2
65 | print([num_nodes, num_edges])
66 |
67 | size_update = int(num_nodes * args.update_ratio * 0.7)
68 |
69 | seed = float(args.seed) if args.seed else None
70 | unseen = True
71 | refit = args.refit > 0
72 | permute = args.permute > 0
73 |
74 | norm=None
75 | special = 'nodropout_DEBUG'
76 |
77 | filename = '_'.join(['equal_size_cite', special, cite_data,
78 | 'size', str(size_update),
79 | 'hidden', str(hidden_dim),
80 | 'out', str(out_dim),
81 | 'fix', str(seed is not None),
82 | 'unseen', str(unseen),
83 | 'refit', str(refit),
84 | 'norm', str(norm),
85 | 'permute', str(permute),
86 | 'seed', str(seed)])
87 | filename = '../data/exp_results/' + filename
88 |
89 |
90 | logging.basicConfig(level=logging.DEBUG, filename=filename,
91 | format="%(asctime)-15s %(levelname)-8s %(message)s")
92 |
93 | if seed is not None:
94 | np.random.seed(seed)
95 |
96 | features_dim = X.shape[1]
97 |
98 | dataset = GraphSequenceBfsRandSampler(g_adj, X, num_permutation=400, seed=seed, fix=False)
99 |
100 | params = {'batch_size': 1,
101 | 'shuffle': True,
102 | 'num_workers': 2}
103 |
104 | dataloader = torch.utils.data.DataLoader(dataset, **params)
105 |
106 | # gcn_step = RecursiveGraphConvolutionStep(features_dim, hidden_dim, out_dim)
107 | gcn_step = RecursiveGraphConvolutionStepAddOn(features_dim, hidden_dim, out_dim, dropout=0.0)
108 | gcn_vae = GraphVae(features_dim, hidden_dim, out_dim, dropout=0)
109 | mlp = MLP(features_dim, hidden_dim, out_dim, dropout=0)
110 |
111 | optimizer = torch.optim.Adam(gcn_step.parameters(), lr=1e-3)
112 | optimizer_vae = torch.optim.Adam(gcn_vae.parameters(), lr=1e-3)
113 | optimizer_mlp = torch.optim.Adam(mlp.parameters(), lr=1e-3)
114 | train_loss = 0
115 |
116 | print('checkpoint1')
117 |
118 | for batch_idx, (adj, feat) in enumerate(dataloader):
119 | adj = adj[0]
120 | feat = feat[0]
121 |
122 | if adj.size()[0] <= size_update:
123 | print("sample size {} too small, skipped!".format(adj.size()[0]))
124 | continue
125 |
126 | # train R-GCN
127 |
128 | optimizer.zero_grad()
129 | gcn_step.train()
130 | # loss = recursive_loss(gcn_step, adj, feat, size_update)
131 | loss = recursive_loss_with_noise(gcn_step, adj, feat, size_update, norm)
132 | loss.backward()
133 | train_loss += loss.item()
134 | optimizer.step()
135 |
136 | # train GCN
137 | optimizer_vae.zero_grad()
138 | gcn_vae.train()
139 | adj_vae_norm = torch.from_numpy(preprocess_graph(adj.numpy()))
140 | z_mean, z_log_std = gcn_vae(adj_vae_norm, feat)
141 | vae_train_loss = vae_loss(z_mean, z_log_std, adj, norm)
142 | vae_train_loss.backward()
143 | optimizer_vae.step()
144 |
145 | # train mlp
146 | optimizer_mlp.zero_grad()
147 | mlp.train()
148 | z_mean, z_log_std = mlp(feat)
149 | mlp_train_loss = vae_loss(z_mean, z_log_std, adj, norm)
150 | mlp_train_loss.backward()
151 | optimizer_mlp.step()
152 |
153 | if batch_idx % 10 == 0:
154 | info ='R-GCN [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(batch_idx, len(dataloader),
155 | 100. * batch_idx / len(dataloader),
156 | loss.item())
157 | print(info)
158 | logging.info(info)
159 |
160 | info = 'GCN [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(batch_idx, len(dataloader),
161 | 100. * batch_idx / len(dataloader),
162 | vae_train_loss.item())
163 | print(info)
164 | logging.info(info)
165 |
166 | info = 'MLP [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(batch_idx, len(dataloader),
167 | 100. * batch_idx / len(dataloader),
168 | mlp_train_loss.item())
169 | print(info)
170 | logging.info(info)
171 |
172 | with torch.no_grad():
173 | adj_feed_norm = torch.from_numpy(preprocess_graph(g_adj))
174 | adj_truth_all = torch.from_numpy(g_adj_all.astype(np.float32))
175 |
176 | feat = torch.from_numpy(X)
177 | feat_all = torch.from_numpy(X_all)
178 |
179 | mask = torch.ones_like(adj_truth_all).type(torch.ByteTensor)
180 | mask[:feat.size()[0], :feat.size()[0]] = 0
181 |
182 | diag_mask = torch.eye(mask.size(0)).type(torch.ByteTensor)
183 | mask = mask * (1 - diag_mask)
184 |
185 | mask = get_equal_mask(adj_truth_all, mask)
186 |
187 |
188 | # # test r-gcn
189 | gcn_step.eval()
190 | z_mean_old, z_log_std_old, z_mean_new, z_log_std_new = gcn_step(torch.from_numpy(g_adj), feat, feat_all[feat.size()[0]:, :])
191 | z_mean = torch.cat((z_mean_old, z_mean_new))
192 | z_log_std = torch.cat((z_log_std_old, z_log_std_new))
193 |
194 | adj_h = sample_reconstruction(z_mean, z_log_std)
195 | if refit:
196 | adj_hat = (adj_h > 0).type(torch.FloatTensor)
197 | adj_hat[:feat.size(0), :feat.size(0)] = torch.from_numpy(g_adj)
198 | z_mean, z_log = gcn_step(adj_hat, feat_all)
199 | adj_h = sample_reconstruction(z_mean, z_log_std)
200 |
201 |
202 | test_loss = reconstruction_loss(adj_truth_all, adj_h, mask, test=True)
203 | auc_rgcn = get_roc_auc_score(adj_truth_all, adj_h, mask)
204 | ap_rgcn = get_average_precision_score(adj_truth_all, adj_h, mask)
205 |
206 | info = 'R-GCN test loss: {:.6f}'.format(test_loss)
207 | print(info)
208 | logging.info(info)
209 |
210 |
211 | # test original gcn
212 | gcn_vae.eval()
213 | adj_vae_norm = torch.eye(feat_all.size()[0])
214 | adj_vae_norm[:feat.size()[0], :feat.size()[0]] = adj_feed_norm
215 | z_mean, z_log_std = gcn_vae(adj_vae_norm, feat_all)
216 | adj_h = sample_reconstruction(z_mean, z_log_std)
217 | test_loss = reconstruction_loss(adj_truth_all, adj_h, mask, test=True)
218 | auc_gcn = get_roc_auc_score(adj_truth_all, adj_h, mask)
219 | ap_gcn = get_average_precision_score(adj_truth_all, adj_h, mask)
220 |
221 | info = 'Original GCN test loss: {:.6f}'.format(test_loss)
222 | print(info)
223 | logging.info(info)
224 |
225 |
226 | # test on mlp
227 | mlp.eval()
228 | z_mean, z_log_std = mlp(feat_all)
229 | adj_h = sample_reconstruction(z_mean, z_log_std)
230 | test_loss = reconstruction_loss(adj_truth_all, adj_h, mask, test=True)
231 | auc_mlp = get_roc_auc_score(adj_truth_all, adj_h, mask)
232 | ap_mlp = get_average_precision_score(adj_truth_all, adj_h, mask)
233 | info = 'MLP test loss: {:.6f}'.format(test_loss)
234 | print(info)
235 | logging.info(info)
236 |
237 | print('AUC:')
238 | info = 'R-GCN auc: {:.6f}'.format(auc_rgcn)
239 | print(info)
240 | logging.info(info)
241 | info = 'Original GCN auc: {:.6f}'.format(auc_gcn)
242 | print(info)
243 | logging.info(info)
244 | info = 'MLP auc: {:.6f}'.format(auc_mlp)
245 | print(info)
246 | logging.info(info)
247 |
248 |
249 |
250 | info = 'R-GCN AP: {:.6f}'.format(ap_rgcn)
251 | print(info)
252 | logging.info(info)
253 | info = 'Original GCN AP: {:.6f}'.format(ap_gcn)
254 | print(info)
255 | logging.info(info)
256 | info = 'MLP AP: {:.6f}'.format(ap_mlp)
257 | print(info)
258 | logging.info(info)
259 |
--------------------------------------------------------------------------------
/src/train_citation_gae.py:
--------------------------------------------------------------------------------
1 | #!/Users/d0x00ar/anaconda3/bin/python3.6
2 | import math
3 | import networkx as nx
4 | import numpy as np
5 | import torch
6 | import os
7 | import scipy.sparse as sp
8 | from torch.utils.data import Dataset
9 |
10 | from layers import RecursiveGraphConvolutionStep, RecursiveGraphConvolutionStepAddOn, GraphVae, MLP
11 | from graph_data import generate_random_graph_data, GraphSequenceBfsRandSampler, preprocess_graph
12 | from loss import recursive_loss, reconstruction_loss, vae_loss, recursive_loss_with_noise_supervised
13 | from utils import link_split_mask, sample_reconstruction, get_roc_auc_score, get_average_precision_score, get_equal_mask
14 |
15 | #os.chdir("/Users/d0x00ar/Documents/GitHub/R-GraphVAE/src")
16 | import logging
17 |
18 | #meta
19 | size_update = 20
20 | hidden_dim = 64
21 | out_dim = 32
22 | cite_data = 'citeseer'
23 | random_h = True
24 | seed = None
25 | unseen = True
26 | refit = False
27 | permute = False
28 | norm=None
29 | special = 'GAE'
30 | seed = 888
31 | num_permutation = 400
32 | weight_decay = 0
33 | dropout=0.0
34 |
35 | head_info = '_'.join(['equal_size_cite', special, cite_data,
36 | 'size', str(size_update),
37 | 'hidden', str(hidden_dim),
38 | 'out', str(out_dim),
39 | 'random_h', str(random_h),
40 | 'fix', str(seed is not None),
41 | 'unseen', str(unseen),
42 | 'refit', str(refit),
43 | 'norm', str(norm),
44 | 'permute', str(permute),
45 | 'seed', str(seed)])
46 |
47 | filename = 'train_original_gae_64_32_citeseer_nodropout'
48 | filename = '../data/important_results/' + filename
49 |
50 |
51 | logging.basicConfig(level=logging.DEBUG, filename=filename,
52 | format="%(asctime)-15s %(levelname)-8s %(message)s")
53 | logging.info('This use the sample implementation.')
54 | logging.info(head_info)
55 |
56 | if seed is not None:
57 | np.random.seed(seed)
58 |
59 | ############### Prepare input data ###############
60 | def read_citation_dat(dataset, with_test=False, permute=False, test_ratio=0.3):
61 | '''
62 | dataset: {'cora', 'citeseer', 'pubmed'}
63 | '''
64 |
65 | feat_fname = '../data/' + dataset + '_features.npz'
66 | adj_fname = '../data/' + dataset + '_graph.npz'
67 | features = sp.load_npz(feat_fname)
68 | adj_orig = sp.load_npz(adj_fname)
69 |
70 | adj_orig = adj_orig + sp.eye(adj_orig.shape[0])
71 |
72 | X_select = features.todense().astype(np.float32)
73 | adj_select = adj_orig.todense().astype(np.float32)
74 |
75 | if permute:
76 | x_idx = np.random.permutation(adj_copy.shape[0])
77 | adj_select = adj_select[np.ix_(x_idx, x_idx)]
78 | X_select = X_select[x_idx, :]
79 |
80 | if with_test:
81 | cut_idx = int(adj_select.shape[0] * (1 - test_ratio))
82 |
83 | adj_train = adj_select[:cut_idx, :cut_idx]
84 | X_train = X_select[:cut_idx, :]
85 | return adj_train, X_train, adj_select, X_select
86 | else:
87 | return adj_select, X_select
88 |
89 |
90 |
91 | g_adj, X, g_adj_all, X_all = read_citation_dat(cite_data, with_test=True)
92 |
93 | features_dim = X.shape[1]
94 |
95 |
96 |
97 | params = {'batch_size': 1,
98 | 'shuffle': True,
99 | 'num_workers': 2}
100 |
101 |
102 | ############### Init models ###############
103 | gcn_vae = GraphVae(features_dim, hidden_dim, out_dim, dropout=dropout)
104 | mlp = MLP(features_dim, hidden_dim, out_dim)
105 |
106 | optimizer_vae = torch.optim.Adam(gcn_vae.parameters(), lr=1e-2)
107 | optimizer_mlp = torch.optim.Adam(mlp.parameters(), lr=1e-2)
108 | train_loss = 0
109 |
110 | cache = None
111 | ################ training loop #####################
112 | adj = torch.from_numpy(g_adj)
113 | feat = torch.from_numpy(X)
114 | for batch_idx in range(num_permutation):
115 |
116 | if adj.size()[0] <= size_update:
117 | print("sample size {} too small, skipped!".format(adj.size()[0]))
118 | continue
119 |
120 | # train GCN
121 | optimizer_vae.zero_grad()
122 | gcn_vae.train()
123 | adj_vae_norm = torch.from_numpy(preprocess_graph(adj.numpy()))
124 | z_mean, z_log_std = gcn_vae(adj_vae_norm, feat)
125 | vae_train_loss = vae_loss(z_mean, z_log_std, adj, norm)
126 | vae_train_loss.backward()
127 | optimizer_vae.step()
128 |
129 | # train mlp
130 | optimizer_mlp.zero_grad()
131 | mlp.train()
132 | z_mean, z_log_std = mlp(feat)
133 | mlp_train_loss = vae_loss(z_mean, z_log_std, adj)
134 | mlp_train_loss.backward()
135 | optimizer_mlp.step()
136 |
137 | if batch_idx % 10 == 0:
138 |
139 | info = 'GCN [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(batch_idx, num_permutation,
140 | 100. * batch_idx / num_permutation,
141 | vae_train_loss.item())
142 | print(info)
143 | logging.info(info)
144 |
145 | info = 'MLP [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(batch_idx, num_permutation,
146 | 100. * batch_idx / num_permutation,
147 | mlp_train_loss.item())
148 | print(info)
149 | logging.info(info)
150 |
151 | with torch.no_grad():
152 |
153 | adj_feed_norm = torch.from_numpy(preprocess_graph(g_adj))
154 | adj_all = torch.from_numpy(g_adj_all)
155 | feat = torch.from_numpy(X)
156 | feat_all = torch.from_numpy(X_all)
157 |
158 | mask = torch.ones_like(adj_all).type(torch.ByteTensor)
159 | mask[:feat.size()[0], :feat.size()[0]] = 0
160 |
161 | diag_mask = torch.eye(mask.size(0)).type(torch.ByteTensor)
162 | mask = mask * (1 - diag_mask)
163 |
164 | mask = get_equal_mask(adj_all, mask, thresh=0.8)
165 |
166 | # test original gcn
167 | gcn_vae.eval()
168 | adj_vae_norm = torch.eye(feat_all.size()[0])
169 | adj_vae_norm[:feat.size()[0], :feat.size()[0]] = adj_feed_norm
170 | z_mean, z_log_std = gcn_vae(adj_vae_norm, feat_all)
171 | adj_h = sample_reconstruction(z_mean, z_log_std)
172 | test_loss = reconstruction_loss(adj_all, adj_h, mask, test=True)
173 | auc_gcn = get_roc_auc_score(adj_all, adj_h, mask)
174 | ap_gcn = get_average_precision_score(adj_all, adj_h, mask)
175 |
176 | info = 'Original GCN test loss: {:.6f}'.format(test_loss)
177 | print(info)
178 | logging.info(info)
179 |
180 |
181 | # test on mlp
182 | mlp.eval()
183 | z_mean, z_log_std = mlp(feat_all)
184 | adj_h = sample_reconstruction(z_mean, z_log_std)
185 | test_loss = reconstruction_loss(adj_all, adj_h, mask, test=True)
186 | auc_mlp = get_roc_auc_score(adj_all, adj_h, mask)
187 | ap_mlp = get_average_precision_score(adj_all, adj_h, mask)
188 | info = 'MLP test loss: {:.6f}'.format(test_loss)
189 | print(info)
190 | logging.info(info)
191 |
192 | ###### refit model ######
193 | # test original gcn
194 | gcn_vae.eval()
195 | adj_vae_norm = torch.eye(feat_all.size()[0])
196 | adj_vae_norm[:feat.size()[0], :feat.size()[0]] = adj_feed_norm
197 | z_mean, z_log_std = gcn_vae(adj_vae_norm, feat_all)
198 | adj_h = sample_reconstruction(z_mean, z_log_std)
199 | # ri-fit
200 | adj_fake = (adj_h.sigmoid() > 0.8).type(torch.FloatTensor)
201 | adj_fake[:feat.size(0), :feat.size(0)] = torch.from_numpy(g_adj)
202 | adj_fake = adj_fake * (1 - torch.eye(adj_fake.size(0))) + torch.eye(adj_fake.size(0))
203 | adj_fake_norm = torch.from_numpy(preprocess_graph(adj_fake.numpy()))
204 |
205 | z_mean, z_log_std = gcn_vae(adj_vae_norm, feat_all)
206 | z_mean_new, z_log_std_new = gcn_vae(adj_fake_norm, feat_all)
207 |
208 | z_mean[feat.size(0):, :] = z_mean_new[feat.size(0):, :]
209 | z_log_std[:feat.size(0), :] = z_log_std[:feat.size(0), :]
210 |
211 | adj_h = sample_reconstruction(z_mean, z_log_std)
212 | auc_gcn_fake = get_roc_auc_score(adj_all, adj_h, mask)
213 | ap_gcn_fake = get_average_precision_score(adj_all, adj_h, mask)
214 |
215 |
216 | print('AUC:')
217 | info = 'Original GCN auc: {:.6f}'.format(auc_gcn)
218 | print(info)
219 | logging.info(info)
220 | info = 'MLP auc: {:.6f}'.format(auc_mlp)
221 | print(info)
222 | logging.info(info)
223 | info = 'FAKE auc: {:.6f}'.format(auc_gcn_fake)
224 | print(info)
225 |
226 |
227 | print('AP:')
228 | info = 'Original GCN AP: {:.6f}'.format(ap_gcn)
229 | print(info)
230 | logging.info(info)
231 | info = 'MLP AP: {:.6f}'.format(ap_mlp)
232 | print(info)
233 | logging.info(info)
234 | info = 'FAKE AP: {:.6f}'.format(ap_gcn_fake)
235 | print(info)
236 | logging.info(info)
237 |
238 |
239 |
240 |
241 | # ###### get embeddings for traing and test objects #####
242 | # with torch.no_grad():
243 | #
244 | # adj_feed_norm = torch.from_numpy(preprocess_graph(g_adj))
245 | # if isinstance(label_all, np.ndarray):
246 | # label_all = torch.from_numpy(label_all)
247 | # feat = torch.from_numpy(X)
248 | # feat_all = torch.from_numpy(X_all)
249 | #
250 | # # gcn_step.eval()
251 | # # z_mean_old, z_log_std_old, z_mean_new, z_log_std_new = gcn_step(torch.from_numpy(g_adj), feat, feat_all[feat.size()[0]:, :])
252 | # #
253 | # # z_mean = torch.cat((z_mean_old, z_mean_new))
254 | # # z_log_std = torch.cat((z_log_std_old, z_log_std_new))
255 | #
256 | # gcn_vae.eval()
257 | # adj_vae_norm = torch.eye(feat_all.size()[0])
258 | # adj_vae_norm[:feat.size()[0], :feat.size()[0]] = adj_feed_norm
259 | # z_mean, z_log_std = gcn_vae(adj_vae_norm, feat_all)
260 | #
261 | # normal = torch.distributions.Normal(0, 1)
262 | # z = normal.sample(z_mean.size())
263 | # z = z * torch.exp(z_log_std) + z_mean
264 | # z = z.numpy()
265 | #
266 | # new_z = z.copy()
267 | # def invert(p):
268 | # p = list(p)
269 | # return np.array([p.index(l) for l in range(len(p))])
270 | #
271 | # new_x_idx = invert(x_idx)
272 | # new_z = z[new_x_idx, :]
273 | #
274 | # test_idx = x_idx[z_mean_old.size(0)+1:]
275 | #
276 | # ## TO DO
277 | # np.savetxt('../saved_models/ice_cream_gcn_embed.txt', new_z)
278 | # np.savetxt('../saved_models/test_idx_gcn_array.txt', test_idx)
279 |
--------------------------------------------------------------------------------
/src/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from sklearn import metrics
4 |
5 | def link_split_mask(adj, mask_ratio=0.1, seed=999):
6 | adj_t = np.tril(adj)
7 | np.random.seed(seed)
8 | edge_lists = np.where(adj_t > 0)
9 | x_idx = np.random.permutation(len(edge_lists[0]))
10 | row_list_pos = edge_lists[0][x_idx][:int(0.1 * len(edge_lists[0]))]
11 | col_list_pos = edge_lists[1][x_idx][:int(0.1 * len(edge_lists[0]))]
12 |
13 |
14 | edge_lists = np.where(adj_t <= 0)
15 | neg_row = edge_lists[0][edge_lists[0] > edge_lists[1]]
16 | neg_col = edge_lists[1][edge_lists[0] > edge_lists[1]]
17 | x_idx = np.random.permutation(len(neg_row))
18 | row_list_neg = neg_row[x_idx][:len(row_list_pos)]
19 | col_list_neg = neg_col[x_idx][:len(row_list_pos)]
20 |
21 | row = np.concatenate((row_list_pos,row_list_neg))
22 | col = np.concatenate((col_list_pos, col_list_neg))
23 |
24 | return np.concatenate((row,col)), np.concatenate((col, row))
25 |
26 | def sample_reconstruction(z_mean, z_log_std):
27 | num_nodes = z_mean.size()[0]
28 | normal = torch.distributions.Normal(0, 1)
29 |
30 | # sample z to approximiate the posterior of A
31 | z = normal.sample(z_mean.size())
32 | z = z * torch.exp(z_log_std) + z_mean
33 | adj_h = torch.mm(z, z.permute(1, 0))
34 | return adj_h
35 |
36 | def get_roc_auc_score(adj, adj_h, mask):
37 | adj_n = adj[mask].numpy() > 0.9
38 | adj_h_n = adj_h[mask].sigmoid().numpy()
39 | return metrics.roc_auc_score(adj_n, adj_h_n)
40 |
41 | def get_average_precision_score(adj, adj_h, mask):
42 | adj_n = adj[mask].numpy() > 0.9
43 | adj_h_n = adj_h[mask].sigmoid().numpy()
44 | return metrics.average_precision_score(adj_n, adj_h_n)
45 |
46 | def get_equal_mask(adj_true, test_mask, thresh=0):
47 | """create a mask which gives equal number of positive and negtive edges"""
48 | adj_true = adj_true > thresh
49 | pos_link_mask = adj_true * test_mask
50 | num_links = int(pos_link_mask.sum().item())
51 |
52 | if num_links > 0.5 * test_mask.sum().item():
53 | raise ValueError('test nodes over connected!')
54 |
55 | neg_link_mask = (1 - adj_true) * test_mask
56 | neg_link_mask = neg_link_mask.numpy()
57 | row, col = np.where(neg_link_mask > 0)
58 | new_idx = np.random.permutation(len(row))
59 | row, col = row[new_idx][:num_links], col[new_idx][:num_links]
60 | neg_link_mask *= 0
61 | neg_link_mask[row, col] = 1
62 | neg_link_mask = torch.from_numpy(neg_link_mask)
63 |
64 | assert((pos_link_mask * neg_link_mask).sum().item() == 0)
65 | assert(neg_link_mask.sum().item() == num_links)
66 | assert(((pos_link_mask + neg_link_mask) * test_mask != (pos_link_mask + neg_link_mask)).sum().item() == 0)
67 | return pos_link_mask + neg_link_mask
68 |
--------------------------------------------------------------------------------
/src/validate_gae_implementation.py:
--------------------------------------------------------------------------------
1 | import math
2 | import argparse
3 | import networkx as nx
4 | import numpy as np
5 | import scipy.sparse as sp
6 | import torch
7 | import pickle as pkl
8 |
9 | from sklearn.metrics import roc_auc_score
10 | from sklearn.metrics import average_precision_score
11 |
12 | from layers import GraphVae, MLP, GraphAE
13 | from loss import reconstruction_loss, vae_loss
14 |
15 | #meta
16 | parser = argparse.ArgumentParser()
17 | parser.add_argument('--hidden_dim', type=int, default=32)
18 | parser.add_argument('--out_dim', type=int, default=16)
19 | parser.add_argument('--num_iters', type=int, default=200)
20 | parser.add_argument('--data_set', type=str, default='cora', choices = ['cora', 'citeseer', 'pubmed'])
21 | parser.add_argument('--seed', type=int, default=888)
22 | args = parser.parse_args()
23 |
24 | hidden_dim = args.hidden_dim
25 | out_dim = args.out_dim
26 | cite_data = args.data_set
27 |
28 | norm=None
29 | num_iters = args.num_iters
30 | seed = args.seed
31 | np.random.seed(seed)
32 |
33 | ############## utility functions ##############
34 | # def read_citation_dat(dataset):
35 | # '''
36 | # dataset: {'cora', 'citeseer', 'pubmed'}
37 | # '''
38 | #
39 | # feat_fname = '../data/' + dataset + '_features.npz'
40 | # adj_fname = '../data/' + dataset + '_graph.npz'
41 | # features = sp.load_npz(feat_fname)
42 | # adj_orig = sp.load_npz(adj_fname)
43 | # adj_orig = adj_orig + sp.eye(adj_orig.shape[0])
44 | # return adj_orig, features
45 | def parse_index_file(filename):
46 | index = []
47 | for line in open(filename):
48 | index.append(int(line.strip()))
49 | return index
50 |
51 | def load_data(dataset):
52 | # load the data: x, tx, allx, graph
53 | names = ['x', 'tx', 'allx', 'graph']
54 | objects = []
55 | for i in range(len(names)):
56 | objects.append(pkl.load(open("../data/test/ind.{}.{}".format(dataset, names[i]))))
57 | x, tx, allx, graph = tuple(objects)
58 | test_idx_reorder = parse_index_file("../data/test/ind.{}.test.index".format(dataset))
59 | test_idx_range = np.sort(test_idx_reorder)
60 |
61 | if dataset == 'citeseer':
62 | # Fix citeseer dataset (there are some isolated nodes in the graph)
63 | # Find isolated nodes, add them as zero-vecs into the right position
64 | test_idx_range_full = range(min(test_idx_reorder), max(test_idx_reorder)+1)
65 | tx_extended = sp.lil_matrix((len(test_idx_range_full), x.shape[1]))
66 | tx_extended[test_idx_range-min(test_idx_range), :] = tx
67 | tx = tx_extended
68 |
69 | features = sp.vstack((allx, tx)).tolil()
70 | features[test_idx_reorder, :] = features[test_idx_range, :]
71 | adj = nx.adjacency_matrix(nx.from_dict_of_lists(graph))
72 |
73 | return adj, features
74 |
75 | def sparse_to_tuple(sparse_mx):
76 | if not sp.isspmatrix_coo(sparse_mx):
77 | sparse_mx = sparse_mx.tocoo()
78 | coords = np.vstack((sparse_mx.row, sparse_mx.col)).transpose()
79 | values = sparse_mx.data
80 | shape = sparse_mx.shape
81 | return coords, values, shape
82 |
83 | def mask_test_edges(adj):
84 | # Function to build test set with 10% positive links
85 | # NOTE: Splits are randomized and results might slightly deviate from reported numbers in the paper.
86 | # TODO: Clean up.
87 |
88 | # Remove diagonal elements
89 | adj = adj - sp.dia_matrix((adj.diagonal()[np.newaxis, :], [0]), shape=adj.shape)
90 | adj.eliminate_zeros()
91 | # Check that diag is zero:
92 | assert np.diag(adj.todense()).sum() == 0
93 |
94 | adj_triu = sp.triu(adj)
95 | adj_tuple = sparse_to_tuple(adj_triu)
96 | edges = adj_tuple[0]
97 | edges_all = sparse_to_tuple(adj)[0]
98 | num_test = int(np.floor(edges.shape[0] / 10.))
99 | num_val = int(np.floor(edges.shape[0] / 20.))
100 |
101 | all_edge_idx = range(edges.shape[0])
102 | np.random.shuffle(all_edge_idx)
103 | val_edge_idx = all_edge_idx[:num_val]
104 | test_edge_idx = all_edge_idx[num_val:(num_val + num_test)]
105 | test_edges = edges[test_edge_idx]
106 | val_edges = edges[val_edge_idx]
107 | train_edges = np.delete(edges, np.hstack([test_edge_idx, val_edge_idx]), axis=0)
108 |
109 | def ismember(a, b, tol=5):
110 | rows_close = np.all(np.round(a - b[:, None], tol) == 0, axis=-1)
111 | return np.any(rows_close)
112 |
113 | test_edges_false = []
114 | while len(test_edges_false) < len(test_edges):
115 | idx_i = np.random.randint(0, adj.shape[0])
116 | idx_j = np.random.randint(0, adj.shape[0])
117 | if idx_i == idx_j:
118 | continue
119 | if ismember([idx_i, idx_j], edges_all):
120 | continue
121 | if test_edges_false:
122 | if ismember([idx_j, idx_i], np.array(test_edges_false)):
123 | continue
124 | if ismember([idx_i, idx_j], np.array(test_edges_false)):
125 | continue
126 | test_edges_false.append([idx_i, idx_j])
127 |
128 | val_edges_false = []
129 | while len(val_edges_false) < len(val_edges):
130 | idx_i = np.random.randint(0, adj.shape[0])
131 | idx_j = np.random.randint(0, adj.shape[0])
132 | if idx_i == idx_j:
133 | continue
134 | if ismember([idx_i, idx_j], train_edges):
135 | continue
136 | if ismember([idx_j, idx_i], train_edges):
137 | continue
138 | if ismember([idx_i, idx_j], val_edges):
139 | continue
140 | if ismember([idx_j, idx_i], val_edges):
141 | continue
142 | if val_edges_false:
143 | if ismember([idx_j, idx_i], np.array(val_edges_false)):
144 | continue
145 | if ismember([idx_i, idx_j], np.array(val_edges_false)):
146 | continue
147 | val_edges_false.append([idx_i, idx_j])
148 |
149 | assert ~ismember(test_edges_false, edges_all)
150 | assert ~ismember(val_edges_false, edges_all)
151 | assert ~ismember(val_edges, train_edges)
152 | assert ~ismember(test_edges, train_edges)
153 | assert ~ismember(val_edges, test_edges)
154 |
155 | data = np.ones(train_edges.shape[0])
156 |
157 | # Re-build adj matrix
158 | adj_train = sp.csr_matrix((data, (train_edges[:, 0], train_edges[:, 1])), shape=adj.shape)
159 | adj_train = adj_train + adj_train.T
160 |
161 | # NOTE: these edge lists only contain single direction of edge!
162 | return adj_train, train_edges, val_edges, val_edges_false, test_edges, test_edges_false
163 |
164 | def preprocess_graph(adj):
165 | adj = sp.coo_matrix(adj)
166 | adj_ = adj + sp.eye(adj.shape[0])
167 | rowsum = np.array(adj_.sum(1))
168 | degree_mat_inv_sqrt = np.diag(np.power(rowsum, -0.5).flatten())
169 | adj_normalized = adj_.dot(degree_mat_inv_sqrt).transpose().dot(degree_mat_inv_sqrt)
170 |
171 | #print(adj_normalized[:20, :].sum(1))
172 |
173 | return adj_normalized.astype(np.float32)
174 |
175 | def get_roc_score(edges_pos, edges_neg, emb):
176 |
177 | def sigmoid(x):
178 | return 1 / (1 + np.exp(-x))
179 |
180 | # Predict on test set of edges
181 | adj_rec = np.dot(emb, emb.T)
182 | preds = []
183 | pos = []
184 | for e in edges_pos:
185 | preds.append(sigmoid(adj_rec[e[0], e[1]]))
186 | pos.append(adj_orig[e[0], e[1]])
187 |
188 | preds_neg = []
189 | neg = []
190 | for e in edges_neg:
191 | preds_neg.append(sigmoid(adj_rec[e[0], e[1]]))
192 | neg.append(adj_orig[e[0], e[1]])
193 |
194 | preds_all = np.hstack([preds, preds_neg])
195 | labels_all = np.hstack([np.ones(len(preds)), np.zeros(len(preds))])
196 | roc_score = roc_auc_score(labels_all, preds_all)
197 | ap_score = average_precision_score(labels_all, preds_all)
198 |
199 | return roc_score, ap_score
200 |
201 |
202 | ############ prepare data ##############
203 | adj, feat = load_data(args.data_set)
204 |
205 | features_dim = feat.shape[1]
206 |
207 | # Store original adjacency matrix (without diagonal entries) for later
208 | adj_orig = adj
209 | adj_orig = adj_orig - sp.dia_matrix((adj_orig.diagonal()[np.newaxis, :], [0]), shape=adj_orig.shape)
210 | adj_orig.eliminate_zeros()
211 |
212 | adj_train, train_edges, val_edges, val_edges_false, test_edges, test_edges_false = mask_test_edges(adj)
213 | adj = adj_train
214 |
215 | adj_label = adj_train + sp.eye(adj_train.shape[0])
216 |
217 | adj_norm = torch.from_numpy(preprocess_graph(adj))
218 | adj_label = torch.from_numpy(adj_label.todense().astype(np.float32))
219 | feat = torch.from_numpy(feat.todense().astype(np.float32))
220 |
221 | ############## init model ##############
222 | gcn_vae = GraphAE(features_dim, hidden_dim, out_dim, bias=False, dropout=0.0)
223 | optimizer_vae = torch.optim.Adam(gcn_vae.parameters(), lr=1e-2)
224 |
225 | mlp = MLP(features_dim, hidden_dim, out_dim, dropout=0.0)
226 | optimizer_mlp = torch.optim.Adam(mlp.parameters(), lr=1e-2)
227 |
228 | for batch_idx in range(num_iters):
229 | # train GCN
230 | optimizer_vae.zero_grad()
231 | gcn_vae.train()
232 | z = gcn_vae(adj_norm, feat)
233 | adj_h = torch.mm(z, z.t())
234 | vae_train_loss = reconstruction_loss(adj_label, adj_h, norm)
235 | vae_train_loss.backward()
236 | optimizer_vae.step()
237 |
238 | #train mlp
239 | optimizer_mlp.zero_grad()
240 | mlp.train()
241 | z_mean, z_log_std = mlp(feat)
242 | mlp_train_loss = vae_loss(z_mean, z_log_std, adj_label)
243 | mlp_train_loss.backward()
244 | optimizer_mlp.step()
245 | print('GCN [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(batch_idx, num_iters,
246 | 100. * batch_idx / num_iters,
247 | vae_train_loss.item()))
248 |
249 | if batch_idx % 10 == 0:
250 | # print('GCN [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(batch_idx, num_iters,
251 | # 100. * batch_idx / num_iters,
252 | # vae_train_loss.item()))
253 | # print('MLP [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(batch_idx, num_iters,
254 | # 100. * batch_idx / num_iters,
255 | # mlp_train_loss.item()))
256 |
257 | with torch.no_grad():
258 |
259 | # test original gcn
260 | gcn_vae.eval()
261 | z = gcn_vae(adj_norm, feat)
262 | roc, ap = get_roc_score(val_edges, val_edges_false, z.numpy())
263 | print('GCN val AP: {:.6f}'.format(ap))
264 | print('GCN val AUC: {:.6f}'.format(roc))
265 |
266 |
267 | mlp.eval()
268 | z_mean, z_log_std = mlp(feat)
269 | normal = torch.distributions.Normal(0, 1)
270 | z = normal.sample(z_mean.size())
271 | z = z * torch.exp(z_log_std) + z_mean
272 | roc, ap = get_roc_score(val_edges, val_edges_false, z.numpy())
273 | print('MLP val AP: {:.6f}'.format(ap))
274 | print('MLP val AUC: {:.6f}'.format(roc))
275 |
276 |
277 | with torch.no_grad():
278 |
279 | mlp.eval()
280 | z_mean, z_log_std = mlp(feat)
281 | normal = torch.distributions.Normal(0, 1)
282 | z = normal.sample(z_mean.size())
283 | z = z * torch.exp(z_log_std) + z_mean
284 | roc, ap = get_roc_score(test_edges, test_edges_false, z.numpy())
285 | print('MLP val AP: {:.6f}'.format(ap))
286 | print('MLP val AUC: {:.6f}'.format(roc))
287 |
288 | gcn_vae.eval()
289 | gcn_vae.eval()
290 | z = gcn_vae(adj_norm, feat)
291 | roc, ap = get_roc_score(test_edges, test_edges_false, z.numpy())
292 | print('GCN test AP: {:.6f}'.format(ap))
293 | print('GCN test AUC: {:.6f}'.format(roc))
294 |
--------------------------------------------------------------------------------
/src/validate_implementation.py:
--------------------------------------------------------------------------------
1 | import math
2 | import argparse
3 | import networkx as nx
4 | import numpy as np
5 | import scipy.sparse as sp
6 | import torch
7 | import pickle as pkl
8 |
9 | from sklearn.metrics import roc_auc_score
10 | from sklearn.metrics import average_precision_score
11 |
12 | from layers import GraphVae, MLP
13 | from loss import reconstruction_loss, vae_loss
14 |
15 |
16 | parser = argparse.ArgumentParser()
17 | parser.add_argument('--hidden_dim', type=int, default=400)
18 | parser.add_argument('--out_dim', type=int, default=200)
19 | parser.add_argument('--num_iters', type=int, default=200)
20 | parser.add_argument('--data_set', type=str, default='cora', choices = ['cora', 'citeseer', 'pubmed'])
21 | parser.add_argument('--seed', type=int, default=888)
22 | args = parser.parse_args()
23 |
24 | hidden_dim = args.hidden_dim
25 | out_dim = args.out_dim
26 | cite_data = args.data_set
27 |
28 | norm=None
29 | num_iters = args.num_iters
30 | seed = args.seed
31 | np.random.seed(seed)
32 |
33 | ############## utility functions ##############
34 | # def read_citation_dat(dataset):
35 | # '''
36 | # dataset: {'cora', 'citeseer', 'pubmed'}
37 | # '''
38 | #
39 | # feat_fname = '../data/' + dataset + '_features.npz'
40 | # adj_fname = '../data/' + dataset + '_graph.npz'
41 | # features = sp.load_npz(feat_fname)
42 | # adj_orig = sp.load_npz(adj_fname)
43 | # adj_orig = adj_orig + sp.eye(adj_orig.shape[0])
44 | # return adj_orig, features
45 | def parse_index_file(filename):
46 | index = []
47 | for line in open(filename):
48 | index.append(int(line.strip()))
49 | return index
50 |
51 | def load_data(dataset):
52 | # load the data: x, tx, allx, graph
53 | names = ['x', 'tx', 'allx', 'graph']
54 | objects = []
55 | for i in range(len(names)):
56 | objects.append(pkl.load(open("../data/test/ind.{}.{}".format(dataset, names[i]))))
57 | x, tx, allx, graph = tuple(objects)
58 | test_idx_reorder = parse_index_file("../data/test/ind.{}.test.index".format(dataset))
59 | test_idx_range = np.sort(test_idx_reorder)
60 |
61 | if dataset == 'citeseer':
62 | # Fix citeseer dataset (there are some isolated nodes in the graph)
63 | # Find isolated nodes, add them as zero-vecs into the right position
64 | test_idx_range_full = range(min(test_idx_reorder), max(test_idx_reorder)+1)
65 | tx_extended = sp.lil_matrix((len(test_idx_range_full), x.shape[1]))
66 | tx_extended[test_idx_range-min(test_idx_range), :] = tx
67 | tx = tx_extended
68 |
69 | features = sp.vstack((allx, tx)).tolil()
70 | features[test_idx_reorder, :] = features[test_idx_range, :]
71 | adj = nx.adjacency_matrix(nx.from_dict_of_lists(graph))
72 |
73 | return adj, features
74 |
75 | def sparse_to_tuple(sparse_mx):
76 | if not sp.isspmatrix_coo(sparse_mx):
77 | sparse_mx = sparse_mx.tocoo()
78 | coords = np.vstack((sparse_mx.row, sparse_mx.col)).transpose()
79 | values = sparse_mx.data
80 | shape = sparse_mx.shape
81 | return coords, values, shape
82 |
83 | def mask_test_edges(adj):
84 | # Function to build test set with 10% positive links
85 | # NOTE: Splits are randomized and results might slightly deviate from reported numbers in the paper.
86 | # TODO: Clean up.
87 |
88 | # Remove diagonal elements
89 | adj = adj - sp.dia_matrix((adj.diagonal()[np.newaxis, :], [0]), shape=adj.shape)
90 | adj.eliminate_zeros()
91 | # Check that diag is zero:
92 | assert np.diag(adj.todense()).sum() == 0
93 |
94 | adj_triu = sp.triu(adj)
95 | adj_tuple = sparse_to_tuple(adj_triu)
96 | edges = adj_tuple[0]
97 | edges_all = sparse_to_tuple(adj)[0]
98 | num_test = int(np.floor(edges.shape[0] / 10.))
99 | num_val = int(np.floor(edges.shape[0] / 20.))
100 |
101 | all_edge_idx = range(edges.shape[0])
102 | np.random.shuffle(all_edge_idx)
103 | val_edge_idx = all_edge_idx[:num_val]
104 | test_edge_idx = all_edge_idx[num_val:(num_val + num_test)]
105 | test_edges = edges[test_edge_idx]
106 | val_edges = edges[val_edge_idx]
107 | train_edges = np.delete(edges, np.hstack([test_edge_idx, val_edge_idx]), axis=0)
108 |
109 | def ismember(a, b, tol=5):
110 | rows_close = np.all(np.round(a - b[:, None], tol) == 0, axis=-1)
111 | return np.any(rows_close)
112 |
113 | test_edges_false = []
114 | while len(test_edges_false) < len(test_edges):
115 | idx_i = np.random.randint(0, adj.shape[0])
116 | idx_j = np.random.randint(0, adj.shape[0])
117 | if idx_i == idx_j:
118 | continue
119 | if ismember([idx_i, idx_j], edges_all):
120 | continue
121 | if test_edges_false:
122 | if ismember([idx_j, idx_i], np.array(test_edges_false)):
123 | continue
124 | if ismember([idx_i, idx_j], np.array(test_edges_false)):
125 | continue
126 | test_edges_false.append([idx_i, idx_j])
127 |
128 | val_edges_false = []
129 | while len(val_edges_false) < len(val_edges):
130 | idx_i = np.random.randint(0, adj.shape[0])
131 | idx_j = np.random.randint(0, adj.shape[0])
132 | if idx_i == idx_j:
133 | continue
134 | if ismember([idx_i, idx_j], train_edges):
135 | continue
136 | if ismember([idx_j, idx_i], train_edges):
137 | continue
138 | if ismember([idx_i, idx_j], val_edges):
139 | continue
140 | if ismember([idx_j, idx_i], val_edges):
141 | continue
142 | if val_edges_false:
143 | if ismember([idx_j, idx_i], np.array(val_edges_false)):
144 | continue
145 | if ismember([idx_i, idx_j], np.array(val_edges_false)):
146 | continue
147 | val_edges_false.append([idx_i, idx_j])
148 |
149 | assert ~ismember(test_edges_false, edges_all)
150 | assert ~ismember(val_edges_false, edges_all)
151 | assert ~ismember(val_edges, train_edges)
152 | assert ~ismember(test_edges, train_edges)
153 | assert ~ismember(val_edges, test_edges)
154 |
155 | data = np.ones(train_edges.shape[0])
156 |
157 | # Re-build adj matrix
158 | adj_train = sp.csr_matrix((data, (train_edges[:, 0], train_edges[:, 1])), shape=adj.shape)
159 | adj_train = adj_train + adj_train.T
160 |
161 | # NOTE: these edge lists only contain single direction of edge!
162 | return adj_train, train_edges, val_edges, val_edges_false, test_edges, test_edges_false
163 |
164 | def preprocess_graph(adj):
165 | adj = sp.coo_matrix(adj)
166 | adj_ = adj + sp.eye(adj.shape[0])
167 | rowsum = np.array(adj_.sum(1))
168 | degree_mat_inv_sqrt = np.diag(np.power(rowsum, -0.5).flatten())
169 | adj_normalized = adj_.dot(degree_mat_inv_sqrt).transpose().dot(degree_mat_inv_sqrt)
170 |
171 | #print(adj_normalized[:20, :].sum(1))
172 |
173 | return adj_normalized.astype(np.float32)
174 |
175 | def get_roc_score(edges_pos, edges_neg, emb):
176 |
177 | def sigmoid(x):
178 | return 1 / (1 + np.exp(-x))
179 |
180 | # Predict on test set of edges
181 | adj_rec = np.dot(emb, emb.T)
182 | preds = []
183 | pos = []
184 | for e in edges_pos:
185 | preds.append(sigmoid(adj_rec[e[0], e[1]]))
186 | pos.append(adj_orig[e[0], e[1]])
187 |
188 | preds_neg = []
189 | neg = []
190 | for e in edges_neg:
191 | preds_neg.append(sigmoid(adj_rec[e[0], e[1]]))
192 | neg.append(adj_orig[e[0], e[1]])
193 |
194 | preds_all = np.hstack([preds, preds_neg])
195 | labels_all = np.hstack([np.ones(len(preds)), np.zeros(len(preds))])
196 | roc_score = roc_auc_score(labels_all, preds_all)
197 | ap_score = average_precision_score(labels_all, preds_all)
198 |
199 | return roc_score, ap_score
200 |
201 |
202 | ############ prepare data ##############
203 | adj, feat = load_data(args.data_set)
204 |
205 | features_dim = feat.shape[1]
206 |
207 | # Store original adjacency matrix (without diagonal entries) for later
208 | adj_orig = adj
209 | adj_orig = adj_orig - sp.dia_matrix((adj_orig.diagonal()[np.newaxis, :], [0]), shape=adj_orig.shape)
210 | adj_orig.eliminate_zeros()
211 |
212 | adj_train, train_edges, val_edges, val_edges_false, test_edges, test_edges_false = mask_test_edges(adj)
213 | adj = adj_train
214 |
215 | adj_label = adj_train + sp.eye(adj_train.shape[0])
216 |
217 | adj_norm = torch.from_numpy(preprocess_graph(adj))
218 | adj_label = torch.from_numpy(adj_label.todense().astype(np.float32))
219 | feat = torch.from_numpy(feat.todense().astype(np.float32))
220 |
221 | ############## init model ##############
222 | gcn_vae = GraphVae(features_dim, hidden_dim, out_dim, bias=False, dropout=0.0)
223 | optimizer_vae = torch.optim.Adam(gcn_vae.parameters(), lr=0.01)
224 |
225 | mlp = MLP(features_dim, hidden_dim, out_dim, dropout=0.0)
226 | optimizer_mlp = torch.optim.Adam(mlp.parameters(), lr=1e-2)
227 |
228 | for batch_idx in range(num_iters):
229 | # train GCN
230 | optimizer_vae.zero_grad()
231 | gcn_vae.train()
232 | z_mean, z_log_std = gcn_vae(adj_norm, feat)
233 | vae_train_loss = vae_loss(z_mean, z_log_std, adj_label)
234 | vae_train_loss.backward()
235 | optimizer_vae.step()
236 |
237 | #train mlp
238 | optimizer_mlp.zero_grad()
239 | mlp.train()
240 | z_mean, z_log_std = mlp(feat)
241 | mlp_train_loss = vae_loss(z_mean, z_log_std, adj_label)
242 | mlp_train_loss.backward()
243 | optimizer_mlp.step()
244 | print('GCN [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(batch_idx, num_iters,
245 | 100. * batch_idx / num_iters,
246 | vae_train_loss.item()))
247 |
248 | if batch_idx % 10 == 0:
249 | # print('GCN [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(batch_idx, num_iters,
250 | # 100. * batch_idx / num_iters,
251 | # vae_train_loss.item()))
252 | # print('MLP [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(batch_idx, num_iters,
253 | # 100. * batch_idx / num_iters,
254 | # mlp_train_loss.item()))
255 |
256 | with torch.no_grad():
257 |
258 | # test original gcn
259 | gcn_vae.eval()
260 | z_mean, z_log_std = gcn_vae(adj_norm, feat)
261 |
262 | normal = torch.distributions.Normal(0, 1)
263 | z = normal.sample(z_mean.size())
264 | z = z * torch.exp(z_log_std) + z_mean
265 |
266 | roc, ap = get_roc_score(val_edges, val_edges_false, z.numpy())
267 | print('GCN val AP: {:.6f}'.format(ap))
268 | print('GCN val AUC: {:.6f}'.format(roc))
269 |
270 |
271 | mlp.eval()
272 | z_mean, z_log_std = mlp(feat)
273 | normal = torch.distributions.Normal(0, 1)
274 | z = normal.sample(z_mean.size())
275 | z = z * torch.exp(z_log_std) + z_mean
276 | roc, ap = get_roc_score(val_edges, val_edges_false, z.numpy())
277 | print('MLP val AP: {:.6f}'.format(ap))
278 | print('MLP val AUC: {:.6f}'.format(roc))
279 |
280 |
281 | with torch.no_grad():
282 |
283 | mlp.eval()
284 | z_mean, z_log_std = mlp(feat)
285 | normal = torch.distributions.Normal(0, 1)
286 | z = normal.sample(z_mean.size())
287 | z = z * torch.exp(z_log_std) + z_mean
288 | roc, ap = get_roc_score(test_edges, test_edges_false, z.numpy())
289 | print('MLP val AP: {:.6f}'.format(ap))
290 | print('MLP val AUC: {:.6f}'.format(roc))
291 |
292 | gcn_vae.eval()
293 | z_mean, z_log_std = gcn_vae(adj_norm, feat)
294 | normal = torch.distributions.Normal(0, 1)
295 | z = normal.sample(z_mean.size())
296 | z = z * torch.exp(z_log_std) + z_mean
297 | roc, ap = get_roc_score(test_edges, test_edges_false, z.numpy())
298 | print('GCN test AP: {:.6f}'.format(ap))
299 | print('GCN test AUC: {:.6f}'.format(roc))
300 |
--------------------------------------------------------------------------------