├── layers ├── __init__.py ├── __pycache__ │ ├── layers.cpython-38.pyc │ ├── layers.cpython-39.pyc │ ├── __init__.cpython-38.pyc │ ├── __init__.cpython-39.pyc │ ├── att_layers.cpython-38.pyc │ ├── att_layers.cpython-39.pyc │ ├── hyp_layers.cpython-38.pyc │ └── hyp_layers.cpython-39.pyc ├── layers.py ├── att_layers.py └── hyp_layers.py ├── models ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-38.pyc │ ├── __init__.cpython-39.pyc │ ├── decoders.cpython-38.pyc │ ├── encoders.cpython-38.pyc │ ├── encoders.cpython-39.pyc │ ├── base_models.cpython-38.pyc │ ├── base_models.cpython-39.pyc │ ├── direciton_diffusion.cpython-38.pyc │ ├── direciton_diffusion.cpython-39.pyc │ ├── direction_diffusion.cpython-38.pyc │ ├── direction_diffusion.cpython-39.pyc │ ├── direction_diffusionset.cpython-38.pyc │ └── direction_diffusionset.cpython-39.pyc ├── Aggregation.py ├── decoders.py ├── model.py ├── hyp_model.py ├── encoders.py └── base_models.py ├── utils ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-38.pyc │ ├── __init__.cpython-39.pyc │ ├── polblogs.cpython-38.pyc │ ├── polblogs.cpython-39.pyc │ ├── data_utils.cpython-38.pyc │ ├── data_utils.cpython-39.pyc │ ├── eval_utils.cpython-38.pyc │ ├── eval_utils.cpython-39.pyc │ ├── math_utils.cpython-38.pyc │ ├── math_utils.cpython-39.pyc │ ├── train_utils.cpython-38.pyc │ └── train_utils.cpython-39.pyc ├── eval_utils.py ├── math_utils.py ├── hyperbolicity.py ├── polblogs.py └── train_utils.py ├── graph_evaluate ├── eval │ ├── __init__.py │ ├── MANIFEST.in │ ├── orca │ │ ├── test.txt │ │ └── orca │ ├── __pycache__ │ │ ├── mmd.cpython-311.pyc │ │ ├── mmd.cpython-38.pyc │ │ ├── mmd.cpython-39.pyc │ │ ├── stats.cpython-311.pyc │ │ ├── stats.cpython-38.pyc │ │ ├── stats.cpython-39.pyc │ │ ├── __init__.cpython-311.pyc │ │ ├── __init__.cpython-38.pyc │ │ └── __init__.cpython-39.pyc │ ├── setup.py │ ├── orcamodule.cpp │ └── mmd.py ├── eval_results │ ├── HyperDiff_MUTAG.csv │ ├── HypDiff_CL100.csv │ └── HypDiff_MUTAG.csv ├── data │ ├── BA.pkl │ ├── Grid.pkl │ ├── SynER_origin.pkl │ ├── SynEgo1000_origin.pkl │ └── SynCommunity1000_origin.pkl ├── __pycache__ │ ├── utils.cpython-38.pyc │ ├── utils.cpython-39.pyc │ ├── args_eval.cpython-38.pyc │ ├── args_eval.cpython-39.pyc │ ├── dist_helper.cpython-38.pyc │ ├── dist_helper.cpython-39.pyc │ ├── spectre_utils.cpython-38.pyc │ └── spectre_utils.cpython-39.pyc ├── baselines │ ├── __pycache__ │ │ ├── baseline_simple.cpython-38.pyc │ │ └── baseline_simple.cpython-39.pyc │ ├── graphvae │ │ ├── data.py │ │ └── train.py │ └── mmsb.py ├── args_eval.py ├── dist_helper.py └── spectre_utils.py ├── data ├── BA.pkl ├── SynEgo1000_origin.pkl └── SynCommunity1000_origin.pkl ├── optimizers ├── __init__.py ├── __pycache__ │ ├── radam.cpython-38.pyc │ ├── radam.cpython-39.pyc │ ├── __init__.cpython-38.pyc │ └── __init__.cpython-39.pyc └── radam.py ├── manifolds ├── __pycache__ │ ├── base.cpython-38.pyc │ ├── base.cpython-39.pyc │ ├── __init__.cpython-38.pyc │ ├── __init__.cpython-39.pyc │ ├── euclidean.cpython-38.pyc │ ├── euclidean.cpython-39.pyc │ ├── poincare.cpython-38.pyc │ ├── poincare.cpython-39.pyc │ ├── hyperboloid.cpython-38.pyc │ └── hyperboloid.cpython-39.pyc ├── __init__.py ├── euclidean.py ├── base.py ├── poincare.py └── hyperboloid.py ├── hyperbolic_learning ├── __pycache__ │ ├── hyperkmeans.cpython-38.pyc │ └── hyperkmeans.cpython-39.pyc ├── hyperbolic_kmeans │ ├── __pycache__ │ │ ├── hkmeans.cpython-38.pyc │ │ ├── hkmeans.cpython-39.pyc │ │ ├── util_hk.cpython-38.pyc │ │ └── util_hk.cpython-39.pyc │ ├── models │ │ ├── karate_vectors │ │ ├── polbooks_vectors │ │ ├── football_vectors │ │ └── enron_vectors │ └── util_hk.py └── hyperkmeans.py ├── main.py ├── .gitignore ├── lp_train.py ├── ddpm_utils.py ├── Synthatic_graph_generator.py ├── config.py ├── mmd_rnn.py ├── ddpm.py └── GlobalProperties.py /layers/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /graph_evaluate/eval/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /graph_evaluate/eval_results/HyperDiff_MUTAG.csv: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /graph_evaluate/eval/MANIFEST.in: -------------------------------------------------------------------------------- 1 | include orca/orca.h 2 | 3 | -------------------------------------------------------------------------------- /data/BA.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RingBDStack/HypDiff/HEAD/data/BA.pkl -------------------------------------------------------------------------------- /graph_evaluate/eval/orca/test.txt: -------------------------------------------------------------------------------- 1 | 4 4 2 | 0 1 3 | 1 2 4 | 2 3 5 | 3 0 6 | 7 | -------------------------------------------------------------------------------- /optimizers/__init__.py: -------------------------------------------------------------------------------- 1 | from torch.optim import Adam 2 | from .radam import RiemannianAdam 3 | -------------------------------------------------------------------------------- /data/SynEgo1000_origin.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RingBDStack/HypDiff/HEAD/data/SynEgo1000_origin.pkl -------------------------------------------------------------------------------- /graph_evaluate/data/BA.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RingBDStack/HypDiff/HEAD/graph_evaluate/data/BA.pkl -------------------------------------------------------------------------------- /graph_evaluate/data/Grid.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RingBDStack/HypDiff/HEAD/graph_evaluate/data/Grid.pkl -------------------------------------------------------------------------------- /graph_evaluate/eval/orca/orca: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RingBDStack/HypDiff/HEAD/graph_evaluate/eval/orca/orca -------------------------------------------------------------------------------- /data/SynCommunity1000_origin.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RingBDStack/HypDiff/HEAD/data/SynCommunity1000_origin.pkl -------------------------------------------------------------------------------- /graph_evaluate/data/SynER_origin.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RingBDStack/HypDiff/HEAD/graph_evaluate/data/SynER_origin.pkl -------------------------------------------------------------------------------- /graph_evaluate/data/SynEgo1000_origin.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RingBDStack/HypDiff/HEAD/graph_evaluate/data/SynEgo1000_origin.pkl -------------------------------------------------------------------------------- /layers/__pycache__/layers.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RingBDStack/HypDiff/HEAD/layers/__pycache__/layers.cpython-38.pyc -------------------------------------------------------------------------------- /layers/__pycache__/layers.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RingBDStack/HypDiff/HEAD/layers/__pycache__/layers.cpython-39.pyc -------------------------------------------------------------------------------- /manifolds/__pycache__/base.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RingBDStack/HypDiff/HEAD/manifolds/__pycache__/base.cpython-38.pyc -------------------------------------------------------------------------------- /manifolds/__pycache__/base.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RingBDStack/HypDiff/HEAD/manifolds/__pycache__/base.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RingBDStack/HypDiff/HEAD/utils/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RingBDStack/HypDiff/HEAD/utils/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/polblogs.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RingBDStack/HypDiff/HEAD/utils/__pycache__/polblogs.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/polblogs.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RingBDStack/HypDiff/HEAD/utils/__pycache__/polblogs.cpython-39.pyc -------------------------------------------------------------------------------- /layers/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RingBDStack/HypDiff/HEAD/layers/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /layers/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RingBDStack/HypDiff/HEAD/layers/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /layers/__pycache__/att_layers.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RingBDStack/HypDiff/HEAD/layers/__pycache__/att_layers.cpython-38.pyc -------------------------------------------------------------------------------- /layers/__pycache__/att_layers.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RingBDStack/HypDiff/HEAD/layers/__pycache__/att_layers.cpython-39.pyc -------------------------------------------------------------------------------- /layers/__pycache__/hyp_layers.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RingBDStack/HypDiff/HEAD/layers/__pycache__/hyp_layers.cpython-38.pyc -------------------------------------------------------------------------------- /layers/__pycache__/hyp_layers.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RingBDStack/HypDiff/HEAD/layers/__pycache__/hyp_layers.cpython-39.pyc -------------------------------------------------------------------------------- /models/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RingBDStack/HypDiff/HEAD/models/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RingBDStack/HypDiff/HEAD/models/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /models/__pycache__/decoders.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RingBDStack/HypDiff/HEAD/models/__pycache__/decoders.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/encoders.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RingBDStack/HypDiff/HEAD/models/__pycache__/encoders.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/encoders.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RingBDStack/HypDiff/HEAD/models/__pycache__/encoders.cpython-39.pyc -------------------------------------------------------------------------------- /optimizers/__pycache__/radam.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RingBDStack/HypDiff/HEAD/optimizers/__pycache__/radam.cpython-38.pyc -------------------------------------------------------------------------------- /optimizers/__pycache__/radam.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RingBDStack/HypDiff/HEAD/optimizers/__pycache__/radam.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/data_utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RingBDStack/HypDiff/HEAD/utils/__pycache__/data_utils.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/data_utils.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RingBDStack/HypDiff/HEAD/utils/__pycache__/data_utils.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/eval_utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RingBDStack/HypDiff/HEAD/utils/__pycache__/eval_utils.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/eval_utils.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RingBDStack/HypDiff/HEAD/utils/__pycache__/eval_utils.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/math_utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RingBDStack/HypDiff/HEAD/utils/__pycache__/math_utils.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/math_utils.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RingBDStack/HypDiff/HEAD/utils/__pycache__/math_utils.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/train_utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RingBDStack/HypDiff/HEAD/utils/__pycache__/train_utils.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/train_utils.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RingBDStack/HypDiff/HEAD/utils/__pycache__/train_utils.cpython-39.pyc -------------------------------------------------------------------------------- /manifolds/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RingBDStack/HypDiff/HEAD/manifolds/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /manifolds/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RingBDStack/HypDiff/HEAD/manifolds/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /manifolds/__pycache__/euclidean.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RingBDStack/HypDiff/HEAD/manifolds/__pycache__/euclidean.cpython-38.pyc -------------------------------------------------------------------------------- /manifolds/__pycache__/euclidean.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RingBDStack/HypDiff/HEAD/manifolds/__pycache__/euclidean.cpython-39.pyc -------------------------------------------------------------------------------- /manifolds/__pycache__/poincare.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RingBDStack/HypDiff/HEAD/manifolds/__pycache__/poincare.cpython-38.pyc -------------------------------------------------------------------------------- /manifolds/__pycache__/poincare.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RingBDStack/HypDiff/HEAD/manifolds/__pycache__/poincare.cpython-39.pyc -------------------------------------------------------------------------------- /models/__pycache__/base_models.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RingBDStack/HypDiff/HEAD/models/__pycache__/base_models.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/base_models.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RingBDStack/HypDiff/HEAD/models/__pycache__/base_models.cpython-39.pyc -------------------------------------------------------------------------------- /optimizers/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RingBDStack/HypDiff/HEAD/optimizers/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /optimizers/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RingBDStack/HypDiff/HEAD/optimizers/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /graph_evaluate/__pycache__/utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RingBDStack/HypDiff/HEAD/graph_evaluate/__pycache__/utils.cpython-38.pyc -------------------------------------------------------------------------------- /graph_evaluate/__pycache__/utils.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RingBDStack/HypDiff/HEAD/graph_evaluate/__pycache__/utils.cpython-39.pyc -------------------------------------------------------------------------------- /graph_evaluate/data/SynCommunity1000_origin.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RingBDStack/HypDiff/HEAD/graph_evaluate/data/SynCommunity1000_origin.pkl -------------------------------------------------------------------------------- /manifolds/__pycache__/hyperboloid.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RingBDStack/HypDiff/HEAD/manifolds/__pycache__/hyperboloid.cpython-38.pyc -------------------------------------------------------------------------------- /manifolds/__pycache__/hyperboloid.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RingBDStack/HypDiff/HEAD/manifolds/__pycache__/hyperboloid.cpython-39.pyc -------------------------------------------------------------------------------- /graph_evaluate/__pycache__/args_eval.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RingBDStack/HypDiff/HEAD/graph_evaluate/__pycache__/args_eval.cpython-38.pyc -------------------------------------------------------------------------------- /graph_evaluate/__pycache__/args_eval.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RingBDStack/HypDiff/HEAD/graph_evaluate/__pycache__/args_eval.cpython-39.pyc -------------------------------------------------------------------------------- /graph_evaluate/eval/__pycache__/mmd.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RingBDStack/HypDiff/HEAD/graph_evaluate/eval/__pycache__/mmd.cpython-311.pyc -------------------------------------------------------------------------------- /graph_evaluate/eval/__pycache__/mmd.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RingBDStack/HypDiff/HEAD/graph_evaluate/eval/__pycache__/mmd.cpython-38.pyc -------------------------------------------------------------------------------- /graph_evaluate/eval/__pycache__/mmd.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RingBDStack/HypDiff/HEAD/graph_evaluate/eval/__pycache__/mmd.cpython-39.pyc -------------------------------------------------------------------------------- /graph_evaluate/__pycache__/dist_helper.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RingBDStack/HypDiff/HEAD/graph_evaluate/__pycache__/dist_helper.cpython-38.pyc -------------------------------------------------------------------------------- /graph_evaluate/__pycache__/dist_helper.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RingBDStack/HypDiff/HEAD/graph_evaluate/__pycache__/dist_helper.cpython-39.pyc -------------------------------------------------------------------------------- /graph_evaluate/eval/__pycache__/stats.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RingBDStack/HypDiff/HEAD/graph_evaluate/eval/__pycache__/stats.cpython-311.pyc -------------------------------------------------------------------------------- /graph_evaluate/eval/__pycache__/stats.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RingBDStack/HypDiff/HEAD/graph_evaluate/eval/__pycache__/stats.cpython-38.pyc -------------------------------------------------------------------------------- /graph_evaluate/eval/__pycache__/stats.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RingBDStack/HypDiff/HEAD/graph_evaluate/eval/__pycache__/stats.cpython-39.pyc -------------------------------------------------------------------------------- /graph_evaluate/eval_results/HypDiff_CL100.csv: -------------------------------------------------------------------------------- 1 | sample_time,epoch,degree_validate,clustering_validate,orbits4_validate,degree_test,clustering_test,orbits4_test 2 | -------------------------------------------------------------------------------- /graph_evaluate/eval_results/HypDiff_MUTAG.csv: -------------------------------------------------------------------------------- 1 | sample_time,epoch,degree_validate,clustering_validate,orbits4_validate,degree_test,clustering_test,orbits4_test 2 | -------------------------------------------------------------------------------- /models/__pycache__/direciton_diffusion.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RingBDStack/HypDiff/HEAD/models/__pycache__/direciton_diffusion.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/direciton_diffusion.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RingBDStack/HypDiff/HEAD/models/__pycache__/direciton_diffusion.cpython-39.pyc -------------------------------------------------------------------------------- /models/__pycache__/direction_diffusion.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RingBDStack/HypDiff/HEAD/models/__pycache__/direction_diffusion.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/direction_diffusion.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RingBDStack/HypDiff/HEAD/models/__pycache__/direction_diffusion.cpython-39.pyc -------------------------------------------------------------------------------- /graph_evaluate/__pycache__/spectre_utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RingBDStack/HypDiff/HEAD/graph_evaluate/__pycache__/spectre_utils.cpython-38.pyc -------------------------------------------------------------------------------- /graph_evaluate/__pycache__/spectre_utils.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RingBDStack/HypDiff/HEAD/graph_evaluate/__pycache__/spectre_utils.cpython-39.pyc -------------------------------------------------------------------------------- /graph_evaluate/eval/__pycache__/__init__.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RingBDStack/HypDiff/HEAD/graph_evaluate/eval/__pycache__/__init__.cpython-311.pyc -------------------------------------------------------------------------------- /graph_evaluate/eval/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RingBDStack/HypDiff/HEAD/graph_evaluate/eval/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /graph_evaluate/eval/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RingBDStack/HypDiff/HEAD/graph_evaluate/eval/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /models/__pycache__/direction_diffusionset.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RingBDStack/HypDiff/HEAD/models/__pycache__/direction_diffusionset.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/direction_diffusionset.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RingBDStack/HypDiff/HEAD/models/__pycache__/direction_diffusionset.cpython-39.pyc -------------------------------------------------------------------------------- /hyperbolic_learning/__pycache__/hyperkmeans.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RingBDStack/HypDiff/HEAD/hyperbolic_learning/__pycache__/hyperkmeans.cpython-38.pyc -------------------------------------------------------------------------------- /hyperbolic_learning/__pycache__/hyperkmeans.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RingBDStack/HypDiff/HEAD/hyperbolic_learning/__pycache__/hyperkmeans.cpython-39.pyc -------------------------------------------------------------------------------- /manifolds/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import ManifoldParameter 2 | from .euclidean import Euclidean 3 | from .hyperboloid import Hyperboloid 4 | from .poincare import PoincareBall 5 | -------------------------------------------------------------------------------- /graph_evaluate/baselines/__pycache__/baseline_simple.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RingBDStack/HypDiff/HEAD/graph_evaluate/baselines/__pycache__/baseline_simple.cpython-38.pyc -------------------------------------------------------------------------------- /graph_evaluate/baselines/__pycache__/baseline_simple.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RingBDStack/HypDiff/HEAD/graph_evaluate/baselines/__pycache__/baseline_simple.cpython-39.pyc -------------------------------------------------------------------------------- /hyperbolic_learning/hyperbolic_kmeans/__pycache__/hkmeans.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RingBDStack/HypDiff/HEAD/hyperbolic_learning/hyperbolic_kmeans/__pycache__/hkmeans.cpython-38.pyc -------------------------------------------------------------------------------- /hyperbolic_learning/hyperbolic_kmeans/__pycache__/hkmeans.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RingBDStack/HypDiff/HEAD/hyperbolic_learning/hyperbolic_kmeans/__pycache__/hkmeans.cpython-39.pyc -------------------------------------------------------------------------------- /hyperbolic_learning/hyperbolic_kmeans/__pycache__/util_hk.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RingBDStack/HypDiff/HEAD/hyperbolic_learning/hyperbolic_kmeans/__pycache__/util_hk.cpython-38.pyc -------------------------------------------------------------------------------- /hyperbolic_learning/hyperbolic_kmeans/__pycache__/util_hk.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RingBDStack/HypDiff/HEAD/hyperbolic_learning/hyperbolic_kmeans/__pycache__/util_hk.cpython-39.pyc -------------------------------------------------------------------------------- /graph_evaluate/eval/setup.py: -------------------------------------------------------------------------------- 1 | from distutils.core import setup, Extension 2 | 3 | orca_module = Extension('orca', 4 | sources = ['orcamodule.cpp'], 5 | extra_compile_args=['-std=c++11'],) 6 | 7 | setup (name = 'orca', 8 | version = '1.0', 9 | description = 'ORCA motif counting package', 10 | ext_modules = [orca_module]) 11 | 12 | -------------------------------------------------------------------------------- /utils/eval_utils.py: -------------------------------------------------------------------------------- 1 | from sklearn.metrics import average_precision_score, accuracy_score, f1_score 2 | 3 | def acc_f1(output, labels, average='binary'): 4 | preds = output.max(1)[1].type_as(labels) 5 | if preds.is_cuda: 6 | preds = preds.cpu() 7 | labels = labels.cpu() 8 | accuracy = accuracy_score(preds, labels) 9 | f1 = f1_score(preds, labels, average=average) 10 | return accuracy, f1 11 | 12 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from config import parser 2 | from diff import hyperdiff,test,hyperdiff_graphset,test_graphset 3 | 4 | import os 5 | import time 6 | 7 | args = parser.parse_args() 8 | os.environ['DATAPATH'] = 'data/' 9 | os.environ['LOG_DIR'] = 'logs/'+args.taskselect+'/' 10 | 11 | if args.taskselect=='lptask': 12 | from lp_train import train 13 | if args.type=='train': 14 | # Hyperbolic Geometric Auto-encoding 15 | train(args) 16 | # Hyperbolic Geometric Diffusion Process 17 | hyperdiff(args) 18 | 19 | elif args.type=='test': 20 | # test based on a trained model 21 | test(args) 22 | 23 | elif args.taskselect=='graphtask': 24 | from graph_train import graph_train 25 | # an easy encoder: encoding from adj 26 | graph_train(args) 27 | from diff_graphset import diff_train 28 | diff_train(args) -------------------------------------------------------------------------------- /models/Aggregation.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | 4 | 5 | 6 | 7 | class AvePool(torch.nn.Module): 8 | def __init__(self): 9 | super(AvePool, self).__init__() 10 | 11 | def forward(self, in_tensor): 12 | return torch.sum(in_tensor,1) 13 | 14 | class GcnPool(torch.nn.Module): 15 | """ 16 | This layer apply a chain of mlp on each node of tthe graph. 17 | thr input is a matric matrrix with n rows whixh n is the nide number. 18 | """ 19 | 20 | def __init__(self, inputDim, OutDim=1024): 21 | """ 22 | 23 | :param inputDim: the feture size of input matrix; Number of the columns or dim of feature matrix 24 | :param normalize: either use the normalizer layer or not 25 | :param layers: the graph feature size or the size of fature matrix before aggregation 26 | """ 27 | super(GcnPool, self).__init__() 28 | self.featureTrnsfr = torch.nn.Linear(inputDim, OutDim) 29 | 30 | def forward(self, in_tensor, activation=torch.tanh): 31 | z = self.featureTrnsfr(in_tensor) 32 | z = torch.mean(z, 1) 33 | z = activation(z) 34 | return z 35 | 36 | -------------------------------------------------------------------------------- /hyperbolic_learning/hyperbolic_kmeans/models/karate_vectors: -------------------------------------------------------------------------------- 1 | 34 2 2 | 34 -0.16521078398865507 -0.09834702478146498 3 | 1 0.2689998915588179 0.5478055699508698 4 | 33 -0.26380420700980345 -0.1677194174662416 5 | 3 0.31240940150649543 0.0031734391458109502 6 | 2 0.36839439079301894 0.3861976571742569 7 | 4 0.27547103116866095 0.29822321218139114 8 | 32 -0.0358142838555646 -0.40380633280311373 9 | 9 -0.1916348374039683 -0.06906006529177099 10 | 14 0.2774297689512082 0.26447762992066004 11 | 24 -0.38373312303561846 -0.7255246141134708 12 | 6 -0.039745222376482606 0.7389129778789855 13 | 7 -0.02835266818205453 0.748750477830748 14 | 8 0.2709259185438642 0.2700594035020807 15 | 28 -0.14834510588483896 -0.5240884515277903 16 | 30 -0.5132144185279716 -0.5511242117055789 17 | 31 -0.22955148588244847 -0.17999443956056677 18 | 5 0.00320613508453613 0.6986272507698765 19 | 11 0.005080764560599365 0.7096177956301655 20 | 20 0.2498985936841383 0.40111342773510456 21 | 26 -0.19372443239639672 -0.7343748396730037 22 | 25 -0.13518007854169406 -0.887882400872604 23 | 29 0.09303844969178927 -0.3920492121331586 24 | 10 0.2295729985861106 -0.19871969202909198 25 | 13 0.26324374066283 0.36331737930502583 26 | 17 -0.043738498194513586 0.7282793624811396 27 | 18 0.3162375382266646 0.4259633708990057 28 | 22 0.3147177279453047 0.4323962921750595 29 | 27 -0.725832073305434 -0.6006317777972117 30 | 15 -0.8155313711279882 0.05716011417293502 31 | 16 -0.7100356671015798 -0.2692177086944283 32 | 19 -0.8928907512049694 -0.16613301056307728 33 | 21 -0.4530969924533113 -0.12936076680960737 34 | 23 -0.29289458790271106 -0.18056768694093223 35 | 12 0.21778469100045408 0.5502469869500473 36 | -------------------------------------------------------------------------------- /manifolds/euclidean.py: -------------------------------------------------------------------------------- 1 | """Euclidean manifold.""" 2 | 3 | from manifolds.base import Manifold 4 | 5 | 6 | class Euclidean(Manifold): 7 | """ 8 | Euclidean Manifold class. 9 | """ 10 | 11 | def __init__(self): 12 | super(Euclidean, self).__init__() 13 | self.name = 'Euclidean' 14 | 15 | def normalize(self, p): 16 | dim = p.size(-1) 17 | p.view(-1, dim).renorm_(2, 0, 1.) 18 | return p 19 | 20 | def sqdist(self, p1, p2, c): 21 | return (p1 - p2).pow(2).sum(dim=-1) 22 | 23 | def egrad2rgrad(self, p, dp, c): 24 | return dp 25 | 26 | def proj(self, p, c): 27 | return p 28 | 29 | def proj_tan(self, u, p, c): 30 | return u 31 | 32 | def proj_tan0(self, u, c): 33 | return u 34 | 35 | def expmap(self, u, p, c): 36 | return p + u 37 | 38 | def logmap(self, p1, p2, c): 39 | return p2 - p1 40 | 41 | def expmap0(self, u, c): 42 | return u 43 | 44 | def logmap0(self, p, c): 45 | return p 46 | 47 | def mobius_add(self, x, y, c, dim=-1): 48 | return x + y 49 | 50 | def mobius_matvec(self, m, x, c): 51 | mx = x @ m.transpose(-1, -2) 52 | return mx 53 | 54 | def init_weights(self, w, c, irange=1e-5): 55 | w.data.uniform_(-irange, irange) 56 | return w 57 | 58 | def inner(self, p, c, u, v=None, keepdim=False): 59 | if v is None: 60 | v = u 61 | return (u * v).sum(dim=-1, keepdim=keepdim) 62 | 63 | def ptransp(self, x, y, v, c): 64 | return v 65 | 66 | def ptransp0(self, x, v, c): 67 | return x + v 68 | -------------------------------------------------------------------------------- /utils/math_utils.py: -------------------------------------------------------------------------------- 1 | """Math utils functions.""" 2 | 3 | import torch 4 | 5 | 6 | def cosh(x, clamp=15): 7 | return x.clamp(-clamp, clamp).cosh() 8 | 9 | 10 | def sinh(x, clamp=15): 11 | return x.clamp(-clamp, clamp).sinh() 12 | 13 | 14 | def tanh(x, clamp=15): 15 | return x.clamp(-clamp, clamp).tanh() 16 | 17 | 18 | def arcosh(x): 19 | return Arcosh.apply(x) 20 | 21 | 22 | def arsinh(x): 23 | return Arsinh.apply(x) 24 | 25 | 26 | def artanh(x): 27 | return Artanh.apply(x) 28 | 29 | 30 | class Artanh(torch.autograd.Function): 31 | @staticmethod 32 | def forward(ctx, x): 33 | x = x.clamp(-1 + 1e-15, 1 - 1e-15) 34 | ctx.save_for_backward(x) 35 | z = x.double() 36 | return (torch.log_(1 + z).sub_(torch.log_(1 - z))).mul_(0.5).to(x.dtype) 37 | 38 | @staticmethod 39 | def backward(ctx, grad_output): 40 | input, = ctx.saved_tensors 41 | return grad_output / (1 - input ** 2) 42 | 43 | 44 | class Arsinh(torch.autograd.Function): 45 | @staticmethod 46 | def forward(ctx, x): 47 | ctx.save_for_backward(x) 48 | z = x.double() 49 | return (z + torch.sqrt_(1 + z.pow(2))).clamp_min_(1e-15).log_().to(x.dtype) 50 | 51 | @staticmethod 52 | def backward(ctx, grad_output): 53 | input, = ctx.saved_tensors 54 | return grad_output / (1 + input ** 2) ** 0.5 55 | 56 | 57 | class Arcosh(torch.autograd.Function): 58 | @staticmethod 59 | def forward(ctx, x): 60 | x = x.clamp(min=1.0 + 1e-15) 61 | ctx.save_for_backward(x) 62 | z = x.double() 63 | return (z + torch.sqrt_(z.pow(2) - 1)).clamp_min_(1e-15).log_().to(x.dtype) 64 | 65 | @staticmethod 66 | def backward(ctx, grad_output): 67 | input, = ctx.saved_tensors 68 | return grad_output / (input ** 2 - 1) ** 0.5 69 | 70 | -------------------------------------------------------------------------------- /utils/hyperbolicity.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle as pkl 3 | import sys 4 | import time 5 | 6 | import networkx as nx 7 | import numpy as np 8 | from tqdm import tqdm 9 | 10 | from utils.data_utils import load_data_lp 11 | 12 | 13 | def hyperbolicity_sample(G, num_samples=50000): 14 | curr_time = time.time() 15 | hyps = [] 16 | for i in tqdm(range(num_samples)): 17 | curr_time = time.time() 18 | node_tuple = np.random.choice(G.nodes(), 4, replace=False) 19 | s = [] 20 | try: 21 | d01 = nx.shortest_path_length(G, source=node_tuple[0], target=node_tuple[1], weight=None) 22 | d23 = nx.shortest_path_length(G, source=node_tuple[2], target=node_tuple[3], weight=None) 23 | d02 = nx.shortest_path_length(G, source=node_tuple[0], target=node_tuple[2], weight=None) 24 | d13 = nx.shortest_path_length(G, source=node_tuple[1], target=node_tuple[3], weight=None) 25 | d03 = nx.shortest_path_length(G, source=node_tuple[0], target=node_tuple[3], weight=None) 26 | d12 = nx.shortest_path_length(G, source=node_tuple[1], target=node_tuple[2], weight=None) 27 | s.append(d01 + d23) 28 | s.append(d02 + d13) 29 | s.append(d03 + d12) 30 | s.sort() 31 | hyps.append((s[-1] - s[-2]) / 2) 32 | except Exception as e: 33 | continue 34 | print('Time for hyp: ', time.time() - curr_time) 35 | return max(hyps) 36 | 37 | 38 | if __name__ == '__main__': 39 | dataset = 'pubmed' 40 | data_path = os.path.join(os.environ['DATAPATH'], dataset) 41 | data = load_data_lp(dataset, use_feats=False, data_path=data_path) 42 | graph = nx.from_scipy_sparse_matrix(data['adj_train']) 43 | print('Computing hyperbolicity', graph.number_of_nodes(), graph.number_of_edges()) 44 | hyp = hyperbolicity_sample(graph) 45 | print('Hyp: ', hyp) 46 | 47 | -------------------------------------------------------------------------------- /graph_evaluate/eval/orcamodule.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | 6 | #include "orca/orca.h" 7 | 8 | static PyObject * 9 | orca_motifs(PyObject *self, PyObject *args) 10 | { 11 | const char *orbit_type; 12 | int graphlet_size; 13 | const char *input_filename; 14 | const char *output_filename; 15 | int sts; 16 | 17 | if (!PyArg_ParseTuple(args, "siss", &orbit_type, &graphlet_size, &input_filename, &output_filename)) 18 | return NULL; 19 | sts = system(orbit_type); 20 | motif_counts(orbit_type, graphlet_size, input_filename, output_filename); 21 | return PyLong_FromLong(sts); 22 | } 23 | 24 | static PyMethodDef OrcaMethods[] = { 25 | {"motifs", orca_motifs, METH_VARARGS, 26 | "Compute motif counts."}, 27 | }; 28 | 29 | static struct PyModuleDef orcamodule = { 30 | PyModuleDef_HEAD_INIT, 31 | "orca", /* name of module */ 32 | NULL, /* module documentation, may be NULL */ 33 | -1, /* size of per-interpreter state of the module, 34 | or -1 if the module keeps state in global variables. */ 35 | OrcaMethods 36 | }; 37 | 38 | PyMODINIT_FUNC 39 | PyInit_orca(void) 40 | { 41 | return PyModule_Create(&orcamodule); 42 | } 43 | 44 | int main(int argc, char *argv[]) { 45 | 46 | wchar_t *program = Py_DecodeLocale(argv[0], NULL); 47 | if (program == NULL) { 48 | fprintf(stderr, "Fatal error: cannot decode argv[0]\n"); 49 | exit(1); 50 | } 51 | 52 | /* Add a built-in module, before Py_Initialize */ 53 | PyImport_AppendInittab("orca", PyInit_orca); 54 | 55 | /* Pass argv[0] to the Python interpreter */ 56 | Py_SetProgramName(program); 57 | 58 | /* Initialize the Python interpreter. Required. */ 59 | Py_Initialize(); 60 | 61 | /* Optionally import the module; alternatively, 62 | import can be deferred until the embedded script 63 | imports it. */ 64 | PyImport_ImportModule("orca"); 65 | 66 | PyMem_RawFree(program); 67 | 68 | } 69 | 70 | -------------------------------------------------------------------------------- /models/decoders.py: -------------------------------------------------------------------------------- 1 | """Graph decoders.""" 2 | import manifolds 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from layers.att_layers import GraphAttentionLayer 7 | from layers.layers import GraphConvolution, Linear 8 | 9 | 10 | class Decoder(nn.Module): 11 | """ 12 | Decoder abstract class for node classification tasks. 13 | """ 14 | 15 | def __init__(self, c): 16 | super(Decoder, self).__init__() 17 | self.c = c 18 | 19 | def decode(self, x, adj): 20 | if self.decode_adj: 21 | input = (x, adj) 22 | probs, _ = self.cls.forward(input) 23 | else: 24 | probs = self.cls.forward(x) 25 | return probs 26 | 27 | 28 | class GCNDecoder(Decoder): 29 | """ 30 | Graph Convolution Decoder. 31 | """ 32 | 33 | def __init__(self, c, args): 34 | super(GCNDecoder, self).__init__(c) 35 | act = lambda x: x 36 | self.cls = GraphConvolution(args.dim, args.n_classes, args.dropout, act, args.bias) 37 | self.decode_adj = True 38 | 39 | 40 | class GATDecoder(Decoder): 41 | """ 42 | Graph Attention Decoder. 43 | """ 44 | 45 | def __init__(self, c, args): 46 | super(GATDecoder, self).__init__(c) 47 | self.cls = GraphAttentionLayer(args.dim, args.n_classes, args.dropout, F.elu, args.alpha, 1, True) 48 | self.decode_adj = True 49 | 50 | 51 | class LinearDecoder(Decoder): 52 | """ 53 | MLP Decoder for Hyperbolic/Euclidean node classification models. 54 | """ 55 | 56 | def __init__(self, c, args): 57 | super(LinearDecoder, self).__init__(c) 58 | self.manifold = getattr(manifolds, args.manifold)() 59 | self.input_dim = args.dim 60 | self.output_dim = args.n_classes 61 | self.bias = args.bias 62 | self.cls = Linear(self.input_dim, self.output_dim, args.dropout, lambda x: x, self.bias) 63 | self.decode_adj = False 64 | 65 | def decode(self, x, adj): 66 | h = self.manifold.proj_tan0(self.manifold.logmap0(x, c=self.c), c=self.c) 67 | return super(LinearDecoder, self).decode(h, adj) 68 | 69 | def extra_repr(self): 70 | return 'in_features={}, out_features={}, bias={}, c={}'.format( 71 | self.input_dim, self.output_dim, self.bias, self.c 72 | ) 73 | 74 | 75 | model2decoder = { 76 | 'GCN': GCNDecoder, 77 | 'GAT': GATDecoder, 78 | 'HNN': LinearDecoder, 79 | 'HGCN': LinearDecoder, 80 | 'MLP': LinearDecoder, 81 | 'Shallow': LinearDecoder, 82 | } 83 | 84 | -------------------------------------------------------------------------------- /layers/layers.py: -------------------------------------------------------------------------------- 1 | """Euclidean layers.""" 2 | import math 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from torch.nn.modules.module import Module 8 | from torch.nn.parameter import Parameter 9 | 10 | 11 | def get_dim_act(args): 12 | """ 13 | Helper function to get dimension and activation at every layer. 14 | :param args: 15 | :return: 16 | """ 17 | if not args.act: 18 | act = lambda x: x 19 | else: 20 | act = getattr(F, args.act) 21 | acts = [act] * (args.num_layers - 1) 22 | dims = [args.feat_dim] + ([args.dim] * (args.num_layers - 1)) 23 | if args.task in ['lp', 'rec']: 24 | dims += [args.dim] 25 | acts += [act] 26 | return dims, acts 27 | 28 | 29 | class GraphConvolution(Module): 30 | """ 31 | Simple GCN layer. 32 | """ 33 | 34 | def __init__(self, in_features, out_features, dropout, act, use_bias): 35 | super(GraphConvolution, self).__init__() 36 | self.dropout = dropout 37 | self.linear = nn.Linear(in_features, out_features, use_bias) 38 | self.act = act 39 | self.in_features = in_features 40 | self.out_features = out_features 41 | 42 | def forward(self, input): 43 | x, adj = input 44 | hidden = self.linear.forward(x) 45 | hidden = F.dropout(hidden, self.dropout, training=self.training) 46 | if adj.is_sparse: 47 | support = torch.spmm(adj, hidden) 48 | else: 49 | support = torch.mm(adj, hidden) 50 | output = self.act(support), adj 51 | return output 52 | 53 | def extra_repr(self): 54 | return 'input_dim={}, output_dim={}'.format( 55 | self.in_features, self.out_features 56 | ) 57 | 58 | 59 | class Linear(Module): 60 | """ 61 | Simple Linear layer with dropout. 62 | """ 63 | 64 | def __init__(self, in_features, out_features, dropout, act, use_bias): 65 | super(Linear, self).__init__() 66 | self.dropout = dropout 67 | self.linear = nn.Linear(in_features, out_features, use_bias) 68 | self.act = act 69 | 70 | def forward(self, x): 71 | hidden = self.linear.forward(x) 72 | hidden = F.dropout(hidden, self.dropout, training=self.training) 73 | out = self.act(hidden) 74 | return out 75 | 76 | 77 | class FermiDiracDecoder(Module): 78 | """Fermi Dirac to compute edge probabilities based on distances.""" 79 | 80 | def __init__(self, r, t): 81 | super(FermiDiracDecoder, self).__init__() 82 | self.r = r 83 | self.t = t 84 | 85 | def forward(self, dist): 86 | probs = 1. / (torch.exp((dist - self.r) / self.t) + 1.0) 87 | return probs 88 | 89 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | ### Python template 2 | # Byte-compiled / optimized / DLL files 3 | __pycache__/ 4 | *.py[cod] 5 | *$py.class 6 | *.pyc 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | share/python-wheels/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | MANIFEST 30 | 31 | # PyInstaller 32 | # Usually these files are written by a python script from a template 33 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 34 | *.manifest 35 | *.spec 36 | 37 | # Installer logs 38 | pip-log.txt 39 | pip-delete-this-directory.txt 40 | 41 | # Unit test / coverage reports 42 | htmlcov/ 43 | .tox/ 44 | .nox/ 45 | .coverage 46 | .coverage.* 47 | .cache 48 | nosetests.xml 49 | coverage.xml 50 | *.cover 51 | *.py,cover 52 | .hypothesis/ 53 | .pytest_cache/ 54 | cover/ 55 | 56 | # Translations 57 | *.mo 58 | *.pot 59 | 60 | # Django stuff: 61 | *.log 62 | local_settings.py 63 | db.sqlite3 64 | db.sqlite3-journal 65 | 66 | # Flask stuff: 67 | instance/ 68 | .webassets-cache 69 | 70 | # Scrapy stuff: 71 | .scrapy 72 | 73 | # Sphinx documentation 74 | docs/_build/ 75 | 76 | # PyBuilder 77 | .pybuilder/ 78 | target/ 79 | 80 | # Jupyter Notebook 81 | .ipynb_checkpoints 82 | 83 | # IPython 84 | profile_default/ 85 | ipython_config.py 86 | 87 | # pyenv 88 | # For a library or package, you might want to ignore these files since the code is 89 | # intended to run in multiple environments; otherwise, check them in: 90 | # .python-version 91 | 92 | # pipenv 93 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 94 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 95 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 96 | # install all needed dependencies. 97 | #Pipfile.lock 98 | 99 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 100 | __pypackages__/ 101 | 102 | # Celery stuff 103 | celerybeat-schedule 104 | celerybeat.pid 105 | 106 | # SageMath parsed files 107 | *.sage.py 108 | 109 | # Environments 110 | *.env 111 | .env 112 | .venv 113 | env/ 114 | venv/ 115 | ENV/ 116 | env.bak/ 117 | venv.bak/ 118 | 119 | # Spyder project settings 120 | .spyderproject 121 | .spyproject 122 | 123 | # Rope project settings 124 | .ropeproject 125 | 126 | # mkdocs documentation 127 | /site 128 | 129 | # mypy 130 | .mypy_cache/ 131 | .dmypy.json 132 | dmypy.json 133 | 134 | # Pyre type checker 135 | .pyre/ 136 | 137 | # pytype static type analyzer 138 | .pytype/ 139 | 140 | # Cython debug symbols 141 | cython_debug/ -------------------------------------------------------------------------------- /graph_evaluate/baselines/graphvae/data.py: -------------------------------------------------------------------------------- 1 | import networkx as nx 2 | import numpy as np 3 | import torch 4 | 5 | class GraphAdjSampler(torch.utils.data.Dataset): 6 | def __init__(self, G_list, max_num_nodes, features='id'): 7 | self.max_num_nodes = max_num_nodes 8 | self.adj_all = [] 9 | self.len_all = [] 10 | self.feature_all = [] 11 | 12 | for G in G_list: 13 | adj = nx.to_numpy_matrix(G) 14 | # the diagonal entries are 1 since they denote node probability 15 | self.adj_all.append( 16 | np.asarray(adj) + np.identity(G.number_of_nodes())) 17 | self.len_all.append(G.number_of_nodes()) 18 | if features == 'id': 19 | self.feature_all.append(np.identity(max_num_nodes)) 20 | elif features == 'deg': 21 | degs = np.sum(np.array(adj), 1) 22 | degs = np.expand_dims(np.pad(degs, [0, max_num_nodes - G.number_of_nodes()], 0), 23 | axis=1) 24 | self.feature_all.append(degs) 25 | elif features == 'struct': 26 | degs = np.sum(np.array(adj), 1) 27 | degs = np.expand_dims(np.pad(degs, [0, max_num_nodes - G.number_of_nodes()], 28 | 'constant'), 29 | axis=1) 30 | clusterings = np.array(list(nx.clustering(G).values())) 31 | clusterings = np.expand_dims(np.pad(clusterings, 32 | [0, max_num_nodes - G.number_of_nodes()], 33 | 'constant'), 34 | axis=1) 35 | self.feature_all.append(np.hstack([degs, clusterings])) 36 | 37 | def __len__(self): 38 | return len(self.adj_all) 39 | 40 | def __getitem__(self, idx): 41 | adj = self.adj_all[idx] 42 | num_nodes = adj.shape[0] 43 | adj_padded = np.zeros((self.max_num_nodes, self.max_num_nodes)) 44 | adj_padded[:num_nodes, :num_nodes] = adj 45 | 46 | adj_decoded = np.zeros(self.max_num_nodes * (self.max_num_nodes + 1) // 2) 47 | node_idx = 0 48 | 49 | adj_vectorized = adj_padded[np.triu(np.ones((self.max_num_nodes,self.max_num_nodes)) ) == 1] 50 | # the following 2 lines recover the upper triangle of the adj matrix 51 | #recovered = np.zeros((self.max_num_nodes, self.max_num_nodes)) 52 | #recovered[np.triu(np.ones((self.max_num_nodes, self.max_num_nodes)) ) == 1] = adj_vectorized 53 | #print(recovered) 54 | 55 | return {'adj':adj_padded, 56 | 'adj_decoded':adj_vectorized, 57 | 'features':self.feature_all[idx].copy()} 58 | 59 | -------------------------------------------------------------------------------- /manifolds/base.py: -------------------------------------------------------------------------------- 1 | """Base manifold.""" 2 | 3 | from torch.nn import Parameter 4 | 5 | 6 | class Manifold(object): 7 | """ 8 | Abstract class to define operations on a manifold. 9 | """ 10 | 11 | def __init__(self): 12 | super().__init__() 13 | self.eps = 10e-8 14 | 15 | def sqdist(self, p1, p2, c): 16 | """Squared distance between pairs of points.""" 17 | raise NotImplementedError 18 | 19 | def egrad2rgrad(self, p, dp, c): 20 | """Converts Euclidean Gradient to Riemannian Gradients.""" 21 | raise NotImplementedError 22 | 23 | def proj(self, p, c): 24 | """Projects point p on the manifold.""" 25 | raise NotImplementedError 26 | 27 | def proj_tan(self, u, p, c): 28 | """Projects u on the tangent space of p.""" 29 | raise NotImplementedError 30 | 31 | def proj_tan0(self, u, c): 32 | """Projects u on the tangent space of the origin.""" 33 | raise NotImplementedError 34 | 35 | def expmap(self, u, p, c): 36 | """Exponential map of u at point p.""" 37 | raise NotImplementedError 38 | 39 | def logmap(self, p1, p2, c): 40 | """Logarithmic map of point p1 at point p2.""" 41 | raise NotImplementedError 42 | 43 | def expmap0(self, u, c): 44 | """Exponential map of u at the origin.""" 45 | raise NotImplementedError 46 | 47 | def logmap0(self, p, c): 48 | """Logarithmic map of point p at the origin.""" 49 | raise NotImplementedError 50 | 51 | def mobius_add(self, x, y, c, dim=-1): 52 | """Adds points x and y.""" 53 | raise NotImplementedError 54 | 55 | def mobius_matvec(self, m, x, c): 56 | """Performs hyperboic martrix-vector multiplication.""" 57 | raise NotImplementedError 58 | 59 | def init_weights(self, w, c, irange=1e-5): 60 | """Initializes random weigths on the manifold.""" 61 | raise NotImplementedError 62 | 63 | def inner(self, p, c, u, v=None, keepdim=False): 64 | """Inner product for tangent vectors at point x.""" 65 | raise NotImplementedError 66 | 67 | def ptransp(self, x, y, u, c): 68 | """Parallel transport of u from x to y.""" 69 | raise NotImplementedError 70 | 71 | def ptransp0(self, x, u, c): 72 | """Parallel transport of u from the origin to y.""" 73 | raise NotImplementedError 74 | 75 | 76 | class ManifoldParameter(Parameter): 77 | """ 78 | Subclass of torch.nn.Parameter for Riemannian optimization. 79 | """ 80 | def __new__(cls, data, requires_grad, manifold, c): 81 | return Parameter.__new__(cls, data, requires_grad) 82 | 83 | def __init__(self, data, requires_grad, manifold, c): 84 | self.c = c 85 | self.manifold = manifold 86 | 87 | def __repr__(self): 88 | return '{} Parameter containing:\n'.format(self.manifold.name) + super(Parameter, self).__repr__() 89 | -------------------------------------------------------------------------------- /utils/polblogs.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Callable, List, Optional 3 | 4 | import torch 5 | 6 | from torch_geometric.data import ( 7 | Data, 8 | InMemoryDataset, 9 | download_url, 10 | extract_tar, 11 | ) 12 | 13 | 14 | class PolBlogs(InMemoryDataset): 15 | r"""The Political Blogs dataset from the `"The Political Blogosphere and 16 | the 2004 US Election: Divided they Blog" 17 | `_ paper. 18 | 19 | :class:`Polblogs` is a graph with 1,490 vertices (representing political 20 | blogs) and 19,025 edges (links between blogs). 21 | The links are automatically extracted from a crawl of the front page of the 22 | blog. 23 | Each vertex receives a label indicating the political leaning of the blog: 24 | liberal or conservative. 25 | 26 | Args: 27 | root (str): Root directory where the dataset should be saved. 28 | transform (callable, optional): A function/transform that takes in an 29 | :obj:`torch_geometric.data.Data` object and returns a transformed 30 | version. The data object will be transformed before every access. 31 | (default: :obj:`None`) 32 | pre_transform (callable, optional): A function/transform that takes in 33 | an :obj:`torch_geometric.data.Data` object and returns a 34 | transformed version. The data object will be transformed before 35 | being saved to disk. (default: :obj:`None`) 36 | 37 | **STATS:** 38 | 39 | .. list-table:: 40 | :widths: 10 10 10 10 41 | :header-rows: 1 42 | 43 | * - #nodes 44 | - #edges 45 | - #features 46 | - #classes 47 | * - 1,490 48 | - 19,025 49 | - 0 50 | - 2 51 | """ 52 | 53 | url = 'https://netset.telecom-paris.fr/datasets/polblogs.tar.gz' 54 | 55 | def __init__(self, root: str, transform: Optional[Callable] = None, 56 | pre_transform: Optional[Callable] = None): 57 | super().__init__(root, transform, pre_transform) 58 | self.data, self.slices = torch.load(self.processed_paths[0]) 59 | 60 | @property 61 | def raw_file_names(self) -> List[str]: 62 | return ['adjacency.csv', 'labels.csv'] 63 | 64 | @property 65 | def processed_file_names(self) -> str: 66 | return 'data.pt' 67 | 68 | def download(self): 69 | path = download_url(self.url, self.raw_dir) 70 | extract_tar(path, self.raw_dir) 71 | os.unlink(path) 72 | 73 | def process(self): 74 | import pandas as pd 75 | 76 | edge_index = pd.read_csv(self.raw_paths[0], header=None, sep='\t', 77 | usecols=[0, 1]) 78 | edge_index = torch.tensor(edge_index.values.astype(float)).t().contiguous() 79 | # edge_index = torch.from_numpy(edge_index.values).t().contiguous() 80 | 81 | y = pd.read_csv(self.raw_paths[1], header=None, sep='\t') 82 | y = torch.tensor(y.values.astype(float)).view(-1) 83 | # y = torch.from_numpy(y.values).view(-1) 84 | 85 | data = Data(edge_index=edge_index, y=y, num_nodes=y.size(0)) 86 | 87 | if self.pre_transform is not None: 88 | data = self.pre_transform(data) 89 | 90 | torch.save(self.collate([data]), self.processed_paths[0]) -------------------------------------------------------------------------------- /utils/train_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn.functional as F 6 | import torch.nn.modules.loss 7 | 8 | 9 | def format_metrics(metrics, split): 10 | """Format metric in metric dict for logging.""" 11 | return " ".join( 12 | ["{}_{}: {:.4f}".format(split, metric_name, metric_val) for metric_name, metric_val in metrics.items()]) 13 | 14 | 15 | def get_dir_name(models_dir): 16 | """Gets a directory to save the model. 17 | 18 | If the directory already exists, then append a new integer to the end of 19 | it. This method is useful so that we don't overwrite existing models 20 | when launching new jobs. 21 | 22 | Args: 23 | models_dir: The directory where all the models are. 24 | 25 | Returns: 26 | The name of a new directory to save the training logs and model weights. 27 | """ 28 | if not os.path.exists(models_dir): 29 | save_dir = os.path.join(models_dir, '0') 30 | os.makedirs(save_dir) 31 | else: 32 | existing_dirs = np.array( 33 | [ 34 | d 35 | for d in os.listdir(models_dir) 36 | if os.path.isdir(os.path.join(models_dir, d)) 37 | ] 38 | ).astype(np.int) 39 | if len(existing_dirs) > 0: 40 | dir_id = str(existing_dirs.max() + 1) 41 | else: 42 | dir_id = "1" 43 | save_dir = os.path.join(models_dir, dir_id) 44 | os.makedirs(save_dir) 45 | return save_dir 46 | 47 | 48 | def add_flags_from_config(parser, config_dict): 49 | """ 50 | Adds a flag (and default value) to an ArgumentParser for each parameter in a config 51 | """ 52 | 53 | def OrNone(default): 54 | def func(x): 55 | # Convert "none" to proper None object 56 | if x.lower() == "none": 57 | return None 58 | # If default is None (and x is not None), return x without conversion as str 59 | elif default is None: 60 | return str(x) 61 | # Otherwise, default has non-None type; convert x to that type 62 | else: 63 | return type(default)(x) 64 | 65 | return func 66 | 67 | for param in config_dict: 68 | default, description = config_dict[param] 69 | try: 70 | if isinstance(default, dict): 71 | parser = add_flags_from_config(parser, default) 72 | elif isinstance(default, list): 73 | if len(default) > 0: 74 | # pass a list as argument 75 | parser.add_argument( 76 | f"--{param}", 77 | action="append", 78 | type=type(default[0]), 79 | default=default, 80 | help=description 81 | ) 82 | else: 83 | pass 84 | parser.add_argument(f"--{param}", action="append", default=default, help=description) 85 | else: 86 | pass 87 | parser.add_argument(f"--{param}", type=OrNone(default), default=default, help=description) 88 | except argparse.ArgumentError: 89 | print( 90 | f"Could not add flag for param {param} because it was already present." 91 | ) 92 | return parser 93 | 94 | 95 | -------------------------------------------------------------------------------- /graph_evaluate/args_eval.py: -------------------------------------------------------------------------------- 1 | 2 | ### program configuration 3 | class Args(): 4 | def __init__(self): 5 | ### if clean tensorboard 6 | self.clean_tensorboard = False 7 | ### Which CUDA GPU device is used for training 8 | self.cuda = 1 9 | 10 | # The dependent Bernoulli sequence version of GraphRNN 11 | self.note = 'HypDiff' 12 | 13 | self.graph_type = 'MUTAG' #protein, "MUTAG","QM9" "IMDB-BINARY","COLLAB" 14 | 15 | # if none, then auto calculate 16 | self.max_num_node = None # max number of nodes in a graph 17 | self.max_prev_node = None # max previous node that looks back 18 | 19 | ### network config 20 | ## GraphRNN 21 | if 'small' in self.graph_type: 22 | self.parameter_shrink = 2 23 | else: 24 | self.parameter_shrink = 1 25 | self.hidden_size_rnn = int(128/self.parameter_shrink) # hidden size for main RNN 26 | self.hidden_size_rnn_output = 16 # hidden size for output RNN 27 | self.embedding_size_rnn = int(64/self.parameter_shrink) # the size for LSTM input 28 | self.embedding_size_rnn_output = 8 # the embedding size for output rnn 29 | self.embedding_size_output = int(64/self.parameter_shrink) # the embedding size for output (VAE/MLP) 30 | 31 | self.batch_size = 32 # normal: 32, and the rest should be changed accordingly 32 | self.test_batch_size = 32 33 | self.test_total_size = (self.test_batch_size+1)*10 34 | self.num_layers = 4 35 | 36 | ### training config 37 | self.num_workers = 4 # num workers to load data, default 4 38 | self.batch_ratio = 32 # how many batches of samples per epoch, default 32, e.g., 1 epoch = 32 batches 39 | self.epochs = 100 # now one epoch means self.batch_ratio x batch_size 40 | self.epochs_test_start = 50 41 | self.epochs_test = 50 42 | self.epochs_log = 50 43 | self.epochs_save = 50 44 | 45 | self.lr = 0.003 46 | self.milestones = [400, 1000] 47 | self.lr_rate = 0.3 48 | 49 | self.sample_time = 2 # sample time in each time step, when validating 50 | 51 | ### output config 52 | # self.dir_input = "/dfs/scratch0/jiaxuany0/" 53 | self.dir_input = "./" 54 | self.model_save_path = self.dir_input+'model_save/' # only for nll evaluation 55 | self.graph_save_path = self.dir_input+'graphs/' 56 | self.figure_save_path = self.dir_input+'figures/' 57 | self.timing_save_path = self.dir_input+'timing/' 58 | self.figure_prediction_save_path = self.dir_input+'figures_prediction/' 59 | self.nll_save_path = self.dir_input+'nll/' 60 | 61 | 62 | self.load = False # if load model, default lr is very low 63 | self.load_epoch = 3000 64 | self.save = True 65 | 66 | 67 | ### baseline config 68 | # self.generator_baseline = 'Gnp' 69 | self.generator_baseline = 'BA' 70 | 71 | # self.metric_baseline = 'general' 72 | # self.metric_baseline = 'degree' 73 | self.metric_baseline = 'clustering' 74 | 75 | 76 | ### filenames to save intemediate and final outputs 77 | self.fname = self.note + '_' + self.graph_type + '_' + str(self.num_layers) + '_' + str(self.hidden_size_rnn) + '_' 78 | self.fname_pred = self.note+'_'+self.graph_type+'_'+str(self.num_layers)+'_'+ str(self.hidden_size_rnn)+'_pred_' 79 | self.fname_train = self.note+'_'+self.graph_type+'_'+str(self.num_layers)+'_'+ str(self.hidden_size_rnn)+'_train_' 80 | self.fname_test = self.note + '_' + self.graph_type + '_' + str(self.num_layers) + '_' + str(self.hidden_size_rnn) + '_test_' 81 | self.fname_baseline = self.graph_save_path + self.graph_type + self.generator_baseline+'_'+self.metric_baseline 82 | 83 | -------------------------------------------------------------------------------- /lp_train.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from __future__ import print_function 3 | 4 | import json 5 | import os 6 | import pickle 7 | import time 8 | 9 | import numpy as np 10 | import optimizers 11 | import torch 12 | from config import parser 13 | from models.base_models import LPModel 14 | from utils.data_utils import load_data 15 | import os 16 | os.environ['CUDA_LAUNCH_BLOCKING'] = '1' 17 | 18 | import scipy.sparse as sp 19 | from tqdm import tqdm 20 | 21 | def train(args): 22 | np.random.seed(args.seed) 23 | torch.manual_seed(args.seed) 24 | if int(args.cuda) >= 0: 25 | torch.cuda.manual_seed(args.seed) 26 | args.device = 'cuda:' + str(args.cuda) if int(args.cuda) >= 0 else 'cpu' 27 | save_dir = os.path.join(os.environ['LOG_DIR'], args.dataset) 28 | if not os.path.exists(save_dir): 29 | os.makedirs(save_dir) 30 | 31 | # Load data 32 | data = load_data(args, os.path.join(os.environ['DATAPATH'], args.dataset)) 33 | 34 | args.n_nodes, args.feat_dim = data['features'].shape 35 | 36 | args.nb_false_edges = len(data['train_edges_false']) 37 | args.nb_edges = len(data['train_edges']) 38 | Model = LPModel 39 | # No validation for reconstruction task 40 | args.eval_freq = args.epochs + 1 41 | 42 | if not args.lr_reduce_freq: 43 | args.lr_reduce_freq = args.epochs 44 | 45 | model = Model(args) 46 | optimizer = getattr(optimizers, args.optimizer)(params=model.parameters(), lr=args.lr, 47 | weight_decay=args.weight_decay) 48 | lr_scheduler = torch.optim.lr_scheduler.StepLR( 49 | optimizer, 50 | step_size=int(args.lr_reduce_freq), 51 | gamma=float(args.gamma) 52 | ) 53 | tot_params = sum([np.prod(p.size()) for p in model.parameters()]) 54 | if args.cuda is not None and int(args.cuda) >= 0 : 55 | os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda) 56 | model = model.to(args.device) 57 | for x, val in data.items(): 58 | if torch.is_tensor(data[x]): 59 | data[x] = data[x].to(args.device) 60 | # Train model 61 | t_total = time.time() 62 | counter = 0 63 | best_val_metrics = model.init_metric_dict() 64 | best_test_metrics = None 65 | best_emb = None 66 | for epoch in tqdm(range(args.epochs)): 67 | tt = time.time() 68 | model.train() 69 | optimizer.zero_grad() 70 | embeddings,t,h0,adj= model.encode(data['features'], data['adj_train_norm']) 71 | # print(h0) 72 | 73 | train_metrics = model.compute_metrics(embeddings, data, t,h0,adj,'train') 74 | # print(train_metrics['loss']) 75 | train_metrics['loss'].backward() 76 | if args.grad_clip is not None: 77 | max_norm = float(args.grad_clip) 78 | all_params = list(model.parameters()) 79 | for param in all_params: 80 | torch.nn.utils.clip_grad_norm_(param, max_norm) 81 | optimizer.step() 82 | lr_scheduler.step() 83 | 84 | if (epoch + 1) % args.eval_freq == 0: 85 | model.eval() 86 | embeddings,t,h0,adj = model.encode(data['features'], data['adj_train_norm']) 87 | val_metrics = model.compute_metrics(embeddings, data, t,h0,adj,'val') 88 | 89 | model.eval() 90 | 91 | best_test_metrics = model.compute_metrics(embeddings, data,t,h0, adj,'test') 92 | 93 | # print(best_test_metrics) 94 | print('End encoding!') 95 | np.save(os.path.join(save_dir, 'embeddings.npy'), h0.cpu().detach().numpy()) 96 | if hasattr(model.encoder, 'att_adj'): 97 | filename = os.path.join(save_dir, args.dataset + '_att_adj.p') 98 | pickle.dump(model.encoder.att_adj.cpu().to_dense(), open(filename, 'wb')) 99 | print('Dumped attention adj: ' + filename) 100 | 101 | json.dump(vars(args), open(os.path.join(save_dir, 'config.json'), 'w')) 102 | torch.save(model.state_dict(), os.path.join(save_dir, 'model.pth')) 103 | 104 | if __name__ == '__main__': 105 | args = parser.parse_args() 106 | train(args) 107 | -------------------------------------------------------------------------------- /hyperbolic_learning/hyperkmeans.py: -------------------------------------------------------------------------------- 1 | # import libraries 2 | import numpy as np 3 | import pandas as pd 4 | import matplotlib.pyplot as plt 5 | 6 | import networkx as nx 7 | import sys 8 | import os 9 | import torch 10 | # import modules within repository 11 | 12 | from utils import * 13 | from .hyperbolic_kmeans.hkmeans import HyperbolicKMeans, plot_clusters 14 | # ignore warnings 15 | import warnings 16 | warnings.filterwarnings('ignore'); 17 | # # display multiple outputs within a cell 18 | # from IPython.core.interactiveshell import InteractiveShell 19 | # InteractiveShell.ast_node_interactivity = "all"; 20 | 21 | def hkmeanscom(args): 22 | # load polbooks data 23 | 24 | # load pre-trained embedding coordinates 25 | save_dir = os.path.join(os.environ['LOG_DIR'], args.dataset) 26 | # if args.save: 27 | # if not args.save_dir: 28 | # # dt = datetime.datetime.now() 29 | # # date = f"{dt.year}_{dt.month}_{dt.day}" 30 | # # models_dir = os.path.join(os.environ['LOG_DIR'], args.task, date) 31 | # # save_dir = get_dir_name(models_dir) 32 | # save_dir = os.path.join(os.environ['LOG_DIR'], args.dataset) 33 | # else: 34 | # save_dir = args.save_dir 35 | #file_path = os.path.join('/home/wyc/Code/HyperDiff/logs/lp/2023_10_25/9/', 'embeddings.npy') 36 | file_path=os.path.join(save_dir, 'embeddings.npy') 37 | emb_data= np.load(file_path, allow_pickle=True) 38 | # print(emb_data.shape) 39 | #emb_data=torch.randn(105,2) 40 | # fit unsupervised clustering 41 | 42 | m=3 43 | hkmeans = HyperbolicKMeans(n_clusters=m) 44 | hkmeans.fit(emb_data, max_epochs=5) 45 | labels=hkmeans.assignments 46 | center=hkmeans.centroids 47 | lable_tmp=[] 48 | for i in range(len(labels)): 49 | for j in range(m): 50 | if(labels[i][j]==1): 51 | lable_tmp.append(j) 52 | 53 | #print(label_tmp.shape) 54 | # print(center.shape) 55 | lable_path=os.path.join(save_dir,'label.npy') 56 | center_path=os.path.join(save_dir,'center.npy') 57 | np.save(lable_path,lable_tmp) 58 | np.save(center_path,center) 59 | 60 | def graph_hkmeanscom(args, emb_data, dataloader): 61 | 62 | save_dir = os.path.join(os.environ['LOG_DIR'], args.dataset) 63 | # if args.save: 64 | # if not args.save_dir: 65 | # # dt = datetime.datetime.now() 66 | # # date = f"{dt.year}_{dt.month}_{dt.day}" 67 | # # models_dir = os.path.join(os.environ['LOG_DIR'], args.task, date) 68 | # # save_dir = get_dir_name(models_dir) 69 | # save_dir = os.path.join(os.environ['LOG_DIR'], args.dataset) 70 | # else: 71 | # save_dir = args.save_dir 72 | #file_path = os.path.join('/home/wyc/Code/HyperDiff/logs/lp/2023_10_25/9/', 'embeddings.npy') 73 | # file_path=os.path.join(save_dir, 'embeddings.npy') 74 | # emb_data= np.load(file_path, allow_pickle=True) 75 | # print(emb_data.shape) 76 | #emb_data=torch.randn(105,2) 77 | # fit unsupervised clustering 78 | # emb_data = torch.tensor(emb_data).float() 79 | emb_data = emb_data.numpy() 80 | label=[] 81 | centers=[] 82 | # for batch_idx, data in enumerate(dataloader): 83 | for batch_idx in range(len(emb_data)): 84 | m=3 85 | hkmeans = HyperbolicKMeans(n_clusters=m) 86 | hkmeans.fit(emb_data[batch_idx], max_epochs=5) 87 | labels=hkmeans.assignments 88 | center=hkmeans.centroids 89 | lable_tmp=[] 90 | for i in range(len(labels)): 91 | for j in range(m): 92 | if(labels[i][j]==1): 93 | lable_tmp.append(j) 94 | label.append(lable_tmp) 95 | centers.append(center) 96 | #print(label_tmp.shape) 97 | # print(center.shape) 98 | lable_path=os.path.join(save_dir,'label'+str(batch_idx)+'.npy') 99 | center_path=os.path.join(save_dir,'center'+str(batch_idx)+'.npy') 100 | np.save(lable_path,lable_tmp) 101 | np.save(center_path,center) 102 | lable_path = os.path.join(save_dir, 'label' + '.npy') 103 | center_path = os.path.join(save_dir, 'center' + '.npy') 104 | np.save(lable_path,label) 105 | np.save(center_path,centers) -------------------------------------------------------------------------------- /ddpm_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from tqdm import tqdm 3 | import numpy as np 4 | 5 | class DDPMSampler(): 6 | def __init__(self, beta_1, beta_T, T, diffusion_fn, device, shape): 7 | ''' 8 | beta_1 : beta_1 of diffusion process 9 | beta_T : beta_T of diffusion process 10 | T : step of diffusion process 11 | diffusion_fn : trained diffusion network 12 | shape : data shape 13 | ''' 14 | 15 | self.betas = torch.linspace(start=beta_1, end=beta_T, steps=T) 16 | self.alphas = 1 - self.betas 17 | self.alpha_bars = torch.cumprod(1 - torch.linspace(start=beta_1, end=beta_T, steps=T), dim=0).to(device=device) 18 | self.alpha_prev_bars = torch.cat([torch.Tensor([1]).to(device=device), self.alpha_bars[:-1]]) 19 | self.shape = shape 20 | self.deta = 0.01 21 | self.diffusion_fn = diffusion_fn 22 | self.device = device 23 | 24 | def getpri(self, t): 25 | c_num = np.sqrt(self.c) 26 | T = 2000 27 | out = self.deta * torch.tanh(c_num * t / T) 28 | #print(c_num * t / T) 29 | b, *_ = t.shape 30 | return out.reshape(b, *((1,) * (2 - 1))) 31 | 32 | def _one_diffusion_step(self, x, direction, restrict, target): 33 | ''' 34 | x : perturbated data 35 | ''' 36 | if not restrict: 37 | for idx in reversed(range(len(self.alpha_bars))): 38 | noise = torch.zeros_like(x) if idx == 0 else torch.randn_like(x) 39 | sqrt_tilde_beta = torch.sqrt( 40 | (1 - self.alpha_prev_bars[idx]) / (1 - self.alpha_bars[idx]) * self.betas[idx]) 41 | if target == 'pred_noise': 42 | predict_epsilon = self.diffusion_fn(x, idx, get_target=False) 43 | mu_theta_xt = torch.sqrt(1 / self.alphas[idx]) * ( 44 | x - self.betas[idx] / torch.sqrt(1 - self.alpha_bars[idx]) * predict_epsilon) 45 | x = mu_theta_xt + sqrt_tilde_beta * noise 46 | elif target == 'pred_x0': 47 | predict_x0 = self.diffusion_fn(x, idx) 48 | predict_epsilon = torch.sqrt(1 / (1 - self.alphas[idx])) * ( 49 | x - torch.sqrt(self.alphas[idx]) * predict_x0) 50 | mu_theta_xt = torch.sqrt(1 / self.alphas[idx]) * ( 51 | x - self.betas[idx] / torch.sqrt(1 - self.alpha_bars[idx]) * predict_epsilon) 52 | x = mu_theta_xt + sqrt_tilde_beta * noise 53 | yield x 54 | else: 55 | for idx in reversed(range(len(self.alpha_bars))): 56 | pri = self.getpri(idx) 57 | noise = torch.zeros_like(x) if idx == 0 else torch.randn_like(x) 58 | noise = direction * torch.abs(noise) 59 | sqrt_tilde_beta = torch.sqrt( 60 | (1 - self.alpha_prev_bars[idx]) / (1 - self.alpha_bars[idx]) * self.betas[idx]) 61 | if target == 'pred_noise': 62 | predict_epsilon = self.diffusion_fn(x, idx) 63 | predict_x0 = (x - torch.sqrt(1 - self.alpha_bars[idx]) * direction * torch.abs(predict_epsilon)) / ( 64 | torch.sqrt(self.alpha_bars[idx]) + self.deta) 65 | if (idx > 0): 66 | x = x - (torch.sqrt(self.alpha_bars[idx]) * predict_x0 + pri * predict_x0 + torch.sqrt( 67 | 1 - self.alpha_bars[idx]) * noise) + (torch.sqrt( 68 | self.alpha_bars[idx - 1]) * predict_x0 + pri * predict_x0 + torch.sqrt( 69 | 1 - self.alpha_bars[idx - 1]) * noise) 70 | elif target == 'pred_x0': 71 | predict_x0 = self.diffusion_fn(x, idx) 72 | if (idx > 0): 73 | x = x - (torch.sqrt(self.alpha_bars[idx]) * predict_x0 + pri * predict_x0 + torch.sqrt( 74 | 1 - self.alpha_bars[idx]) * noise) + (torch.sqrt( 75 | self.alpha_bars[idx - 1]) * predict_x0 + pri * predict_x0 + torch.sqrt( 76 | 1 - self.alpha_bars[idx - 1]) * noise) 77 | 78 | yield x 79 | 80 | @torch.no_grad() 81 | def sampling(self, sampling_number, x, only_final=False, restrict=False, target='pred_noise'): 82 | ''' 83 | sampling_number : a number of generation 84 | only_final : If True, return is an only output of final schedule step 85 | ''' 86 | direction = torch.sign(x) 87 | sample = torch.randn([sampling_number, *self.shape]).to(device=self.device).squeeze() 88 | sampling_list = [] 89 | 90 | final = None 91 | for idx, sample in enumerate(tqdm(self._one_diffusion_step(sample, direction, restrict, target))): 92 | final = sample 93 | if not only_final: 94 | sampling_list.append(final) 95 | 96 | return final if only_final else torch.stack(sampling_list) 97 | -------------------------------------------------------------------------------- /hyperbolic_learning/hyperbolic_kmeans/models/polbooks_vectors: -------------------------------------------------------------------------------- 1 | 107 2 2 | 9 0.7046087383288888 -0.25695412491576375 3 | 13 0.6575087840075093 -0.05568487622114636 4 | 4 0.8598665471490705 -0.13553175864582365 5 | 74 -0.37153819952051337 0.3034554345041999 6 | 63 -0.5247701919513179 0.04584799000844434 7 | 62 -0.5583092856261503 0.45696246955500464 8 | 66 -0.35728736689394935 0.3807207153774604 9 | 31 -0.787557578613153 0.22120987080437357 10 | 12 0.5697490908570185 -0.35353126652451766 11 | 40 0.40915195097138174 -0.1509147811159358 12 | 47 0.377767465725793 -0.16926802155061224 13 | 10 0.4355397119807345 -0.6037119572549787 14 | 67 -0.4858918980120315 0.10559152927482718 15 | 68 -0.5586289423175957 0.05557727475080089 16 | 11 0.8951682758438829 0.014742970395967218 17 | 34 -0.6306660669860991 -0.03489077661775147 18 | 75 -0.27674629241409077 0.4058329417182256 19 | 14 0.5253450794088335 -0.08661378496083 20 | 33 -0.2581962667282671 -0.6031935138354038 21 | 69 -0.42231695853894513 -0.0205663118394071 22 | 77 -0.3291226550056587 0.384164080164691 23 | 93 -0.3073028888911153 0.34161288851651705 24 | 7 0.5458366979313702 0.6661007097449605 25 | 32 -0.7570641545713138 -0.09530332106986927 26 | 21 0.3106881258013218 -0.4957906694566841 27 | 38 0.8533017997801482 0.20070801988958967 28 | 15 0.4673156588092308 -0.24451741646134822 29 | 24 0.7238454687969046 -0.13948760429166224 30 | 25 0.427489031558284 -0.29146462120281647 31 | 27 0.4126167494578887 -0.26718986138879747 32 | 28 0.5827909367237898 -0.27659730540310895 33 | 37 0.7299304101385393 0.07782241346597629 34 | 72 -0.4649704194902326 0.08121139419423207 35 | 82 -0.3458127429133858 -0.6919888322005062 36 | 5 0.5388624357734805 0.7750015496968019 37 | 8 0.452643713628992 0.6434430570294597 38 | 41 0.4501239562719716 -0.147840765456634 39 | 49 -0.007733441531764411 -0.5108433097812095 40 | 54 0.7581308146562159 0.15648323808729356 41 | 70 -0.5376267300158273 0.04374937228035323 42 | 73 -0.3305833158831965 0.12703423586823148 43 | 79 -0.5006259039064803 0.0665023900715221 44 | 6 0.5487647595968044 0.7260958505375845 45 | 23 0.41769313306879857 -0.08742382839300884 46 | 39 0.7852887769969916 0.13523798275105278 47 | 53 0.770167474928376 0.15115747513206249 48 | 60 -0.4334118174502487 -0.045023650894214035 49 | 65 -0.6802683703565194 0.1749400975537152 50 | 76 -0.2508920071172869 0.43341772413286056 51 | 83 -0.36202827605002624 -0.6872729902297926 52 | 104 -0.2546221724591613 0.41571626231370745 53 | 1 0.5946113745038827 0.7851266336414179 54 | 42 0.46758531822396115 -0.06954809982416535 55 | 45 0.3781365894977165 -0.23450732316279857 56 | 51 -0.34205692099311624 -0.7531920277472836 57 | 58 0.3878163947467862 -0.06752657881867119 58 | 64 -0.3784590158925541 -0.5794727870251969 59 | 87 -0.32013671596150467 0.9271894066280484 60 | 89 -0.30792975118171445 0.9099070990577488 61 | 95 -0.36226964051187727 0.2830448497221914 62 | 97 -0.3307511381087155 0.3726131752386288 63 | 98 -0.32285601767354244 0.3348311759374826 64 | 16 0.8862803361071081 0.0029536855310105226 65 | 18 0.4812659824789473 -0.08896815804349639 66 | 20 0.8688423570983056 -0.014098659697747337 67 | 22 0.7389183522161881 -0.185853088978547 68 | 26 0.3952625905985228 -0.08735968694759723 69 | 35 -0.13649541364094123 -0.08189212185499643 70 | 36 0.6237769068009177 -0.02868928448725484 71 | 43 0.4334247956840668 -0.09887729386301554 72 | 44 0.4870004738538052 -0.0747985682653158 73 | 52 -0.23816987707852605 -0.6497901075990279 74 | 57 0.8346190148074809 0.20236801816427516 75 | 59 0.1703778973695289 -0.3344983339198897 76 | 61 0.1972247495263647 -0.2540104154107525 77 | 78 -0.5030185942444328 0.02136353759224305 78 | 80 0.9472890786928801 0.2387639449204936 79 | 84 -0.3300353754157964 -0.6617693917327515 80 | 86 -0.307973765565951 0.9432053202929374 81 | 92 -0.2123840731062162 0.405746625870189 82 | 94 -0.337203217448736 0.2822663998874448 83 | 96 -0.5117144773993654 0.23944489340241412 84 | 102 -0.3706916654679104 0.14071699192497858 85 | 105 -0.32354570984039493 0.17143833544202752 86 | 2 0.6156756478991998 0.7570261332029108 87 | 3 0.5682977829453524 0.7969028081386588 88 | 30 0.3223361388984562 0.24611825721490577 89 | 46 0.4424988525596052 0.028039677969000826 90 | 48 0.2352630414937468 -0.5833834687256998 91 | 55 0.878768553798213 0.004721356496922857 92 | 56 0.2945303014563393 -0.2069669965886845 93 | 71 -0.546842387216187 0.11076899719710057 94 | 81 -0.16784511245829084 0.368454137545002 95 | 85 -0.32243106074375705 -0.6261203357369363 96 | 88 -0.2350095581126571 0.8838913359906228 97 | 90 -0.30855400767978175 0.9195097283604609 98 | 91 -0.1748924341009564 0.45618665176890594 99 | 101 -0.31765109145924136 0.4279281429110989 100 | 17 0.883973315010489 0.004369254541300928 101 | 19 0.5296696990435625 0.12419781301320709 102 | 29 0.23348016714885222 0.7451252007658975 103 | 50 0.33584705280091015 -0.5833526513556448 104 | 100 -0.38004640793288247 -0.5660352105162821 105 | 103 -0.4042121637754467 0.12343161127317125 106 | 99 -0.3917665005815977 -0.5740914599287628 107 | u 0.2889956057642895 0.9330238105507371 108 | v 0.2889336261520143 0.9327643872900935 109 | -------------------------------------------------------------------------------- /Synthatic_graph_generator.py: -------------------------------------------------------------------------------- 1 | 2 | import networkx as nx 3 | import scipy 4 | import numpy as np 5 | from plotter import plotG 6 | import numpy 7 | from operator import itemgetter 8 | import random 9 | 10 | 11 | 12 | 13 | def Synthetic_data(type= "grid", rand = False): 14 | if rand==True: 15 | if type == "grid": 16 | G = grid(random.randint(10,15), random.randint(10,15)) 17 | elif type == "community": 18 | G = n_community([50, 50], p_inter=0.05) 19 | elif type == "ego": 20 | G = ego() 21 | elif type == "lobster": 22 | G = lobster() 23 | elif type == "multi_rel_com": 24 | G = multi_rel_com() 25 | else: 26 | numpy.random.seed(4812) 27 | np.random.RandomState(1234) 28 | random.seed(245) 29 | 30 | if type == "grid": 31 | G = grid() 32 | elif type== "community": 33 | G = n_community([50,50,50,50], p_inter=0.05) 34 | elif type == "ego": 35 | G = ego() 36 | elif type=="lobster": 37 | G=lobster() 38 | elif type =="multi_rel_com": 39 | G = multi_rel_com() 40 | 41 | plotG(G, type) 42 | return nx.adjacency_matrix(G), scipy.sparse.lil_matrix(scipy.sparse.identity(G.number_of_nodes())) 43 | 44 | def grid(m= 10, n=10 ): 45 | # https: // networkx.github.io / documentation / stable / auto_examples / drawing / plot_four_grids.html 46 | G = nx.grid_2d_graph(m, n) # 4x4 grid 47 | return G 48 | 49 | def n_community(c_sizes, p_inter=0.1, p_intera=0.4): 50 | graphs = [nx.gnp_random_graph(c_sizes[i], p_intera, seed=i) for i in range(len(c_sizes))] 51 | G = nx.disjoint_union_all(graphs) 52 | communities = list(nx.connected_components(G)) 53 | for i in range(len(communities)): 54 | subG1 = communities[i] 55 | nodes1 = list(subG1) 56 | for j in range(i+1, len(communities)): 57 | subG2 = communities[j] 58 | nodes2 = list(subG2) 59 | has_inter_edge = False 60 | for n1 in nodes1: 61 | for n2 in nodes2: 62 | if np.random.rand() < p_inter: 63 | G.add_edge(n1, n2) 64 | has_inter_edge = True 65 | if not has_inter_edge: 66 | G.add_edge(nodes1[0], nodes2[0]) 67 | # print('connected comp: ', len(list(nx.connected_components(G)))) 68 | return G 69 | 70 | def multi_rel_com(comunities =[[50,50,50,50], [100,100]], graph_size= 200): 71 | """ 72 | 73 | :param comunities: a list of lists, in which each list determine a seet of communities and the size of each one, 74 | the inter and intera edge probablity will be random. 75 | :node_num the graph size 76 | :return: 77 | """ 78 | graphs = [] 79 | for community in comunities: 80 | graphs.append(ncommunity(community, graph_size, random.uniform(.0001,.01), random.uniform(.2,.7))) 81 | 82 | H = nx.compose(graphs[0], graphs[1]) 83 | for i in range(2, len(graphs)): 84 | H = nx.compose(H, graphs[i]) 85 | return H 86 | 87 | 88 | def ncommunity(c_sizes, graph_size, p_inter=0.1, p_intera=0.4 ): 89 | graphs = [nx.gnp_random_graph(c_sizes[i], p_intera, seed=i) for i in range(len(c_sizes))] 90 | G = nx.disjoint_union_all(graphs) 91 | communities = list(nx.connected_components(G)) 92 | for i in range(len(communities)): 93 | subG1 = communities[i] 94 | nodes1 = list(subG1) 95 | for j in range(i + 1, len(communities)): 96 | subG2 = communities[j] 97 | nodes2 = list(subG2) 98 | has_inter_edge = False 99 | for n1 in nodes1: 100 | for n2 in nodes2: 101 | if np.random.rand() < p_inter: 102 | G.add_edge(n1, n2) 103 | has_inter_edge = True 104 | if not has_inter_edge: 105 | G.add_edge(nodes1[0], nodes2[0]) 106 | 107 | x = list(range(graph_size)) 108 | random.shuffle(x) 109 | 110 | if(len(G)> graph_size): 111 | G.add_nodes_from([i for i in range(len(G), graph_size)]) 112 | mapping = {k: v for k, v in zip(list(range(graph_size)), x)} 113 | G = nx.relabel_nodes(G, mapping) 114 | return G 115 | 116 | def lobster(): 117 | p1 = 0.7 118 | p2 = 0.7 119 | mean_node = 80 120 | G = nx.random_lobster(mean_node, p1, p2) 121 | return G 122 | 123 | def ego(): 124 | # Create a BA model graph 125 | n = 2000 126 | m = 3 127 | G = nx.generators.barabasi_albert_graph(n, m) 128 | # find node with largest degree 129 | node_and_degree = G.degree() 130 | (largest_hub, degree) = sorted(node_and_degree, key=itemgetter(1))[-1] 131 | # Create ego graph of main hub 132 | hub_ego = nx.ego_graph(G, largest_hub) 133 | return hub_ego 134 | 135 | 136 | if __name__ == '__main__': 137 | Synthetic_data("multi_rel_com") 138 | print("closed") 139 | -------------------------------------------------------------------------------- /graph_evaluate/baselines/graphvae/train.py: -------------------------------------------------------------------------------- 1 | 2 | import argparse 3 | import matplotlib.pyplot as plt 4 | import networkx as nx 5 | import numpy as np 6 | import os 7 | from random import shuffle 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.init as init 11 | from torch.autograd import Variable 12 | import torch.nn.functional as F 13 | from torch import optim 14 | from torch.optim.lr_scheduler import MultiStepLR 15 | 16 | import data 17 | from baselines.graphvae.model import GraphVAE 18 | from baselines.graphvae.data import GraphAdjSampler 19 | 20 | CUDA = 2 21 | 22 | LR_milestones = [500, 1000] 23 | 24 | def build_model(args, max_num_nodes): 25 | out_dim = max_num_nodes * (max_num_nodes + 1) // 2 26 | if args.feature_type == 'id': 27 | input_dim = max_num_nodes 28 | elif args.feature_type == 'deg': 29 | input_dim = 1 30 | elif args.feature_type == 'struct': 31 | input_dim = 2 32 | model = GraphVAE(input_dim, 64, 256, max_num_nodes) 33 | return model 34 | 35 | def train(args, dataloader, model): 36 | epoch = 1 37 | optimizer = optim.Adam(list(model.parameters()), lr=args.lr) 38 | scheduler = MultiStepLR(optimizer, milestones=LR_milestones, gamma=args.lr) 39 | 40 | model.train() 41 | for epoch in range(5000): 42 | for batch_idx, data in enumerate(dataloader): 43 | model.zero_grad() 44 | features = data['features'].float() 45 | adj_input = data['adj'].float() 46 | 47 | features = Variable(features).cuda() 48 | adj_input = Variable(adj_input).cuda() 49 | 50 | loss = model(features, adj_input) 51 | print('Epoch: ', epoch, ', Iter: ', batch_idx, ', Loss: ', loss) 52 | loss.backward() 53 | 54 | optimizer.step() 55 | scheduler.step() 56 | break 57 | 58 | def arg_parse(): 59 | parser = argparse.ArgumentParser(description='GraphVAE arguments.') 60 | io_parser = parser.add_mutually_exclusive_group(required=False) 61 | io_parser.add_argument('--dataset', dest='dataset', 62 | help='Input dataset.') 63 | 64 | parser.add_argument('--lr', dest='lr', type=float, 65 | help='Learning rate.') 66 | parser.add_argument('--batch_size', dest='batch_size', type=int, 67 | help='Batch size.') 68 | parser.add_argument('--num_workers', dest='num_workers', type=int, 69 | help='Number of workers to load data.') 70 | parser.add_argument('--max_num_nodes', dest='max_num_nodes', type=int, 71 | help='Predefined maximum number of nodes in train/test graphs. -1 if determined by \ 72 | training data.') 73 | parser.add_argument('--feature', dest='feature_type', 74 | help='Feature used for encoder. Can be: id, deg') 75 | 76 | parser.set_defaults(dataset='grid', 77 | feature_type='id', 78 | lr=0.001, 79 | batch_size=1, 80 | num_workers=1, 81 | max_num_nodes=-1) 82 | return parser.parse_args() 83 | 84 | def main(): 85 | prog_args = arg_parse() 86 | 87 | os.environ['CUDA_VISIBLE_DEVICES'] = str(CUDA) 88 | print('CUDA', CUDA) 89 | ### running log 90 | 91 | if prog_args.dataset == 'enzymes': 92 | graphs= data.Graph_load_batch(min_num_nodes=10, name='ENZYMES') 93 | num_graphs_raw = len(graphs) 94 | elif prog_args.dataset == 'grid': 95 | graphs = [] 96 | for i in range(2,3): 97 | for j in range(2,3): 98 | graphs.append(nx.grid_2d_graph(i,j)) 99 | num_graphs_raw = len(graphs) 100 | 101 | if prog_args.max_num_nodes == -1: 102 | max_num_nodes = max([graphs[i].number_of_nodes() for i in range(len(graphs))]) 103 | else: 104 | max_num_nodes = prog_args.max_num_nodes 105 | # remove graphs with number of nodes greater than max_num_nodes 106 | graphs = [g for g in graphs if g.number_of_nodes() <= max_num_nodes] 107 | 108 | graphs_len = len(graphs) 109 | print('Number of graphs removed due to upper-limit of number of nodes: ', 110 | num_graphs_raw - graphs_len) 111 | graphs_test = graphs[int(0.8 * graphs_len):] 112 | #graphs_train = graphs[0:int(0.8*graphs_len)] 113 | graphs_train = graphs 114 | 115 | print('total graph num: {}, training set: {}'.format(len(graphs),len(graphs_train))) 116 | print('max number node: {}'.format(max_num_nodes)) 117 | 118 | dataset = GraphAdjSampler(graphs_train, max_num_nodes, features=prog_args.feature_type) 119 | #sample_strategy = torch.utils.data.sampler.WeightedRandomSampler( 120 | # [1.0 / len(dataset) for i in range(len(dataset))], 121 | # num_samples=prog_args.batch_size, 122 | # replacement=False) 123 | dataset_loader = torch.utils.data.DataLoader( 124 | dataset, 125 | batch_size=prog_args.batch_size, 126 | num_workers=prog_args.num_workers) 127 | model = build_model(prog_args, max_num_nodes).cuda() 128 | train(prog_args, dataset_loader, model) 129 | 130 | 131 | if __name__ == '__main__': 132 | main() 133 | -------------------------------------------------------------------------------- /hyperbolic_learning/hyperbolic_kmeans/util_hk.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def norm(x, axis=None): 4 | return np.linalg.norm(x, axis=axis) 5 | 6 | #------------------------- 7 | #----- Poincaré Disk ----- 8 | #------------------------- 9 | 10 | # NOTE: POSSIBLE ISSUE WITH DIFFERENT WAYS TO SPECIFY MINKOWSKI DOT PRODUCT 11 | # arbritray sign gives different signatures (+, +, +, -), (+, -, -, -) 12 | 13 | # distance in poincare disk 14 | def poincare_dist(u, v, eps=1e-5): 15 | d = 1 + 2 * norm(u-v)**2 / ((1 - norm(u)**2) * (1 - norm(v)**2) + eps) 16 | return np.arccosh(d) 17 | 18 | # compute symmetric poincare distance matrix 19 | def poincare_distances(embedding): 20 | n = embedding.shape[0] 21 | dist_matrix = np.zeros((n, n)) 22 | for i in range(n): 23 | for j in range(i+1, n): 24 | dist_matrix[i][j] = poincare_dist(embedding[i], embedding[j]) 25 | return dist_matrix 26 | 27 | # convert array from poincare disk to hyperboloid 28 | def poincare_pts_to_hyperboloid(Y, eps=1e-6, metric='lorentz'): 29 | mink_pts = np.zeros((Y.shape[0], Y.shape[1]+1)) 30 | 31 | r = norm(Y, axis=1) 32 | if metric == 'minkowski': 33 | mink_pts[:, 0] = 2/(1 - r**2 + eps) * (1 + r**2)/2 34 | for i in range(1,Y.shape[1]): 35 | mink_pts[:, i+1] = 2/(1 - r**2 + eps) * Y[:, i] 36 | #mink_pts[:, 2] = 2/(1 - r**2 + eps) * Y[:, 1] 37 | else: 38 | for i in range(Y.shape[1]): 39 | mink_pts[:, i] = 2/(1 - r**2 + eps) * Y[:, i] 40 | # mink_pts[:, 1] = 2/(1 - r**2 + eps) * Y[:, 1] 41 | mink_pts[:,Y.shape[1]] = 2/(1 - r**2 + eps) * (1 + r**2)/2 42 | return mink_pts 43 | 44 | # convert single point to hyperboloid 45 | def poincare_pt_to_hyperboloid(y, eps=1e-6, metric='lorentz'): 46 | #print(y.shape) 47 | d=y.shape 48 | # print(d[0]) 49 | mink_pt = np.zeros((d[0]+1,)) 50 | r = norm(y) 51 | if metric == 'minkowski': 52 | mink_pt[0] = 2/(1 - r**2 + eps) * (1 + r**2)/2 53 | for i in range(1,d[0]+1): 54 | mink_pt[i]=2/(1 - r**2 + eps) * y[i] 55 | 56 | else: 57 | for i in range(d[0]): 58 | mink_pt[i]=2/(1 - r**2 + eps) * y[i] 59 | 60 | mink_pt[d[0]] = 2/(1 - r**2 + eps) * (1 + r**2)/2 61 | return mink_pt 62 | 63 | #------------------------------ 64 | #----- Hyperboloid Model ------ 65 | #------------------------------ 66 | 67 | # NOTE: POSSIBLE ISSUE WITH DIFFERENT WAYS TO SPECIFY MINKOWSKI DOT PRODUCT 68 | # arbritray sign gives different signatures (+, +, +, -), (+, -, -, -) 69 | 70 | # define hyperboloid bilinear form 71 | def hyperboloid_dot(u, v): 72 | 73 | return np.dot(u[:-1], v[:-1]) - u[-1]*v[-1] 74 | 75 | # define alternate minkowski/hyperboloid bilinear form 76 | def minkowski_dot(u, v): 77 | return u[0]*v[0] - np.dot(u[1:], v[1:]) 78 | 79 | # hyperboloid distance function 80 | def hyperboloid_dist(u, v, eps=1e-6, metric='lorentz'): 81 | if metric == 'minkowski': 82 | dist = np.arccosh(-1*minkowski_dot(u, v)) 83 | else: 84 | dist = np.arccosh(-1*hyperboloid_dot(u, v)) 85 | if np.isnan(dist): 86 | #print('Hyperboloid dist returned nan value') 87 | return eps 88 | else: 89 | return dist 90 | 91 | # compute symmetric hyperboloid distance matrix 92 | def hyperboloid_distances(embedding): 93 | n = embedding.shape[0] 94 | dist_matrix = np.zeros((n, n)) 95 | for i in range(n): 96 | for j in range(i+1, n): 97 | dist_matrix[i][j] = hyperboloid_dist(embedding[i], embedding[j]) 98 | return dist_matrix 99 | 100 | # convert array to poincare disk 101 | def hyperboloid_pts_to_poincare(X, eps=1e-6, metric='lorentz'): 102 | poincare_pts = np.zeros((X.shape[0], X.shape[1]-1)) 103 | if metric == 'minkowski': 104 | for i in range(X.shape[1]-1): 105 | poincare_pts[:, i] = X[:, i+1] / ((X[:, 0]+1) + eps) 106 | # poincare_pts[:, 1] = X[:, 2] / ((X[:, 0]+1) + eps) 107 | else: 108 | for i in range(X.shape[1]-1): 109 | poincare_pts[:, i] = X[:, i] / ((X[:, X.shape[1]-1]+1) + eps) 110 | # poincare_pts[:, 1] = X[:, 1] / ((X[:, 2]+1) + eps) 111 | return poincare_pts 112 | 113 | # project within disk 114 | def proj(theta,eps=0.1): 115 | if norm(theta) >= 1: 116 | theta = theta/norm(theta) - eps 117 | return theta 118 | 119 | # convert single point to poincare 120 | def hyperboloid_pt_to_poincare(x, eps=1e-6, metric='lorentz'): 121 | d=x.shape 122 | d=d[0]-1 123 | poincare_pt = np.zeros((d, )) 124 | if metric == 'minkowski': 125 | for i in range(d): 126 | poincare_pt[i] = x[i+1] / ((x[0]+1) + eps) 127 | #poincare_pt[1] = x[2] / ((x[0]+1) + eps) 128 | else: 129 | for i in range(d): 130 | poincare_pt[i] = x[i] / ((x[d]+1) + eps) 131 | #poincare_pt[1] = x[1] / ((x[2]+1) + eps) 132 | return proj(poincare_pt) 133 | 134 | # helper function to generate samples 135 | def generate_data(n, radius=0.7, hyperboloid=False): 136 | theta = np.random.uniform(0, 2*np.pi, n) 137 | u = np.random.uniform(0, radius, n) 138 | r = np.sqrt(u) 139 | x = r * np.cos(theta) 140 | y = r * np.sin(theta) 141 | init_data = np.hstack((x.reshape(-1,1), y.reshape(-1,1))) 142 | if hyperboloid: 143 | return poincare_pts_to_hyperboloid(init_data) 144 | else: 145 | return init_data -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from utils.train_utils import add_flags_from_config 4 | 5 | config_args = { 6 | 'training_lp_config': { 7 | 'lr': (0.0001, 'learning rate'), 8 | 'dropout': (0.2, 'dropout probability'), 9 | 'cuda': (0, 'which cuda device to use (-1 for cpu training)'), 10 | 'epochs': (10, 'maximum number of epochs to train for'), 11 | 'weight-decay': (0.00001, 'l2 regularization strength'), 12 | 'optimizer': ('Adam', 'which optimizer to use, can be any of [Adam, RiemannianAdam]'), 13 | 'momentum': (0.99, 'momentum in optimizer'), 14 | 'seed': (1432, 'seed for training'), 15 | 'log-freq': (5, 'how often to compute print train/val metrics (in epochs)'), 16 | 'eval-freq': (1, 'how often to compute val metrics (in epochs)'), 17 | 'save-dir': (None, 'path to save training logs and model weights (defaults to logs/task/date/run/)'), 18 | 'sweep-c': (0, ''), 19 | 'lr-reduce-freq': (None, 'reduce lr every lr-reduce-freq or None to keep lr constant'), 20 | 'gamma': (0.2, 'gamma for lr scheduler'), 21 | 'print-epoch': (True, ''), 22 | 'grad-clip': (10, 'max norm for gradient clipping, or None for no gradient clipping'), 23 | 'min-epochs': (None, 'do not early stop before min-epochs') 24 | }, 25 | 'model_config': { 26 | 'model': ('HGCN', 'which encoder to use, can be any of [Shallow, MLP, HNN, GCN, GAT, HyperGCN]'), 27 | 'dim': (512, 'embedding dimension'), 28 | 'hid1': (64, 'embedding dimension'), 29 | 'hid2': (32, 'embedding dimension'), 30 | 'manifold': ('PoincareBall', 'which manifold to use, can be any of [Hyperboloid, PoincareBall]'), 31 | 'c': (0.6, 'hyperbolic radius, set to None for trainable curvature'), 32 | 'r': (2., 'fermi-dirac decoder parameter for lp'), 33 | 't': (1., 'fermi-dirac decoder parameter for lp'), 34 | 'pos-weight': (0, 'whether to upweight positive class in node classification tasks'), 35 | 'num-layers': (2, 'number of hidden layers in encoder'), 36 | 'bias': (0, 'whether to use bias (1) or not (0)'), 37 | 'act': ('relu', 'which activation function to use (or None for no activation)'), 38 | 'n-heads': (4, 'number of attention heads for graph attention networks, must be a divisor dim'), 39 | 'alpha': (0.6, 'alpha for leakyrelu in graph attention networks'), 40 | 'double-precision': ('0', 'whether to use double precision'), 41 | 'use-att': (0, 'whether to use hyperbolic attention or not'), 42 | 'local-agg': (0, 'whether to local tangent space aggregation or not') 43 | }, 44 | 'data_config': { 45 | 'dataset': ('MUTAG', 'which dataset to use'), 46 | 'val-prop': (0.05, 'proportion of validation edges for link prediction'), 47 | 'test-prop': (0.1, 'proportion of test edges for link prediction'), 48 | 'use-feats': (0, 'whether to use node features or not'), 49 | 'normalize-feats': (1, 'whether to normalize input node features'), 50 | 'normalize-adj': (1, 'whether to row-normalize the adjacency matrix'), 51 | 'split-seed': (1432, 'seed for data splits (train/test/val)'), 52 | }, 53 | 'work_type_config':{ 54 | 'type':('train','which type to choose, you can select train or test'), 55 | 'diff_epoc':(1, 'maximum number of epochs to train for'), 56 | 'taskselect':('graphtask', '[lptask, graphtask]') 57 | }, 58 | 'training_diffusion_config': { 59 | 'target': ('pred_noise', 'at every Vis_step, the plots will be updated'), 60 | 'restrict': (False, 'whether use the geometric contraints'), 61 | 'lr_diff': (0.0001, 'model learning rate'), 62 | 'epoch_diff': (5000, 'maximum number of epochs to train for'), 63 | 'epoch_load': (5000, 'maximum number of epochs to train for') 64 | }, 65 | 'training_hvae_config': { 66 | 'Vis_step': (1000, 'at every Vis_step, the plots will be updated'), 67 | 'redraw': (False, 'either update the log plot each step'), 68 | 'epoch_number': (5000, 'maximum number of epochs to train for'), 69 | 'graphEmDim': (64, 'the dimention of graph Embeding LAyer; z'), 70 | 'graph_save_path': (None, 'the direc to save generated synthatic graphs'), 71 | 'use_feature': (True, 'either use features or identity matrix'), 72 | 'PATH': ('model', 'a string which determine the path in wich model will be saved'), 73 | 'decoder': ('FC', 'the decoder type, FC is only option in this rep'), 74 | 'encoder_type': ("HAvePool", 'the encoder: only option in this rep is Ave'), 75 | 'batchSize': (200, 'the size of each batch; the number of graphs is the mini batch'), 76 | 'UseGPU': (True, 'either use GPU or not if availabel'), 77 | 'model_vae': ('graphVAE', 'only option is graphVAE'), 78 | 'device': ("cuda:0", 'Which device should be used'), 79 | 'task': ("graphGeneration", 'only option in this rep is graphGeneration'), 80 | 'bfsOrdering': (True, 'use bfs for graph permutations'), 81 | 'directed': (True, 'is the dataset directed?!'), 82 | 'beta': (None, 'beta coefiicieny'), 83 | 'plot_testGraphs': (True, 'shall the test set be printed'), 84 | 'ideal_Evalaution': (False, 'if you want to comapre the 50%50 subset of dataset comparision?!') 85 | } 86 | } 87 | 88 | parser = argparse.ArgumentParser() 89 | for _, config_dict in config_args.items(): 90 | parser = add_flags_from_config(parser, config_dict) 91 | -------------------------------------------------------------------------------- /hyperbolic_learning/hyperbolic_kmeans/models/football_vectors: -------------------------------------------------------------------------------- 1 | 117 2 2 | 0 0.42188138800334557 0.8702295000280139 3 | 1 -0.7430413097369526 0.5666952237075653 4 | 104 0.2970899928950441 0.5325533436235209 5 | 2 0.6497297026873939 -0.6856454560611063 6 | 3 0.16392461975958997 -0.9126095254214484 7 | 6 0.678231596647116 -0.6130084570341932 8 | 15 0.6282751096273168 -0.45919999234685255 9 | 5 0.26319182633836974 -0.8527788154222775 10 | 7 0.8421496097584927 0.4417972661576641 11 | 67 0.7860557212530045 -0.013037359000339057 12 | 53 0.7766308780426957 -0.03989598553487827 13 | 88 0.7651371923536704 -0.03512223911149387 14 | 4 0.4855796692650613 0.7923458757665107 15 | 9 0.4952014288217301 0.7440561152811678 16 | 16 0.4053047514191927 0.7457701780718974 17 | 23 0.4532526150428279 0.6597626332092877 18 | 35 -0.2872175883536931 0.7123215404822753 19 | 65 -0.48531004704939684 -0.5406296034629752 20 | 25 -0.6896931005599571 0.5101796542783507 21 | 27 -0.5478127377463308 -0.5823788492576576 22 | 37 -0.6631894353043104 0.4828444146225076 23 | 45 -0.664961009243156 0.506734783815641 24 | 89 -0.6558598575751666 0.47699201176345546 25 | 109 -0.6544102051006812 0.48557368633382364 26 | 13 0.6919428266655674 -0.4881451025715024 27 | 47 0.5495204687253182 -0.41003739645991555 28 | 60 0.5536755534631704 -0.44922870642719803 29 | 64 0.5375112085721053 -0.43221431366403174 30 | 72 0.16456573026135618 -0.7666484533833076 31 | 74 0.18872569326512992 -0.7171933271226401 32 | 100 0.5531304303426134 -0.44768949913229045 33 | 106 0.5372344340016927 -0.4278984013773863 34 | 40 0.13351780827259405 -0.8672161235947721 35 | 81 0.17643636248246572 -0.67077750546337 36 | 84 0.18553093745835023 -0.6577770728471485 37 | 69 0.3144417580209658 0.12555830790107225 38 | 98 0.17499232176473134 -0.6688868668383939 39 | 32 0.6083417971178993 -0.4455610823677657 40 | 39 0.5788726656621597 -0.45679018925798476 41 | 55 -0.29451030717200777 0.6688962428178572 42 | 8 0.8139037392198217 0.49946338876044744 43 | 21 0.7843480754569245 0.42782031611938404 44 | 22 0.7705527125755028 0.46345384768429465 45 | 68 0.7055542339482841 0.3929416977701263 46 | 73 0.7989400656025484 -0.0062699200610682815 47 | 77 0.7280217198160184 0.4077318389352939 48 | 78 0.7043916880995074 0.4042232044086235 49 | 82 0.024059188871463237 0.21742026582401325 50 | 111 0.6832495906871358 0.3865246356741833 51 | 51 0.7057081547321264 0.4137348404052715 52 | 17 -0.6209349512653264 -0.5921779426756915 53 | 18 -0.9568513355337628 -0.028961295441562687 54 | 34 -0.8793930292113498 0.007545846275370741 55 | 38 -0.8336429157024776 -0.15244387246838179 56 | 43 -0.7957197160367058 -0.13515417247144274 57 | 110 0.7658711068718839 -0.03423949740073783 58 | 92 -0.17900986113160003 0.03240935591729328 59 | 114 0.7185068626677757 -0.030397344008380425 60 | 20 -0.6946645202838097 -0.6198552062583775 61 | 62 -0.5354342907159726 -0.556084961629967 62 | 87 -0.4829173195677484 -0.5364823254774639 63 | 19 -0.37735805714304527 0.8350277211562183 64 | 31 -0.9491833623801758 0.04461872344370403 65 | 61 -0.8040440454878017 -0.019478302788374164 66 | 29 -0.3054273167851486 0.8881051337194189 67 | 30 -0.30960518196772085 0.7830973712669057 68 | 44 -0.3054042957823949 0.30062145218607383 69 | 79 -0.2812055861650795 0.6433649410031218 70 | 70 -0.46562473280199956 -0.5091832567775092 71 | 76 -0.43023186838760585 -0.5010379783682303 72 | 46 0.9080886176617939 -0.01862394625619717 73 | 66 -0.16519510653994302 -0.025275179175582134 74 | 80 -0.29016440976525315 0.568509720514648 75 | 91 -0.13220419267595476 0.06904534040474417 76 | 49 0.8584287832454888 -0.051740689076473424 77 | 48 -0.1704537685927944 0.013773263130853588 78 | 86 -0.11693661718011275 0.028376491494056803 79 | 83 0.7613758113906995 -0.04631935473775823 80 | 33 -0.6666481234192434 0.4853705645697798 81 | 41 0.3855598780615985 0.6447969869895869 82 | 93 0.35279817092536125 0.5906493923793291 83 | 57 -0.1662507889054939 0.008788227980833261 84 | 101 -0.2692586296222677 0.6152430343497629 85 | 103 -0.6374943301710123 0.4741800820714387 86 | 105 -0.6574466251603406 0.4844889527844997 87 | 14 -0.9238699378795275 -0.18728102764492438 88 | 11 0.3413944042744255 0.15631467830796275 89 | 26 -0.8749226254287519 -0.17189251880218095 90 | 52 0.1618743840263358 -0.8206197371082704 91 | 58 -0.12060630746277126 -0.41624319807510185 92 | 102 0.15483083555827648 -0.7081801483284993 93 | 108 0.6965756418779019 0.39864071886003705 94 | 10 0.3447160526816999 -0.873910933042636 95 | 107 0.1601979778482424 -0.6409828171477547 96 | 24 0.3710716418028123 0.1694750619992933 97 | 12 -0.9585286839732432 -0.20007044766478904 98 | 54 -0.8322340187675641 -0.0030236380073669082 99 | 71 -0.8159834482360491 -0.0048955317787182815 100 | 99 -0.7919323703816471 0.01192185682549786 101 | 95 -0.45907300576872945 -0.4760308609479501 102 | 96 -0.4407647385776408 -0.46335252848259345 103 | 113 -0.44819978356281687 -0.46686258015656995 104 | 94 -0.27079784623148784 0.6281210081592002 105 | 75 -0.1964334315266217 -0.022872303024380862 106 | 56 -0.5324865736146867 -0.7344455832753883 107 | 112 -0.14456478742259782 -0.1084979595125978 108 | 90 0.33127465218231866 0.21423198144162278 109 | 28 0.3873004624407157 0.26638888948682105 110 | 85 -0.762782285402684 -0.13427199809209076 111 | 50 0.45119037659275973 0.27250284126192936 112 | 63 -0.2478295438829226 -0.35553754494364714 113 | 97 -0.08812202043336287 -0.3007694838528239 114 | 36 -0.6491347195467851 -0.13150881154562727 115 | 59 -0.3060333756989114 -0.4452911477356562 116 | 42 -0.6850688881870777 -0.04750325407457203 117 | u 0.06711265361830102 0.3393743189309152 118 | v 0.01061882846551033 0.3187415835581022 119 | -------------------------------------------------------------------------------- /graph_evaluate/eval/mmd.py: -------------------------------------------------------------------------------- 1 | import concurrent.futures 2 | from functools import partial 3 | import networkx as nx 4 | import numpy as np 5 | import torch 6 | from scipy.linalg import toeplitz 7 | import pyemd 8 | # from PyEMD import EMD as pyemd 9 | from tqdm import tqdm 10 | 11 | device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') 12 | 13 | def emd(x, y, distance_scaling=1.0): 14 | support_size = max(len(x), len(y)) 15 | d_mat = toeplitz(range(support_size)).astype(float) 16 | distance_mat = d_mat / distance_scaling 17 | 18 | # convert histogram values x and y to float, and make them equal len 19 | x = x.astype(float) 20 | y = y.astype(float) 21 | if len(x) < len(y): 22 | x = np.hstack((x, [0.0] * (support_size - len(x)))) 23 | elif len(y) < len(x): 24 | y = np.hstack((y, [0.0] * (support_size - len(y)))) 25 | 26 | emd = pyemd.emd(x, y, distance_mat) 27 | return emd 28 | 29 | def l2(x, y): 30 | dist = np.linalg.norm(x - y, 2) 31 | return dist 32 | 33 | 34 | def gaussian_emd(x, y, sigma=1.0, distance_scaling=1.0): 35 | ''' Gaussian kernel with squared distance in exponential term replaced by EMD 36 | Args: 37 | x, y: 1D pmf of two distributions with the same support 38 | sigma: standard deviation 39 | ''' 40 | support_size = max(len(x), len(y)) 41 | d_mat = toeplitz(range(support_size)).astype(float) 42 | distance_mat = d_mat / distance_scaling 43 | 44 | # convert histogram values x and y to float, and make them equal len 45 | x = x.astype(float) 46 | y = y.astype(float) 47 | if len(x) < len(y): 48 | x = np.hstack((x, [0.0] * (support_size - len(x)))) 49 | elif len(y) < len(x): 50 | y = np.hstack((y, [0.0] * (support_size - len(y)))) 51 | 52 | emd = pyemd.emd(x, y, distance_mat) 53 | return np.exp(-emd * emd / (2 * sigma * sigma)) 54 | 55 | def gaussian(x, y, sigma=1.0): 56 | dist = np.linalg.norm(x - y, 2) 57 | return np.exp(-dist * dist / (2 * sigma * sigma)) 58 | 59 | def kernel_parallel_unpacked(x, samples2, kernel): 60 | d = 0 61 | for s2 in samples2: 62 | d += kernel(x, s2) 63 | return d 64 | 65 | def kernel_parallel_worker(t): 66 | return kernel_parallel_unpacked(*t) 67 | 68 | def disc(samples1, samples2, kernel, is_parallel=True, *args, **kwargs): 69 | ''' Discrepancy between 2 samples 70 | ''' 71 | d = 0 72 | if not is_parallel: 73 | for s1 in samples1: 74 | for s2 in samples2: 75 | d += kernel(s1, s2, *args, **kwargs) 76 | else: 77 | with concurrent.futures.ProcessPoolExecutor() as executor: 78 | for dist in list(tqdm(executor.map(kernel_parallel_worker, 79 | [(s1, samples2, partial(kernel, *args, **kwargs)) for s1 in samples1]))): 80 | # for dist in executor.map(kernel_parallel_worker, 81 | # [(s1, samples2, partial(kernel, *args, **kwargs)) for s1 in samples1]): 82 | d += dist 83 | d /= len(samples1) * len(samples2) 84 | return d 85 | 86 | 87 | def compute_mmd(samples1, samples2, kernel, is_hist=True, *args, **kwargs): 88 | ''' MMD between two samples 89 | ''' 90 | # normalize histograms into pmf 91 | if is_hist: 92 | samples1 = [s1 / np.sum(s1) for s1 in samples1] 93 | samples2 = [s2 / np.sum(s2) for s2 in samples2] 94 | # print('===============================') 95 | # print('s1: ', disc(samples1, samples1, kernel, *args, **kwargs)) 96 | # print('--------------------------') 97 | # print('s2: ', disc(samples2, samples2, kernel, *args, **kwargs)) 98 | # print('--------------------------') 99 | # print('cross: ', disc(samples1, samples2, kernel, *args, **kwargs)) 100 | # print('===============================') 101 | return disc(samples1, samples1, kernel, *args, **kwargs) + \ 102 | disc(samples2, samples2, kernel, *args, **kwargs) - \ 103 | 2 * disc(samples1, samples2, kernel, *args, **kwargs) 104 | 105 | def compute_emd(samples1, samples2, kernel, is_hist=True, *args, **kwargs): 106 | ''' EMD between average of two samples 107 | ''' 108 | # normalize histograms into pmf 109 | if is_hist: 110 | samples1 = [np.mean(samples1)] 111 | samples2 = [np.mean(samples2)] 112 | # print('===============================') 113 | # print('s1: ', disc(samples1, samples1, kernel, *args, **kwargs)) 114 | # print('--------------------------') 115 | # print('s2: ', disc(samples2, samples2, kernel, *args, **kwargs)) 116 | # print('--------------------------') 117 | # print('cross: ', disc(samples1, samples2, kernel, *args, **kwargs)) 118 | # print('===============================') 119 | return disc(samples1, samples2, kernel, *args, **kwargs),[samples1[0],samples2[0]] 120 | 121 | 122 | def test(): 123 | s1 = np.array([0.2, 0.8]) 124 | s2 = np.array([0.3, 0.7]) 125 | samples1 = [s1, s2] 126 | 127 | s3 = np.array([0.25, 0.75]) 128 | s4 = np.array([0.35, 0.65]) 129 | samples2 = [s3, s4] 130 | 131 | s5 = np.array([0.8, 0.2]) 132 | s6 = np.array([0.7, 0.3]) 133 | samples3 = [s5, s6] 134 | 135 | print('between samples1 and samples2: ', compute_mmd(samples1, samples2, kernel=gaussian_emd, 136 | is_parallel=False, sigma=1.0)) 137 | print('between samples1 and samples3: ', compute_mmd(samples1, samples3, kernel=gaussian_emd, 138 | is_parallel=False, sigma=1.0)) 139 | 140 | if __name__ == '__main__': 141 | test() 142 | 143 | -------------------------------------------------------------------------------- /manifolds/poincare.py: -------------------------------------------------------------------------------- 1 | """Poincare ball manifold.""" 2 | 3 | import torch 4 | 5 | from manifolds.base import Manifold 6 | from utils.math_utils import artanh, tanh 7 | 8 | 9 | class PoincareBall(Manifold): 10 | """ 11 | PoicareBall Manifold class. 12 | 13 | We use the following convention: x0^2 + x1^2 + ... + xd^2 < 1 / c 14 | 15 | Note that 1/sqrt(c) is the Poincare ball radius. 16 | 17 | """ 18 | 19 | def __init__(self, ): 20 | super(PoincareBall, self).__init__() 21 | self.name = 'PoincareBall' 22 | self.min_norm = 1e-15 23 | self.eps = {torch.float32: 4e-3, torch.float64: 1e-5} 24 | 25 | def sqdist(self, p1, p2, c): 26 | sqrt_c = c ** 0.5 27 | dist_c = artanh( 28 | sqrt_c * self.mobius_add(-p1, p2, c, dim=-1).norm(dim=-1, p=2, keepdim=False) 29 | ) 30 | dist = dist_c * 2 / sqrt_c 31 | return dist ** 2 32 | 33 | def _lambda_x(self, x, c): 34 | x_sqnorm = torch.sum(x.data.pow(2), dim=-1, keepdim=True) 35 | return 2 / (1. - c * x_sqnorm).clamp_min(self.min_norm) 36 | 37 | def egrad2rgrad(self, p, dp, c): 38 | lambda_p = self._lambda_x(p, c) 39 | dp /= lambda_p.pow(2) 40 | return dp 41 | 42 | def proj(self, x, c): 43 | norm = torch.clamp_min(x.norm(dim=-1, keepdim=True, p=2), self.min_norm) 44 | maxnorm = (1 - self.eps[x.dtype]) / (c ** 0.5) 45 | cond = norm > maxnorm 46 | projected = x / norm * maxnorm 47 | return torch.where(cond, projected, x) 48 | 49 | def proj_tan(self, u, p, c): 50 | return u 51 | 52 | def proj_tan0(self, u, c): 53 | return u 54 | 55 | def expmap(self, u, p, c): 56 | sqrt_c = c ** 0.5 57 | u_norm = u.norm(dim=-1, p=2, keepdim=True).clamp_min(self.min_norm) 58 | second_term = ( 59 | tanh(sqrt_c / 2 * self._lambda_x(p, c) * u_norm) 60 | * u 61 | / (sqrt_c * u_norm) 62 | ) 63 | gamma_1 = self.mobius_add(p, second_term, c) 64 | return gamma_1 65 | 66 | def logmap(self, p1, p2, c): 67 | sub = self.mobius_add(-p1, p2, c) 68 | sub_norm = sub.norm(dim=-1, p=2, keepdim=True).clamp_min(self.min_norm) 69 | lam = self._lambda_x(p1, c) 70 | sqrt_c = c ** 0.5 71 | return 2 / sqrt_c / lam * artanh(sqrt_c * sub_norm) * sub / sub_norm 72 | 73 | def expmap0(self, u, c): 74 | sqrt_c = c ** 0.5 75 | u_norm = torch.clamp_min(u.norm(dim=-1, p=2, keepdim=True), self.min_norm) 76 | gamma_1 = tanh(sqrt_c * u_norm) * u / (sqrt_c * u_norm) 77 | return gamma_1 78 | 79 | def logmap0(self, p, c): 80 | sqrt_c = c ** 0.5 81 | p_norm = p.norm(dim=-1, p=2, keepdim=True).clamp_min(self.min_norm) 82 | scale = 1. / sqrt_c * artanh(sqrt_c * p_norm) / p_norm 83 | return scale * p 84 | 85 | def mobius_add(self, x, y, c, dim=-1): 86 | x2 = x.pow(2).sum(dim=dim, keepdim=True) 87 | y2 = y.pow(2).sum(dim=dim, keepdim=True) 88 | xy = (x * y).sum(dim=dim, keepdim=True) 89 | num = (1 + 2 * c * xy + c * y2) * x + (1 - c * x2) * y 90 | denom = 1 + 2 * c * xy + c ** 2 * x2 * y2 91 | return num / denom.clamp_min(self.min_norm) 92 | 93 | def mobius_matvec(self, m, x, c): 94 | sqrt_c = c ** 0.5 95 | x_norm = x.norm(dim=-1, keepdim=True, p=2).clamp_min(self.min_norm) 96 | mx = x @ m.transpose(-1, -2) 97 | mx_norm = mx.norm(dim=-1, keepdim=True, p=2).clamp_min(self.min_norm) 98 | res_c = tanh(mx_norm / x_norm * artanh(sqrt_c * x_norm)) * mx / (mx_norm * sqrt_c) 99 | cond = (mx == 0).prod(-1, keepdim=True, dtype=torch.uint8) 100 | res_0 = torch.zeros(1, dtype=res_c.dtype, device=res_c.device) 101 | res = torch.where(cond, res_0, res_c) 102 | return res 103 | 104 | def init_weights(self, w, c, irange=1e-5): 105 | w.data.uniform_(-irange, irange) 106 | return w 107 | 108 | def _gyration(self, u, v, w, c, dim: int = -1): 109 | u2 = u.pow(2).sum(dim=dim, keepdim=True) 110 | v2 = v.pow(2).sum(dim=dim, keepdim=True) 111 | uv = (u * v).sum(dim=dim, keepdim=True) 112 | uw = (u * w).sum(dim=dim, keepdim=True) 113 | vw = (v * w).sum(dim=dim, keepdim=True) 114 | c2 = c ** 2 115 | a = -c2 * uw * v2 + c * vw + 2 * c2 * uv * vw 116 | b = -c2 * vw * u2 - c * uw 117 | d = 1 + 2 * c * uv + c2 * u2 * v2 118 | return w + 2 * (a * u + b * v) / d.clamp_min(self.min_norm) 119 | 120 | def inner(self, x, c, u, v=None, keepdim=False): 121 | if v is None: 122 | v = u 123 | lambda_x = self._lambda_x(x, c) 124 | return lambda_x ** 2 * (u * v).sum(dim=-1, keepdim=keepdim) 125 | 126 | def ptransp(self, x, y, u, c): 127 | lambda_x = self._lambda_x(x, c) 128 | lambda_y = self._lambda_x(y, c) 129 | return self._gyration(y, -x, u, c) * lambda_x / lambda_y 130 | 131 | def ptransp_(self, x, y, u, c): 132 | lambda_x = self._lambda_x(x, c) 133 | lambda_y = self._lambda_x(y, c) 134 | return self._gyration(y, -x, u, c) * lambda_x / lambda_y 135 | 136 | def ptransp0(self, x, u, c): 137 | lambda_x = self._lambda_x(x, c) 138 | return 2 * u / lambda_x.clamp_min(self.min_norm) 139 | 140 | def to_hyperboloid(self, x, c): 141 | K = 1./ c 142 | sqrtK = K ** 0.5 143 | sqnorm = torch.norm(x, p=2, dim=1, keepdim=True) ** 2 144 | return sqrtK * torch.cat([K + sqnorm, 2 * sqrtK * x], dim=1) / (K - sqnorm) 145 | 146 | -------------------------------------------------------------------------------- /manifolds/hyperboloid.py: -------------------------------------------------------------------------------- 1 | """Hyperboloid manifold.""" 2 | 3 | import torch 4 | 5 | from manifolds.base import Manifold 6 | from utils.math_utils import arcosh, cosh, sinh 7 | 8 | 9 | class Hyperboloid(Manifold): 10 | """ 11 | Hyperboloid manifold class. 12 | 13 | We use the following convention: -x0^2 + x1^2 + ... + xd^2 = -K 14 | 15 | c = 1 / K is the hyperbolic curvature. 16 | """ 17 | 18 | def __init__(self): 19 | super(Hyperboloid, self).__init__() 20 | self.name = 'Hyperboloid' 21 | self.eps = {torch.float32: 1e-7, torch.float64: 1e-15} 22 | self.min_norm = 1e-15 23 | self.max_norm = 1e6 24 | 25 | def minkowski_dot(self, x, y, keepdim=True): 26 | res = torch.sum(x * y, dim=-1) - 2 * x[..., 0] * y[..., 0] 27 | if keepdim: 28 | res = res.view(res.shape + (1,)) 29 | return res 30 | 31 | def minkowski_norm(self, u, keepdim=True): 32 | dot = self.minkowski_dot(u, u, keepdim=keepdim) 33 | return torch.sqrt(torch.clamp(dot, min=self.eps[u.dtype])) 34 | 35 | def sqdist(self, x, y, c): 36 | K = 1. / c 37 | prod = self.minkowski_dot(x, y) 38 | theta = torch.clamp(-prod / K, min=1.0 + self.eps[x.dtype]) 39 | sqdist = K * arcosh(theta) ** 2 40 | # clamp distance to avoid nans in Fermi-Dirac decoder 41 | return torch.clamp(sqdist, max=50.0) 42 | 43 | def proj(self, x, c): 44 | K = 1. / c 45 | d = x.size(-1) - 1 46 | y = x.narrow(-1, 1, d) 47 | y_sqnorm = torch.norm(y, p=2, dim=1, keepdim=True) ** 2 48 | mask = torch.ones_like(x) 49 | mask[:, 0] = 0 50 | vals = torch.zeros_like(x) 51 | vals[:, 0:1] = torch.sqrt(torch.clamp(K + y_sqnorm, min=self.eps[x.dtype])) 52 | return vals + mask * x 53 | 54 | def proj_tan(self, u, x, c): 55 | K = 1. / c 56 | d = x.size(1) - 1 57 | ux = torch.sum(x.narrow(-1, 1, d) * u.narrow(-1, 1, d), dim=1, keepdim=True) 58 | mask = torch.ones_like(u) 59 | mask[:, 0] = 0 60 | vals = torch.zeros_like(u) 61 | vals[:, 0:1] = ux / torch.clamp(x[:, 0:1], min=self.eps[x.dtype]) 62 | return vals + mask * u 63 | 64 | def proj_tan0(self, u, c): 65 | narrowed = u.narrow(-1, 0, 1) 66 | vals = torch.zeros_like(u) 67 | vals[:, 0:1] = narrowed 68 | return u - vals 69 | 70 | def expmap(self, u, x, c): 71 | K = 1. / c 72 | sqrtK = K ** 0.5 73 | normu = self.minkowski_norm(u) 74 | normu = torch.clamp(normu, max=self.max_norm) 75 | theta = normu / sqrtK 76 | theta = torch.clamp(theta, min=self.min_norm) 77 | result = cosh(theta) * x + sinh(theta) * u / theta 78 | return self.proj(result, c) 79 | 80 | def logmap(self, x, y, c): 81 | K = 1. / c 82 | xy = torch.clamp(self.minkowski_dot(x, y) + K, max=-self.eps[x.dtype]) - K 83 | u = y + xy * x * c 84 | normu = self.minkowski_norm(u) 85 | normu = torch.clamp(normu, min=self.min_norm) 86 | dist = self.sqdist(x, y, c) ** 0.5 87 | result = dist * u / normu 88 | return self.proj_tan(result, x, c) 89 | 90 | def expmap0(self, u, c): 91 | K = 1. / c 92 | sqrtK = K ** 0.5 93 | d = u.size(-1) - 1 94 | x = u.narrow(-1, 1, d).view(-1, d) 95 | x_norm = torch.norm(x, p=2, dim=1, keepdim=True) 96 | x_norm = torch.clamp(x_norm, min=self.min_norm) 97 | theta = x_norm / sqrtK 98 | res = torch.ones_like(u) 99 | res[:, 0:1] = sqrtK * cosh(theta) 100 | res[:, 1:] = sqrtK * sinh(theta) * x / x_norm 101 | return self.proj(res, c) 102 | 103 | def logmap0(self, x, c): 104 | K = 1. / c 105 | sqrtK = K ** 0.5 106 | d = x.size(-1) - 1 107 | y = x.narrow(-1, 1, d).view(-1, d) 108 | y_norm = torch.norm(y, p=2, dim=1, keepdim=True) 109 | y_norm = torch.clamp(y_norm, min=self.min_norm) 110 | res = torch.zeros_like(x) 111 | theta = torch.clamp(x[:, 0:1] / sqrtK, min=1.0 + self.eps[x.dtype]) 112 | res[:, 1:] = sqrtK * arcosh(theta) * y / y_norm 113 | return res 114 | 115 | def mobius_add(self, x, y, c): 116 | u = self.logmap0(y, c) 117 | v = self.ptransp0(x, u, c) 118 | return self.expmap(v, x, c) 119 | 120 | def mobius_matvec(self, m, x, c): 121 | u = self.logmap0(x, c) 122 | mu = u @ m.transpose(-1, -2) 123 | return self.expmap0(mu, c) 124 | 125 | def ptransp(self, x, y, u, c): 126 | logxy = self.logmap(x, y, c) 127 | logyx = self.logmap(y, x, c) 128 | sqdist = torch.clamp(self.sqdist(x, y, c), min=self.min_norm) 129 | alpha = self.minkowski_dot(logxy, u) / sqdist 130 | res = u - alpha * (logxy + logyx) 131 | return self.proj_tan(res, y, c) 132 | 133 | def ptransp0(self, x, u, c): 134 | K = 1. / c 135 | sqrtK = K ** 0.5 136 | x0 = x.narrow(-1, 0, 1) 137 | d = x.size(-1) - 1 138 | y = x.narrow(-1, 1, d) 139 | y_norm = torch.clamp(torch.norm(y, p=2, dim=1, keepdim=True), min=self.min_norm) 140 | y_normalized = y / y_norm 141 | v = torch.ones_like(x) 142 | v[:, 0:1] = - y_norm 143 | v[:, 1:] = (sqrtK - x0) * y_normalized 144 | alpha = torch.sum(y_normalized * u[:, 1:], dim=1, keepdim=True) / sqrtK 145 | res = u - alpha * v 146 | return self.proj_tan(res, x, c) 147 | 148 | def to_poincare(self, x, c): 149 | K = 1. / c 150 | sqrtK = K ** 0.5 151 | d = x.size(-1) - 1 152 | return sqrtK * x.narrow(-1, 1, d) / (x[:, 0:1] + sqrtK) 153 | 154 | -------------------------------------------------------------------------------- /layers/att_layers.py: -------------------------------------------------------------------------------- 1 | """Attention layers (some modules are copied from https://github.com/Diego999/pyGAT.""" 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | class DenseAtt(nn.Module): 9 | def __init__(self, in_features, dropout): 10 | super(DenseAtt, self).__init__() 11 | self.dropout = dropout 12 | self.linear = nn.Linear(2 * in_features, 1, bias=True) 13 | self.in_features = in_features 14 | 15 | def forward (self, x, adj): 16 | n = x.size(0) 17 | # n x 1 x d 18 | x_left = torch.unsqueeze(x, 1) 19 | x_left = x_left.expand(-1, n, -1) 20 | # 1 x n x d 21 | x_right = torch.unsqueeze(x, 0) 22 | x_right = x_right.expand(n, -1, -1) 23 | 24 | x_cat = torch.cat((x_left, x_right), dim=2) 25 | att_adj = self.linear(x_cat).squeeze() 26 | att_adj = F.sigmoid(att_adj) 27 | att_adj = torch.mul(adj.to_dense(), att_adj) 28 | return att_adj 29 | 30 | 31 | class SpecialSpmmFunction(torch.autograd.Function): 32 | """Special function for only sparse region backpropataion layer.""" 33 | 34 | @staticmethod 35 | def forward(ctx, indices, values, shape, b): 36 | assert indices.requires_grad == False 37 | a = torch.sparse_coo_tensor(indices, values, shape) 38 | ctx.save_for_backward(a, b) 39 | ctx.N = shape[0] 40 | return torch.matmul(a, b) 41 | 42 | @staticmethod 43 | def backward(ctx, grad_output): 44 | a, b = ctx.saved_tensors 45 | grad_values = grad_b = None 46 | if ctx.needs_input_grad[1]: 47 | grad_a_dense = grad_output.matmul(b.t()) 48 | edge_idx = a._indices()[0, :] * ctx.N + a._indices()[1, :] 49 | grad_values = grad_a_dense.view(-1)[edge_idx] 50 | if ctx.needs_input_grad[3]: 51 | grad_b = a.t().matmul(grad_output) 52 | return None, grad_values, None, grad_b 53 | 54 | 55 | class SpecialSpmm(nn.Module): 56 | def forward(self, indices, values, shape, b): 57 | return SpecialSpmmFunction.apply(indices, values, shape, b) 58 | 59 | 60 | class SpGraphAttentionLayer(nn.Module): 61 | """ 62 | Sparse version GAT layer, similar to https://arxiv.org/abs/1710.10903 63 | """ 64 | 65 | def __init__(self, in_features, out_features, dropout, alpha, activation): 66 | super(SpGraphAttentionLayer, self).__init__() 67 | self.in_features = in_features 68 | self.out_features = out_features 69 | self.alpha = alpha 70 | 71 | self.W = nn.Parameter(torch.zeros(size=(in_features, out_features))) 72 | nn.init.xavier_normal_(self.W.data, gain=1.414) 73 | 74 | self.a = nn.Parameter(torch.zeros(size=(1, 2 * out_features))) 75 | nn.init.xavier_normal_(self.a.data, gain=1.414) 76 | 77 | self.dropout = nn.Dropout(dropout) 78 | self.leakyrelu = nn.LeakyReLU(self.alpha) 79 | self.special_spmm = SpecialSpmm() 80 | self.act = activation 81 | 82 | def forward(self, input, adj): 83 | N = input.size()[0] 84 | edge = adj._indices() 85 | 86 | h = torch.mm(input, self.W) 87 | # h: N x out 88 | assert not torch.isnan(h).any() 89 | 90 | # Self-attention on the nodes - Shared attention mechanism 91 | edge_h = torch.cat((h[edge[0, :], :], h[edge[1, :], :]), dim=1).t() 92 | # edge: 2*D x E 93 | 94 | edge_e = torch.exp(-self.leakyrelu(self.a.mm(edge_h).squeeze())) 95 | assert not torch.isnan(edge_e).any() 96 | # edge_e: E 97 | 98 | ones = torch.ones(size=(N, 1)) 99 | if h.is_cuda: 100 | ones = ones.cuda() 101 | e_rowsum = self.special_spmm(edge, edge_e, torch.Size([N, N]), ones) 102 | # e_rowsum: N x 1 103 | 104 | edge_e = self.dropout(edge_e) 105 | # edge_e: E 106 | 107 | h_prime = self.special_spmm(edge, edge_e, torch.Size([N, N]), h) 108 | assert not torch.isnan(h_prime).any() 109 | # h_prime: N x out 110 | 111 | h_prime = h_prime.div(e_rowsum) 112 | # h_prime: N x out 113 | assert not torch.isnan(h_prime).any() 114 | return self.act(h_prime) 115 | 116 | def __repr__(self): 117 | return self.__class__.__name__ + ' (' + str(self.in_features) + ' -> ' + str(self.out_features) + ')' 118 | 119 | 120 | class GraphAttentionLayer(nn.Module): 121 | def __init__(self, input_dim, output_dim, dropout, activation, alpha, nheads, concat): 122 | """Sparse version of GAT.""" 123 | super(GraphAttentionLayer, self).__init__() 124 | self.dropout = dropout 125 | self.output_dim = output_dim 126 | self.attentions = [SpGraphAttentionLayer(input_dim, 127 | output_dim, 128 | dropout=dropout, 129 | alpha=alpha, 130 | activation=activation) for _ in range(nheads)] 131 | self.concat = concat 132 | for i, attention in enumerate(self.attentions): 133 | self.add_module('attention_{}'.format(i), attention) 134 | 135 | def forward(self, input): 136 | x, adj = input 137 | x = F.dropout(x, self.dropout, training=self.training) 138 | if self.concat: 139 | h = torch.cat([att(x, adj) for att in self.attentions], dim=1) 140 | else: 141 | h_cat = torch.cat([att(x, adj).view((-1, self.output_dim, 1)) for att in self.attentions], dim=2) 142 | h = torch.mean(h_cat, dim=2) 143 | h = F.dropout(h, self.dropout, training=self.training) 144 | return (h, adj) 145 | -------------------------------------------------------------------------------- /graph_evaluate/dist_helper.py: -------------------------------------------------------------------------------- 1 | ############################################################################### 2 | # 3 | # Adapted from https://github.com/lrjconan/GRAN/ which in turn is adapted from https://github.com/JiaxuanYou/graph-generation 4 | # 5 | ############################################################################### 6 | import pyemd 7 | import numpy as np 8 | import concurrent.futures 9 | from functools import partial 10 | from scipy.linalg import toeplitz 11 | 12 | 13 | def emd(x, y, distance_scaling=1.0): 14 | support_size = max(len(x), len(y)) 15 | d_mat = toeplitz(range(support_size)).astype(float) 16 | distance_mat = d_mat / distance_scaling 17 | 18 | # convert histogram values x and y to float, and make them equal len 19 | x = x.astype(float) 20 | y = y.astype(float) 21 | if len(x) < len(y): 22 | x = np.hstack((x, [0.0] * (support_size - len(x)))) 23 | elif len(y) < len(x): 24 | y = np.hstack((y, [0.0] * (support_size - len(y)))) 25 | 26 | emd = pyemd.emd(x, y, distance_mat) 27 | return emd 28 | 29 | 30 | 31 | def l2(x, y): 32 | dist = np.linalg.norm(x - y, 2) 33 | return dist 34 | 35 | 36 | def emd(x, y, sigma=1.0, distance_scaling=1.0): 37 | ''' EMD 38 | Args: 39 | x, y: 1D pmf of two distributions with the same support 40 | sigma: standard deviation 41 | ''' 42 | support_size = max(len(x), len(y)) 43 | d_mat = toeplitz(range(support_size)).astype(float) 44 | distance_mat = d_mat / distance_scaling 45 | 46 | # convert histogram values x and y to float, and make them equal len 47 | x = x.astype(float) 48 | y = y.astype(float) 49 | if len(x) < len(y): 50 | x = np.hstack((x, [0.0] * (support_size - len(x)))) 51 | elif len(y) < len(x): 52 | y = np.hstack((y, [0.0] * (support_size - len(y)))) 53 | 54 | return np.abs(pyemd.emd(x, y, distance_mat)) 55 | 56 | 57 | def gaussian_emd(x, y, sigma=1.0, distance_scaling=1.0): 58 | ''' Gaussian kernel with squared distance in exponential term replaced by EMD 59 | Args: 60 | x, y: 1D pmf of two distributions with the same support 61 | sigma: standard deviation 62 | ''' 63 | support_size = max(len(x), len(y)) 64 | d_mat = toeplitz(range(support_size)).astype(float) 65 | distance_mat = d_mat / distance_scaling 66 | 67 | # convert histogram values x and y to float, and make them equal len 68 | x = x.astype(float) 69 | y = y.astype(float) 70 | if len(x) < len(y): 71 | x = np.hstack((x, [0.0] * (support_size - len(x)))) 72 | elif len(y) < len(x): 73 | y = np.hstack((y, [0.0] * (support_size - len(y)))) 74 | 75 | emd = pyemd.emd(x, y, distance_mat) 76 | return np.exp(-emd * emd / (2 * sigma * sigma)) 77 | 78 | 79 | def gaussian(x, y, sigma=1.0): 80 | support_size = max(len(x), len(y)) 81 | # convert histogram values x and y to float, and make them equal len 82 | x = x.astype(float) 83 | y = y.astype(float) 84 | if len(x) < len(y): 85 | x = np.hstack((x, [0.0] * (support_size - len(x)))) 86 | elif len(y) < len(x): 87 | y = np.hstack((y, [0.0] * (support_size - len(y)))) 88 | 89 | dist = np.linalg.norm(x - y, 2) 90 | return np.exp(-dist * dist / (2 * sigma * sigma)) 91 | 92 | 93 | def gaussian_tv(x, y, sigma=1.0): 94 | support_size = max(len(x), len(y)) 95 | # convert histogram values x and y to float, and make them equal len 96 | x = x.astype(float) 97 | y = y.astype(float) 98 | if len(x) < len(y): 99 | x = np.hstack((x, [0.0] * (support_size - len(x)))) 100 | elif len(y) < len(x): 101 | y = np.hstack((y, [0.0] * (support_size - len(y)))) 102 | 103 | dist = np.abs(x - y).sum() / 2.0 104 | return np.exp(-dist * dist / (2 * sigma * sigma)) 105 | 106 | 107 | def kernel_parallel_unpacked(x, samples2, kernel): 108 | d = 0 109 | for s2 in samples2: 110 | d += kernel(x, s2) 111 | return d 112 | 113 | 114 | def kernel_parallel_worker(t): 115 | return kernel_parallel_unpacked(*t) 116 | 117 | 118 | def disc(samples1, samples2, kernel, is_parallel=True, *args, **kwargs): 119 | ''' Discrepancy between 2 samples ''' 120 | d = 0 121 | 122 | if not is_parallel: 123 | for s1 in samples1: 124 | for s2 in samples2: 125 | d += kernel(s1, s2, *args, **kwargs) 126 | else: 127 | with concurrent.futures.ThreadPoolExecutor() as executor: 128 | for dist in executor.map(kernel_parallel_worker, [ 129 | (s1, samples2, partial(kernel, *args, **kwargs)) for s1 in samples1 130 | ]): 131 | d += dist 132 | if len(samples1) * len(samples2) > 0: 133 | d /= len(samples1) * len(samples2) 134 | else: 135 | d = 1e+6 136 | return d 137 | 138 | 139 | def compute_mmd(samples1, samples2, kernel, is_hist=True, *args, **kwargs): 140 | ''' MMD between two samples ''' 141 | # normalize histograms into pmf 142 | if is_hist: 143 | samples1 = [s1 / (np.sum(s1) + 1e-6) for s1 in samples1] 144 | samples2 = [s2 / (np.sum(s2) + 1e-6) for s2 in samples2] 145 | return disc(samples1, samples1, kernel, *args, **kwargs) + disc(samples2, samples2, kernel, *args, **kwargs) - \ 146 | 2 * disc(samples1, samples2, kernel, *args, **kwargs) 147 | 148 | 149 | def compute_emd(samples1, samples2, kernel, is_hist=True, *args, **kwargs): 150 | ''' EMD between average of two samples ''' 151 | # normalize histograms into pmf 152 | if is_hist: 153 | samples1 = [np.mean(samples1)] 154 | samples2 = [np.mean(samples2)] 155 | return disc(samples1, samples2, kernel, *args, 156 | **kwargs), [samples1[0], samples2[0]] 157 | -------------------------------------------------------------------------------- /mmd_rnn.py: -------------------------------------------------------------------------------- 1 | import concurrent.futures 2 | from functools import partial 3 | import networkx as nx 4 | import numpy as np 5 | from scipy.linalg import toeplitz 6 | import pyemd 7 | 8 | # source: https://github.com/JiaxuanYou/graph-generation 9 | def emd(x, y, distance_scaling=1.0): 10 | support_size = max(len(x), len(y)) 11 | d_mat = toeplitz(range(support_size)).astype(np.float) 12 | distance_mat = d_mat / distance_scaling 13 | 14 | # convert histogram values x and y to float, and make them equal len 15 | x = x.astype(np.float) 16 | y = y.astype(np.float) 17 | if len(x) < len(y): 18 | x = np.hstack((x, [0.0] * (support_size - len(x)))) 19 | elif len(y) < len(x): 20 | y = np.hstack((y, [0.0] * (support_size - len(y)))) 21 | 22 | emd = pyemd.emd(x, y, distance_mat) 23 | return emd 24 | 25 | 26 | def l2(x, y): 27 | dist = np.linalg.norm(x - y, 2) 28 | return dist 29 | 30 | 31 | def gaussian_tv(x, y, sigma=1.0): 32 | support_size = max(len(x), len(y)) 33 | # convert histogram values x and y to float, and make them equal len 34 | x = x.astype(np.float) 35 | y = y.astype(np.float) 36 | if len(x) < len(y): 37 | x = np.hstack((x, [0.0] * (support_size - len(x)))) 38 | elif len(y) < len(x): 39 | y = np.hstack((y, [0.0] * (support_size - len(y)))) 40 | 41 | dist = np.abs(x - y).sum() / 2.0 42 | return np.exp(-dist * dist / (2 * sigma * sigma)) 43 | 44 | def gaussian_emd(x, y, sigma=1.0, distance_scaling=1.0): 45 | ''' Gaussian kernel with squared distance in exponential term replaced by EMD 46 | Args: 47 | x, y: 1D pmf of two distributions with the same support 48 | sigma: standard deviation 49 | ''' 50 | support_size = max(len(x), len(y)) 51 | d_mat = toeplitz(range(support_size)).astype(np.float) 52 | distance_mat = d_mat / distance_scaling 53 | 54 | # convert histogram values x and y to float, and make them equal len 55 | x = x.astype(np.float) 56 | y = y.astype(np.float) 57 | if len(x) < len(y): 58 | x = np.hstack((x, [0.0] * (support_size - len(x)))) 59 | elif len(y) < len(x): 60 | y = np.hstack((y, [0.0] * (support_size - len(y)))) 61 | 62 | emd = pyemd.emd(x, y, distance_mat) 63 | return np.exp(-emd * emd / (2 * sigma * sigma)) 64 | 65 | 66 | def gaussian(x, y, sigma=1.0): 67 | dist = np.linalg.norm(x - y, 2) 68 | return np.exp(-dist * dist / (2 * sigma * sigma)) 69 | 70 | 71 | def kernel_parallel_unpacked(x, samples2, kernel): 72 | d = 0 73 | for s2 in samples2: 74 | d += kernel(x, s2) 75 | return d 76 | 77 | 78 | def kernel_parallel_worker(t): 79 | return kernel_parallel_unpacked(*t) 80 | 81 | 82 | def disc(samples1, samples2, kernel, is_parallel=False, *args, **kwargs): 83 | ''' Discrepancy between 2 samples 84 | ''' 85 | d = 0 86 | if not is_parallel: 87 | for s1 in samples1: 88 | for s2 in samples2: 89 | d += kernel(s1, s2, *args, **kwargs) 90 | else: 91 | with concurrent.futures.ProcessPoolExecutor() as executor: 92 | for dist in executor.map(kernel_parallel_worker, 93 | [(s1, samples2, partial(kernel, *args, **kwargs)) for s1 in samples1]): 94 | d += dist 95 | d /= len(samples1) * len(samples2) 96 | return d 97 | 98 | 99 | def compute_mmd(samples1, samples2, kernel, is_hist=True, *args, **kwargs): 100 | ''' MMD between two samples 101 | ''' 102 | # normalize histograms into pmf 103 | if is_hist: 104 | samples1 = [s1 / np.sum(s1) for s1 in samples1] 105 | samples2 = [s2 / np.sum(s2) for s2 in samples2] 106 | # print('===============================') 107 | # print('s1: ', disc(samples1, samples1, kernel, *args, **kwargs)) 108 | # print('--------------------------') 109 | # print('s2: ', disc(samples2, samples2, kernel, *args, **kwargs)) 110 | # print('--------------------------') 111 | # print('cross: ', disc(samples1, samples2, kernel, *args, **kwargs)) 112 | # print('===============================') 113 | return disc(samples1, samples1, kernel, *args, **kwargs) + \ 114 | disc(samples2, samples2, kernel, *args, **kwargs) - \ 115 | 2 * disc(samples1, samples2, kernel, *args, **kwargs) 116 | # return disc(samples1, samples1, kernel, *args, **kwargs) 117 | 118 | def compute_emd(samples1, samples2, kernel, is_hist=True, *args, **kwargs): 119 | ''' EMD between average of two samples 120 | ''' 121 | # normalize histograms into pmf 122 | if is_hist: 123 | samples1 = [np.mean(samples1)] 124 | samples2 = [np.mean(samples2)] 125 | # print('===============================') 126 | # print('s1: ', disc(samples1, samples1, kernel, *args, **kwargs)) 127 | # print('--------------------------') 128 | # print('s2: ', disc(samples2, samples2, kernel, *args, **kwargs)) 129 | # print('--------------------------') 130 | # print('cross: ', disc(samples1, samples2, kernel, *args, **kwargs)) 131 | # print('===============================') 132 | return disc(samples1, samples2, kernel, *args, **kwargs), [samples1[0], samples2[0]] 133 | 134 | 135 | def test(): 136 | s1 = np.array([0.2, 0.8]) 137 | s2 = np.array([0.3, 0.7]) 138 | samples1 = [s1, s2] 139 | 140 | s3 = np.array([0.25, 0.75]) 141 | s4 = np.array([0.35, 0.65]) 142 | samples2 = [s3, s4] 143 | 144 | s5 = np.array([0.8, 0.2]) 145 | s6 = np.array([0.7, 0.3]) 146 | samples3 = [s5, s6] 147 | 148 | print('between samples1 and samples2: ', compute_mmd(samples1, samples2, kernel=gaussian_emd, 149 | is_parallel=False, sigma=1.0)) 150 | print('between samples1 and samples3: ', compute_mmd(samples1, samples3, kernel=gaussian_emd, 151 | is_parallel=False, sigma=1.0)) 152 | print('between samples1 and samples3: ', compute_mmd(samples1, samples3, kernel=gaussian_tv)) 153 | 154 | if __name__ == '__main__': 155 | test() 156 | -------------------------------------------------------------------------------- /models/model.py: -------------------------------------------------------------------------------- 1 | import models.Aggregation as Aggregation 2 | import dgl 3 | from utils.util import * 4 | 5 | class AveEncoder(torch.nn.Module): 6 | def __init__(self, in_feature_dim, hiddenLayers=[256, 256, 256], GraphLatntDim=1024): 7 | super(AveEncoder, self).__init__() 8 | 9 | hiddenLayers = [in_feature_dim] + hiddenLayers + [GraphLatntDim] 10 | self.normLayers = torch.nn.ModuleList( 11 | [torch.nn.LayerNorm(hiddenLayers[i + 1], elementwise_affine=False) for i in range(len(hiddenLayers) - 1)]) 12 | self.normLayers.append(torch.nn.LayerNorm(hiddenLayers[-1], elementwise_affine=False)) 13 | self.GCNlayers = torch.nn.ModuleList([dgl.nn.pytorch.conv.GraphConv(hiddenLayers[i], hiddenLayers[i + 1], 14 | activation=None, bias=True, weight=True) for 15 | i in range(len(hiddenLayers) - 1)]) 16 | 17 | self.poolingLayer = Aggregation.AvePool() 18 | 19 | self.stochastic_mean_layer = node_mlp(GraphLatntDim, [GraphLatntDim]) 20 | self.stochastic_log_std_layer = node_mlp(GraphLatntDim, [GraphLatntDim]) 21 | 22 | def forward(self, graph, features, batchSize, activation=torch.nn.LeakyReLU(0.01)): 23 | h = features 24 | for i in range(len(self.GCNlayers)): 25 | h = self.GCNlayers[i](graph, h) 26 | h = activation(h) 27 | # if((itd', torch.sqrt(1 - self.alpha_bars).to(x.device), epsilon) 119 | xt = torch.einsum('t,d->td', torch.sqrt(self.alpha_bars).to(x.device), x) 120 | 121 | xt_q = xt + epsilons 122 | et = self.backbone(xt_q, idx) 123 | 124 | # Gaussian D_KL(m1, s1, m2, s2) = log(s2/s1) + 0.5 * (s1**2 + (m1-m2)**2) / s2**2 - 0.5 125 | 126 | # LT = D_KL( q(xT|x0) ; p(xT) ) 127 | sigma_T = torch.sqrt(1 - self.alpha_bars[-1]).to(x.device) 128 | LT = torch.log(1 / sigma_T) + 0.5 * (sigma_T ** 2 + xt[-1].square().mean()) - 0.5 129 | 130 | # Lt = D_KL( q(x_t-1|xt,x0) ; p(x_t-1|xt) ) = 0.5 * ( 1 - SNR_t-1/SNRt ) ||e-et||**2 131 | Lt = 0.5 * ((1 - self.snr[:-1] / self.snr[1:]).to(x.device) * (et - epsilon.reshape(1, -1)).square().mean( 132 | dim=1)[1:]).sum() 133 | 134 | # L0 = - log p(x0|x1) = - log( 1/sigma/sqrt(2pi) * exp( -1/2/sigma**2 * (x-mu)**2 ) 135 | sigma_0 = torch.sqrt((1 - self.alpha_bars[0]) / (1 - self.alpha_bars[1]) * self.beta[1]) 136 | L0 = - torch.log(1 / sigma_0 / torch.sqrt(2 * torch.tensor(np.pi).to(x.device)) * torch.exp( 137 | -0.5 * (et[0] - epsilon) ** 2) + 1e-10).mean() # add 1e-10 to avoid inf 138 | 139 | nll = LT + Lt + L0 140 | 141 | return nll 142 | -------------------------------------------------------------------------------- /graph_evaluate/baselines/mmsb.py: -------------------------------------------------------------------------------- 1 | """Stochastic block model.""" 2 | 3 | import argparse 4 | import os 5 | from time import time 6 | 7 | import edward as ed 8 | import networkx as nx 9 | import numpy as np 10 | import tensorflow as tf 11 | 12 | from edward.models import Bernoulli, Multinomial, Beta, Dirichlet, PointMass, Normal 13 | from observations import karate 14 | from sklearn.metrics.cluster import adjusted_rand_score 15 | 16 | import utils 17 | 18 | CUDA = 2 19 | ed.set_seed(int(time())) 20 | #ed.set_seed(42) 21 | 22 | # DATA 23 | #X_data, Z_true = karate("data") 24 | 25 | def disjoint_cliques_test_graph(num_cliques, clique_size): 26 | G = nx.disjoint_union_all([nx.complete_graph(clique_size) for _ in range(num_cliques)]) 27 | return nx.to_numpy_matrix(G) 28 | 29 | def mmsb(N, K, data): 30 | # sparsity 31 | rho = 0.3 32 | # MODEL 33 | # probability of belonging to each of K blocks for each node 34 | gamma = Dirichlet(concentration=tf.ones([K])) 35 | # block connectivity 36 | Pi = Beta(concentration0=tf.ones([K, K]), concentration1=tf.ones([K, K])) 37 | # probability of belonging to each of K blocks for all nodes 38 | Z = Multinomial(total_count=1.0, probs=gamma, sample_shape=N) 39 | # adjacency 40 | X = Bernoulli(probs=(1 - rho) * tf.matmul(Z, tf.matmul(Pi, tf.transpose(Z)))) 41 | 42 | # INFERENCE (EM algorithm) 43 | qgamma = PointMass(params=tf.nn.softmax(tf.Variable(tf.random_normal([K])))) 44 | qPi = PointMass(params=tf.nn.sigmoid(tf.Variable(tf.random_normal([K, K])))) 45 | qZ = PointMass(params=tf.nn.softmax(tf.Variable(tf.random_normal([N, K])))) 46 | 47 | #qgamma = Normal(loc=tf.get_variable("qgamma/loc", [K]), 48 | # scale=tf.nn.softplus( 49 | # tf.get_variable("qgamma/scale", [K]))) 50 | #qPi = Normal(loc=tf.get_variable("qPi/loc", [K, K]), 51 | # scale=tf.nn.softplus( 52 | # tf.get_variable("qPi/scale", [K, K]))) 53 | #qZ = Normal(loc=tf.get_variable("qZ/loc", [N, K]), 54 | # scale=tf.nn.softplus( 55 | # tf.get_variable("qZ/scale", [N, K]))) 56 | 57 | #inference = ed.KLqp({gamma: qgamma, Pi: qPi, Z: qZ}, data={X: data}) 58 | inference = ed.MAP({gamma: qgamma, Pi: qPi, Z: qZ}, data={X: data}) 59 | 60 | #inference.run() 61 | n_iter = 6000 62 | inference.initialize(optimizer=tf.train.AdamOptimizer(learning_rate=0.01), n_iter=n_iter) 63 | 64 | tf.global_variables_initializer().run() 65 | 66 | for _ in range(inference.n_iter): 67 | info_dict = inference.update() 68 | inference.print_progress(info_dict) 69 | 70 | inference.finalize() 71 | print('qgamma after: ', qgamma.mean().eval()) 72 | return qZ.mean().eval(), qPi.eval() 73 | 74 | def arg_parse(): 75 | parser = argparse.ArgumentParser(description='MMSB arguments.') 76 | parser.add_argument('--dataset', dest='dataset', 77 | help='Input dataset.') 78 | parser.add_argument('--K', dest='K', type=int, 79 | help='Number of blocks.') 80 | parser.add_argument('--samples-per-G', dest='samples', type=int, 81 | help='Number of samples for every graph.') 82 | 83 | parser.set_defaults(dataset='community', 84 | K=4, 85 | samples=1) 86 | return parser.parse_args() 87 | 88 | def graph_gen_from_blockmodel(B, Z): 89 | n_blocks = len(B) 90 | B = np.array(B) 91 | Z = np.array(Z) 92 | adj_prob = np.dot(Z, np.dot(B, np.transpose(Z))) 93 | adj = np.random.binomial(1, adj_prob * 0.3) 94 | return nx.from_numpy_matrix(adj) 95 | 96 | if __name__ == '__main__': 97 | prog_args = arg_parse() 98 | os.environ['CUDA_VISIBLE_DEVICES'] = str(CUDA) 99 | print('CUDA', CUDA) 100 | 101 | X_dataset = [] 102 | #X_data = nx.to_numpy_matrix(nx.connected_caveman_graph(4, 7)) 103 | if prog_args.dataset == 'clique_test': 104 | X_data = disjoint_cliques_test_graph(4, 7) 105 | X_dataset.append(X_data) 106 | elif prog_args.dataset == 'citeseer': 107 | graphs = utils.citeseer_ego() 108 | X_dataset = [nx.to_numpy_matrix(g) for g in graphs] 109 | elif prog_args.dataset == 'community': 110 | graphs = [] 111 | for i in range(2, 3): 112 | for j in range(30, 81): 113 | for k in range(10): 114 | graphs.append(utils.caveman_special(i,j, p_edge=0.3)) 115 | X_dataset = [nx.to_numpy_matrix(g) for g in graphs] 116 | elif prog_args.dataset == 'grid': 117 | graphs = [] 118 | for i in range(10,20): 119 | for j in range(10,20): 120 | graphs.append(nx.grid_2d_graph(i,j)) 121 | X_dataset = [nx.to_numpy_matrix(g) for g in graphs] 122 | elif prog_args.dataset.startswith('community'): 123 | graphs = [] 124 | num_communities = int(prog_args.dataset[-1]) 125 | print('Creating dataset with ', num_communities, ' communities') 126 | c_sizes = np.random.choice([12, 13, 14, 15, 16, 17], num_communities) 127 | for k in range(3000): 128 | graphs.append(utils.n_community(c_sizes, p_inter=0.01)) 129 | X_dataset = [nx.to_numpy_matrix(g) for g in graphs] 130 | 131 | print('Number of graphs: ', len(X_dataset)) 132 | K = prog_args.K # number of clusters 133 | gen_graphs = [] 134 | for i in range(len(X_dataset)): 135 | if i % 5 == 0: 136 | print(i) 137 | X_data = X_dataset[i] 138 | N = X_data.shape[0] # number of vertices 139 | 140 | Zp, B = mmsb(N, K, X_data) 141 | #print("Block: ", B) 142 | Z_pred = Zp.argmax(axis=1) 143 | print("Result (label flip can happen):") 144 | #print("prob: ", Zp) 145 | print("Predicted") 146 | print(Z_pred) 147 | #print(Z_true) 148 | #print("Adjusted Rand Index =", adjusted_rand_score(Z_pred, Z_true)) 149 | for j in range(prog_args.samples): 150 | gen_graphs.append(graph_gen_from_blockmodel(B, Zp)) 151 | 152 | save_path = '/lfs/local/0/rexy/graph-generation/eval_results/mmsb/' 153 | utils.save_graph_list(gen_graphs, os.path.join(save_path, prog_args.dataset + '.dat')) 154 | 155 | -------------------------------------------------------------------------------- /models/hyp_model.py: -------------------------------------------------------------------------------- 1 | import models.Aggregation as Aggregation 2 | import dgl 3 | from utils.util import * 4 | from models.hyp_layers import HNNLayer 5 | from manifolds import PoincareBall 6 | class HAveEncoder(torch.nn.Module): 7 | def __init__(self, in_feature_dim, hiddenLayers=[256, 256, 256], GraphLatntDim=1024): 8 | super(HAveEncoder, self).__init__() 9 | self.manifold = PoincareBall() 10 | hiddenLayers = [in_feature_dim] + hiddenLayers + [GraphLatntDim] 11 | self.normLayers = torch.nn.ModuleList( 12 | [torch.nn.LayerNorm(hiddenLayers[i + 1], elementwise_affine=False) for i in range(len(hiddenLayers) - 1)]) 13 | self.normLayers.append(torch.nn.LayerNorm(hiddenLayers[-1], elementwise_affine=False)) 14 | self.GCNlayers = torch.nn.ModuleList([dgl.nn.pytorch.conv.GraphConv(hiddenLayers[i], hiddenLayers[i + 1], 15 | activation=None, bias=True, weight=True) for 16 | i in range(len(hiddenLayers) - 1)]) 17 | 18 | self.poolingLayer = Aggregation.AvePool() 19 | self.HNNlayers= HNNLayer(GraphLatntDim, GraphLatntDim, self.manifold,self.manifold) 20 | self.stochastic_mean_layer = node_mlp(GraphLatntDim, [GraphLatntDim]) 21 | self.stochastic_log_std_layer = node_mlp(GraphLatntDim, [GraphLatntDim]) 22 | 23 | def forward(self, graph, features, batchSize, activation=torch.nn.LeakyReLU(0.01)): 24 | h = features 25 | for i in range(len(self.GCNlayers)): 26 | h = self.GCNlayers[i](graph, h) 27 | h = activation(h) 28 | # if((i 0 38 | dims, acts = get_dim_act(args) 39 | layers = [] 40 | for i in range(len(dims) - 1): 41 | in_dim, out_dim = dims[i], dims[i + 1] 42 | act = acts[i] 43 | layers.append(Linear(in_dim, out_dim, args.dropout, act, args.bias)) 44 | self.layers = nn.Sequential(*layers) 45 | self.encode_graph = False 46 | 47 | 48 | class HNN(Encoder): 49 | """ 50 | Hyperbolic Neural Networks. 51 | """ 52 | 53 | def __init__(self, c, args): 54 | super(HNN, self).__init__(c) 55 | self.manifold = getattr(manifolds, args.manifold)() 56 | assert args.num_layers > 1 57 | dims, acts, _ = hyp_layers.get_dim_act_curv(args) 58 | hnn_layers = [] 59 | for i in range(len(dims) - 1): 60 | in_dim, out_dim = dims[i], dims[i + 1] 61 | act = acts[i] 62 | hnn_layers.append( 63 | hyp_layers.HNNLayer( 64 | self.manifold, in_dim, out_dim, self.c, args.dropout, act, args.bias) 65 | ) 66 | self.layers = nn.Sequential(*hnn_layers) 67 | self.encode_graph = False 68 | 69 | def encode(self, x, adj): 70 | x_hyp = self.manifold.proj(self.manifold.expmap0(self.manifold.proj_tan0(x, self.c), c=self.c), c=self.c) 71 | return super(HNN, self).encode(x_hyp, adj) 72 | 73 | class GCN(Encoder): 74 | """ 75 | Graph Convolution Networks. 76 | """ 77 | 78 | def __init__(self, c, args): 79 | super(GCN, self).__init__(c) 80 | assert args.num_layers > 0 81 | dims, acts = get_dim_act(args) 82 | gc_layers = [] 83 | for i in range(len(dims) - 1): 84 | in_dim, out_dim = dims[i], dims[i + 1] 85 | act = acts[i] 86 | gc_layers.append(GraphConvolution(in_dim, out_dim, args.dropout, act, args.bias)) 87 | self.layers = nn.Sequential(*gc_layers) 88 | self.encode_graph = True 89 | 90 | 91 | class HGCN(Encoder): 92 | """ 93 | Hyperbolic-GCN. 94 | """ 95 | 96 | def __init__(self, c, args): 97 | super(HGCN, self).__init__(c) 98 | self.manifold = getattr(manifolds, args.manifold)() 99 | assert args.num_layers > 1 100 | dims, acts, self.curvatures = hyp_layers.get_dim_act_curv(args) 101 | self.curvatures.append(self.c) 102 | hgc_layers = [] 103 | for i in range(len(dims) - 1): 104 | c_in, c_out = self.curvatures[i], self.curvatures[i + 1] 105 | in_dim, out_dim = dims[i], dims[i + 1] 106 | act = acts[i] 107 | hgc_layers.append( 108 | hyp_layers.HyperbolicGraphConvolution( 109 | self.manifold, in_dim, out_dim, c_in, c_out, args.dropout, act, args.bias, args.use_att, args.local_agg 110 | ) 111 | ) 112 | self.layers = nn.Sequential(*hgc_layers) 113 | self.encode_graph = True 114 | 115 | def encode(self, x, adj): 116 | x_tan = self.manifold.proj_tan0(x, self.curvatures[0]) 117 | x_hyp = self.manifold.expmap0(x_tan, c=self.curvatures[0]) 118 | x_hyp = self.manifold.proj(x_hyp, c=self.curvatures[0]) 119 | return super(HGCN, self).encode(x_hyp, adj) 120 | 121 | 122 | class GAT(Encoder): 123 | """ 124 | Graph Attention Networks. 125 | """ 126 | 127 | def __init__(self, c, args): 128 | super(GAT, self).__init__(c) 129 | assert args.num_layers > 0 130 | dims, acts = get_dim_act(args) 131 | gat_layers = [] 132 | for i in range(len(dims) - 1): 133 | in_dim, out_dim = dims[i], dims[i + 1] 134 | act = acts[i] 135 | assert dims[i + 1] % args.n_heads == 0 136 | out_dim = dims[i + 1] // args.n_heads 137 | concat = True 138 | gat_layers.append( 139 | GraphAttentionLayer(in_dim, out_dim, args.dropout, act, args.alpha, args.n_heads, concat)) 140 | self.layers = nn.Sequential(*gat_layers) 141 | self.encode_graph = True 142 | 143 | 144 | class Shallow(Encoder): 145 | """ 146 | Shallow Embedding method. 147 | Learns embeddings or loads pretrained embeddings and uses an MLP for classification. 148 | """ 149 | 150 | def __init__(self, c, args): 151 | super(Shallow, self).__init__(c) 152 | self.manifold = getattr(manifolds, args.manifold)() 153 | self.use_feats = args.use_feats 154 | weights = torch.Tensor(args.n_nodes, args.dim) 155 | if not args.pretrained_embeddings: 156 | weights = self.manifold.init_weights(weights, self.c) 157 | trainable = True 158 | else: 159 | weights = torch.Tensor(np.load(args.pretrained_embeddings)) 160 | assert weights.shape[0] == args.n_nodes, "The embeddings you passed seem to be for another dataset." 161 | trainable = False 162 | self.lt = manifolds.ManifoldParameter(weights, trainable, self.manifold, self.c) 163 | self.all_nodes = torch.LongTensor(list(range(args.n_nodes))) 164 | layers = [] 165 | if args.pretrained_embeddings is not None and args.num_layers > 0: 166 | # MLP layers after pre-trained embeddings 167 | dims, acts = get_dim_act(args) 168 | if self.use_feats: 169 | dims[0] = args.feat_dim + weights.shape[1] 170 | else: 171 | dims[0] = weights.shape[1] 172 | for i in range(len(dims) - 1): 173 | in_dim, out_dim = dims[i], dims[i + 1] 174 | act = acts[i] 175 | layers.append(Linear(in_dim, out_dim, args.dropout, act, args.bias)) 176 | self.layers = nn.Sequential(*layers) 177 | self.encode_graph = False 178 | 179 | def encode(self, x, adj): 180 | h = self.lt[self.all_nodes, :] 181 | if self.use_feats: 182 | h = torch.cat((h, x), 1) 183 | return super(Shallow, self).encode(h, adj) 184 | -------------------------------------------------------------------------------- /optimizers/radam.py: -------------------------------------------------------------------------------- 1 | """Riemannian adam optimizer geoopt implementation (https://github.com/geoopt/).""" 2 | import torch.optim 3 | from manifolds import Euclidean, ManifoldParameter 4 | 5 | # in order not to create it at each iteration 6 | _default_manifold = Euclidean() 7 | 8 | 9 | class OptimMixin(object): 10 | def __init__(self, *args, stabilize=None, **kwargs): 11 | self._stabilize = stabilize 12 | super().__init__(*args, **kwargs) 13 | 14 | def stabilize_group(self, group): 15 | pass 16 | 17 | def stabilize(self): 18 | """Stabilize parameters if they are off-manifold due to numerical reasons 19 | """ 20 | for group in self.param_groups: 21 | self.stabilize_group(group) 22 | 23 | 24 | def copy_or_set_(dest, source): 25 | """ 26 | A workaround to respect strides of :code:`dest` when copying :code:`source` 27 | (https://github.com/geoopt/geoopt/issues/70) 28 | Parameters 29 | ---------- 30 | dest : torch.Tensor 31 | Destination tensor where to store new data 32 | source : torch.Tensor 33 | Source data to put in the new tensor 34 | Returns 35 | ------- 36 | dest 37 | torch.Tensor, modified inplace 38 | """ 39 | if dest.stride() != source.stride(): 40 | return dest.copy_(source) 41 | else: 42 | return dest.set_(source) 43 | 44 | 45 | class RiemannianAdam(OptimMixin, torch.optim.Adam): 46 | r"""Riemannian Adam with the same API as :class:`torch.optim.Adam` 47 | Parameters 48 | ---------- 49 | params : iterable 50 | iterable of parameters to optimize or dicts defining 51 | parameter groups 52 | lr : float (optional) 53 | learning rate (default: 1e-3) 54 | betas : Tuple[float, float] (optional) 55 | coefficients used for computing 56 | running averages of gradient and its square (default: (0.9, 0.999)) 57 | eps : float (optional) 58 | term added to the denominator to improve 59 | numerical stability (default: 1e-8) 60 | weight_decay : float (optional) 61 | weight decay (L2 penalty) (default: 0) 62 | amsgrad : bool (optional) 63 | whether to use the AMSGrad variant of this 64 | algorithm from the paper `On the Convergence of Adam and Beyond`_ 65 | (default: False) 66 | Other Parameters 67 | ---------------- 68 | stabilize : int 69 | Stabilize parameters if they are off-manifold due to numerical 70 | reasons every ``stabilize`` steps (default: ``None`` -- no stabilize) 71 | .. _On the Convergence of Adam and Beyond: 72 | https://openreview.net/forum?id=ryQu7f-RZ 73 | """ 74 | 75 | def step(self, closure=None): 76 | """Performs a single optimization step. 77 | Arguments 78 | --------- 79 | closure : callable (optional) 80 | A closure that reevaluates the model 81 | and returns the loss. 82 | """ 83 | loss = None 84 | if closure is not None: 85 | loss = closure() 86 | with torch.no_grad(): 87 | for group in self.param_groups: 88 | if "step" not in group: 89 | group["step"] = 0 90 | betas = group["betas"] 91 | weight_decay = group["weight_decay"] 92 | eps = group["eps"] 93 | learning_rate = group["lr"] 94 | amsgrad = group["amsgrad"] 95 | for point in group["params"]: 96 | grad = point.grad 97 | if grad is None: 98 | continue 99 | if isinstance(point, (ManifoldParameter)): 100 | manifold = point.manifold 101 | c = point.c 102 | else: 103 | manifold = _default_manifold 104 | c = None 105 | if grad.is_sparse: 106 | raise RuntimeError( 107 | "Riemannian Adam does not support sparse gradients yet (PR is welcome)" 108 | ) 109 | 110 | state = self.state[point] 111 | 112 | # State initialization 113 | if len(state) == 0: 114 | state["step"] = 0 115 | # Exponential moving average of gradient values 116 | state["exp_avg"] = torch.zeros_like(point) 117 | # Exponential moving average of squared gradient values 118 | state["exp_avg_sq"] = torch.zeros_like(point) 119 | if amsgrad: 120 | # Maintains max of all exp. moving avg. of sq. grad. values 121 | state["max_exp_avg_sq"] = torch.zeros_like(point) 122 | # make local variables for easy access 123 | exp_avg = state["exp_avg"] 124 | exp_avg_sq = state["exp_avg_sq"] 125 | # actual step 126 | grad.add_(weight_decay, point) 127 | grad = manifold.egrad2rgrad(point, grad, c) 128 | exp_avg.mul_(betas[0]).add_(1 - betas[0], grad) 129 | exp_avg_sq.mul_(betas[1]).add_( 130 | 1 - betas[1], manifold.inner(point, c, grad, keepdim=True) 131 | ) 132 | if amsgrad: 133 | max_exp_avg_sq = state["max_exp_avg_sq"] 134 | # Maintains the maximum of all 2nd moment running avg. till now 135 | torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) 136 | # Use the max. for normalizing running avg. of gradient 137 | denom = max_exp_avg_sq.sqrt().add_(eps) 138 | else: 139 | denom = exp_avg_sq.sqrt().add_(eps) 140 | group["step"] += 1 141 | bias_correction1 = 1 - betas[0] ** group["step"] 142 | bias_correction2 = 1 - betas[1] ** group["step"] 143 | step_size = ( 144 | learning_rate * bias_correction2 ** 0.5 / bias_correction1 145 | ) 146 | 147 | # copy the state, we need it for retraction 148 | # get the direction for ascend 149 | direction = exp_avg / denom 150 | # transport the exponential averaging to the new point 151 | new_point = manifold.proj(manifold.expmap(-step_size * direction, point, c), c) 152 | exp_avg_new = manifold.ptransp(point, new_point, exp_avg, c) 153 | # use copy only for user facing point 154 | copy_or_set_(point, new_point) 155 | exp_avg.set_(exp_avg_new) 156 | 157 | group["step"] += 1 158 | if self._stabilize is not None and group["step"] % self._stabilize == 0: 159 | self.stabilize_group(group) 160 | return loss 161 | 162 | @torch.no_grad() 163 | def stabilize_group(self, group): 164 | for p in group["params"]: 165 | if not isinstance(p, ManifoldParameter): 166 | continue 167 | state = self.state[p] 168 | if not state: # due to None grads 169 | continue 170 | manifold = p.manifold 171 | c = p.c 172 | exp_avg = state["exp_avg"] 173 | copy_or_set_(p, manifold.proj(p, c)) 174 | exp_avg.set_(manifold.proj_tan(exp_avg, u, c)) 175 | -------------------------------------------------------------------------------- /models/base_models.py: -------------------------------------------------------------------------------- 1 | """Base model class.""" 2 | 3 | import numpy as np 4 | from sklearn.metrics import roc_auc_score, average_precision_score 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | from layers.layers import FermiDiracDecoder 10 | import manifolds 11 | import models.encoders as encoders 12 | from torch_geometric.nn import GCNConv 13 | 14 | class BaseModel(nn.Module): 15 | """ 16 | Base model for graph embedding tasks. 17 | """ 18 | 19 | def __init__(self, args): 20 | super(BaseModel, self).__init__() 21 | self.manifold_name = args.manifold 22 | self.num_timesteps=1000 23 | if args.c is not None: 24 | self.c = torch.tensor([args.c]) 25 | if not args.cuda == -1: 26 | self.c = self.c.to(args.device) 27 | else: 28 | self.c = nn.Parameter(torch.Tensor([1.])) 29 | self.manifold = getattr(manifolds, self.manifold_name)() 30 | if self.manifold.name == 'Hyperboloid': 31 | args.feat_dim = args.feat_dim + 1 32 | self.nnodes = args.n_nodes 33 | self.encoder = getattr(encoders, args.model)(self.c, args) 34 | def exists(self,x): 35 | return x is not None 36 | 37 | def default(self,val, d): 38 | if self.exists(val): 39 | return val 40 | return d() if callable(d) else d 41 | def cal_u0(self,x): 42 | x=self.manifold.logmap0(x,c=1.0) 43 | u0=torch.mean(x,dim=0) 44 | u0=self.manifold.expmap0(u0,c=1.0) 45 | return u0 46 | def extract(self, a, t, x_shape): 47 | b, *_ = t.shape 48 | out = a.gather(-1, t) 49 | return out.reshape(b, *((1,) * (len(x_shape) - 1))) 50 | 51 | def tran_direction(self,direction_vector, gaussian_point): 52 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 53 | 54 | transformed_vector = torch.sign(direction_vector) 55 | transformed_point= gaussian_point*transformed_vector 56 | return transformed_point 57 | 58 | def get_alphas(self,timesteps): 59 | def linear_beta_schedule(timesteps): 60 | scale = 1000 / timesteps 61 | beta_start = scale * 0.0001 62 | beta_end = scale * 0.02 63 | return torch.linspace(beta_start, beta_end, timesteps, dtype = torch.float64) 64 | betas = linear_beta_schedule(timesteps) 65 | alphas = 1. - betas 66 | alphas_cumprod = torch.cumprod(alphas, dim=0) 67 | alphas_minus=torch.sqrt(1. - alphas_cumprod) 68 | return torch.sqrt(alphas_cumprod),alphas_minus 69 | def q_sample(self, x_start, t,direction, noise=None): 70 | noise = self.default(noise, lambda: torch.randn_like(x_start)) 71 | alphas_cumprod, minus_alphas_cumprod = self.get_alphas(1000) 72 | alphas_cumprod=alphas_cumprod.cuda() 73 | minus_alphas_cumprod=minus_alphas_cumprod.cuda() 74 | return ( 75 | self.extract(alphas_cumprod, t, x_start.shape) * x_start + 76 | self.extract(minus_alphas_cumprod, t, x_start.shape) * noise 77 | ) 78 | def encode(self, x, adj): 79 | if self.manifold.name == 'Hyperboloid': 80 | o = torch.zeros_like(x) 81 | x = torch.cat([o[:, 0:1], x], dim=1) 82 | h = self.encoder.encode(x, adj) 83 | # print(h) 84 | b, n = h.shape 85 | t=1 86 | x=h 87 | h0 = h 88 | return x,t,h0,adj 89 | 90 | 91 | def compute_metrics(self, embeddings, data, split): 92 | raise NotImplementedError 93 | 94 | def init_metric_dict(self): 95 | raise NotImplementedError 96 | 97 | def has_improved(self, m1, m2): 98 | raise NotImplementedError 99 | 100 | 101 | class MLP(nn.Module): 102 | def __init__(self, input_dim, hidden_dim1, hidden_dim2, output_dim): 103 | super(MLP, self).__init__() 104 | input=input_dim+1 105 | self.fc1 = nn.Linear(input, hidden_dim1) 106 | self.fc2 = nn.Linear(hidden_dim1, output_dim) 107 | self.fc3 = nn.Linear(hidden_dim2, output_dim) 108 | self.relu = nn.ReLU() 109 | self.tanh=nn.Tanh() 110 | self.sigmoid = nn.Sigmoid() 111 | 112 | def forward(self, x, t): 113 | # t=[ : ,None] 114 | # result = torch.cat((x, t), dim=2) 115 | x = torch.cat((x, t.unsqueeze(1)), dim=1).float() 116 | #x = self.sigmoid(self.fc1(x)) 117 | x=self.tanh(self.fc1(x)) 118 | #x = self.sigmoid(self.fc2(x)) 119 | x=self.fc2(x) 120 | #x = self.relu(self.fc3(x)) 121 | #x=self.fc3(x) 122 | return x 123 | 124 | class gcndec(nn.Module): 125 | def __init__(self,input_dim ): 126 | super(gcndec, self).__init__() 127 | self.conv1 = GCNConv(input_dim, 32) 128 | self.conv2 = GCNConv(32,input_dim) 129 | 130 | 131 | def forward(self, x, adj): 132 | x=x.to(torch.float32) 133 | #print(adj.shape) 134 | x=self.conv1(x,adj) 135 | #x=F.relu(x) 136 | x=self.conv2(x,adj) 137 | x=F.relu(x) 138 | return x 139 | 140 | class LPModel(BaseModel): 141 | """ 142 | Base model for link prediction task. 143 | """ 144 | 145 | def __init__(self, args): 146 | super(LPModel, self).__init__(args) 147 | self.dc = FermiDiracDecoder(r=args.r, t=args.t) 148 | self.nb_false_edges = args.nb_false_edges 149 | self.nb_edges = args.nb_edges 150 | self.mlp=MLP(args.dim, args.hid1, args.hid2, args.dim) 151 | self.gcn=gcndec(args.dim) 152 | def decode(self, h, idx): 153 | if self.manifold_name == 'Euclidean': 154 | h = self.manifold.normalize(h) 155 | emb_in = h[idx[:, 0], :] 156 | emb_out = h[idx[:, 1], :] 157 | sqdist = self.manifold.sqdist(emb_in, emb_out, self.c) 158 | probs = self.dc.forward(sqdist) 159 | return probs 160 | def sample(self,x,adj): 161 | img=torch.randn_like(x) 162 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 163 | batch=len(img) 164 | t = torch.full((batch,), 1000, device=device).long() 165 | h=self.gcn(img,adj) 166 | 167 | return h 168 | 169 | def compute_metrics(self, embeddings, data,t,h0, adj,split): 170 | if split == 'train' or split == 'all': 171 | sample_edges_false = data[f'{split}_edges_false'] 172 | edges_false = sample_edges_false[np.random.randint(0, len(data[f'{split}_edges_false']), self.nb_edges)] 173 | else: 174 | edges_false = data[f'{split}_edges_false'] 175 | 176 | # print('------') 177 | pos_scores = self.decode(h0, data[f'{split}_edges']) 178 | neg_scores = self.decode(h0, edges_false) 179 | 180 | # print('------') 181 | 182 | loss = F.binary_cross_entropy(pos_scores, torch.ones_like(pos_scores)) 183 | 184 | loss += F.binary_cross_entropy(neg_scores, torch.zeros_like(neg_scores)) 185 | #loss=loss+loss_diff 186 | # print(loss) 187 | #print(loss_diff) 188 | if pos_scores.is_cuda: 189 | pos_scores = pos_scores.cpu() 190 | neg_scores = neg_scores.cpu() 191 | labels = [1] * pos_scores.shape[0] + [0] * neg_scores.shape[0] 192 | preds = list(pos_scores.data.numpy()) + list(neg_scores.data.numpy()) 193 | roc = roc_auc_score(labels, preds) 194 | ap = average_precision_score(labels, preds) 195 | 196 | losses = loss 197 | metrics = {'loss': losses, 'roc': roc, 'ap': ap} 198 | return metrics 199 | 200 | def init_metric_dict(self): 201 | return {'roc': -1, 'ap': -1} 202 | 203 | def has_improved(self, m1, m2): 204 | return 0.5 * (m1['roc'] + m1['ap']) < 0.5 * (m2['roc'] + m2['ap']) 205 | 206 | -------------------------------------------------------------------------------- /GlobalProperties.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class kernel(torch.nn.Module): 4 | """ 5 | this class return a list of kernel ordered by keywords in kernel_type 6 | """ 7 | def __init__(self, **ker): 8 | """ 9 | :param ker: 10 | kernel_type; a list of string which determine needed kernels 11 | """ 12 | self.device = ker.get("device") 13 | super(kernel, self).__init__() 14 | self.kernel_type = ker.get("kernel_type") 15 | kernel_set = set(self.kernel_type) 16 | 17 | if "in_degree_dist" in kernel_set or "out_degree_dist" in kernel_set: 18 | self.degree_hist = Histogram(self.device, ker.get("degree_bin_width").to(self.device), ker.get("degree_bin_center").to(self.device)) 19 | 20 | if "RPF" in kernel_set: 21 | self.num_of_steps = ker.get("step_num") 22 | self.hist = Histogram(self.device, ker.get("bin_width"), ker.get("bin_center")) 23 | 24 | if "trans_matrix" in kernel_set: 25 | self.num_of_steps = ker.get("step_num") 26 | 27 | 28 | 29 | def forward(self,adj): 30 | vec = self.kernel_function(adj) 31 | # return self.hist(vec) 32 | return vec 33 | 34 | def kernel_function(self, adj): # TODO: another var for keeping the number of moments 35 | # ToDo: here we assumed the matrix is symetrix(undirected) which might not 36 | vec = [] # feature vector 37 | for kernel in self.kernel_type: 38 | if "TotalNumberOfTriangles" == kernel: 39 | vec.append(self.TotalNumberOfTriangles(adj)) 40 | if "in_degree_dist" == kernel: 41 | degree_hit = [] 42 | for i in range(adj.shape[0]): 43 | # degree = adj[i, subgraph_indexes[i]][:, subgraph_indexes[i]].sum(1).view(1, -1) 44 | degree = adj[i].sum(1).view(1, -1) 45 | degree_hit.append(self.degree_hist(degree.to(self.device))) 46 | vec.append(torch.cat(degree_hit)) 47 | if "out_degree_dist" == kernel: 48 | degree_hit = [] 49 | for i in range(adj.shape[0]): 50 | degree = adj[i].sum(0).view(1, -1) 51 | degree_hit.append(self.degree_hist(degree)) 52 | vec.append(torch.cat(degree_hit)) 53 | if "RPF" == kernel: 54 | raise("should be changed") #ToDo: need to be fixed 55 | tr_p = self.S_step_trasition_probablity(adj, self.num_of_steps) 56 | for i in range(len(tr_p)): 57 | vec.append(self.hist(torch.diag(tr_p[i]))) 58 | 59 | if "trans_matrix" == kernel: 60 | vec.extend(self.S_step_trasition_probablity(adj, self.num_of_steps)) 61 | # vec = torch.cat(vec,1) 62 | 63 | if "tri" == kernel: # compare the nodes degree in the given order 64 | tri, square = self.tri_square_count(adj) 65 | vec.append(tri), vec.append(square) 66 | 67 | if "TrianglesOfEachNode" == kernel: # this kernel returns a verctor, element i of this vector is the number of triangeles which are centered at node i 68 | vec.append(self.TrianglesOfEachNode(adj)) 69 | 70 | if "ThreeStepPath" == kernel: 71 | vec.append(self.TreeStepPathes(adj)) 72 | return vec 73 | 74 | def tri_square_count(self, adj): 75 | ind = torch.eye(adj[0].shape[0]).to(self.device) 76 | adj = adj - ind 77 | two__ = torch.matmul(adj, adj) 78 | tri_ = torch.matmul(two__, adj) 79 | squares = torch.matmul(two__, two__) 80 | return (torch.diagonal(tri_, dim1=1, dim2=2), torch.diagonal(squares, dim1=1, dim2=2)) 81 | 82 | def S_step_trasition_probablity(self, adj, s=4, dataset_scale=None ): 83 | """ 84 | this method take an adjacency matrix and return its j0: 109 | # TP_list.append(torch.matmul(p1,p1)) 110 | TP_list.append( p1) 111 | for i in range(s-1): 112 | TP_list.append(torch.matmul(p1, TP_list[-1] )) 113 | return TP_list 114 | 115 | def TrianglesOfEachNode(self, adj, ): 116 | """ 117 | this method take an adjacency matrix and count the number of triangles centered at each node; this method return a vector for each graph 118 | """ 119 | 120 | p1 = adj.to(self.device) 121 | p1 = p1 * (1 - torch.eye(adj.shape[-1], adj.shape[-1])).to(self.device) 122 | 123 | # to save memory Use ineficient loop 124 | tri = torch.diagonal(torch.matmul(p1, torch.matmul(p1, p1)),dim1=-2, dim2=-1)/6 125 | return tri 126 | 127 | def TreeStepPathes(self, adj, ): 128 | """ 129 | this method take an adjacency matrix and count the number of pathes between each two node with lenght 3; this method return a matrix for each graph 130 | """ 131 | 132 | p1 = adj.to(self.device) 133 | p1 = p1 * (1 - torch.eye(adj.shape[-1], adj.shape[-1])).to(self.device) 134 | 135 | # to save memory Use ineficient loop 136 | # tri = torch.diagonal(torch.matmul(p1, torch.matmul(p1, p1)),dim1=-2, dim2=-1)/6 137 | tri = torch.matmul(p1, torch.matmul(p1, p1)) 138 | return tri 139 | 140 | def TotalNumberOfTriangles(self, adj): 141 | """ 142 | this method take an adjacency matrix and count the number of triangles in it the corresponding graph 143 | """ 144 | p1 = adj.to(self.device) 145 | p1 = p1 * (1 - torch.eye(adj.shape[-1], adj.shape[-1])).to(self.device) 146 | 147 | # to save memory Use ineficient loop 148 | tri = torch.diagonal(torch.matmul(p1, torch.matmul(p1, p1)),dim1=-2, dim2=-1)/6 149 | return tri.sum(-1) 150 | 151 | class Histogram(torch.nn.Module): 152 | # this is a soft histograam Function. 153 | #for deails check section "3.2. The Learnable Histogram Layer" of 154 | # "Learnable Histogram: Statistical Context Features for Deep Neural Networks" 155 | def __init__(self, device, bin_width = None, bin_centers = None): 156 | super(Histogram, self).__init__() 157 | self.device = device 158 | self.bin_width = bin_width.to(self.device) 159 | self.bin_center = bin_centers.to(self.device) 160 | if self.bin_width == None: 161 | self.prism() 162 | else: 163 | self.bin_num = self.bin_width.shape[0] 164 | 165 | def forward(self, vec): 166 | #REceive a vector and return the soft histogram 167 | 168 | #comparing each element with each of the bin center 169 | score_vec = vec.view(vec.shape[0],1, vec.shape[1], ) - self.bin_center 170 | # score_vec = vec-self.bin_center 171 | score_vec = 1-torch.abs(score_vec)*self.bin_width 172 | score_vec = torch.relu(score_vec) 173 | return score_vec.sum(2) 174 | 175 | def prism(self): 176 | pass 177 | -------------------------------------------------------------------------------- /hyperbolic_learning/hyperbolic_kmeans/models/enron_vectors: -------------------------------------------------------------------------------- 1 | 184 2 2 | 178 -0.1984460592553305 0.20025020743308602 3 | 63 -0.8768459006564906 0.29556627908296873 4 | 58 -0.7299118049269566 0.16657129640387844 5 | 169 -0.7051791844195993 -0.7002918631029453 6 | 146 -0.6055346913972595 0.2706896090945791 7 | 155 -0.7064508603396817 -0.7016183546486351 8 | 163 -0.5975938196436054 0.2447475146763857 9 | 82 0.25872055317872267 0.0012221461779773936 10 | 114 -0.7000061458994081 -0.6925091267501868 11 | 126 0.6304807994985938 -0.72538077321905 12 | 107 -0.026034350042001664 0.24362036768399034 13 | 27 -0.15055489685931822 0.2340519018858813 14 | 17 0.9665462652694211 -0.23123225104884604 15 | 118 0.6965449449375939 -0.6947517823429039 16 | 167 0.7288067165866203 0.6679498306770806 17 | 105 0.380225144309162 0.8985371209511979 18 | 33 0.7397009558328966 0.6692955630546585 19 | 162 -0.7062316326194856 -0.7014377773425183 20 | 110 -0.7064859179858186 -0.7015837427593166 21 | 145 -0.76039438860719 -0.20068487052906836 22 | 65 -0.6480395879283742 -0.6334467147327948 23 | 165 -0.7064323836009319 -0.7016707219389255 24 | 158 0.736198556126861 0.6637499348103867 25 | 38 -0.7204384787346725 -0.6007776678272059 26 | 39 0.7040991384103598 -0.6931063560439886 27 | 153 0.4538449433852943 0.7783609788494534 28 | 157 0.8829231348842763 -0.229826749543188 29 | 124 0.7438593300693214 0.6661775896766015 30 | 108 0.7432181103735975 0.6680904706903793 31 | 78 0.8165738152257263 -0.4008738018082216 32 | 92 -0.721753939317935 -0.6787196016918841 33 | 9 -0.07768515593465432 -0.9841464281634166 34 | 66 0.28640033306819396 0.6951570157572056 35 | 50 -0.5319659296015956 -0.8140873868818096 36 | 140 0.5800876286675246 -0.7254901953322157 37 | 128 0.2724781961034251 0.865764472737214 38 | 34 -0.9155806094915181 0.36554808356557195 39 | 6 -0.2825992339599058 -0.6846301051005649 40 | 98 0.7444752886483734 0.6666396293881055 41 | 51 -0.015634720744960237 0.3227236234046733 42 | 41 0.8542724276196729 -0.4675345019539692 43 | 112 -0.6136820104031678 -0.38562199792160007 44 | 74 0.02615696880781361 0.9968126936920976 45 | 29 -0.6085994598871429 -0.7806801315506688 46 | 90 0.7231525165832555 -0.6687253207202188 47 | 141 0.7140408263750713 -0.6927412304539717 48 | 95 0.7415262874257744 0.6643382274596425 49 | 151 0.7438268349934722 0.6630956090496904 50 | 103 0.7429847909791781 0.6658866076815595 51 | 91 -0.0756242307436219 -0.9921673668412403 52 | 115 -0.5160121744274104 -0.8297138894534238 53 | 22 -0.6099509887162402 -0.7775817311537127 54 | 160 -0.6558130768912326 -0.7185223850421865 55 | 132 0.6313707217022275 -0.7238460741783902 56 | 88 0.961722074957934 -0.23020781694244202 57 | 147 0.046978762480358795 0.4355861000136077 58 | 133 0.770947337530867 0.29154706446186096 59 | 94 0.06759389583982814 0.6232194987397858 60 | 62 0.6141880434155943 -0.7148007110979967 61 | 176 0.6241296559732196 -0.7383659796263874 62 | 161 0.7137539842815538 0.6436367380887059 63 | 73 -0.10094960049698963 0.8011448149700179 64 | 93 0.5281507646034013 -0.6631019796531623 65 | 177 0.7453064899534685 0.66476395838847 66 | 148 -0.8645921394101197 0.1834446349383526 67 | 4 0.6424368287595502 -0.18952812381883982 68 | 80 0.8598387074862807 -0.45679647182267896 69 | 36 0.8447219371325186 -0.45636470485426034 70 | 130 0.7376183907430952 -0.6592840996878078 71 | 96 -0.06688689179674781 -0.07576425621476242 72 | 113 0.7448526721099663 0.666786959558637 73 | 136 0.6381359813588515 -0.7328282364274545 74 | 11 -0.7565496736430731 -0.6361154500991819 75 | 61 0.6876191295839719 -0.6793946361723444 76 | 13 -0.07552706018641474 -0.9925142396083999 77 | 89 0.5116015788588462 -0.53651557357767 78 | 26 0.7175817803368142 -0.6902026340620832 79 | 60 -0.6236710855511973 -0.7497258743401157 80 | 57 0.031285593778278636 0.9612039534429198 81 | 32 0.8082193516001687 0.5608097630614676 82 | 47 0.5823200041765539 -0.7036427759188144 83 | 159 0.5538028733995654 -0.366152039514512 84 | 156 0.9608880474312749 -0.23064830825207427 85 | 101 0.8143965416257015 -0.4147160649773038 86 | 56 0.6911760479042031 -0.21337254648899648 87 | 127 0.8247967481884825 -0.40912725997610566 88 | 99 -0.5845657446591559 -0.7767134508062978 89 | 37 -0.1348185288461578 -0.5182541837410765 90 | 23 0.7384309600116042 0.6574514814831182 91 | 120 0.6284013834737201 -0.7284935588937144 92 | 25 0.744043172620397 0.6675426068613411 93 | 21 0.7415071981363179 0.3358825773706985 94 | 172 -0.5028089538753167 -0.7151260755091373 95 | 67 0.10922035470992478 0.4487051785129445 96 | 70 0.9073351115937835 -0.07302264098641502 97 | 134 0.647945336874843 -0.7316958144552046 98 | 123 -0.7171546795615156 -0.5634023730426104 99 | 72 0.5312735318150076 0.812585779656446 100 | 109 -0.9166170851728149 0.25754941092266653 101 | 102 0.6399307809095691 -0.7347702573138541 102 | 30 -0.07250358476681787 -0.9742185631316638 103 | 84 0.8546235130506167 -0.21976037097398252 104 | 59 0.6019082439258644 -0.7459717226553203 105 | 8 0.7447992359177548 0.6669066599733855 106 | 69 0.030300541230837613 0.9736158583459044 107 | 100 0.8027707280833482 0.5797088347309399 108 | 1 0.6790014726077499 -0.16503206392307632 109 | 170 0.7443756924820587 0.6666524652836443 110 | 16 0.4687250886125325 -0.8061910707104645 111 | 116 0.7626771903057379 -0.6038303766702513 112 | 2 0.911433290579809 -0.253211299007134 113 | 5 0.17483753142871677 0.8851465746862941 114 | 85 0.7700781294596543 0.301662371961469 115 | 183 0.7553374980636153 0.3097598762258683 116 | 152 -0.07950523421463689 -0.9799016171436151 117 | 12 -0.8944269932353397 0.3566579764600584 118 | 173 0.8210441169210266 -0.1218277120597871 119 | 46 0.5605094377654453 -0.019234636381405934 120 | 49 0.6614129655872623 -0.26615094770909165 121 | 174 -0.5197384501659923 -0.8358833831656475 122 | 143 0.5965993132778459 -0.7640588334737657 123 | 83 0.21441468638796796 -0.5631823825359735 124 | 48 -0.07935985616417207 -0.9774230573722491 125 | 43 0.7992986818004688 -0.11461894087589082 126 | 14 0.7399379803283387 -0.6631714166529837 127 | 171 -0.5188500244319026 -0.8351678408224611 128 | 81 0.9433813106573411 -0.23011526657855647 129 | 129 0.01901091701342024 -0.6116552089148477 130 | 10 0.8611162457157884 -0.23940421702090064 131 | 20 -0.0784074574283313 -0.971016496284033 132 | 75 0.9383758239372784 -0.2095432365703771 133 | 97 0.9048885561125388 -0.17135622018997354 134 | 154 0.9095894426224456 -0.2633687757053742 135 | 86 0.7985521452938219 0.554621835236311 136 | 53 0.4078552803784103 0.21467119581700642 137 | 139 0.8518690261514159 -0.341495723776223 138 | 104 -0.07551080587528547 -0.9920929288207555 139 | 125 0.7254446444642217 0.41628541938921115 140 | 77 0.8094199737020105 0.5644994725252644 141 | 121 0.8286287315420713 0.51884208321067 142 | 64 0.7847432086794491 0.48982956024005286 143 | 119 -0.031722878796517875 -0.8853397741868338 144 | 40 0.7865570352778919 0.5091947862564561 145 | 166 0.945922867726464 -0.22876053280150993 146 | 149 0.7266505023073662 0.35522054452787344 147 | 55 -0.07428504548165181 -0.982482138285088 148 | 137 0.2603575276389304 -0.8933425611245108 149 | 106 -0.3704552601867352 -0.770566267303802 150 | 7 0.7857314917484888 0.5331231811638022 151 | 180 0.9640803150852842 -0.031222041534763587 152 | 179 0.8640688622346819 -0.25484767535068703 153 | 19 0.5355401108396809 -0.7319959572397163 154 | 164 0.5781385955627504 -0.7339758684252553 155 | 142 -0.07540437870376432 -0.9920606970852331 156 | 28 0.16362604578038253 -0.9685312940027968 157 | 24 0.7647760456352481 -0.575346820360563 158 | 144 0.9318627436267385 -0.264704427977355 159 | 0 -0.08015456584350882 -0.9561731156718968 160 | 54 0.618263345432344 -0.025402547271672096 161 | 181 0.009842145980783099 0.8771639707144948 162 | 15 0.7555267841820941 -0.5133576297292706 163 | 35 0.7933164748287 0.5395131746046482 164 | 3 0.8890376582472552 -0.22753576586955485 165 | 175 0.8546244598670377 -0.11750208242712609 166 | 79 -0.06605393013025693 -0.9471331037834557 167 | 18 0.8182900047666142 0.5046262755802189 168 | 45 0.7988765104446538 -0.3167979010363044 169 | 68 0.7370619343003879 -0.5499844118785548 170 | 150 0.5985667034554907 -0.7858351877173674 171 | 76 0.8228230430156183 0.40193699677173367 172 | 135 0.8626328740941509 -0.5055468565496868 173 | 131 -0.0759718398757799 -0.9853402929951228 174 | 168 -0.44357491220812967 -0.8233122189199789 175 | 138 0.8027206633835832 0.5467449238901679 176 | 182 0.7309203004890116 -0.614574603937921 177 | 52 0.744038232154234 0.5359574318251807 178 | 44 0.8068526156161697 0.5600214269133212 179 | 122 0.7485690729218217 0.6526629998424012 180 | 87 0.23547559409358762 -0.9072741097474325 181 | 111 -0.8849556248279503 -0.3100743229055148 182 | 42 0.1808931267391387 -0.9117614206913168 183 | 31 0.7604298795178338 0.6414131200183597 184 | 71 0.13590763269319187 -0.9883976510957799 185 | 117 0.10608887892485419 -0.9760126081457292 186 | -------------------------------------------------------------------------------- /graph_evaluate/spectre_utils.py: -------------------------------------------------------------------------------- 1 | import concurrent.futures 2 | import numpy as np 3 | import networkx as nx 4 | from dist_helper import compute_mmd, gaussian_emd, gaussian, emd, gaussian_tv, disc 5 | from torch_geometric.utils import to_networkx 6 | from datetime import datetime 7 | from scipy.linalg import eigvalsh 8 | 9 | def degree_worker(G): 10 | return np.array(nx.degree_histogram(G)) 11 | 12 | 13 | def degree_stats(graph_ref_list, graph_pred_list, is_parallel=True, compute_emd=False): 14 | ''' Compute the distance between the degree distributions of two unordered sets of graphs. 15 | Args: 16 | graph_ref_list, graph_target_list: two lists of networkx graphs to be evaluated 17 | ''' 18 | sample_ref = [] 19 | sample_pred = [] 20 | # in case an empty graph is generated 21 | graph_pred_list_remove_empty = [ 22 | G for G in graph_pred_list if not G.number_of_nodes() == 0 23 | ] 24 | 25 | prev = datetime.now() 26 | if is_parallel: 27 | with concurrent.futures.ThreadPoolExecutor() as executor: 28 | for deg_hist in executor.map(degree_worker, graph_ref_list): 29 | sample_ref.append(deg_hist) 30 | with concurrent.futures.ThreadPoolExecutor() as executor: 31 | for deg_hist in executor.map(degree_worker, graph_pred_list_remove_empty): 32 | sample_pred.append(deg_hist) 33 | else: 34 | for i in range(len(graph_ref_list)): 35 | degree_temp = np.array(nx.degree_histogram(graph_ref_list[i])) 36 | sample_ref.append(degree_temp) 37 | for i in range(len(graph_pred_list_remove_empty)): 38 | degree_temp = np.array( 39 | nx.degree_histogram(graph_pred_list_remove_empty[i])) 40 | sample_pred.append(degree_temp) 41 | 42 | # mmd_dist = compute_mmd(sample_ref, sample_pred, kernel=gaussian_emd) 43 | # mmd_dist = compute_mmd(sample_ref, sample_pred, kernel=emd) 44 | if compute_emd: 45 | # EMD option uses the same computation as GraphRNN, the alternative is MMD as computed by GRAN 46 | # mmd_dist = compute_mmd(sample_ref, sample_pred, kernel=emd) 47 | mmd_dist = compute_mmd(sample_ref, sample_pred, kernel=gaussian_emd) 48 | else: 49 | mmd_dist = compute_mmd(sample_ref, sample_pred, kernel=gaussian_tv) 50 | # mmd_dist = compute_mmd(sample_ref, sample_pred, kernel=gaussian) 51 | 52 | elapsed = datetime.now() - prev 53 | # if PRINT_TIME: 54 | # print('Time computing degree mmd: ', elapsed) 55 | return mmd_dist 56 | 57 | def spectral_worker(G, n_eigvals=-1): 58 | # eigs = nx.laplacian_spectrum(G) 59 | try: 60 | eigs = eigvalsh(nx.normalized_laplacian_matrix(G).todense()) 61 | except: 62 | eigs = np.zeros(G.number_of_nodes()) 63 | if n_eigvals > 0: 64 | eigs = eigs[1:n_eigvals + 1] 65 | spectral_pmf, _ = np.histogram(eigs, bins=200, range=(-1e-5, 2), density=False) 66 | spectral_pmf = spectral_pmf / spectral_pmf.sum() 67 | return spectral_pmf 68 | 69 | def spectral_stats(graph_ref_list, graph_pred_list, is_parallel=True, n_eigvals=-1, compute_emd=False): 70 | ''' Compute the distance between the degree distributions of two unordered sets of graphs. 71 | Args: 72 | graph_ref_list, graph_target_list: two lists of networkx graphs to be evaluated 73 | ''' 74 | sample_ref = [] 75 | sample_pred = [] 76 | # in case an empty graph is generated 77 | graph_pred_list_remove_empty = [ 78 | G for G in graph_pred_list if not G.number_of_nodes() == 0 79 | ] 80 | 81 | prev = datetime.now() 82 | if is_parallel: 83 | with concurrent.futures.ThreadPoolExecutor() as executor: 84 | for spectral_density in executor.map(spectral_worker, graph_ref_list, [n_eigvals for i in graph_ref_list]): 85 | sample_ref.append(spectral_density) 86 | with concurrent.futures.ThreadPoolExecutor() as executor: 87 | for spectral_density in executor.map(spectral_worker, graph_pred_list_remove_empty, 88 | [n_eigvals for i in graph_ref_list]): 89 | sample_pred.append(spectral_density) 90 | else: 91 | for i in range(len(graph_ref_list)): 92 | spectral_temp = spectral_worker(graph_ref_list[i], n_eigvals) 93 | sample_ref.append(spectral_temp) 94 | for i in range(len(graph_pred_list_remove_empty)): 95 | spectral_temp = spectral_worker(graph_pred_list_remove_empty[i], n_eigvals) 96 | sample_pred.append(spectral_temp) 97 | 98 | # mmd_dist = compute_mmd(sample_ref, sample_pred, kernel=gaussian_emd) 99 | # mmd_dist = compute_mmd(sample_ref, sample_pred, kernel=emd) 100 | if compute_emd: 101 | # EMD option uses the same computation as GraphRNN, the alternative is MMD as computed by GRAN 102 | # mmd_dist = compute_mmd(sample_ref, sample_pred, kernel=emd) 103 | mmd_dist = compute_mmd(sample_ref, sample_pred, kernel=gaussian_emd) 104 | else: 105 | mmd_dist = compute_mmd(sample_ref, sample_pred, kernel=gaussian_tv) 106 | # mmd_dist = compute_mmd(sample_ref, sample_pred, kernel=gaussian) 107 | 108 | elapsed = datetime.now() - prev 109 | # if PRINT_TIME: 110 | # print('Time computing degree mmd: ', elapsed) 111 | return mmd_dist 112 | 113 | def clustering_worker(param): 114 | G, bins = param 115 | clustering_coeffs_list = list(nx.clustering(G).values()) 116 | hist, _ = np.histogram( 117 | clustering_coeffs_list, bins=bins, range=(0.0, 1.0), density=False) 118 | return hist 119 | 120 | def clustering_stats(graph_ref_list, 121 | graph_pred_list, 122 | bins=100, 123 | is_parallel=True, compute_emd=False): 124 | sample_ref = [] 125 | sample_pred = [] 126 | graph_pred_list_remove_empty = [ 127 | G for G in graph_pred_list if not G.number_of_nodes() == 0 128 | ] 129 | 130 | prev = datetime.now() 131 | if is_parallel: 132 | with concurrent.futures.ThreadPoolExecutor() as executor: 133 | for clustering_hist in executor.map(clustering_worker, 134 | [(G, bins) for G in graph_ref_list]): 135 | sample_ref.append(clustering_hist) 136 | with concurrent.futures.ThreadPoolExecutor() as executor: 137 | for clustering_hist in executor.map( 138 | clustering_worker, [(G, bins) for G in graph_pred_list_remove_empty]): 139 | sample_pred.append(clustering_hist) 140 | 141 | # check non-zero elements in hist 142 | # total = 0 143 | # for i in range(len(sample_pred)): 144 | # nz = np.nonzero(sample_pred[i])[0].shape[0] 145 | # total += nz 146 | # print(total) 147 | else: 148 | for i in range(len(graph_ref_list)): 149 | clustering_coeffs_list = list(nx.clustering(graph_ref_list[i]).values()) 150 | hist, _ = np.histogram( 151 | clustering_coeffs_list, bins=bins, range=(0.0, 1.0), density=False) 152 | sample_ref.append(hist) 153 | 154 | for i in range(len(graph_pred_list_remove_empty)): 155 | clustering_coeffs_list = list( 156 | nx.clustering(graph_pred_list_remove_empty[i]).values()) 157 | hist, _ = np.histogram( 158 | clustering_coeffs_list, bins=bins, range=(0.0, 1.0), density=False) 159 | sample_pred.append(hist) 160 | 161 | if compute_emd: 162 | # EMD option uses the same computation as GraphRNN, the alternative is MMD as computed by GRAN 163 | # mmd_dist = compute_mmd(sample_ref, sample_pred, kernel=emd, sigma=1.0 / 10) 164 | mmd_dist = compute_mmd(sample_ref, sample_pred, kernel=gaussian_emd, sigma=1.0 / 10, distance_scaling=bins) 165 | else: 166 | mmd_dist = compute_mmd(sample_ref, sample_pred, kernel=gaussian_tv, sigma=1.0 / 10) 167 | 168 | elapsed = datetime.now() - prev 169 | # if PRINT_TIME: 170 | # print('Time computing clustering mmd: ', elapsed) 171 | return mmd_dist 172 | 173 | def new_compute(reference_graphs, networkx_graphs): 174 | print("Computing degree stats..") 175 | degree = degree_stats(reference_graphs, networkx_graphs, is_parallel=True, 176 | compute_emd=True) 177 | print('Degree value:', degree) 178 | 179 | 180 | print("Computing spectre stats...") 181 | spectre = spectral_stats(reference_graphs, networkx_graphs, is_parallel=True, n_eigvals=-1, 182 | compute_emd=True) 183 | print('Spectre value:', spectre) 184 | 185 | 186 | print("Computing clustering stats...") 187 | clustering = clustering_stats(reference_graphs, networkx_graphs, bins=100, is_parallel=True, 188 | compute_emd=True) 189 | print('Cluster value:', clustering) --------------------------------------------------------------------------------