├── .gitignore ├── LICENSE ├── README.md ├── README.rst ├── TF_GEOMETRIC_LOGO.png ├── benchmarks └── node_classification │ ├── bench_node_cls_early_stop_appnp.py │ ├── bench_node_cls_early_stop_gat.py │ ├── bench_node_cls_early_stop_gcn.py │ ├── bench_node_cls_early_stop_sgc.py │ ├── bench_node_cls_early_stop_ssgc.py │ ├── bench_report_results.py │ └── run_multi_times.sh ├── demo ├── demo_appnp.py ├── demo_asap.py ├── demo_chebynet.py ├── demo_checkpoint.py ├── demo_dgi.py ├── demo_diff_pool.py ├── demo_distributed_gcn.py ├── demo_distributed_mean_pool.py ├── demo_drop_edge_gcn.py ├── demo_elegant_api.py ├── demo_gae.py ├── demo_gat.py ├── demo_gcn.py ├── demo_gin.py ├── demo_graph_sage.py ├── demo_graph_sage_func.py ├── demo_mean_pool.py ├── demo_min_cut_pool.py ├── demo_model_net_dataset.py ├── demo_sag_pool_h.py ├── demo_sample_neighbors.py ├── demo_save_and_load_model.py ├── demo_set2set.py ├── demo_sgc.py ├── demo_sort_pool.py ├── demo_sparse_node_features.py ├── demo_ssgc.py ├── demo_tagcn.py └── demo_topk_pool.py ├── deploy.sh ├── doc ├── Makefile ├── build.sh ├── make.bat ├── requirements.txt ├── source │ ├── conf.py │ ├── index.rst │ ├── index_cn.rst │ ├── modules │ │ ├── datasets.rst │ │ ├── layers.rst │ │ ├── nn.rst │ │ ├── root.rst │ │ └── utils │ │ │ └── graph_utils.rst │ ├── wiki │ │ ├── installation.rst │ │ └── quickstart.rst │ └── wiki_cn │ │ ├── installation.rst │ │ └── quickstart.rst └── test.sh ├── readthedocs.yaml ├── setup.py ├── tf_geometric ├── __init__.py ├── data │ ├── __init__.py │ ├── dataset.py │ └── graph.py ├── datasets │ ├── __init__.py │ ├── abnormal.py │ ├── amazon_electronics.py │ ├── blog_catalog.py │ ├── coauthor.py │ ├── csr_npz.py │ ├── hgb.py │ ├── model_net.py │ ├── nars_academic.py │ ├── ogb.py │ ├── planetoid.py │ ├── ppi.py │ ├── reddit.py │ ├── synthetic.py │ └── tu.py ├── layers │ ├── __init__.py │ ├── conv │ │ ├── __init__.py │ │ ├── appnp.py │ │ ├── chebynet.py │ │ ├── gat.py │ │ ├── gcn.py │ │ ├── gin.py │ │ ├── graph_sage.py │ │ ├── le_conv.py │ │ ├── sgc.py │ │ ├── ssgc.py │ │ └── tagcn.py │ ├── kernel │ │ ├── __init__.py │ │ └── map_reduce.py │ ├── pool │ │ ├── __init__.py │ │ ├── asap.py │ │ ├── common_pool.py │ │ ├── diff_pool.py │ │ ├── min_cut_pool.py │ │ ├── sag_pool.py │ │ ├── set2set.py │ │ └── sort_pool.py │ └── sampling │ │ ├── __init__.py │ │ └── drop_edge.py ├── nn │ ├── __init__.py │ ├── conv │ │ ├── __init__.py │ │ ├── appnp.py │ │ ├── chebynet.py │ │ ├── gat.py │ │ ├── gcn.py │ │ ├── gin.py │ │ ├── graph_sage.py │ │ ├── le_conv.py │ │ ├── sgc.py │ │ ├── ssgc.py │ │ └── tagcn.py │ ├── kernel │ │ ├── __init__.py │ │ ├── map_reduce.py │ │ └── segment.py │ ├── pool │ │ ├── __init__.py │ │ ├── asap.py │ │ ├── cluster_pool.py │ │ ├── common_pool.py │ │ ├── diff_pool.py │ │ ├── min_cut_pool.py │ │ ├── sag_pool.py │ │ ├── set2set.py │ │ ├── sort_pool.py │ │ └── topk_pool.py │ └── sampling │ │ ├── __init__.py │ │ └── drop_edge.py └── utils │ ├── __init__.py │ ├── data_utils.py │ ├── graph_utils.py │ ├── tf_sparse_utils.py │ ├── tf_utils.py │ └── union_utils.py └── tutorial_intro.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | *.zip 3 | /test/ 4 | /data/ 5 | /figures/ 6 | /demo/data/ 7 | /demo/*/data/ 8 | /test/keras/ 9 | /test/keras/data/ 10 | **/__pycache__/ 11 | **/.idea/ 12 | __pycache__/ 13 | 14 | /doc/build/ 15 | /docs/build/ 16 | /logs/ 17 | /models/ 18 | /demo/logs/ 19 | /demo/models/ 20 | 21 | benchmarks/**/data/ 22 | benchmarks/**/results.txt 23 | 24 | *.swp 25 | 26 | *.py[cod] 27 | *$py.class 28 | 29 | # C extensions 30 | *.so 31 | 32 | # Distribution / packaging 33 | .Python 34 | build/ 35 | develop-eggs/ 36 | dist/ 37 | downloads/ 38 | eggs/ 39 | .eggs/ 40 | lib/ 41 | lib64/ 42 | parts/ 43 | sdist/ 44 | var/ 45 | wheels/ 46 | pip-wheel-metadata/ 47 | share/python-wheels/ 48 | *.egg-info/ 49 | .installed.cfg 50 | *.egg 51 | MANIFEST 52 | 53 | # PyInstaller 54 | # Usually these files are written by a python script from a template 55 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 56 | *.manifest 57 | *.spec 58 | 59 | # Installer logs 60 | pip-log.txt 61 | pip-delete-this-directory.txt 62 | 63 | # Unit test / coverage reports 64 | htmlcov/ 65 | .tox/ 66 | .nox/ 67 | .coverage 68 | .coverage.* 69 | .cache 70 | nosetests.xml 71 | coverage.xml 72 | *.cover 73 | *.py,cover 74 | .hypothesis/ 75 | .pytest_cache/ 76 | 77 | # Translations 78 | *.mo 79 | *.pot 80 | 81 | # Django stuff: 82 | *.log 83 | local_settings.py 84 | db.sqlite3 85 | db.sqlite3-journal 86 | 87 | # Flask stuff: 88 | instance/ 89 | .webassets-cache 90 | 91 | # Scrapy stuff: 92 | .scrapy 93 | 94 | # Sphinx documentation 95 | docs/_build/ 96 | 97 | # PyBuilder 98 | target/ 99 | 100 | # Jupyter Notebook 101 | .ipynb_checkpoints 102 | 103 | # IPython 104 | profile_default/ 105 | ipython_config.py 106 | 107 | # pyenv 108 | .python-version 109 | 110 | # pipenv 111 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 112 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 113 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 114 | # install all needed dependencies. 115 | #Pipfile.lock 116 | 117 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 118 | __pypackages__/ 119 | 120 | # Celery stuff 121 | celerybeat-schedule 122 | celerybeat.pid 123 | 124 | # SageMath parsed files 125 | *.sage.py 126 | 127 | # Environments 128 | .env 129 | .venv 130 | env/ 131 | venv/ 132 | ENV/ 133 | env.bak/ 134 | venv.bak/ 135 | 136 | # Spyder project settings 137 | .spyderproject 138 | .spyproject 139 | 140 | # Rope project settings 141 | .ropeproject 142 | 143 | # mkdocs documentation 144 | /site 145 | 146 | # mypy 147 | .mypy_cache/ 148 | .dmypy.json 149 | dmypy.json 150 | 151 | # Pyre type checker 152 | .pyre/ 153 | -------------------------------------------------------------------------------- /TF_GEOMETRIC_LOGO.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CrawlScript/tf_geometric/b7c40d9005c60b27a7e18b919d32418c5548252e/TF_GEOMETRIC_LOGO.png -------------------------------------------------------------------------------- /benchmarks/node_classification/bench_node_cls_early_stop_appnp.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | import os 3 | 4 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 5 | from tf_geometric.utils import tf_utils 6 | import tf_geometric as tfg 7 | import tensorflow as tf 8 | import time 9 | from tqdm import tqdm 10 | import numpy as np 11 | 12 | dataset = "cora" 13 | # dataset = "citeseer" 14 | # dataset = "pubmed" 15 | 16 | graph, (train_index, valid_index, test_index) = tfg.datasets.PlanetoidDataset(dataset).load_data() 17 | 18 | 19 | num_steps = 401 20 | patience = 100 21 | 22 | 23 | num_classes = graph.y.max() + 1 24 | drop_rate = 0.5 25 | learning_rate = 1e-2 26 | # l2_coef = 5e-4 27 | l2_coef = 1e-3 28 | 29 | if dataset == "pubmed": 30 | l2_coef = 3e-3 31 | num_steps = 201 32 | 33 | model = tfg.layers.APPNP([64, num_classes], alpha=0.1, k=10, 34 | dense_drop_rate=drop_rate, edge_drop_rate=drop_rate) 35 | 36 | model.build_cache_for_graph(graph) 37 | 38 | 39 | # @tf_utils.function can speed up functions for TensorFlow 2.x 40 | @tf_utils.function 41 | def forward(graph, training=False): 42 | return model([graph.x, graph.edge_index, graph.edge_weight], training=training, cache=graph.cache) 43 | 44 | 45 | @tf_utils.function 46 | def compute_loss(logits, mask_index, vars): 47 | masked_logits = tf.gather(logits, mask_index) 48 | masked_labels = tf.gather(graph.y, mask_index) 49 | losses = tf.nn.softmax_cross_entropy_with_logits( 50 | logits=masked_logits, 51 | labels=tf.one_hot(masked_labels, depth=num_classes) 52 | ) 53 | 54 | kernel_vals = [var for var in vars if "kernel" in var.name] 55 | l2_losses = [tf.nn.l2_loss(kernel_var) for kernel_var in kernel_vals] 56 | cls_loss = tf.reduce_mean(losses) 57 | l2_loss = tf.add_n(l2_losses) 58 | return cls_loss + l2_loss * l2_coef, cls_loss, l2_loss 59 | 60 | 61 | @tf_utils.function 62 | def evaluate(current_test_index): 63 | with tf.GradientTape() as tape: 64 | logits = forward(graph) 65 | loss = compute_loss(logits, current_test_index, tape.watched_variables()) 66 | masked_logits = tf.gather(logits, current_test_index) 67 | masked_labels = tf.gather(graph.y, current_test_index) 68 | y_pred = tf.argmax(masked_logits, axis=-1, output_type=tf.int32) 69 | 70 | corrects = tf.equal(y_pred, masked_labels) 71 | accuracy = tf.reduce_mean(tf.cast(corrects, tf.float32)) 72 | return accuracy, loss 73 | 74 | 75 | @tf_utils.function 76 | def evaluate_test(): 77 | return evaluate(test_index) 78 | 79 | 80 | @tf_utils.function 81 | def evaluate_val(): 82 | return evaluate(valid_index) 83 | 84 | 85 | optimizer = tf.keras.optimizers.Adam(learning_rate=5e-3) 86 | 87 | 88 | @tf_utils.function 89 | def train_step(): 90 | with tf.GradientTape() as tape: 91 | logits = forward(graph, training=True) 92 | loss, _, _ = compute_loss(logits, train_index, tape.watched_variables()) 93 | 94 | vars = tape.watched_variables() 95 | grads = tape.gradient(loss, vars) 96 | optimizer.apply_gradients(zip(grads, vars)) 97 | return loss 98 | 99 | 100 | val_accuracy_list = [] 101 | test_accuracy_list = [] 102 | loss_list = [] 103 | 104 | best_val_accuracy = 0 105 | min_val_loss = 1000 106 | 107 | final_test_accuracy = None 108 | final_step = None 109 | 110 | patience_counter = 0 111 | 112 | for step in range(1, num_steps): 113 | 114 | loss = train_step() 115 | 116 | if step % 1 == 0: 117 | test_accuracy, _ = evaluate_test() 118 | val_accuracy, (_, val_loss, _) = evaluate_val() 119 | 120 | val_accuracy = val_accuracy.numpy() 121 | val_loss = val_loss.numpy() 122 | 123 | if val_accuracy > best_val_accuracy or val_loss < min_val_loss: 124 | patience_counter = 0 125 | else: 126 | patience_counter += 1 127 | if patience_counter > patience: 128 | break 129 | 130 | # if val_accuracy > best_val_accuracy and val_loss < min_val_loss: 131 | if val_accuracy > best_val_accuracy and val_loss < min_val_loss: 132 | final_test_accuracy = test_accuracy 133 | final_step = step 134 | 135 | best_val_accuracy = val_accuracy 136 | min_val_loss = val_loss 137 | 138 | val_accuracy_list.append(val_accuracy) 139 | test_accuracy_list.append(test_accuracy) 140 | loss_list.append(val_loss) 141 | 142 | print( 143 | "step = {}\tloss = {:.4f}\tval_accuracy = {:.4f}\tval_loss = {:.4f}\t" 144 | "test_accuracy = {:.4f}\tfinal_test_accuracy = {:.4f}\tfinal_step = {}" 145 | .format(step, loss, val_accuracy, val_loss, test_accuracy, final_test_accuracy, final_step)) 146 | print("patience_counter = {}".format(patience_counter)) 147 | 148 | print("final accuracy: {}\tfinal_step: {}".format(final_test_accuracy, final_step)) 149 | 150 | with open("results.txt", "a", encoding="utf-8") as f: 151 | f.write("{}\n".format(final_test_accuracy)) 152 | -------------------------------------------------------------------------------- /benchmarks/node_classification/bench_node_cls_early_stop_sgc.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | import os 3 | 4 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 5 | from tf_geometric.utils import tf_utils 6 | import tf_geometric as tfg 7 | import tensorflow as tf 8 | import time 9 | from tqdm import tqdm 10 | import numpy as np 11 | 12 | dataset = "cora" 13 | # dataset = "citeseer" 14 | # dataset = "pubmed" 15 | 16 | graph, (train_index, valid_index, test_index) = tfg.datasets.PlanetoidDataset(dataset).load_data() 17 | 18 | num_classes = graph.y.max() + 1 19 | 20 | learning_rate = 0.2 21 | l2_coef = 5e-6 22 | num_steps = 201 23 | patience = 100 24 | 25 | if dataset == "citeseer": 26 | l2_coef = 1e-4 27 | elif dataset == "pubmed": 28 | l2_coef = 5e-5 29 | num_steps = 61 30 | 31 | model = tfg.layers.SGC(num_classes, k=2) 32 | model.build_cache_for_graph(graph) 33 | 34 | 35 | # @tf_utils.function can speed up functions for TensorFlow 2.x 36 | @tf_utils.function 37 | def forward(graph, training=False): 38 | return model([graph.x, graph.edge_index, graph.edge_weight], training=training, cache=graph.cache) 39 | 40 | 41 | @tf_utils.function 42 | def compute_loss(logits, mask_index, vars): 43 | masked_logits = tf.gather(logits, mask_index) 44 | masked_labels = tf.gather(graph.y, mask_index) 45 | losses = tf.nn.softmax_cross_entropy_with_logits( 46 | logits=masked_logits, 47 | labels=tf.one_hot(masked_labels, depth=num_classes) 48 | ) 49 | 50 | kernel_vals = [var for var in vars if "kernel" in var.name] 51 | l2_losses = [tf.nn.l2_loss(kernel_var) for kernel_var in kernel_vals] 52 | cls_loss = tf.reduce_mean(losses) 53 | l2_loss = tf.add_n(l2_losses) 54 | return cls_loss + l2_loss * l2_coef, cls_loss, l2_loss 55 | 56 | 57 | @tf_utils.function 58 | def evaluate(current_test_index): 59 | with tf.GradientTape() as tape: 60 | logits = forward(graph) 61 | loss = compute_loss(logits, current_test_index, tape.watched_variables()) 62 | masked_logits = tf.gather(logits, current_test_index) 63 | masked_labels = tf.gather(graph.y, current_test_index) 64 | y_pred = tf.argmax(masked_logits, axis=-1, output_type=tf.int32) 65 | 66 | corrects = tf.equal(y_pred, masked_labels) 67 | accuracy = tf.reduce_mean(tf.cast(corrects, tf.float32)) 68 | return accuracy, loss 69 | 70 | 71 | @tf_utils.function 72 | def evaluate_test(): 73 | return evaluate(test_index) 74 | 75 | 76 | @tf_utils.function 77 | def evaluate_val(): 78 | return evaluate(valid_index) 79 | 80 | 81 | optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate) 82 | 83 | 84 | @tf_utils.function 85 | def train_step(): 86 | with tf.GradientTape() as tape: 87 | logits = forward(graph, training=True) 88 | loss, _, _ = compute_loss(logits, train_index, tape.watched_variables()) 89 | 90 | vars = tape.watched_variables() 91 | grads = tape.gradient(loss, vars) 92 | optimizer.apply_gradients(zip(grads, vars)) 93 | return loss 94 | 95 | 96 | val_accuracy_list = [] 97 | test_accuracy_list = [] 98 | loss_list = [] 99 | 100 | best_val_accuracy = 0 101 | min_val_loss = 1000 102 | 103 | final_test_accuracy = None 104 | final_step = None 105 | 106 | patience_counter = 0 107 | 108 | for step in range(1, num_steps): 109 | 110 | loss = train_step() 111 | 112 | if step % 1 == 0: 113 | test_accuracy, _ = evaluate_test() 114 | val_accuracy, (_, val_loss, _) = evaluate_val() 115 | 116 | val_accuracy = val_accuracy.numpy() 117 | val_loss = val_loss.numpy() 118 | 119 | if val_accuracy > best_val_accuracy or val_loss < min_val_loss: 120 | patience_counter = 0 121 | else: 122 | patience_counter += 1 123 | if patience_counter > patience: 124 | break 125 | 126 | # if val_accuracy > best_val_accuracy and val_loss < min_val_loss: 127 | if val_accuracy > best_val_accuracy and val_loss < min_val_loss: 128 | final_test_accuracy = test_accuracy 129 | final_step = step 130 | 131 | best_val_accuracy = val_accuracy 132 | min_val_loss = val_loss 133 | 134 | val_accuracy_list.append(val_accuracy) 135 | test_accuracy_list.append(test_accuracy) 136 | loss_list.append(val_loss) 137 | 138 | print( 139 | "step = {}\tloss = {:.4f}\tval_accuracy = {:.4f}\tval_loss = {:.4f}\t" 140 | "test_accuracy = {:.4f}\tfinal_test_accuracy = {:.4f}\tfinal_step = {}" 141 | .format(step, loss, val_accuracy, val_loss, test_accuracy, final_test_accuracy, final_step)) 142 | print("patience_counter = {}".format(patience_counter)) 143 | 144 | print("final accuracy: {}\tfinal_step: {}".format(final_test_accuracy, final_step)) 145 | 146 | with open("results.txt", "a", encoding="utf-8") as f: 147 | f.write("{}\n".format(final_test_accuracy)) 148 | -------------------------------------------------------------------------------- /benchmarks/node_classification/bench_report_results.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | import numpy as np 3 | np.set_printoptions(precision=4) 4 | 5 | accuracy_list = [] 6 | with open("results.txt", "r", encoding="utf-8") as f: 7 | for line in f: 8 | accuracy = float(line.strip()) 9 | accuracy_list.append(accuracy) 10 | 11 | accuracy_list = np.array(accuracy_list) 12 | 13 | accuracy_mean = np.mean(accuracy_list) 14 | accuracy_std = np.std(accuracy_list) 15 | 16 | print("accuracy_list = {}".format(accuracy_list)) 17 | print("num_tests = {}".format(len(accuracy_list))) 18 | print("accuracy: mean = {:.4f}\tstd = {:.4f}".format(accuracy_mean, accuracy_std)) 19 | -------------------------------------------------------------------------------- /benchmarks/node_classification/run_multi_times.sh: -------------------------------------------------------------------------------- 1 | rm results.txt 2 | #SCRIPT=bench_node_cls_early_stop_gat.py 3 | #SCRIPT=bench_node_cls_early_stop_gcn.py 4 | SCRIPT=bench_node_cls_early_stop_appnp.py 5 | #SCRIPT=bench_node_cls_early_stop_sgc.py 6 | 7 | for i in $(seq 1 20) 8 | do 9 | python $SCRIPT 10 | python bench_report_results.py 11 | done 12 | 13 | -------------------------------------------------------------------------------- /demo/demo_appnp.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | import os 3 | 4 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 5 | from tf_geometric.utils import tf_utils 6 | import tensorflow as tf 7 | import tf_geometric as tfg 8 | from tqdm import tqdm 9 | import time 10 | 11 | graph, (train_index, valid_index, test_index) = tfg.datasets.CoraDataset().load_data() 12 | 13 | 14 | num_classes = graph.y.max() + 1 15 | drop_rate = 0.5 16 | learning_rate = 1e-2 17 | 18 | 19 | # APPNP Model 20 | class APPNPModel(tf.keras.Model): 21 | 22 | def __init__(self, *args, **kwargs): 23 | super().__init__(*args, **kwargs) 24 | self.appnp = tfg.layers.APPNP([64, num_classes], alpha=0.1, k=10, 25 | dense_drop_rate=drop_rate, edge_drop_rate=drop_rate) 26 | self.dropout = tf.keras.layers.Dropout(drop_rate) 27 | 28 | def call(self, inputs, training=None, mask=None, cache=None): 29 | x, edge_index, edge_weight = inputs 30 | h = self.dropout(x, training=training) 31 | h = self.appnp([h, edge_index, edge_weight], training=training, cache=cache) 32 | return h 33 | 34 | 35 | model = APPNPModel() 36 | 37 | 38 | # @tf_utils.function can speed up functions for TensorFlow 2.x. 39 | # @tf_utils.function is not compatible with TensorFlow 1.x and dynamic graph.cache. 40 | @tf_utils.function 41 | def forward(graph, training=False): 42 | return model([graph.x, graph.edge_index, graph.edge_weight], training=training, cache=graph.cache) 43 | 44 | 45 | # The following line is only necessary for using APPNP with @tf_utils.function 46 | # For usage without @tf_utils.function, you can commont the following line and APPNP layers can automatically manager the cache 47 | model.appnp.build_cache_for_graph(graph) 48 | 49 | 50 | @tf_utils.function 51 | def compute_loss(logits, mask_index, vars): 52 | masked_logits = tf.gather(logits, mask_index) 53 | masked_labels = tf.gather(graph.y, mask_index) 54 | losses = tf.nn.softmax_cross_entropy_with_logits( 55 | logits=masked_logits, 56 | labels=tf.one_hot(masked_labels, depth=num_classes) 57 | ) 58 | 59 | kernel_vars = [var for var in vars if "kernel" in var.name] 60 | l2_losses = [tf.nn.l2_loss(kernel_var) for kernel_var in kernel_vars] 61 | 62 | return tf.reduce_mean(losses) + tf.add_n(l2_losses) * 5e-4 63 | 64 | 65 | @tf_utils.function 66 | def train_step(): 67 | with tf.GradientTape() as tape: 68 | logits = forward(graph, training=True) 69 | loss = compute_loss(logits, train_index, tape.watched_variables()) 70 | 71 | vars = tape.watched_variables() 72 | grads = tape.gradient(loss, vars) 73 | optimizer.apply_gradients(zip(grads, vars)) 74 | return loss 75 | 76 | 77 | @tf_utils.function 78 | def evaluate(): 79 | logits = forward(graph) 80 | masked_logits = tf.gather(logits, test_index) 81 | masked_labels = tf.gather(graph.y, test_index) 82 | 83 | y_pred = tf.argmax(masked_logits, axis=-1, output_type=tf.int32) 84 | 85 | corrects = tf.equal(y_pred, masked_labels) 86 | accuracy = tf.reduce_mean(tf.cast(corrects, tf.float32)) 87 | return accuracy 88 | 89 | 90 | optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate) 91 | 92 | for step in range(1, 401): 93 | loss = train_step() 94 | if step % 20 == 0: 95 | accuracy = evaluate() 96 | print("step = {}\tloss = {}\taccuracy = {}".format(step, loss, accuracy)) 97 | 98 | print("\nstart speed test...") 99 | num_test_iterations = 1000 100 | start_time = time.time() 101 | for _ in tqdm(range(num_test_iterations)): 102 | logits = forward(graph) 103 | end_time = time.time() 104 | print("mean forward time: {} seconds".format((end_time - start_time) / num_test_iterations)) 105 | 106 | if tf.__version__[0] == "1": 107 | print("** @tf_utils.function is disabled in TensorFlow 1.x. " 108 | "Upgrade to TensorFlow 2.x for 10X faster speed. **") 109 | -------------------------------------------------------------------------------- /demo/demo_asap.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | import os 3 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 4 | from tf_geometric.layers import ASAP, GCN 5 | import tf_geometric as tfg 6 | import tensorflow as tf 7 | import numpy as np 8 | from sklearn.model_selection import train_test_split 9 | from tqdm import tqdm 10 | 11 | # TU Datasets: https://ls11-www.cs.tu-dortmund.de/staff/morris/graphkerneldatasets 12 | graph_dicts = tfg.datasets.TUDataset("NCI1").load_data() 13 | 14 | # Since a TU dataset may contain node_labels, node_attributes etc., each of which can be used as node features 15 | # We process each graph as a dict and return a list of dict for graphs 16 | # You can easily construct you Graph object with the data dict 17 | 18 | num_node_labels = np.max([np.max(graph_dict["node_labels"]) for graph_dict in graph_dicts]) + 1 19 | 20 | 21 | def convert_node_labels_to_one_hot(node_labels): 22 | num_nodes = len(node_labels) 23 | x = np.zeros([num_nodes, num_node_labels], dtype=np.float32) 24 | x[list(range(num_nodes)), node_labels] = 1.0 25 | return x 26 | 27 | 28 | def construct_graph(graph_dict): 29 | return tfg.Graph( 30 | x=convert_node_labels_to_one_hot(graph_dict["node_labels"]), 31 | edge_index=graph_dict["edge_index"], 32 | y=graph_dict["graph_label"] # graph_dict["graph_label"] is a list with one int element 33 | ) 34 | 35 | 36 | graphs = [construct_graph(graph_dict) for graph_dict in graph_dicts] 37 | num_classes = np.max([graph.y[0] for graph in graphs]) + 1 38 | 39 | train_graphs, test_graphs = train_test_split(graphs, test_size=0.1) 40 | 41 | 42 | def create_graph_generator(graphs, batch_size, infinite=False, shuffle=False): 43 | while True: 44 | dataset = tf.data.Dataset.range(len(graphs)) 45 | if shuffle: 46 | dataset = dataset.shuffle(2000) 47 | dataset = dataset.batch(batch_size) 48 | 49 | for batch_graph_index in dataset: 50 | batch_graph_list = [graphs[i] for i in batch_graph_index] 51 | 52 | batch_graph = tfg.BatchGraph.from_graphs(batch_graph_list) 53 | yield batch_graph 54 | 55 | if not infinite: 56 | break 57 | 58 | 59 | batch_size = 128 60 | 61 | 62 | class ASAPModel(tf.keras.Model): 63 | 64 | def __init__(self, *args, **kwargs): 65 | super().__init__(*args, **kwargs) 66 | 67 | self.gcns = [] 68 | self.asaps = [] 69 | 70 | for _ in range(3): 71 | self.gcns.append(GCN(64, activation=tf.nn.relu)) 72 | self.asaps.append(ASAP(ratio=0.5, drop_rate=0.1)) 73 | 74 | self.mlp = tf.keras.Sequential([ 75 | tf.keras.layers.Dense(64, activation=tf.nn.relu), 76 | tf.keras.layers.Dropout(0.5), 77 | tf.keras.layers.Dense(num_classes) 78 | ]) 79 | 80 | def call(self, inputs, training=None, mask=None): 81 | 82 | x, edge_index, edge_weight, node_graph_index = inputs 83 | h = x 84 | 85 | outputs = [] 86 | for gcn, asap in zip(self.gcns, self.asaps): 87 | h = gcn([h, edge_index, edge_weight], training=training) 88 | h, edge_index, edge_weight, node_graph_index = asap([h, edge_index, edge_weight, node_graph_index], 89 | training=training) 90 | output = tf.concat([ 91 | tfg.nn.mean_pool(h, node_graph_index), 92 | tfg.nn.max_pool(h, node_graph_index) 93 | ], axis=-1) 94 | outputs.append(output) 95 | 96 | h = tf.reduce_sum(tf.stack(outputs, axis=1), axis=1) 97 | 98 | # Predict Graph Labels 99 | h = self.mlp(h, training=training) 100 | return h 101 | 102 | 103 | model = ASAPModel() 104 | 105 | 106 | def forward(batch_graph, training=False): 107 | return model([batch_graph.x, batch_graph.edge_index, batch_graph.edge_weight, batch_graph.node_graph_index], 108 | training=training) 109 | 110 | 111 | def evaluate(): 112 | accuracy_m = tf.keras.metrics.Accuracy() 113 | 114 | for test_batch_graph in create_graph_generator(test_graphs, batch_size, shuffle=False, infinite=False): 115 | logits = forward(test_batch_graph) 116 | preds = tf.argmax(logits, axis=-1) 117 | accuracy_m.update_state(test_batch_graph.y, preds) 118 | 119 | return accuracy_m.result().numpy() 120 | 121 | 122 | optimizer = tf.keras.optimizers.Adam(learning_rate=1e-2) 123 | 124 | train_batch_generator = create_graph_generator(train_graphs, batch_size, shuffle=True, infinite=True) 125 | 126 | 127 | for step in tqdm(range(20000)): 128 | train_batch_graph = next(train_batch_generator) 129 | with tf.GradientTape() as tape: 130 | logits = forward(train_batch_graph, training=True) 131 | losses = tf.nn.softmax_cross_entropy_with_logits( 132 | logits=logits, 133 | labels=tf.one_hot(train_batch_graph.y, depth=num_classes) 134 | ) 135 | 136 | vars = tape.watched_variables() 137 | grads = tape.gradient(losses, vars) 138 | optimizer.apply_gradients(zip(grads, vars)) 139 | 140 | if step % 20 == 0: 141 | mean_loss = tf.reduce_mean(losses) 142 | accuracy = evaluate() 143 | print("step = {}\tloss = {}\taccuracy = {}".format(step, mean_loss, accuracy)) 144 | 145 | -------------------------------------------------------------------------------- /demo/demo_chebynet.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | import os 3 | 4 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 5 | from tf_geometric.utils import tf_utils 6 | import tf_geometric as tfg 7 | import tensorflow as tf 8 | from tensorflow import keras 9 | from tf_geometric.datasets import CoraDataset 10 | from tqdm import tqdm 11 | 12 | graph, (train_index, valid_index, test_index) = CoraDataset().load_data() 13 | 14 | num_classes = graph.y.max() + 1 15 | 16 | model = tfg.layers.ChebyNet(64, k=3, activation=tf.nn.relu) 17 | fc = tf.keras.Sequential([ 18 | keras.layers.Dropout(0.5), 19 | keras.layers.Dense(num_classes) 20 | ]) 21 | 22 | model.build_cache_for_graph(graph) 23 | 24 | 25 | # @tf_utils.function can speed up functions for TensorFlow 2.x 26 | @tf_utils.function 27 | def forward(graph, training=False): 28 | h = model([graph.x, graph.edge_index, graph.edge_weight], cache=graph.cache) 29 | h = fc(h, training=training) 30 | return h 31 | 32 | 33 | @tf_utils.function 34 | def compute_loss(logits, mask_index, vars): 35 | masked_logits = tf.gather(logits, mask_index) 36 | masked_labels = tf.gather(graph.y, mask_index) 37 | losses = tf.nn.softmax_cross_entropy_with_logits( 38 | logits=masked_logits, 39 | labels=tf.one_hot(masked_labels, depth=num_classes) 40 | ) 41 | 42 | kernel_vars = [var for var in vars if "kernel" in var.name] 43 | l2_losses = [tf.nn.l2_loss(kernel_var) for kernel_var in kernel_vars] 44 | 45 | return tf.reduce_mean(losses) + tf.add_n(l2_losses) * 5e-4 46 | 47 | 48 | optimizer = tf.keras.optimizers.Adam(learning_rate=1e-2) 49 | 50 | 51 | @tf_utils.function 52 | def train_step(): 53 | with tf.GradientTape() as tape: 54 | logits = forward(graph, training=True) 55 | loss = compute_loss(logits, train_index, tape.watched_variables()) 56 | 57 | vars = tape.watched_variables() 58 | grads = tape.gradient(loss, vars) 59 | optimizer.apply_gradients(zip(grads, vars)) 60 | return loss 61 | 62 | 63 | @tf_utils.function 64 | def evaluate(): 65 | logits = forward(graph) 66 | masked_logits = tf.gather(logits, test_index) 67 | masked_labels = tf.gather(graph.y, test_index) 68 | 69 | y_pred = tf.argmax(masked_logits, axis=-1, output_type=tf.int32) 70 | 71 | corrects = tf.equal(y_pred, masked_labels) 72 | accuracy = tf.reduce_mean(tf.cast(corrects, tf.float32)) 73 | return accuracy 74 | 75 | 76 | best_test_acc = 0 77 | for step in tqdm(range(1, 101)): 78 | loss = train_step() 79 | 80 | test_acc = evaluate() 81 | if test_acc > best_test_acc: 82 | best_test_acc = test_acc 83 | print("step = {}\tloss = {}\tbest_test_acc = {}".format(step, loss, best_test_acc)) 84 | -------------------------------------------------------------------------------- /demo/demo_distributed_gcn.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | import os 3 | # multi-gpu ids 4 | os.environ["CUDA_VISIBLE_DEVICES"] = "0,1" 5 | import tf_geometric as tfg 6 | from tf_geometric.layers import GCN 7 | from tensorflow.keras.regularizers import L1L2 8 | import tensorflow as tf 9 | 10 | 11 | graph, (train_index, valid_index, test_index) = tfg.datasets.CoraDataset().load_data() 12 | num_classes = graph.y.max() + 1 13 | 14 | drop_rate = 0.5 15 | learning_rate = 1e-2 16 | l2_coef = 5e-4 17 | 18 | 19 | # custom network 20 | class GCNNetwork(tf.keras.Model): 21 | 22 | def __init__(self, *args, **kwargs): 23 | super().__init__(*args, **kwargs) 24 | self.gcn0 = GCN(16, activation=tf.nn.relu, kernel_regularizer=L1L2(l2=l2_coef)) 25 | self.gcn1 = GCN(num_classes, kernel_regularizer=L1L2(l2=l2_coef)) 26 | self.dropout = tf.keras.layers.Dropout(drop_rate) 27 | 28 | def call(self, inputs, training=None, mask=None): 29 | x, edge_index = inputs 30 | h = self.dropout(x, training=training) 31 | h = self.gcn0([h, edge_index], training=training) 32 | h = self.dropout(h, training=training) 33 | h = self.gcn1([h, edge_index], training=training) 34 | return h 35 | 36 | 37 | # prepare a generator and a dataset for distributed training 38 | def create_batch_generator(): 39 | while True: 40 | yield (graph.x, graph.edge_index), graph.y 41 | 42 | 43 | def dataset_fn(ctx): 44 | dataset = tf.data.Dataset.from_generator( 45 | create_batch_generator, 46 | output_types=((tf.float32, tf.int32), tf.int32), 47 | output_shapes=((tf.TensorShape([None, graph.x.shape[1]]), tf.TensorShape([2, None])), tf.TensorShape([None])) 48 | ) 49 | return dataset 50 | 51 | 52 | strategy = tf.distribute.MirroredStrategy() 53 | distributed_dataset = strategy.experimental_distribute_datasets_from_function(dataset_fn) 54 | 55 | # The model will automatically use all seen GPUs defined by "CUDA_VISIBLE_DEVICES" for distributed training 56 | with strategy.scope(): 57 | model = GCNNetwork() 58 | 59 | 60 | # custom loss function 61 | def masked_cross_entropy(y_true, logits): 62 | y_true = tf.cast(y_true, tf.int32) 63 | masked_logits = tf.gather(logits, train_index) 64 | masked_labels = tf.gather(y_true, train_index) 65 | losses = tf.nn.softmax_cross_entropy_with_logits( 66 | logits=masked_logits, 67 | labels=tf.one_hot(masked_labels, depth=num_classes) 68 | ) 69 | 70 | return tf.reduce_mean(losses) 71 | 72 | 73 | model.compile( 74 | optimizer=tf.keras.optimizers.Adam(learning_rate=1e-2), 75 | loss=masked_cross_entropy, 76 | # run_eagerly=True 77 | ) 78 | 79 | 80 | def evaluate(): 81 | logits = model([graph.x, graph.edge_index]) 82 | masked_logits = tf.gather(logits, test_index) 83 | masked_labels = tf.gather(graph.y, test_index) 84 | 85 | y_pred = tf.argmax(masked_logits, axis=-1, output_type=tf.int32) 86 | corrects = tf.cast(tf.equal(masked_labels, y_pred), tf.float32) 87 | accuracy = tf.reduce_mean(corrects) 88 | return accuracy.numpy() 89 | 90 | 91 | class EvaluationCallback(tf.keras.callbacks.Callback): 92 | def on_epoch_end(self, epoch, logs=None): 93 | if epoch % 20 == 0: 94 | test_accuracy = evaluate() 95 | print("epoch = {}\ttest_accuracy = {}".format(epoch, test_accuracy)) 96 | 97 | 98 | # The model will automatically use all seen GPUs defined by "CUDA_VISIBLE_DEVICES" for distributed training 99 | model.fit(distributed_dataset, steps_per_epoch=1, epochs=201, callbacks=[EvaluationCallback()], verbose=2) 100 | 101 | 102 | test_accuracy = evaluate() 103 | print("final test_accuracy = {}".format(test_accuracy)) -------------------------------------------------------------------------------- /demo/demo_drop_edge_gcn.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | import os 3 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 4 | from tf_geometric.layers import GCN, DropEdge 5 | from tensorflow.keras.layers import Dropout 6 | from tf_geometric.utils import tf_utils 7 | import tensorflow as tf 8 | import tf_geometric as tfg 9 | from tqdm import tqdm 10 | import time 11 | 12 | graph, (train_index, valid_index, test_index) = tfg.datasets.SupervisedCoraDataset().load_data() 13 | 14 | num_classes = graph.y.max() + 1 15 | num_gcns = 8 16 | drop_rate = 0.5 17 | edge_drop_rate = 0.5 18 | learning_rate = 5e-4 19 | l2_coe = 0.0 20 | 21 | units_list = [128] * (num_gcns - 1) + [num_classes] 22 | 23 | 24 | # Simple Multi-layer GCN Model 25 | class GCNModel(tf.keras.Model): 26 | 27 | def __init__(self, *args, **kwargs): 28 | super().__init__(*args, **kwargs) 29 | 30 | activations = [tf.nn.relu if i < len(units_list) - 1 else None for i in range(len(units_list))] 31 | self.gcns = [GCN(units=units, activation=activation) for units, activation in zip(units_list, activations)] 32 | self.dropout = Dropout(drop_rate) 33 | 34 | def call(self, inputs, training=None, mask=None): 35 | 36 | x, edge_index, edge_weight = inputs 37 | h = self.dropout(x, training=training) 38 | 39 | cache = {} 40 | for i in range(num_gcns): 41 | h = self.gcns[i]([h, edge_index, edge_weight], cache=cache) 42 | 43 | return h 44 | 45 | 46 | drop_edge = DropEdge(edge_drop_rate, force_undirected=True) 47 | model = GCNModel() 48 | 49 | 50 | # @tf_utils.function can speed up functions for TensorFlow 2.x. 51 | # @tf_utils.function is not compatible with TensorFlow 1.x and dynamic graph.cache. 52 | @tf_utils.function 53 | def forward(graph, training=False): 54 | 55 | # DropEdge: Towards Deep Graph Convolutional Networks on Node Classification 56 | edge_index, edge_weight = drop_edge([graph.edge_index, graph.edge_weight], training=training) 57 | 58 | return model([graph.x, edge_index, edge_weight], training=training) 59 | 60 | 61 | @tf_utils.function 62 | def compute_loss(logits, mask_index, vars): 63 | masked_logits = tf.gather(logits, mask_index) 64 | masked_labels = tf.gather(graph.y, mask_index) 65 | losses = tf.nn.softmax_cross_entropy_with_logits( 66 | logits=masked_logits, 67 | labels=tf.one_hot(masked_labels, depth=num_classes) 68 | ) 69 | 70 | kernel_vars = [var for var in vars if "kernel" in var.name] 71 | l2_losses = [tf.nn.l2_loss(kernel_var) for kernel_var in kernel_vars] 72 | 73 | return tf.reduce_mean(losses) + tf.add_n(l2_losses) * l2_coe 74 | 75 | 76 | @tf_utils.function 77 | def evaluate(): 78 | logits = forward(graph) 79 | masked_logits = tf.gather(logits, test_index) 80 | masked_labels = tf.gather(graph.y, test_index) 81 | 82 | y_pred = tf.argmax(masked_logits, axis=-1, output_type=tf.int32) 83 | 84 | corrects = tf.equal(y_pred, masked_labels) 85 | accuracy = tf.reduce_mean(tf.cast(corrects, tf.float32)) 86 | return accuracy 87 | 88 | 89 | optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate) 90 | 91 | for step in range(1, 501): 92 | with tf.GradientTape() as tape: 93 | logits = forward(graph, training=True) 94 | loss = compute_loss(logits, train_index, tape.watched_variables()) 95 | 96 | vars = tape.watched_variables() 97 | grads = tape.gradient(loss, vars) 98 | optimizer.apply_gradients(zip(grads, vars)) 99 | 100 | if step % 20 == 0: 101 | accuracy = evaluate() 102 | print("step = {}\tloss = {}\taccuracy = {}".format(step, loss, accuracy)) 103 | 104 | print("\nstart speed test...") 105 | num_test_iterations = 1000 106 | start_time = time.time() 107 | for _ in tqdm(range(num_test_iterations)): 108 | logits = forward(graph) 109 | end_time = time.time() 110 | print("mean forward time: {} seconds".format((end_time - start_time) / num_test_iterations)) 111 | 112 | if tf.__version__[0] == "1": 113 | print("** @tf_utils.function is disabled in TensorFlow 1.x. " 114 | "Upgrade to TensorFlow 2.x for 10X faster speed. **") 115 | -------------------------------------------------------------------------------- /demo/demo_elegant_api.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | import numpy as np 3 | import tf_geometric as tfg 4 | import tensorflow as tf 5 | 6 | graph = tfg.Graph( 7 | x=np.random.randn(5, 20), # 5 nodes, 20 features, 8 | edge_index=[[0, 0, 1, 3], 9 | [1, 2, 2, 1]] # 4 undirected edges 10 | ) 11 | 12 | print("Graph Desc: \n", graph) 13 | 14 | graph = graph.to_directed() # pre-process edges 15 | print("Processed Graph Desc: \n", graph) 16 | print("Processed Edge Index:\n", graph.edge_index) 17 | 18 | # Multi-head Graph Attention Network (GAT) 19 | gat_layer = tfg.layers.GAT(units=4, num_heads=4, activation=tf.nn.relu) 20 | output = gat_layer([graph.x, graph.edge_index]) 21 | print("Output of GAT: \n", output) -------------------------------------------------------------------------------- /demo/demo_gae.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | import os 3 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 4 | from tf_geometric.utils import tf_utils 5 | import tf_geometric as tfg 6 | import tensorflow as tf 7 | from tf_geometric.utils.graph_utils import edge_train_test_split, negative_sampling 8 | from tqdm import tqdm 9 | 10 | 11 | graph, (train_index, valid_index, test_index) = tfg.datasets.CoraDataset().load_data() 12 | 13 | 14 | # undirected edges can be used for evaluation 15 | undirected_train_edge_index, undirected_test_edge_index, _, _ = edge_train_test_split( 16 | edge_index=graph.edge_index, 17 | test_size=0.15 18 | ) 19 | 20 | # use negative_sampling with replace=False to create negative edges for test 21 | undirected_test_neg_edge_index = negative_sampling( 22 | num_samples=undirected_test_edge_index.shape[1], 23 | num_nodes=graph.num_nodes, 24 | edge_index=graph.edge_index, 25 | replace=False 26 | ) 27 | 28 | # for training, you should convert undirected edges to directed edges for correct GCN propagation 29 | train_graph = tfg.Graph(x=graph.x, edge_index=undirected_train_edge_index).to_directed() 30 | 31 | 32 | embedding_size = 16 33 | drop_rate = 0.2 34 | 35 | gcn0 = tfg.layers.GCN(32, activation=tf.nn.relu) 36 | gcn1 = tfg.layers.GCN(embedding_size) 37 | dropout = tf.keras.layers.Dropout(drop_rate) 38 | 39 | 40 | @tf_utils.function 41 | def encode(graph, training=False): 42 | h = gcn0([graph.x, graph.edge_index, graph.edge_weight], cache=graph.cache) 43 | h = dropout(h, training=training) 44 | h = gcn1([h, graph.edge_index, graph.edge_weight], cache=graph.cache) 45 | return h 46 | 47 | 48 | gcn0.build_cache_for_graph(graph) 49 | gcn0.build_cache_for_graph(train_graph) 50 | 51 | 52 | @tf_utils.function 53 | def predict_edge(embedded, edge_index): 54 | row, col = edge_index[0], edge_index[1] 55 | embedded_row = tf.gather(embedded, row) 56 | embedded_col = tf.gather(embedded, col) 57 | 58 | # dot product 59 | logits = tf.reduce_sum(embedded_row * embedded_col, axis=-1) 60 | return logits 61 | 62 | 63 | @tf_utils.function 64 | def compute_loss(pos_edge_logits, neg_edge_logits): 65 | pos_losses = tf.nn.sigmoid_cross_entropy_with_logits( 66 | logits=pos_edge_logits, 67 | labels=tf.ones_like(pos_edge_logits) 68 | ) 69 | 70 | neg_losses = tf.nn.sigmoid_cross_entropy_with_logits( 71 | logits=neg_edge_logits, 72 | labels=tf.zeros_like(neg_edge_logits) 73 | ) 74 | 75 | return tf.reduce_mean(pos_losses) + tf.reduce_mean(neg_losses) 76 | 77 | 78 | def evaluate(): 79 | embedded = encode(train_graph) 80 | 81 | pos_edge_logits = predict_edge(embedded, undirected_test_edge_index) 82 | neg_edge_logits = predict_edge(embedded, undirected_test_neg_edge_index) 83 | 84 | pos_edge_scores = tf.nn.sigmoid(pos_edge_logits) 85 | neg_edge_scores = tf.nn.sigmoid(neg_edge_logits) 86 | 87 | y_true = tf.concat([tf.ones_like(pos_edge_scores), tf.zeros_like(neg_edge_scores)], axis=0) 88 | y_pred = tf.concat([pos_edge_scores, neg_edge_scores], axis=0) 89 | 90 | auc_m = tf.keras.metrics.AUC() 91 | auc_m.update_state(y_true, y_pred) 92 | 93 | return auc_m.result().numpy() 94 | 95 | 96 | optimizer = tf.keras.optimizers.Adam(learning_rate=1e-2) 97 | 98 | for step in tqdm(range(1000)): 99 | with tf.GradientTape() as tape: 100 | embedded = encode(train_graph, training=True) 101 | 102 | # negative sampling for training 103 | train_neg_edge_index = negative_sampling( 104 | train_graph.num_edges, 105 | graph.num_nodes, 106 | edge_index=None#train_graph.edge_index 107 | ) 108 | 109 | pos_edge_logits = predict_edge(embedded, train_graph.edge_index) 110 | neg_edge_logits = predict_edge(embedded, train_neg_edge_index) 111 | 112 | loss = compute_loss(pos_edge_logits, neg_edge_logits) 113 | 114 | vars = tape.watched_variables() 115 | grads = tape.gradient(loss, vars) 116 | optimizer.apply_gradients(zip(grads, vars)) 117 | 118 | if step % 20 == 0: 119 | auc_score = evaluate() 120 | print("step = {}\tloss = {}\tauc_score = {}".format(step, loss, auc_score)) -------------------------------------------------------------------------------- /demo/demo_gat.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | import os 3 | 4 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 5 | from tf_geometric.utils import tf_utils 6 | import tf_geometric as tfg 7 | import tensorflow as tf 8 | import time 9 | from tqdm import tqdm 10 | 11 | graph, (train_index, valid_index, test_index) = tfg.datasets.CoraDataset().load_data() 12 | 13 | num_classes = graph.y.max() + 1 14 | drop_rate = 0.6 15 | 16 | 17 | # Multi-layer GAT Model 18 | class GATModel(tf.keras.Model): 19 | 20 | def __init__(self, *args, **kwargs): 21 | super().__init__(*args, **kwargs) 22 | self.gat0 = tfg.layers.GAT(64, activation=tf.nn.relu, num_heads=8, attention_units=8, edge_drop_rate=drop_rate) 23 | self.gat1 = tfg.layers.GAT(num_classes, num_heads=1, attention_units=1, edge_drop_rate=drop_rate) 24 | 25 | # The GAT paper mentioned that: "Specially, if we perform multi-head attention on the final (prediction) layer of 26 | # the network, concatenation is no longer sensible - instead, we employ averaging". 27 | # In tf_geometric, if you want to set num_heads > 1 for the last output GAT layer, you can set split_value_heads=False 28 | # as follows to employ averaging instead of concatenation. 29 | # self.gat1 = tfg.layers.GAT(num_classes, num_heads=8, attention_units=8, split_value_heads=False, edge_drop_rate=drop_rate) 30 | 31 | self.dropout = tf.keras.layers.Dropout(drop_rate) 32 | 33 | def call(self, inputs, training=None, mask=None, cache=None): 34 | x, edge_index = inputs 35 | h = self.dropout(x, training=training) 36 | h = self.gat0([h, edge_index], training=training) 37 | h = self.dropout(h, training=training) 38 | h = self.gat1([h, edge_index], training=training) 39 | return h 40 | 41 | 42 | model = GATModel() 43 | 44 | 45 | # @tf_utils.function can speed up functions for TensorFlow 2.x 46 | @tf_utils.function 47 | def forward(graph, training=False): 48 | return model([graph.x, graph.edge_index], training=training) 49 | 50 | 51 | @tf_utils.function 52 | def compute_loss(logits, mask_index, vars): 53 | masked_logits = tf.gather(logits, mask_index) 54 | masked_labels = tf.gather(graph.y, mask_index) 55 | losses = tf.nn.softmax_cross_entropy_with_logits( 56 | logits=masked_logits, 57 | labels=tf.one_hot(masked_labels, depth=num_classes) 58 | ) 59 | 60 | kernel_vars = [var for var in vars if "kernel" in var.name] 61 | l2_losses = [tf.nn.l2_loss(kernel_var) for kernel_var in kernel_vars] 62 | 63 | return tf.reduce_mean(losses) + tf.add_n(l2_losses) * 5e-4 64 | 65 | 66 | @tf_utils.function 67 | def train_step(): 68 | with tf.GradientTape() as tape: 69 | logits = forward(graph, training=True) 70 | loss = compute_loss(logits, train_index, tape.watched_variables()) 71 | 72 | vars = tape.watched_variables() 73 | grads = tape.gradient(loss, vars) 74 | optimizer.apply_gradients(zip(grads, vars)) 75 | return loss 76 | 77 | 78 | @tf_utils.function 79 | def evaluate(): 80 | logits = forward(graph) 81 | masked_logits = tf.gather(logits, test_index) 82 | masked_labels = tf.gather(graph.y, test_index) 83 | y_pred = tf.argmax(masked_logits, axis=-1, output_type=tf.int32) 84 | 85 | corrects = tf.equal(y_pred, masked_labels) 86 | accuracy = tf.reduce_mean(tf.cast(corrects, tf.float32)) 87 | return accuracy 88 | 89 | 90 | optimizer = tf.keras.optimizers.Adam(learning_rate=5e-3) 91 | 92 | for step in range(1, 401): 93 | loss = train_step() 94 | 95 | if step % 20 == 0: 96 | accuracy = evaluate() 97 | print("step = {}\tloss = {}\taccuracy = {}".format(step, loss, accuracy)) 98 | 99 | print("\nstart speed test...") 100 | num_test_iterations = 1000 101 | start_time = time.time() 102 | for _ in tqdm(range(num_test_iterations)): 103 | logits = forward(graph) 104 | end_time = time.time() 105 | print("mean forward time: {} seconds".format((end_time - start_time) / num_test_iterations)) 106 | 107 | if tf.__version__[0] == "1": 108 | print("** @tf_utils.function is disabled in TensorFlow 1.x. " 109 | "Upgrade to TensorFlow 2.x for 10X faster speed. **") 110 | -------------------------------------------------------------------------------- /demo/demo_gcn.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | import os 3 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 4 | from tf_geometric.utils import tf_utils 5 | import tensorflow as tf 6 | import tf_geometric as tfg 7 | from tqdm import tqdm 8 | import time 9 | 10 | graph, (train_index, valid_index, test_index) = tfg.datasets.CoraDataset().load_data() 11 | 12 | num_classes = graph.y.max() + 1 13 | drop_rate = 0.5 14 | learning_rate = 1e-2 15 | 16 | 17 | # Multi-layer GCN Model 18 | class GCNModel(tf.keras.Model): 19 | 20 | def __init__(self, *args, **kwargs): 21 | super().__init__(*args, **kwargs) 22 | self.gcn0 = tfg.layers.GCN(16, activation=tf.nn.relu) 23 | self.gcn1 = tfg.layers.GCN(num_classes) 24 | self.dropout = tf.keras.layers.Dropout(drop_rate) 25 | 26 | def call(self, inputs, training=None, mask=None, cache=None): 27 | x, edge_index, edge_weight = inputs 28 | h = self.dropout(x, training=training) 29 | h = self.gcn0([h, edge_index, edge_weight], cache=cache) 30 | h = self.dropout(h, training=training) 31 | h = self.gcn1([h, edge_index, edge_weight], cache=cache) 32 | return h 33 | 34 | 35 | model = GCNModel() 36 | 37 | 38 | # @tf_utils.function can speed up functions for TensorFlow 2.x. 39 | # @tf_utils.function is not compatible with TensorFlow 1.x and dynamic graph.cache. 40 | @tf_utils.function 41 | def forward(graph, training=False): 42 | return model([graph.x, graph.edge_index, graph.edge_weight], training=training, cache=graph.cache) 43 | 44 | 45 | # The following line is only necessary for using GCN with @tf_utils.function 46 | # For usage without @tf_utils.function, you can commont the following line and GCN layers can automatically manager the cache 47 | model.gcn0.build_cache_for_graph(graph) 48 | 49 | 50 | @tf_utils.function 51 | def compute_loss(logits, mask_index, vars): 52 | masked_logits = tf.gather(logits, mask_index) 53 | masked_labels = tf.gather(graph.y, mask_index) 54 | losses = tf.nn.sparse_softmax_cross_entropy_with_logits( 55 | logits=masked_logits, 56 | labels=masked_labels 57 | ) 58 | 59 | kernel_vars = [var for var in vars if "kernel" in var.name] 60 | l2_losses = [tf.nn.l2_loss(kernel_var) for kernel_var in kernel_vars] 61 | 62 | return tf.reduce_mean(losses) + tf.add_n(l2_losses) * 5e-4 63 | 64 | 65 | optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate) 66 | 67 | 68 | @tf_utils.function 69 | def train_step(): 70 | with tf.GradientTape() as tape: 71 | logits = forward(graph, training=True) 72 | loss = compute_loss(logits, train_index, tape.watched_variables()) 73 | 74 | vars = tape.watched_variables() 75 | grads = tape.gradient(loss, vars) 76 | optimizer.apply_gradients(zip(grads, vars)) 77 | return loss 78 | 79 | 80 | @tf_utils.function 81 | def evaluate(): 82 | logits = forward(graph) 83 | masked_logits = tf.gather(logits, test_index) 84 | masked_labels = tf.gather(graph.y, test_index) 85 | 86 | y_pred = tf.argmax(masked_logits, axis=-1, output_type=tf.int32) 87 | 88 | corrects = tf.equal(y_pred, masked_labels) 89 | accuracy = tf.reduce_mean(tf.cast(corrects, tf.float32)) 90 | return accuracy 91 | 92 | 93 | for step in range(1, 201): 94 | loss = train_step() 95 | if step % 20 == 0: 96 | accuracy = evaluate() 97 | print("step = {}\tloss = {}\taccuracy = {}".format(step, loss, accuracy)) 98 | 99 | print("\nstart speed test...") 100 | num_test_iterations = 1000 101 | start_time = time.time() 102 | for _ in tqdm(range(num_test_iterations)): 103 | logits = forward(graph) 104 | end_time = time.time() 105 | print("mean forward time: {} seconds".format((end_time - start_time) / num_test_iterations)) 106 | 107 | if tf.__version__[0] == "1": 108 | print("** @tf_utils.function is disabled in TensorFlow 1.x. " 109 | "Upgrade to TensorFlow 2.x for 10X faster speed. **") 110 | -------------------------------------------------------------------------------- /demo/demo_graph_sage.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | import os 3 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 4 | import tf_geometric as tfg 5 | from tf_geometric.datasets.ppi import PPIDataset 6 | from tf_geometric.utils.graph_utils import RandomNeighborSampler 7 | import tensorflow as tf 8 | from tensorflow import keras 9 | import numpy as np 10 | from sklearn.metrics import f1_score 11 | from tqdm import tqdm 12 | 13 | train_graphs, valid_graphs, test_graphs = PPIDataset().load_data() 14 | 15 | # traverse all graphs 16 | for graph in train_graphs + valid_graphs + test_graphs: 17 | neighbor_sampler = RandomNeighborSampler(graph.edge_index) 18 | graph.cache["sampler"] = neighbor_sampler 19 | 20 | num_classes = train_graphs[0].y.shape[1] 21 | 22 | graph_sages = [ 23 | # tfg.layers.MaxPoolGraphSage(units=256, activation=tf.nn.relu, concat=True), 24 | # tfg.layers.MaxPoolGraphSage(units=256, activation=tf.nn.relu, concat=True) 25 | 26 | # tfg.layers.MeanPoolGraphSage(units=256, activation=tf.nn.relu, concat=True), 27 | # tfg.layers.MeanPoolGraphSage(units=256, activation=tf.nn.relu, concat=True) 28 | 29 | tfg.layers.MeanGraphSage(units=256, activation=tf.nn.relu, concat=True), 30 | tfg.layers.MeanGraphSage(units=256, activation=tf.nn.relu, concat=True) 31 | 32 | # tfg.layers.SumGraphSage(units=256, activation=tf.nn.relu, concat=True), 33 | # tfg.layers.SumGraphSage(units=256, activation=tf.nn.relu, concat=True) 34 | 35 | # tfg.layers.LSTMGraphSage(units=256, activation=tf.nn.relu, concat=True), 36 | # tfg.layers.LSTMGraphSage(units=256, activation=tf.nn.relu, concat=True) 37 | 38 | # tfg.layers.GCNGraphSage(units=256, activation=tf.nn.relu), 39 | # tfg.layers.GCNGraphSage(units=256, activation=tf.nn.relu) 40 | ] 41 | 42 | fc = tf.keras.Sequential([ 43 | keras.layers.Dropout(0.3), 44 | tf.keras.layers.Dense(num_classes) 45 | ]) 46 | 47 | num_sampled_neighbors_list = [25, 10] 48 | 49 | 50 | def forward(graph, training=False): 51 | neighbor_sampler = graph.cache["sampler"] 52 | h = graph.x 53 | for i, (graph_sage, num_sampled_neighbors) in enumerate(zip(graph_sages, num_sampled_neighbors_list)): 54 | sampled_edge_index, sampled_edge_weight = neighbor_sampler.sample(k=num_sampled_neighbors) 55 | h = graph_sage([h, sampled_edge_index, sampled_edge_weight], training=training) 56 | h = fc(h, training=training) 57 | return h 58 | 59 | 60 | def compute_loss(logits, vars): 61 | losses = tf.nn.sigmoid_cross_entropy_with_logits( 62 | logits=logits, 63 | labels=tf.convert_to_tensor(graph.y, dtype=tf.float32) 64 | ) 65 | 66 | kernel_vars = [var for var in vars if "kernel" in var.name] 67 | l2_losses = [tf.nn.l2_loss(kernel_var) for kernel_var in kernel_vars] 68 | 69 | return tf.reduce_mean(losses) + tf.add_n(l2_losses) * 1e-5 70 | 71 | 72 | def calc_f1(y_true, y_pred): 73 | y_pred[y_pred > 0] = 1 74 | y_pred[y_pred <= 0] = 0 75 | 76 | return f1_score(y_true, y_pred, average="micro") 77 | 78 | 79 | def evaluate(graphs): 80 | y_preds = [] 81 | y_true = [] 82 | 83 | for graph in graphs: 84 | y_true.append(graph.y) 85 | logits = forward(graph) 86 | y_preds.append(logits.numpy()) 87 | 88 | y_pred = np.concatenate(y_preds, axis=0) 89 | y = np.concatenate(y_true, axis=0) 90 | 91 | mic = calc_f1(y, y_pred) 92 | 93 | return mic 94 | 95 | 96 | optimizer = tf.keras.optimizers.Adam(learning_rate=1e-2) 97 | 98 | for epoch in tqdm(range(20)): 99 | 100 | for graph in train_graphs: 101 | with tf.GradientTape() as tape: 102 | logits = forward(graph, training=True) 103 | loss = compute_loss(logits, tape.watched_variables()) 104 | 105 | vars = tape.watched_variables() 106 | grads = tape.gradient(loss, vars) 107 | optimizer.apply_gradients(zip(grads, vars)) 108 | 109 | if epoch % 1 == 0: 110 | valid_f1_mic = evaluate(valid_graphs) 111 | test_f1_mic = evaluate(test_graphs) 112 | print("epoch = {}\tloss = {}\tvalid_f1_micro = {}".format(epoch, loss, valid_f1_mic)) 113 | print("epoch = {}\ttest_f1_micro = {}".format(epoch, test_f1_mic)) 114 | # test_f1_mic = evaluate(test_graphs) 115 | # print("test_f1_micro = {}".format(test_f1_mic)) -------------------------------------------------------------------------------- /demo/demo_mean_pool.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | import os 3 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 4 | import tf_geometric as tfg 5 | import tensorflow as tf 6 | import numpy as np 7 | from sklearn.model_selection import train_test_split 8 | 9 | # TU Datasets: https://ls11-www.cs.tu-dortmund.de/staff/morris/graphkerneldatasets 10 | graph_dicts = tfg.datasets.TUDataset("NCI1").load_data() 11 | 12 | # Since a TU dataset may contain node_labels, node_attributes etc., each of which can be used as node features 13 | # We process each graph as a dict and return a list of dict for graphs 14 | # You can easily construct you Graph object with the data dict 15 | 16 | num_node_labels = np.max([np.max(graph_dict["node_labels"]) for graph_dict in graph_dicts]) + 1 17 | 18 | 19 | def convert_node_labels_to_one_hot(node_labels): 20 | num_nodes = len(node_labels) 21 | x = np.zeros([num_nodes, num_node_labels], dtype=np.float32) 22 | x[list(range(num_nodes)), node_labels] = 1.0 23 | return x 24 | 25 | 26 | def construct_graph(graph_dict): 27 | return tfg.Graph( 28 | x=convert_node_labels_to_one_hot(graph_dict["node_labels"]), 29 | edge_index=graph_dict["edge_index"], 30 | y=graph_dict["graph_label"] # graph_dict["graph_label"] is a list with one int element 31 | ) 32 | 33 | 34 | graphs = [construct_graph(graph_dict) for graph_dict in graph_dicts] 35 | num_classes = np.max([graph.y[0] for graph in graphs]) + 1 36 | 37 | train_graphs, test_graphs = train_test_split(graphs, test_size=0.1) 38 | 39 | 40 | def create_graph_generator(graphs, batch_size, infinite=False, shuffle=False): 41 | while True: 42 | dataset = tf.data.Dataset.range(len(graphs)) 43 | if shuffle: 44 | dataset = dataset.shuffle(2000) 45 | dataset = dataset.batch(batch_size) 46 | 47 | for batch_graph_index in dataset: 48 | batch_graph_list = [graphs[i] for i in batch_graph_index] 49 | 50 | batch_graph = tfg.BatchGraph.from_graphs(batch_graph_list) 51 | yield batch_graph 52 | 53 | if not infinite: 54 | break 55 | 56 | 57 | batch_size = 256 58 | 59 | drop_rate = 0.4 60 | 61 | 62 | class MeanPoolNetwork(tf.keras.Model): 63 | 64 | def __init__(self, *args, **kwargs): 65 | super().__init__(*args, **kwargs) 66 | 67 | self.gcn0 = tfg.layers.GCN(64, activation=tf.nn.relu) 68 | self.gcn1 = tfg.layers.GCN(32, activation=tf.nn.relu) 69 | self.dropout = tf.keras.layers.Dropout(drop_rate) 70 | self.dense = tf.keras.layers.Dense(num_classes) 71 | 72 | # @tf_utils.function(experimental_relax_shapes=True) 73 | def call(self, inputs, training=None, mask=None): 74 | x, edge_index, node_graph_index = inputs 75 | 76 | # GCN Encoder 77 | h = self.gcn0([x, edge_index]) 78 | h = self.dropout(h, training=training) 79 | h = self.gcn1([h, edge_index]) 80 | 81 | # Mean Pooling 82 | h = tfg.nn.mean_pool(h, node_graph_index) 83 | h = self.dropout(h, training=training) 84 | 85 | # Predict Graph Labels 86 | h = self.dense(h) 87 | return h 88 | 89 | 90 | model = MeanPoolNetwork() 91 | 92 | 93 | def forward(batch_graph, training=False): 94 | return model([batch_graph.x, batch_graph.edge_index, batch_graph.node_graph_index], training=training) 95 | 96 | 97 | def evaluate(): 98 | accuracy_m = tf.keras.metrics.Accuracy() 99 | 100 | for test_batch_graph in create_graph_generator(test_graphs, batch_size, shuffle=False, infinite=False): 101 | logits = forward(test_batch_graph) 102 | preds = tf.argmax(logits, axis=-1) 103 | accuracy_m.update_state(test_batch_graph.y, preds) 104 | 105 | return accuracy_m.result().numpy() 106 | 107 | 108 | optimizer = tf.keras.optimizers.Adam(learning_rate=5e-3) 109 | 110 | train_batch_generator = create_graph_generator(train_graphs, batch_size, shuffle=True, infinite=True) 111 | 112 | for step in range(2000): 113 | train_batch_graph = next(train_batch_generator) 114 | with tf.GradientTape() as tape: 115 | logits = forward(train_batch_graph, training=True) 116 | losses = tf.nn.softmax_cross_entropy_with_logits( 117 | logits=logits, 118 | labels=tf.one_hot(train_batch_graph.y, depth=num_classes) 119 | ) 120 | 121 | vars = tape.watched_variables() 122 | grads = tape.gradient(losses, vars) 123 | optimizer.apply_gradients(zip(grads, vars)) 124 | 125 | if step % 20 == 0: 126 | mean_loss = tf.reduce_mean(losses) 127 | accuracy = evaluate() 128 | print("step = {}\tloss = {}\taccuracy = {}".format(step, mean_loss, accuracy)) 129 | -------------------------------------------------------------------------------- /demo/demo_model_net_dataset.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | from tf_geometric.datasets import ModelNet10Dataset, ModelNet40Dataset 3 | 4 | train_graphs, test_graphs, label_names = ModelNet10Dataset().load_data() 5 | 6 | for graph in test_graphs: 7 | print(graph.y) 8 | 9 | 10 | -------------------------------------------------------------------------------- /demo/demo_sample_neighbors.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | import os 3 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 4 | from tf_geometric.utils.graph_utils import reindex_sampled_edge_index, RandomNeighborSampler 5 | from tf_geometric.datasets import CoraDataset 6 | import time 7 | import numpy as np 8 | import tf_geometric as tfg 9 | import tf_sparse as tfs 10 | 11 | edge_index = [ 12 | [0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5], 13 | [1, 2, 3, 4, 5, 0, 4, 7, 1, 7, 2, 3, 6, 9, 0, 2, 3, 4, 7, 8, 10] 14 | ] 15 | 16 | neighbor_sampler = RandomNeighborSampler(edge_index) 17 | sampled_virtual_edge_index, sampled_virtual_edge_weight = neighbor_sampler.sample(k=5, sampled_node_index=([4, 2], [2, 6, 7, 8, 9, 10]), padding=False) 18 | print(sampled_virtual_edge_index) 19 | print(sampled_virtual_edge_weight) 20 | 21 | asdfasdf 22 | 23 | graph, (train_index, valid_index, test_index) = CoraDataset().load_data() 24 | neighbor_sampler = RandomNeighborSampler(graph.edge_index) 25 | 26 | for _ in range(100): 27 | start = time.time() 28 | sampled_virtual_edge_index, sampled_virtual_edge_weight = neighbor_sampler.sample(ratio=0.5) 29 | print(sampled_virtual_edge_index) 30 | print(sampled_virtual_edge_weight) 31 | print(time.time() - start) 32 | 33 | 34 | for _ in range(100): 35 | start = time.time() 36 | print("sample for sampled nodes: ") 37 | sampled_node_index = np.arange(100, 200) 38 | sampled_virtual_edge_index, sampled_virtual_edge_weight = neighbor_sampler.sample(ratio=0.5, sampled_node_index=sampled_node_index) 39 | print("sampled_node_index: ", sampled_node_index) 40 | print("sampled_virtual_edge_index: \n", sampled_virtual_edge_index) 41 | print("sampled_virtual_edge_weight: \n", sampled_virtual_edge_weight) 42 | 43 | # print("reindex sampled nodes and edges to construct edges for a subgraph: ") 44 | # reindexed_edge_index = reindex_sampled_edge_index(sampled_edge_index, sampled_node_index) 45 | # print("reindexed_edge_index: \n", reindexed_edge_index) 46 | # print(time.time() - start, "\n") 47 | -------------------------------------------------------------------------------- /demo/demo_save_and_load_model.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | import os 3 | 4 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 5 | from tf_geometric.utils import tf_utils 6 | import tf_geometric as tfg 7 | import tensorflow as tf 8 | 9 | graph, (train_index, valid_index, test_index) = tfg.datasets.CoraDataset().load_data() 10 | 11 | num_classes = graph.y.max() + 1 12 | drop_rate = 0.6 13 | checkpoint_dir = "./models" 14 | checkpoint_prefix = os.path.join(checkpoint_dir, "gat") 15 | 16 | 17 | # Multi-layer GAT Model 18 | class GATModel(tf.keras.Model): 19 | 20 | def __init__(self, *args, **kwargs): 21 | super().__init__(*args, **kwargs) 22 | self.gat0 = tfg.layers.GAT(64, activation=tf.nn.relu, num_heads=8, drop_rate=drop_rate, attention_units=8) 23 | self.gat1 = tfg.layers.GAT(num_classes, drop_rate=drop_rate, attention_units=1) 24 | self.dropout = tf.keras.layers.Dropout(drop_rate) 25 | 26 | def call(self, inputs, training=None, mask=None, cache=None): 27 | x, edge_index = inputs 28 | h = self.dropout(x, training=training) 29 | h = self.gat0([h, edge_index], training=training) 30 | h = self.dropout(h, training=training) 31 | h = self.gat1([h, edge_index], training=training) 32 | return h 33 | 34 | 35 | # Model/Layer objects in TensorFlow may delay the creation of variables to their first call, when input shapes are available. 36 | # Therefore, you must call the model at least once before writing checkpoints. 37 | model = GATModel() 38 | 39 | 40 | # @tf_utils.function can speed up functions for TensorFlow 2.x 41 | @tf_utils.function 42 | def forward(graph, training=False): 43 | return model([graph.x, graph.edge_index], training=training) 44 | 45 | 46 | @tf_utils.function 47 | def compute_loss(logits, mask_index, vars): 48 | masked_logits = tf.gather(logits, mask_index) 49 | masked_labels = tf.gather(graph.y, mask_index) 50 | losses = tf.nn.softmax_cross_entropy_with_logits( 51 | logits=masked_logits, 52 | labels=tf.one_hot(masked_labels, depth=num_classes) 53 | ) 54 | 55 | kernel_vars = [var for var in vars if "kernel" in var.name] 56 | l2_losses = [tf.nn.l2_loss(kernel_var) for kernel_var in kernel_vars] 57 | 58 | return tf.reduce_mean(losses) + tf.add_n(l2_losses) * 5e-4 59 | 60 | 61 | @tf_utils.function 62 | def evaluate(): 63 | logits = forward(graph) 64 | masked_logits = tf.gather(logits, test_index) 65 | masked_labels = tf.gather(graph.y, test_index) 66 | y_pred = tf.argmax(masked_logits, axis=-1, output_type=tf.int32) 67 | 68 | corrects = tf.equal(y_pred, masked_labels) 69 | accuracy = tf.reduce_mean(tf.cast(corrects, tf.float32)) 70 | return accuracy 71 | 72 | 73 | optimizer = tf.keras.optimizers.Adam(learning_rate=5e-3) 74 | 75 | 76 | @tf_utils.function 77 | def train_step(): 78 | with tf.GradientTape() as tape: 79 | logits = forward(graph, training=True) 80 | loss = compute_loss(logits, train_index, tape.watched_variables()) 81 | 82 | vars = tape.watched_variables() 83 | grads = tape.gradient(loss, vars) 84 | optimizer.apply_gradients(zip(grads, vars)) 85 | 86 | return loss 87 | 88 | 89 | for step in range(1, 401): 90 | 91 | loss = train_step() 92 | 93 | if step % 20 == 0: 94 | accuracy = evaluate() 95 | print("step = {}\tloss = {}\taccuracy = {}".format(step, loss, accuracy)) 96 | 97 | # save model 98 | # Different from tf.train.Checkpoint.save, model.save_weights will only create one checkpoint. 99 | # That is, calling model.save_weights will overwrite the last checkpoint file. 100 | model.save_weights(checkpoint_prefix) 101 | print("save model at step {}".format(step)) 102 | 103 | # create new model and restore it from the checkpoint 104 | restored_model = GATModel() 105 | 106 | # https://www.tensorflow.org/guide/checkpoint#delayed_restorations 107 | # Layer/Model objects in TensorFlow may delay the creation of variables to their first call, when input shapes are available. 108 | # For example the shape of a Dense layer's kernel depends on both the layer's input and output shapes, 109 | # and so the output shape required as a constructor argument is not enough information to create the variable on its own. 110 | # Since calling a Layer/Model also reads the variable's value, a restore must happen between the variable's creation and its first use. 111 | # To support this idiom, tf.train.Checkpoint queues restores which don't yet have a matching variable. 112 | # In this case, some variables, such as model.gat0.kernel and model.gat0.bias will not be immediately restored after calling checkpoint.restore. 113 | # The will be automatically restored during the first call of restored_model. 114 | restored_model.load_weights(tf.train.latest_checkpoint(checkpoint_dir)) 115 | 116 | 117 | # @tf_utils.function can speed up functions for TensorFlow 2.x 118 | @tf_utils.function 119 | def forward_by_restored_model(graph, training=False): 120 | return restored_model([graph.x, graph.edge_index], training=training) 121 | 122 | 123 | print("\ninfer with model:") 124 | print(forward(graph)) 125 | 126 | print("\ninfer with restored_model:") 127 | print(forward_by_restored_model(graph)) 128 | -------------------------------------------------------------------------------- /demo/demo_set2set.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | import os 3 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 4 | from tf_geometric.layers import Set2Set 5 | import tf_geometric as tfg 6 | import tensorflow as tf 7 | from tensorflow import keras 8 | import numpy as np 9 | from sklearn.model_selection import train_test_split 10 | from tqdm import tqdm 11 | 12 | # TU Datasets: https://ls11-www.cs.tu-dortmund.de/staff/morris/graphkerneldatasets 13 | graph_dicts = tfg.datasets.TUDataset("NCI1").load_data() 14 | 15 | # Since a TU dataset may contain node_labels, node_attributes etc., each of which can be used as node features 16 | # We process each graph as a dict and return a list of dict for graphs 17 | # You can easily construct you Graph object with the data dict 18 | 19 | num_node_labels = np.max([np.max(graph_dict["node_labels"]) for graph_dict in graph_dicts]) + 1 20 | 21 | 22 | def convert_node_labels_to_one_hot(node_labels): 23 | num_nodes = len(node_labels) 24 | x = np.zeros([num_nodes, num_node_labels], dtype=np.float32) 25 | x[list(range(num_nodes)), node_labels] = 1.0 26 | return x 27 | 28 | 29 | def construct_graph(graph_dict): 30 | return tfg.Graph( 31 | x=convert_node_labels_to_one_hot(graph_dict["node_labels"]), 32 | edge_index=graph_dict["edge_index"], 33 | y=graph_dict["graph_label"] # graph_dict["graph_label"] is a list with one int element 34 | ) 35 | 36 | 37 | graphs = [construct_graph(graph_dict) for graph_dict in graph_dicts] 38 | num_classes = np.max([graph.y[0] for graph in graphs]) + 1 39 | 40 | train_graphs, test_graphs = train_test_split(graphs, test_size=0.1) 41 | 42 | 43 | def create_graph_generator(graphs, batch_size, infinite=False, shuffle=False): 44 | while True: 45 | dataset = tf.data.Dataset.range(len(graphs)) 46 | if shuffle: 47 | dataset = dataset.shuffle(2000) 48 | dataset = dataset.batch(batch_size) 49 | 50 | for batch_graph_index in dataset: 51 | batch_graph_list = [graphs[i] for i in batch_graph_index] 52 | 53 | batch_graph = tfg.BatchGraph.from_graphs(batch_graph_list) 54 | yield batch_graph 55 | 56 | if not infinite: 57 | break 58 | 59 | 60 | batch_size = 200 61 | 62 | 63 | class Set2SetModel(tf.keras.Model): 64 | 65 | def __init__(self, *args, **kwargs): 66 | super().__init__(*args, **kwargs) 67 | 68 | self.set2set = Set2Set(num_iterations=4) 69 | self.graph_sage0 = tfg.layers.MeanGraphSage(64, activation=tf.nn.relu) 70 | self.graph_sage1 = tfg.layers.MeanGraphSage(64, activation=tf.nn.relu) 71 | 72 | self.mlp = tf.keras.Sequential([ 73 | tf.keras.layers.Dense(64, activation=tf.nn.relu), 74 | tf.keras.layers.Dropout(0.5), 75 | tf.keras.layers.Dense(num_classes) 76 | ]) 77 | 78 | def call(self, inputs, training=None, mask=None): 79 | x, edge_index, edge_weight, node_graph_index = inputs 80 | h = self.graph_sage0([x, edge_index, edge_weight], training=training) 81 | h = self.graph_sage1([h, edge_index, edge_weight], training=training) 82 | h = self.set2set([h, node_graph_index], training=training) 83 | logits = self.mlp(h, training=training) 84 | 85 | return logits 86 | 87 | 88 | num_clusters_list = [16, 4] 89 | num_features_list = [256, 256] 90 | 91 | model = Set2SetModel() 92 | 93 | 94 | def forward(batch_graph, training=False): 95 | return model([batch_graph.x, batch_graph.edge_index, batch_graph.edge_weight, batch_graph.node_graph_index], 96 | training=training) 97 | 98 | 99 | def evaluate(): 100 | accuracy_m = keras.metrics.Accuracy() 101 | 102 | for test_batch_graph in create_graph_generator(test_graphs, batch_size, shuffle=False, infinite=False): 103 | logits = forward(test_batch_graph) 104 | preds = tf.argmax(logits, axis=-1) 105 | accuracy_m.update_state(test_batch_graph.y, preds) 106 | 107 | return accuracy_m.result().numpy() 108 | 109 | 110 | optimizer = tf.keras.optimizers.Adam(learning_rate=5e-4) 111 | 112 | train_batch_generator = create_graph_generator(train_graphs, batch_size, shuffle=True, infinite=True) 113 | 114 | for step in tqdm(range(20000)): 115 | train_batch_graph = next(train_batch_generator) 116 | with tf.GradientTape() as tape: 117 | logits = forward(train_batch_graph, training=True) 118 | losses = tf.nn.softmax_cross_entropy_with_logits( 119 | logits=logits, 120 | labels=tf.one_hot(train_batch_graph.y, depth=num_classes) 121 | ) 122 | 123 | vars = tape.watched_variables() 124 | grads = tape.gradient(losses, vars) 125 | optimizer.apply_gradients(zip(grads, vars)) 126 | 127 | if step % 20 == 0: 128 | mean_loss = tf.reduce_mean(losses) 129 | accuracy = evaluate() 130 | print("step = {}\tloss = {}\taccuracy = {}".format(step, mean_loss, accuracy)) 131 | -------------------------------------------------------------------------------- /demo/demo_sgc.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | import os 3 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 4 | from tf_geometric.utils import tf_utils 5 | import tf_geometric as tfg 6 | import tensorflow as tf 7 | from tf_geometric.datasets import CoraDataset 8 | 9 | graph, (train_index, valid_index, test_index) = CoraDataset().load_data() 10 | 11 | num_classes = graph.y.max() + 1 12 | 13 | model = tfg.layers.SGC(num_classes, k=2) 14 | 15 | 16 | # @tf_utils.function can speed up functions for TensorFlow 2.x. 17 | # @tf_utils.function is not compatible with TensorFlow 1.x and dynamic graph.cache. 18 | @tf_utils.function 19 | def forward(graph, training=False): 20 | h = model([graph.x, graph.edge_index, graph.edge_weight], training=training, cache=graph.cache) 21 | return h 22 | 23 | 24 | # The following line is only necessary for using GCN with @tf_utils.function 25 | # For usage without @tf_utils.function, you can commont the following line and GCN layers can automatically manager the cache 26 | model.build_cache_for_graph(graph) 27 | 28 | 29 | @tf_utils.function 30 | def compute_loss(logits, mask_index, vars): 31 | masked_logits = tf.gather(logits, mask_index) 32 | masked_labels = tf.gather(graph.y, mask_index) 33 | 34 | losses = tf.nn.softmax_cross_entropy_with_logits( 35 | logits=masked_logits, 36 | labels=tf.one_hot(masked_labels, depth=num_classes) 37 | ) 38 | 39 | kernel_vars = [var for var in vars if "kernel" in var.name] 40 | l2_losses = [tf.nn.l2_loss(kernel_var) for kernel_var in kernel_vars] 41 | 42 | return tf.reduce_mean(losses) + tf.add_n(l2_losses) * 5e-6 43 | 44 | 45 | @tf_utils.function 46 | def evaluate(mask): 47 | logits = forward(graph) 48 | masked_logits = tf.gather(logits, mask) 49 | masked_labels = tf.gather(graph.y, mask) 50 | 51 | y_pred = tf.argmax(masked_logits, axis=-1, output_type=tf.int32) 52 | 53 | corrects = tf.equal(y_pred, masked_labels) 54 | accuracy = tf.reduce_mean(tf.cast(corrects, tf.float32)) 55 | return accuracy 56 | 57 | 58 | optimizer = tf.keras.optimizers.Adam(learning_rate=0.2) 59 | 60 | for step in range(1, 201): 61 | with tf.GradientTape() as tape: 62 | logits = forward(graph, training=True) 63 | loss = compute_loss(logits, train_index, tape.watched_variables()) 64 | 65 | vars = tape.watched_variables() 66 | grads = tape.gradient(loss, vars) 67 | optimizer.apply_gradients(zip(grads, vars)) 68 | 69 | if step % 20 == 0: 70 | valid_acc = evaluate(valid_index) 71 | test_acc = evaluate(test_index) 72 | 73 | print("step = {}\tloss = {}\tvalid_acc = {}\ttest_acc = {}".format(step, loss, valid_acc, test_acc)) 74 | -------------------------------------------------------------------------------- /demo/demo_sparse_node_features.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | import os 3 | # Use GPU0 4 | from tf_geometric.utils import tf_utils 5 | 6 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 7 | import tensorflow as tf 8 | import tf_geometric as tfg 9 | import tf_sparse as tfs 10 | 11 | 12 | num_nodes = 5 13 | edge_index = [ 14 | [0, 0, 1, 2, 3, 3, 4], 15 | [1, 2, 4, 3, 1, 4, 2] 16 | ] 17 | 18 | # Sparse node features 19 | # tf.sparse.eye creates a two-dimensional sparse tensor with ones along the diagonal 20 | # x is the one-hot encoding of node ids (from 0 to num_nodes - 1) in the form of a sparse matrix 21 | # This is usually used for feature-less cases, such as recommendation systems. 22 | x = tfs.eye(num_nodes) 23 | print("Sparse (One-hot) Node Features: ") 24 | print(x.to_dense()) 25 | 26 | # tf.sparse.SparseTensor can be used as node features (x) 27 | graph = tfg.Graph(x, edge_index).to_directed() 28 | print("\nConstructed Graph:") 29 | print(graph) 30 | 31 | # create a one-layer GNN model 32 | model = tfg.layers.GCN(4) 33 | # model = tfg.layers.SGC(4, k=3) 34 | # model = tfg.layers.ChebyNet(4, k=4) 35 | # model = tfg.layers.TAGCN(4, k=4) 36 | # model = tfg.layers.APPNP([4, 4], tf.nn.relu, k=10) 37 | 38 | # predict with the GCN model 39 | @tf_utils.function 40 | def forward(graph): 41 | return model([graph.x, graph.edge_index]) 42 | 43 | logits = forward(graph) 44 | print("\nModel Output:") 45 | print(logits) 46 | 47 | # tfg.Graph objects with sparse node features can also be combined into a tfg.BatchGraph object 48 | batch_graph = tfg.BatchGraph.from_graphs([graph, graph]) 49 | print("\nCombined Batch Graph") 50 | print(batch_graph) 51 | -------------------------------------------------------------------------------- /demo/demo_ssgc.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | import os 3 | 4 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 5 | from tf_geometric.utils import tf_utils 6 | import tensorflow as tf 7 | import tf_geometric as tfg 8 | from tqdm import tqdm 9 | import time 10 | 11 | graph, (train_index, valid_index, test_index) = tfg.datasets.CoraDataset().load_data() 12 | 13 | num_classes = graph.y.max() + 1 14 | drop_rate = 0.5 15 | learning_rate = 5e-3 16 | l2_coef = 1e-3 17 | 18 | 19 | # SSGC Model 20 | class SSGCModel(tf.keras.Model): 21 | 22 | def __init__(self, *args, **kwargs): 23 | super().__init__(*args, **kwargs) 24 | self.ssgc = tfg.layers.SSGC([64, num_classes], k=10, alpha=0.1, 25 | dense_drop_rate=drop_rate, edge_drop_rate=drop_rate) 26 | self.dropout = tf.keras.layers.Dropout(drop_rate) 27 | 28 | def call(self, inputs, training=None, mask=None, cache=None): 29 | x, edge_index, edge_weight = inputs 30 | h = self.dropout(x, training=training) 31 | h = self.ssgc([h, edge_index, edge_weight], training=training, cache=cache) 32 | return h 33 | 34 | 35 | model = SSGCModel() 36 | 37 | 38 | # @tf_utils.function can speed up functions for TensorFlow 2.x. 39 | # @tf_utils.function is not compatible with TensorFlow 1.x and dynamic graph.cache. 40 | @tf_utils.function 41 | def forward(graph, training=False): 42 | return model([graph.x, graph.edge_index, graph.edge_weight], training=training, cache=graph.cache) 43 | 44 | 45 | # The following line is only necessary for using SSGC with @tf_utils.function 46 | # For usage without @tf_utils.function, you can commont the following line and SSGC layers can automatically manager the cache 47 | model.ssgc.build_cache_for_graph(graph) 48 | 49 | 50 | @tf_utils.function 51 | def compute_loss(logits, mask_index, vars): 52 | masked_logits = tf.gather(logits, mask_index) 53 | masked_labels = tf.gather(graph.y, mask_index) 54 | losses = tf.nn.softmax_cross_entropy_with_logits( 55 | logits=masked_logits, 56 | labels=tf.one_hot(masked_labels, depth=num_classes) 57 | ) 58 | 59 | kernel_vars = [var for var in vars if "kernel" in var.name] 60 | l2_losses = [tf.nn.l2_loss(kernel_var) for kernel_var in kernel_vars] 61 | 62 | return tf.reduce_mean(losses) + tf.add_n(l2_losses) * l2_coef 63 | 64 | 65 | @tf_utils.function 66 | def train_step(): 67 | with tf.GradientTape() as tape: 68 | logits = forward(graph, training=True) 69 | loss = compute_loss(logits, train_index, tape.watched_variables()) 70 | 71 | vars = tape.watched_variables() 72 | grads = tape.gradient(loss, vars) 73 | optimizer.apply_gradients(zip(grads, vars)) 74 | return loss 75 | 76 | 77 | @tf_utils.function 78 | def evaluate(): 79 | logits = forward(graph) 80 | masked_logits = tf.gather(logits, test_index) 81 | masked_labels = tf.gather(graph.y, test_index) 82 | 83 | y_pred = tf.argmax(masked_logits, axis=-1, output_type=tf.int32) 84 | 85 | corrects = tf.equal(y_pred, masked_labels) 86 | accuracy = tf.reduce_mean(tf.cast(corrects, tf.float32)) 87 | return accuracy 88 | 89 | 90 | optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate) 91 | 92 | for step in range(1, 401): 93 | loss = train_step() 94 | if step % 20 == 0: 95 | accuracy = evaluate() 96 | print("step = {}\tloss = {}\taccuracy = {}".format(step, loss, accuracy)) 97 | 98 | print("\nstart speed test...") 99 | num_test_iterations = 1000 100 | start_time = time.time() 101 | for _ in tqdm(range(num_test_iterations)): 102 | logits = forward(graph) 103 | end_time = time.time() 104 | print("mean forward time: {} seconds".format((end_time - start_time) / num_test_iterations)) 105 | 106 | if tf.__version__[0] == "1": 107 | print("** @tf_utils.function is disabled in TensorFlow 1.x. " 108 | "Upgrade to TensorFlow 2.x for 10X faster speed. **") 109 | -------------------------------------------------------------------------------- /demo/demo_tagcn.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | import os 3 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 4 | from tf_geometric.utils import tf_utils 5 | import tf_geometric as tfg 6 | import tensorflow as tf 7 | import numpy as np 8 | from tf_geometric.datasets import CoraDataset 9 | 10 | graph, (train_index, valid_index, test_index) = CoraDataset().load_data() 11 | 12 | num_classes = graph.y.max() + 1 13 | drop_rate = 0.3 14 | 15 | 16 | # Multi-layer GCN Model 17 | class TAGCNModel(tf.keras.Model): 18 | 19 | def __init__(self, *args, **kwargs): 20 | super().__init__(*args, **kwargs) 21 | 22 | self.tagcn0 = tfg.layers.TAGCN(16, activation=tf.nn.relu) 23 | self.tagcn1 = tfg.layers.TAGCN(num_classes) 24 | self.dropout = tf.keras.layers.Dropout(drop_rate) 25 | 26 | def call(self, inputs, training=None, mask=None, cache=None): 27 | x, edge_index, edge_weight = inputs 28 | h = self.tagcn0([x, edge_index, edge_weight], cache=graph.cache) 29 | h = self.dropout(h, training=training) 30 | h = self.tagcn1([h, edge_index, edge_weight], cache=graph.cache) 31 | return h 32 | 33 | 34 | model = TAGCNModel() 35 | 36 | 37 | # @tf_utils.function can speed up functions for TensorFlow 2.x. 38 | # @tf_utils.function is not compatible with TensorFlow 1.x and dynamic graph.cache. 39 | @tf_utils.function 40 | def forward(graph, training=False): 41 | return model([graph.x, graph.edge_index, graph.edge_weight], training=training, cache=graph.cache) 42 | 43 | 44 | # The following line is only necessary for using GCN with @tf_utils.function 45 | # For usage without @tf_utils.function, you can commont the following line and GCN layers can automatically manager the cache 46 | model.tagcn0.build_cache_for_graph(graph) 47 | 48 | 49 | @tf_utils.function 50 | def compute_loss(logits, mask_index, vars): 51 | masked_logits = tf.gather(logits, mask_index) 52 | masked_labels = tf.gather(graph.y, mask_index) 53 | losses = tf.nn.softmax_cross_entropy_with_logits( 54 | logits=masked_logits, 55 | labels=tf.one_hot(masked_labels, depth=num_classes) 56 | ) 57 | 58 | kernel_vars = [var for var in vars if "kernel" in var.name] 59 | l2_losses = [tf.nn.l2_loss(kernel_var) for kernel_var in kernel_vars] 60 | 61 | return tf.reduce_mean(losses) + tf.add_n(l2_losses) * 5e-4 62 | 63 | 64 | @tf_utils.function 65 | def evaluate(mask): 66 | logits = forward(graph) 67 | masked_logits = tf.gather(logits, mask) 68 | masked_labels = tf.gather(graph.y, mask) 69 | 70 | y_pred = tf.argmax(masked_logits, axis=-1, output_type=tf.int32) 71 | 72 | corrects = tf.equal(y_pred, masked_labels) 73 | accuracy = tf.reduce_mean(tf.cast(corrects, tf.float32)) 74 | return accuracy 75 | 76 | 77 | optimizer = tf.keras.optimizers.Adam(learning_rate=0.01) 78 | 79 | best_test_acc = tmp_valid_acc = 0 80 | for step in range(1, 101): 81 | with tf.GradientTape() as tape: 82 | logits = forward(graph, training=True) 83 | loss = compute_loss(logits, train_index, tape.watched_variables()) 84 | 85 | vars = tape.watched_variables() 86 | grads = tape.gradient(loss, vars) 87 | optimizer.apply_gradients(zip(grads, vars)) 88 | 89 | valid_acc = evaluate(valid_index) 90 | test_acc = evaluate(test_index) 91 | if test_acc > best_test_acc: 92 | best_test_acc = test_acc 93 | tmp_valid_acc = valid_acc 94 | print("step = {}\tloss = {}\tvalid_acc = {}\tbest_test_acc = {}".format(step, loss, tmp_valid_acc, best_test_acc)) 95 | -------------------------------------------------------------------------------- /demo/demo_topk_pool.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | import os 3 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 4 | 5 | import tf_geometric as tfg 6 | 7 | source_index = [0, 0, 1, 1, 1, 2, 2] 8 | score = [0.2, 0.4, 0.1, 0.3, 0.5, 0.4, 0.1] 9 | 10 | topk_node_index = tfg.nn.topk_pool(source_index, score, k=1) 11 | print(topk_node_index) 12 | 13 | 14 | source_index = [1, 0, 0, 2, 1, 2, 1] 15 | score = [0.2, 0.4, 0.1, 0.3, 0.5, 0.4, 0.1] 16 | 17 | topk_node_index = tfg.nn.topk_pool(source_index, score, k=1) 18 | print(topk_node_index) 19 | -------------------------------------------------------------------------------- /deploy.sh: -------------------------------------------------------------------------------- 1 | #m2r README.md 2 | rm -rf tf_geometric.egg-info 3 | rm -rf dist 4 | python setup.py sdist 5 | twine upload dist/* --verbose 6 | -------------------------------------------------------------------------------- /doc/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = source 9 | BUILDDIR = build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /doc/build.sh: -------------------------------------------------------------------------------- 1 | make html 2 | -------------------------------------------------------------------------------- /doc/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.http://sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /doc/requirements.txt: -------------------------------------------------------------------------------- 1 | tf_sparse 2 | numpy >= 1.17.4 3 | networkx >= 2.1 4 | scipy >= 1.1.0 5 | tensorflow == 2.4.1 6 | scikit-learn >= 0.22 7 | ogb_lite >= 0.0.3 8 | tqdm 9 | Sphinx == 3.5.4 10 | Jinja2<3.1 11 | sphinx_rtd_theme -------------------------------------------------------------------------------- /doc/source/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # This file only contains a selection of the most common options. For a full 4 | # list see the documentation: 5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 6 | 7 | # -- Path setup -------------------------------------------------------------- 8 | 9 | # If extensions (or modules to document with autodoc) are in another directory, 10 | # add these directories to sys.path here. If the directory is relative to the 11 | # documentation root, use os.path.abspath to make it absolute, like shown here. 12 | # 13 | import os 14 | import sys 15 | sys.path.insert(0, os.path.abspath('../../')) 16 | import sphinx_rtd_theme 17 | 18 | 19 | # -- Project information ----------------------------------------------------- 20 | 21 | project = 'tf_geometric' 22 | copyright = '2020, Jun Hu' 23 | author = 'Jun Hu' 24 | 25 | 26 | # -- General configuration --------------------------------------------------- 27 | 28 | # Add any Sphinx extension module names here, as strings. They can be 29 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 30 | # ones. 31 | extensions = [ 32 | "sphinx_rtd_theme", 33 | "sphinx.ext.autodoc" 34 | ] 35 | 36 | # Add any paths that contain templates here, relative to this directory. 37 | templates_path = ['_templates'] 38 | 39 | # List of patterns, relative to source directory, that match files and 40 | # directories to ignore when looking for source files. 41 | # This pattern also affects html_static_path and html_extra_path. 42 | exclude_patterns = [] 43 | 44 | 45 | # -- Options for HTML output ------------------------------------------------- 46 | 47 | # The theme to use for HTML and HTML Help pages. See the documentation for 48 | # a list of builtin themes. 49 | # 50 | #html_theme = 'alabaster' 51 | html_theme = "sphinx_rtd_theme" 52 | 53 | html_show_sourcelink = False 54 | master_doc = 'index' 55 | 56 | # Add any paths that contain custom static files (such as style sheets) here, 57 | # relative to this directory. They are copied after the builtin static files, 58 | # so a file named "default.css" will overwrite the builtin "default.css". 59 | html_static_path = ['_static'] 60 | 61 | 62 | 63 | 64 | -------------------------------------------------------------------------------- /doc/source/modules/datasets.rst: -------------------------------------------------------------------------------- 1 | tf_geometric.datasets 2 | =================== 3 | 4 | .. contents:: Contents 5 | :local: 6 | 7 | 8 | Planetoid 9 | -------------------- 10 | 11 | .. autoclass:: tf_geometric.datasets.PlanetoidDataset 12 | :special-members: __init__ 13 | :members: 14 | 15 | Cora 16 | -------------------- 17 | 18 | .. autoclass:: tf_geometric.datasets.CoraDataset 19 | :special-members: __init__ 20 | :members: 21 | 22 | Citeseer 23 | -------------------- 24 | 25 | .. autoclass:: tf_geometric.datasets.CiteseerDataset 26 | :special-members: __init__ 27 | :members: 28 | 29 | Pubmed 30 | -------------------- 31 | 32 | .. autoclass:: tf_geometric.datasets.PubmedDataset 33 | :special-members: __init__ 34 | :members: 35 | 36 | 37 | SupervisedCora 38 | -------------------- 39 | 40 | .. autoclass:: tf_geometric.datasets.SupervisedCoraDataset 41 | :special-members: __init__ 42 | :members: 43 | 44 | 45 | SupervisedCiteseer 46 | -------------------- 47 | 48 | .. autoclass:: tf_geometric.datasets.SupervisedCiteseerDataset 49 | :special-members: __init__ 50 | :members: 51 | 52 | 53 | SupervisedPubmed 54 | -------------------- 55 | 56 | .. autoclass:: tf_geometric.datasets.SupervisedPubmedDataset 57 | :special-members: __init__ 58 | :members: 59 | 60 | 61 | 62 | BlogCatalog 63 | -------------------- 64 | 65 | .. autoclass:: tf_geometric.datasets.MultiLabelBlogCatalogDataset 66 | :special-members: __init__ 67 | :members: 68 | 69 | 70 | 71 | 72 | PPI 73 | -------------------- 74 | 75 | .. autoclass:: tf_geometric.datasets.PPIDataset 76 | :special-members: __init__ 77 | :members: 78 | 79 | 80 | 81 | Reddit 82 | -------------------- 83 | 84 | .. autoclass:: tf_geometric.datasets.TransductiveRedditDataset 85 | :special-members: __init__ 86 | :members: 87 | 88 | .. autoclass:: tf_geometric.datasets.InductiveRedditDataset 89 | :special-members: __init__ 90 | :members: 91 | 92 | 93 | TU 94 | -------------------- 95 | 96 | .. autoclass:: tf_geometric.datasets.TUDataset 97 | :special-members: __init__ 98 | :members: 99 | 100 | 101 | OGB 102 | -------------------- 103 | 104 | .. autoclass:: tf_geometric.datasets.OGBNodePropPredDataset 105 | :special-members: __init__ 106 | :members: 107 | 108 | 109 | 110 | ModelNet 111 | -------------------- 112 | 113 | .. autoclass:: tf_geometric.datasets.ModelNet10Dataset 114 | :special-members: __init__ 115 | :members: 116 | 117 | 118 | .. autoclass:: tf_geometric.datasets.ModelNet40Dataset 119 | :special-members: __init__ 120 | :members: 121 | 122 | 123 | 124 | 125 | AmazonElectronics 126 | -------------------- 127 | 128 | 129 | .. autoclass:: tf_geometric.datasets.AmazonElectronicsDataset 130 | :special-members: __init__ 131 | :members: 132 | 133 | 134 | .. autoclass:: tf_geometric.datasets.AmazonComputersDataset 135 | :special-members: __init__ 136 | :members: 137 | 138 | 139 | .. autoclass:: tf_geometric.datasets.AmazonPhotoDataset 140 | :special-members: __init__ 141 | :members: 142 | 143 | 144 | 145 | Coauthor 146 | -------------------- 147 | 148 | 149 | .. autoclass:: tf_geometric.datasets.CoauthorDataset 150 | :special-members: __init__ 151 | :members: 152 | 153 | 154 | .. autoclass:: tf_geometric.datasets.CoauthorCSDataset 155 | :special-members: __init__ 156 | :members: 157 | 158 | 159 | .. autoclass:: tf_geometric.datasets.CoauthorPhysicsDataset 160 | :special-members: __init__ 161 | :members: 162 | 163 | 164 | 165 | 166 | Fraud Detection 167 | -------------------- 168 | 169 | 170 | .. autoclass:: tf_geometric.datasets.FDAmazonDataset 171 | :special-members: __init__ 172 | :members: 173 | 174 | 175 | .. autoclass:: tf_geometric.datasets.FDYelpChiDataset 176 | :special-members: __init__ 177 | :members: 178 | 179 | 180 | 181 | 182 | HGB 183 | -------------------- 184 | 185 | 186 | .. autoclass:: tf_geometric.datasets.HGBDataset 187 | :special-members: __init__ 188 | :members: 189 | 190 | 191 | .. autoclass:: tf_geometric.datasets.HGBACMDataset 192 | :special-members: __init__ 193 | :members: 194 | 195 | 196 | .. autoclass:: tf_geometric.datasets.HGBDBLPDataset 197 | :special-members: __init__ 198 | :members: 199 | 200 | 201 | .. autoclass:: tf_geometric.datasets.HGBFreebaseDataset 202 | :special-members: __init__ 203 | :members: 204 | 205 | .. autoclass:: tf_geometric.datasets.HGBIMDBDataset 206 | :special-members: __init__ 207 | :members: 208 | 209 | 210 | 211 | 212 | NARS Academic 213 | -------------------- 214 | 215 | 216 | .. autoclass:: tf_geometric.datasets.NARSACMDataset 217 | :special-members: __init__ 218 | :members: 219 | 220 | 221 | -------------------------------------------------------------------------------- /doc/source/modules/layers.rst: -------------------------------------------------------------------------------- 1 | tf_geometric.layers (OOP API) 2 | =================== 3 | 4 | 5 | .. contents:: Contents 6 | :local: 7 | 8 | 9 | GCN 10 | -------------------- 11 | 12 | .. autoclass:: tf_geometric.layers.GCN 13 | :special-members: __init__ 14 | :members: call, build_cache_by_adj, build_cache_for_graph, cache_normed_edge 15 | 16 | 17 | GAT 18 | -------------------- 19 | 20 | .. autoclass:: tf_geometric.layers.GAT 21 | :special-members: __init__ 22 | :members: call 23 | 24 | 25 | APPNP 26 | -------------------- 27 | 28 | .. autoclass:: tf_geometric.layers.APPNP 29 | :special-members: __init__ 30 | :members: call, build_cache_by_adj, build_cache_for_graph, cache_normed_edge 31 | 32 | 33 | GIN 34 | -------------------- 35 | 36 | .. autoclass:: tf_geometric.layers.GIN 37 | :special-members: __init__ 38 | :members: call 39 | 40 | 41 | SGC 42 | -------------------- 43 | 44 | .. autoclass:: tf_geometric.layers.SGC 45 | :special-members: __init__ 46 | :members: call, build_cache_by_adj, build_cache_for_graph, cache_normed_edge 47 | 48 | 49 | 50 | SSGC 51 | -------------------- 52 | 53 | .. autoclass:: tf_geometric.layers.SSGC 54 | :special-members: __init__ 55 | :members: call, build_cache_by_adj, build_cache_for_graph, cache_normed_edge 56 | 57 | 58 | 59 | TAGCN 60 | -------------------- 61 | 62 | .. autoclass:: tf_geometric.layers.TAGCN 63 | :special-members: __init__ 64 | :members: call, build_cache_by_adj, build_cache_for_graph, cache_normed_edge 65 | 66 | 67 | GraphSage 68 | -------------------- 69 | 70 | .. autoclass:: tf_geometric.layers.MeanGraphSage 71 | :special-members: __init__ 72 | :members: call 73 | 74 | 75 | .. autoclass:: tf_geometric.layers.SumGraphSage 76 | :special-members: __init__ 77 | :members: call 78 | 79 | 80 | .. autoclass:: tf_geometric.layers.MeanPoolGraphSage 81 | :special-members: __init__ 82 | :members: call 83 | 84 | 85 | .. autoclass:: tf_geometric.layers.MaxPoolGraphSage 86 | :special-members: __init__ 87 | :members: call 88 | 89 | 90 | .. autoclass:: tf_geometric.layers.GCNGraphSage 91 | :special-members: __init__ 92 | :members: call 93 | 94 | .. autoclass:: tf_geometric.layers.LSTMGraphSage 95 | :special-members: __init__ 96 | :members: call 97 | 98 | 99 | 100 | ChebyNet 101 | -------------------- 102 | .. autoclass:: tf_geometric.layers.ChebyNet 103 | :special-members: __init__ 104 | :members: call, build_cache_for_graph, cache_normed_edge 105 | 106 | 107 | DropEdge 108 | -------------------- 109 | .. autoclass:: tf_geometric.layers.DropEdge 110 | :special-members: __init__ 111 | :members: call 112 | 113 | 114 | CommonPool 115 | -------------------- 116 | .. autoclass:: tf_geometric.layers.MeanPool 117 | :special-members: __init__ 118 | :members: call 119 | 120 | 121 | .. autoclass:: tf_geometric.layers.MinPool 122 | :special-members: __init__ 123 | :members: call 124 | 125 | 126 | .. autoclass:: tf_geometric.layers.MaxPool 127 | :special-members: __init__ 128 | :members: call 129 | 130 | 131 | .. autoclass:: tf_geometric.layers.SumPool 132 | :special-members: __init__ 133 | :members: call 134 | 135 | 136 | 137 | 138 | DiffPool 139 | -------------------- 140 | .. autoclass:: tf_geometric.layers.DiffPool 141 | :special-members: __init__ 142 | :members: call 143 | 144 | 145 | 146 | Set2Set 147 | -------------------- 148 | .. autoclass:: tf_geometric.layers.Set2Set 149 | :special-members: __init__ 150 | :members: call 151 | 152 | 153 | 154 | SAGPool 155 | -------------------- 156 | .. autoclass:: tf_geometric.layers.SAGPool 157 | :special-members: __init__ 158 | :members: call 159 | 160 | 161 | 162 | ASAP 163 | -------------------- 164 | .. autoclass:: tf_geometric.layers.ASAP 165 | :special-members: __init__ 166 | :members: call 167 | 168 | 169 | .. autoclass:: tf_geometric.layers.LEConv 170 | :special-members: __init__ 171 | :members: call 172 | 173 | 174 | 175 | 176 | SortPool 177 | -------------------- 178 | .. autoclass:: tf_geometric.layers.SortPool 179 | :special-members: __init__ 180 | :members: call 181 | 182 | 183 | 184 | MinCutPool 185 | -------------------- 186 | .. autoclass:: tf_geometric.layers.MinCutPool 187 | :special-members: __init__ 188 | :members: call 189 | 190 | 191 | -------------------------------------------------------------------------------- /doc/source/modules/nn.rst: -------------------------------------------------------------------------------- 1 | tf_geometric.nn (Functional API) 2 | =============== 3 | 4 | .. contents:: Contents 5 | :local: 6 | 7 | gcn 8 | -------------------- 9 | 10 | .. autofunction:: tf_geometric.nn.gcn 11 | 12 | .. autofunction:: tf_geometric.nn.gcn_norm_adj 13 | 14 | .. autofunction:: tf_geometric.nn.gcn_build_cache_by_adj 15 | 16 | .. autofunction:: tf_geometric.nn.gcn_build_cache_for_graph 17 | 18 | .. autofunction:: tf_geometric.nn.gcn_norm_edge 19 | 20 | .. autofunction:: tf_geometric.nn.gcn_cache_normed_edge 21 | 22 | 23 | 24 | 25 | gat 26 | -------------------- 27 | 28 | .. autofunction:: tf_geometric.nn.gat 29 | 30 | 31 | appnp 32 | ------------------------------------------------------------------ 33 | 34 | .. autofunction:: tf_geometric.nn.appnp 35 | 36 | 37 | gin 38 | ------------------------------ 39 | 40 | .. autofunction:: tf_geometric.nn.gin 41 | 42 | 43 | 44 | sgc 45 | ------------------------------ 46 | 47 | .. autofunction:: tf_geometric.nn.sgc 48 | 49 | 50 | ssgc 51 | ------------------------------ 52 | 53 | .. autofunction:: tf_geometric.nn.ssgc 54 | 55 | 56 | tagcn 57 | ----------------------------------------------------- 58 | 59 | .. autofunction:: tf_geometric.nn.tagcn 60 | 61 | 62 | graph_sage 63 | -------------------------------------------------------------- 64 | 65 | .. autofunction:: tf_geometric.nn.mean_graph_sage 66 | 67 | .. autofunction:: tf_geometric.nn.sum_graph_sage 68 | 69 | .. autofunction:: tf_geometric.nn.mean_pool_graph_sage 70 | 71 | .. autofunction:: tf_geometric.nn.max_pool_graph_sage 72 | 73 | .. autofunction:: tf_geometric.nn.gcn_graph_sage 74 | 75 | .. autofunction:: tf_geometric.nn.lstm_graph_sage 76 | 77 | 78 | chebynet 79 | ----------------------------------------------------------------------------------------- 80 | 81 | .. autofunction:: tf_geometric.nn.chebynet 82 | 83 | .. autofunction:: tf_geometric.nn.chebynet_norm_edge 84 | 85 | 86 | drop_edge 87 | -------------------------------------------------------------- 88 | 89 | .. autofunction:: tf_geometric.nn.drop_edge 90 | 91 | 92 | mean_pool 93 | -------------------------------------------------------------- 94 | 95 | .. autofunction:: tf_geometric.nn.mean_pool 96 | 97 | max_pool 98 | -------------------------------------------------------------- 99 | 100 | .. autofunction:: tf_geometric.nn.max_pool 101 | 102 | 103 | min_pool 104 | -------------------------------------------------------------- 105 | 106 | .. autofunction:: tf_geometric.nn.min_pool 107 | 108 | 109 | topk_pool 110 | -------------------------------------------------------------- 111 | 112 | .. autofunction:: tf_geometric.nn.topk_pool 113 | 114 | 115 | diff_pool 116 | -------------------------------------------------------------- 117 | 118 | .. autofunction:: tf_geometric.nn.diff_pool 119 | 120 | .. autofunction:: tf_geometric.nn.diff_pool_coarsen 121 | 122 | 123 | 124 | set2set 125 | -------------------------------------------------------------- 126 | 127 | .. autofunction:: tf_geometric.nn.set2set 128 | 129 | 130 | 131 | cluster_pool 132 | -------------------------------------------------------------- 133 | 134 | .. autofunction:: tf_geometric.nn.cluster_pool 135 | 136 | 137 | 138 | 139 | sag_pool 140 | -------------------------------------------------------------- 141 | 142 | .. autofunction:: tf_geometric.nn.sag_pool 143 | 144 | 145 | 146 | asap 147 | -------------------------------------------------------------- 148 | 149 | .. autofunction:: tf_geometric.nn.asap 150 | 151 | .. autofunction:: tf_geometric.nn.le_conv 152 | 153 | 154 | 155 | 156 | sort_pool 157 | -------------------------------------------------------------- 158 | 159 | .. autofunction:: tf_geometric.nn.sort_pool 160 | 161 | 162 | min_cut_pool 163 | -------------------------------------------------------------- 164 | 165 | .. autofunction:: tf_geometric.nn.min_cut_pool 166 | 167 | .. autofunction:: tf_geometric.nn.min_cut_pool_coarsen 168 | 169 | .. autofunction:: tf_geometric.nn.min_cut_pool_compute_losses 170 | 171 | 172 | -------------------------------------------------------------------------------- /doc/source/modules/root.rst: -------------------------------------------------------------------------------- 1 | tf_geometric 2 | ============ 3 | 4 | 5 | Graph (Data Structure for a Single Graph) 6 | ------------------------------------- 7 | 8 | .. autoclass:: tf_geometric.Graph 9 | :special-members: __init__ 10 | :members: 11 | 12 | 13 | 14 | BatchGraph (Data Structure for a Batch of Graphs) 15 | ------------------------------------- 16 | 17 | .. autoclass:: tf_geometric.BatchGraph 18 | :special-members: __init__ 19 | :members: 20 | -------------------------------------------------------------------------------- /doc/source/modules/utils/graph_utils.rst: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CrawlScript/tf_geometric/b7c40d9005c60b27a7e18b919d32418c5548252e/doc/source/modules/utils/graph_utils.rst -------------------------------------------------------------------------------- /doc/source/wiki/installation.rst: -------------------------------------------------------------------------------- 1 | .. _wiki-installation: 2 | 3 | Installation 4 | ============ 5 | 6 | :ref:`(中文版)` 7 | 8 | 9 | Requirements 10 | ------------ 11 | 12 | * Operation System: Windows / Linux / Mac OS 13 | * Python: version >= 3.5 and version != 3.6 14 | * Python Packages: 15 | 16 | * tensorflow/tensorflow-gpu: >= 1.15.0 or >= 2.3.0 17 | * numpy >= 1.17.4 18 | * networkx >= 2.1 19 | * scipy >= 1.1.0 20 | 21 | 22 | Install with pip 23 | ------------ 24 | 25 | Use one of the following commands below: 26 | 27 | .. code-block:: bash 28 | 29 | pip install -U tf_geometric # this will not install the tensorflow/tensorflow-gpu package 30 | 31 | pip install -U tf_geometric[tf1-cpu] # this will install TensorFlow 1.x CPU version 32 | 33 | pip install -U tf_geometric[tf1-gpu] # this will install TensorFlow 1.x GPU version 34 | 35 | pip install -U tf_geometric[tf2-cpu] # this will install TensorFlow 2.x CPU version 36 | 37 | pip install -U tf_geometric[tf2-gpu] # this will install TensorFlow 2.x GPU version 38 | -------------------------------------------------------------------------------- /doc/source/wiki_cn/installation.rst: -------------------------------------------------------------------------------- 1 | .. _wiki_cn-installation: 2 | 3 | 安装 4 | ============ 5 | 6 | :ref:`(English Version)` 7 | 8 | 9 | 环境要求与依赖库 10 | ------------ 11 | 12 | * Operation System: Windows / Linux / Mac OS 13 | * Python: version >= 3.5 and version != 3.6 14 | * Python Packages: 15 | 16 | * tensorflow/tensorflow-gpu: >= 1.15.0 or >= 2.3.0 17 | * numpy >= 1.17.4 18 | * networkx >= 2.1 19 | * scipy >= 1.1.0 20 | 21 | 22 | 23 | 使用pip一键安装tf_geometric及依赖 24 | ------------ 25 | 26 | 使用下面任意一条pip命令进行安装: 27 | 28 | 29 | .. code-block:: bash 30 | 31 | pip install -U tf_geometric # 不会额外安装tensorflow或tensorflow-gpu包 32 | 33 | pip install -U tf_geometric[tf1-cpu] # 会额外安装TensorFlow 1.x CPU版 34 | 35 | pip install -U tf_geometric[tf1-gpu] # 会额外安装TensorFlow 1.x GPU版 36 | 37 | pip install -U tf_geometric[tf2-cpu] # 会额外安装TensorFlow 2.x CPU版 38 | 39 | pip install -U tf_geometric[tf2-gpu] # 会额外安装TensorFlow 2.x GPU版 40 | -------------------------------------------------------------------------------- /doc/test.sh: -------------------------------------------------------------------------------- 1 | cd build/html 2 | python -m http.server 3 | -------------------------------------------------------------------------------- /readthedocs.yaml: -------------------------------------------------------------------------------- 1 | version: 2 2 | 3 | build: 4 | os: ubuntu-22.04 5 | tools: 6 | python: "3.8" 7 | 8 | python: 9 | install: 10 | - requirements: doc/requirements.txt 11 | - method: setuptools 12 | path: . 13 | 14 | formats: [] -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | from setuptools import setup, find_packages 4 | 5 | setup( 6 | name="tf_geometric", 7 | python_requires='>3.5.0', 8 | version="0.1.7", 9 | author="Jun Hu", 10 | author_email="hujunxianligong@gmail.com", 11 | packages=find_packages( 12 | exclude=[ 13 | 'benchmarks', 14 | 'data', 15 | 'demo', 16 | 'dist', 17 | 'doc', 18 | 'docs', 19 | 'logs', 20 | 'models', 21 | 'test' 22 | ] 23 | ), 24 | install_requires=[ 25 | "tf_sparse >= 0.0.17", 26 | "numpy >= 1.17.4", 27 | "networkx >= 2.1", 28 | "scipy >= 1.1.0", 29 | "scikit-learn >= 0.22", 30 | "ogb_lite >= 0.0.3", 31 | "tqdm" 32 | ], 33 | extras_require={ 34 | 'tf1-cpu': ["tensorflow >= 1.15.0,<2.0.0"], 35 | 'tf1-gpu': ["tensorflow-gpu >= 1.15.0,<2.0.0"], 36 | 'tf2-cpu': ["tensorflow >= 2.4.0"], 37 | 'tf2-gpu': ["tensorflow >= 2.4.0"] 38 | }, 39 | description="Efficient and Friendly Graph Neural Network Library for TensorFlow 1.x and 2.x.", 40 | license="GNU General Public License v3.0 (See LICENSE)", 41 | long_description=open("README.rst", "r", encoding="utf-8").read(), 42 | url="https://github.com/CrawlScript/tf_geometric" 43 | ) 44 | -------------------------------------------------------------------------------- /tf_geometric/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | import warnings 4 | 5 | warnings.simplefilter('always', DeprecationWarning) 6 | 7 | import tensorflow as tf 8 | 9 | if tf.__version__[0] == "1": 10 | tf.enable_eager_execution() 11 | 12 | from . import nn, utils, data, datasets, layers 13 | from .data.graph import Graph, BatchGraph, HeteroGraph, HeteroBatchGraph 14 | # from .sparse.sparse_adj import SparseAdj 15 | -------------------------------------------------------------------------------- /tf_geometric/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CrawlScript/tf_geometric/b7c40d9005c60b27a7e18b919d32418c5548252e/tf_geometric/data/__init__.py -------------------------------------------------------------------------------- /tf_geometric/data/dataset.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | import os 3 | 4 | from tensorflow.python.keras.utils.data_utils import _extract_archive 5 | from tf_geometric.utils.data_utils import download_file, load_cache, save_cache 6 | from shutil import copy 7 | 8 | DEFAULT_DATASETS_ROOT = "data" 9 | 10 | 11 | def get_dataset_root_path(dataset_root_path=None, dataset_name=None, 12 | datasets_root_path=DEFAULT_DATASETS_ROOT, mkdir=False): 13 | if dataset_root_path is None: 14 | dataset_root_path = os.path.join(datasets_root_path, dataset_name) 15 | dataset_root_path = os.path.abspath(dataset_root_path) 16 | 17 | if mkdir: 18 | os.makedirs(dataset_root_path, exist_ok=True) 19 | return dataset_root_path 20 | 21 | 22 | class Dataset(object): 23 | pass 24 | 25 | 26 | class DownloadableDataset(object): 27 | 28 | def __init__(self, 29 | dataset_name, 30 | download_urls=None, 31 | download_file_name=None, 32 | cache_name="cache.p", 33 | dataset_root_path=None 34 | ): 35 | self.dataset_name = dataset_name 36 | self.dataset_root_path = get_dataset_root_path(dataset_root_path, dataset_name) 37 | self.download_urls = download_urls 38 | self.download_file_name = download_file_name 39 | 40 | self.download_root_path = os.path.join(self.dataset_root_path, "download") 41 | self.raw_root_path = os.path.join(self.dataset_root_path, "raw") 42 | self.processed_root_path = os.path.join(self.dataset_root_path, "processed") 43 | 44 | if download_urls is not None: 45 | if download_file_name is None: 46 | download_file_name = "{}.zip".format(dataset_name) 47 | self.download_file_path = os.path.join(self.download_root_path, download_file_name) 48 | else: 49 | self.download_file_path = None 50 | 51 | self.cache_path = None if cache_name is None else os.path.join(self.processed_root_path, cache_name) 52 | 53 | self.build_dirs() 54 | 55 | @property 56 | def cache_enabled(self): 57 | return self.cache_path is not None 58 | 59 | def build_dirs(self): 60 | os.makedirs(self.download_root_path, exist_ok=True) 61 | os.makedirs(self.raw_root_path, exist_ok=True) 62 | os.makedirs(self.processed_root_path, exist_ok=True) 63 | 64 | def download(self): 65 | download_file(self.download_file_path, self.download_urls) 66 | 67 | def extract_raw(self): 68 | if len(os.listdir(self.raw_root_path)) == 0: 69 | if self.download_file_path.endswith(".npz"): 70 | copy(self.download_file_path, os.path.join(self.raw_root_path, self.download_file_name)) 71 | else: 72 | _extract_archive(self.download_file_path, self.raw_root_path, archive_format="auto") 73 | else: 74 | print("raw data exists: {}, ignore".format(self.raw_root_path)) 75 | 76 | def process(self): 77 | pass 78 | 79 | def load_data(self): 80 | if self.cache_enabled and os.path.exists(self.cache_path): 81 | print("cache file exists: {}, read cache".format(self.cache_path)) 82 | return load_cache(self.cache_path) 83 | 84 | if self.download_urls is not None: 85 | self.download() 86 | self.extract_raw() 87 | else: 88 | print("downloading and extraction are ignored due to None download_urls") 89 | 90 | processed = self.process() 91 | 92 | if self.cache_enabled: 93 | print("save processed data to cache: ", self.cache_path) 94 | save_cache(processed, self.cache_path) 95 | 96 | return processed 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | -------------------------------------------------------------------------------- /tf_geometric/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | from tf_geometric.datasets.ppi import PPIDataset 3 | from tf_geometric.datasets.tu import TUDataset 4 | from tf_geometric.datasets.planetoid import PlanetoidDataset, CoraDataset, CiteseerDataset, PubmedDataset, \ 5 | SupervisedCoraDataset, SupervisedCiteseerDataset, SupervisedPubmedDataset 6 | from tf_geometric.datasets.blog_catalog import MultiLabelBlogCatalogDataset 7 | from tf_geometric.datasets.reddit import TransductiveRedditDataset, InductiveRedditDataset 8 | from tf_geometric.datasets.ogb import OGBNodePropPredDataset 9 | from tf_geometric.datasets.model_net import ModelNet10Dataset, ModelNet40Dataset 10 | from tf_geometric.datasets.csr_npz import CSRNPZDataset 11 | from tf_geometric.datasets.amazon_electronics import AmazonElectronicsDataset, AmazonComputersDataset, \ 12 | AmazonPhotoDataset 13 | from tf_geometric.datasets.coauthor import CoauthorDataset, CoauthorCSDataset, CoauthorPhysicsDataset 14 | from tf_geometric.datasets.abnormal import FDAmazonDataset, FDYelpChiDataset 15 | from tf_geometric.datasets.hgb import HGBDataset, HGBACMDataset, HGBDBLPDataset, HGBFreebaseDataset, HGBIMDBDataset 16 | from tf_geometric.datasets.nars_academic import NARSACMDataset 17 | 18 | -------------------------------------------------------------------------------- /tf_geometric/datasets/abnormal.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | import numpy as np 4 | from tf_sparse import SparseMatrix 5 | from tf_geometric.data.dataset import DownloadableDataset 6 | import os 7 | from scipy.io import loadmat 8 | 9 | 10 | def _csc_to_edge_index(x): 11 | x = x.tocoo() 12 | return np.stack([x.row, x.col], axis=0) 13 | 14 | 15 | def _csc_to_sparse_matrix(x): 16 | x = x.tocoo() 17 | index = np.stack([x.row, x.col], axis=0) 18 | value = x.data.astype(np.float64) 19 | return SparseMatrix(index, value, shape=x.shape) 20 | 21 | 22 | class _BaseAbnormalMATDataset(DownloadableDataset): 23 | def __init__(self, dataset_name, dataset_root_path=None): 24 | super().__init__(dataset_name, 25 | download_urls=["https://github.com/CrawlScript/gnn_datasets/raw/master/Abnormal/{}.zip".format( 26 | dataset_name)], 27 | download_file_name="{}.zip".format(dataset_name), 28 | cache_name=None, dataset_root_path=dataset_root_path) 29 | 30 | def process(self): 31 | mat_path = os.path.join(self.raw_root_path, "{}.mat".format(self.dataset_name)) 32 | data = loadmat(mat_path) 33 | 34 | # x = _csc_to_sparse_matrix(data["features"]) 35 | x = data["features"].tocoo().astype(np.float64) 36 | y = data["label"][0].astype(np.int64) 37 | 38 | edge_index_dict = {} 39 | 40 | for key, value in data.items(): 41 | if key.startswith("net_") or key == "homo": 42 | edge_index = _csc_to_edge_index(value) 43 | edge_index_dict[key] = edge_index 44 | 45 | return x, edge_index_dict, y 46 | 47 | 48 | class FDYelpChiDataset(_BaseAbnormalMATDataset): 49 | def __init__(self, dataset_root_path=None): 50 | super().__init__("fd_yelp_chi", dataset_root_path) 51 | 52 | 53 | class FDAmazonDataset(_BaseAbnormalMATDataset): 54 | def __init__(self, dataset_root_path=None): 55 | super().__init__("fd_amazon", dataset_root_path) 56 | -------------------------------------------------------------------------------- /tf_geometric/datasets/amazon_electronics.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | from tf_geometric.datasets.csr_npz import CSRNPZDataset 3 | 4 | 5 | class AmazonElectronicsDataset(CSRNPZDataset): 6 | 7 | def __init__(self, dataset_name, dataset_root_path=None): 8 | """ 9 | 10 | :param dataset_name: "amazon-computers" | "amazon-photo" 11 | :param dataset_root_path: 12 | """ 13 | super().__init__(dataset_name=dataset_name, 14 | download_urls=[ 15 | "https://github.com/CrawlScript/gnn_datasets/raw/master/AmazonElectronics/{}.zip".format(dataset_name), 16 | "http://cdn.zhuanzhi.ai/github/{}.zip".format(dataset_name) 17 | ], 18 | download_file_name="{}.zip".format(dataset_name), 19 | cache_name=None, 20 | dataset_root_path=dataset_root_path, 21 | ) 22 | 23 | 24 | class AmazonComputersDataset(AmazonElectronicsDataset): 25 | 26 | def __init__(self, dataset_root_path=None): 27 | super().__init__("amazon-computers", dataset_root_path=dataset_root_path) 28 | 29 | 30 | class AmazonPhotoDataset(AmazonElectronicsDataset): 31 | 32 | def __init__(self, dataset_root_path=None): 33 | super().__init__("amazon-photo", dataset_root_path=dataset_root_path) 34 | -------------------------------------------------------------------------------- /tf_geometric/datasets/blog_catalog.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | import numpy as np 4 | from tf_geometric.data.dataset import DownloadableDataset 5 | import os 6 | import scipy.io as scio 7 | 8 | 9 | class MultiLabelBlogCatalogDataset(DownloadableDataset): 10 | 11 | def __init__(self, dataset_root_path=None): 12 | super().__init__(dataset_name="MultiLabelBlogCatalog", 13 | download_urls=[ 14 | "https://github.com/CrawlScript/gnn_datasets/raw/master/BlogCatalog/multi_label_blog_catalog.zip", 15 | "http://cdn.zhuanzhi.ai/github/multi_label_blog_catalog.zip" 16 | ], 17 | download_file_name="multi_label_blog_catalog.zip", 18 | cache_name="cache.p", 19 | dataset_root_path=dataset_root_path, 20 | ) 21 | 22 | def process(self): 23 | data_path = os.path.join(self.raw_root_path, "multi_label_blog_catalog.mat") 24 | data = scio.loadmat(data_path) 25 | 26 | adj = data['network'].tocoo() 27 | edge_index = np.stack([adj.row, adj.col], axis=0) 28 | 29 | y = data['group'].tocoo().toarray().astype(np.float32) 30 | 31 | return edge_index, y 32 | -------------------------------------------------------------------------------- /tf_geometric/datasets/coauthor.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | from tf_geometric.datasets.csr_npz import CSRNPZDataset 3 | 4 | 5 | class CoauthorDataset(CSRNPZDataset): 6 | 7 | def __init__(self, dataset_name, dataset_root_path=None): 8 | """ 9 | 10 | :param dataset_name: "coauthor-cs" | "coauthor-physics" 11 | :param dataset_root_path: 12 | """ 13 | super().__init__(dataset_name=dataset_name, 14 | download_urls=[ 15 | "https://github.com/CrawlScript/gnn_datasets/raw/master/Coauthor/{}.zip".format(dataset_name), 16 | "http://cdn.zhuanzhi.ai/github/{}.zip".format(dataset_name) 17 | ], 18 | download_file_name="{}.zip".format(dataset_name), 19 | cache_name=None, 20 | dataset_root_path=dataset_root_path, 21 | ) 22 | 23 | 24 | class CoauthorCSDataset(CoauthorDataset): 25 | 26 | def __init__(self, dataset_root_path=None): 27 | super().__init__("coauthor-cs", dataset_root_path=dataset_root_path) 28 | 29 | 30 | class CoauthorPhysicsDataset(CoauthorDataset): 31 | 32 | def __init__(self, dataset_root_path=None): 33 | super().__init__("coauthor-physics", dataset_root_path=dataset_root_path) 34 | -------------------------------------------------------------------------------- /tf_geometric/datasets/csr_npz.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | import numpy as np 3 | from tf_geometric.data.graph import Graph 4 | import scipy.sparse as sp 5 | import os 6 | 7 | from tf_geometric.utils.graph_utils import convert_edge_to_directed, remove_self_loop_edge 8 | from tf_geometric.data.dataset import DownloadableDataset 9 | 10 | 11 | class CSRNPZDataset(DownloadableDataset): 12 | 13 | def process(self): 14 | 15 | # npz_path = os.path.join(self.raw_root_path, "amazon_electronics_{}.npz".format(self.dataset_name.replace("amazon-", ""))) 16 | npz_name = [fname for fname in os.listdir(self.raw_root_path) if fname.endswith(".npz")][0] 17 | npz_path = os.path.join(self.raw_root_path, npz_name) 18 | 19 | with np.load(npz_path) as data: 20 | 21 | x = sp.csr_matrix((data["attr_data"], data["attr_indices"], data["attr_indptr"]), data["attr_shape"]).todense().astype(np.float32) 22 | x[x > 0.0] = 1.0 23 | 24 | adj = sp.csr_matrix((data["adj_data"], data["adj_indices"], data["adj_indptr"]), data["adj_shape"]).tocoo() 25 | edge_index = np.stack([adj.row, adj.col], axis=0).astype(np.int32) 26 | edge_index, _ = remove_self_loop_edge(edge_index) 27 | edge_index, _ = convert_edge_to_directed(edge_index, merge_modes="max") 28 | 29 | y = data["labels"].astype(np.int32) 30 | 31 | graph = Graph(x=x, edge_index=edge_index, y=y) 32 | 33 | return graph 34 | -------------------------------------------------------------------------------- /tf_geometric/datasets/nars_academic.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | import os.path 3 | import json 4 | from collections import defaultdict 5 | import numpy as np 6 | from scipy.io import loadmat 7 | 8 | from tf_geometric.data.graph import HeteroGraph 9 | from tf_geometric.data.dataset import DownloadableDataset 10 | 11 | 12 | class _NARSAcademicDataset(DownloadableDataset): 13 | 14 | def __init__(self, dataset_name, dataset_root_path=None): 15 | """ 16 | 17 | :param dataset_name: "nars_academic_acm" 18 | :param dataset_root_path: 19 | """ 20 | self.sub_dataset_name = dataset_name.split("_")[-1] 21 | 22 | super().__init__(dataset_name=dataset_name, 23 | download_urls=[ 24 | "https://github.com/CrawlScript/gnn_datasets/raw/master/nars_academic/{}.zip".format(self.sub_dataset_name) 25 | ], 26 | download_file_name="{}.zip".format(self.sub_dataset_name), 27 | cache_name=None, 28 | dataset_root_path=dataset_root_path, 29 | ) 30 | 31 | # https://github.com/pyg-team/pytorch_geometric/blob/master/torch_geometric/datasets/hgb_dataset.py 32 | def process(self): 33 | from scipy import io as sio 34 | 35 | data = loadmat(os.path.join(self.raw_root_path, "acm.mat")) 36 | p_vs_l = data['PvsL'] # paper-field? 37 | p_vs_a = data['PvsA'] # paper-author 38 | p_vs_t = data['PvsT'] # paper-term, bag of words 39 | p_vs_c = data['PvsC'] # paper-conference, labels come from that 40 | 41 | # We assign 42 | # (1) KDD papers as class 0 (data mining), 43 | # (2) SIGMOD and VLDB papers as class 1 (database), 44 | # (3) SIGCOMM and MOBICOMM papers as class 2 (communication) 45 | conf_ids = [0, 1, 9, 10, 13] 46 | label_ids = [0, 1, 2, 2, 1] 47 | 48 | p_vs_c_filter = p_vs_c[:, conf_ids] 49 | p_selected = (p_vs_c_filter.sum(1) != 0).A1.nonzero()[0] 50 | p_vs_l = p_vs_l[p_selected] 51 | p_vs_a = p_vs_a[p_selected] 52 | p_vs_t = p_vs_t[p_selected] 53 | p_vs_c = p_vs_c[p_selected] 54 | 55 | # pa = dgl.bipartite(p_vs_a, 'paper', 'pa', 'author') 56 | # pl = dgl.bipartite(p_vs_l, 'paper', 'pf', 'field') 57 | 58 | p_vs_a = p_vs_a.tocoo() 59 | p_vs_l = p_vs_l.tocoo() 60 | 61 | # pa = dgl.heterograph({('paper', 'pa', 'author'): (p_vs_a.col, p_vs_a.row)}) 62 | # pl = dgl.heterograph({('paper', 'pf', 'field'): (p_vs_l.col, p_vs_l.row)}) 63 | 64 | # gs = [pa, pl] 65 | # hg = dgl.hetero_from_relations(gs) 66 | # hg = dgl.heterograph({ 67 | # ('paper', 'pa', 'author'): (p_vs_a.row, p_vs_a.col), 68 | # ('paper', 'pf', 'field'): (p_vs_l.row, p_vs_l.col) 69 | # }) 70 | 71 | 72 | 73 | edge_index_dict = { 74 | ('paper', 'pa', 'author'): np.stack([p_vs_a.row, p_vs_a.col], axis=0).astype(np.int64), 75 | ('paper', 'pf', 'field'): np.stack([p_vs_l.row, p_vs_l.col], axis=0).astype(np.int64) 76 | } 77 | 78 | # features = torch.FloatTensor(p_vs_t.toarray()) 79 | # features = p_vs_t.toarray().astype(np.float64) 80 | 81 | num_authors = p_vs_a.col.max() + 1 82 | num_fields = p_vs_l.col.max() + 1 83 | 84 | x_dict = { 85 | "paper": p_vs_t.toarray().astype(np.float64), 86 | "author": np.zeros([num_authors, 1], dtype=np.float32), 87 | "field": np.zeros([num_fields, 1], dtype=np.float32), 88 | } 89 | 90 | pc_p, pc_c = p_vs_c.nonzero() 91 | labels = np.zeros(len(p_selected), dtype=np.int64) 92 | for conf_id, label_id in zip(conf_ids, label_ids): 93 | labels[pc_p[pc_c == conf_id]] = label_id 94 | 95 | # labels = torch.LongTensor(labels) 96 | labels = np.array(labels, dtype=np.int64) 97 | 98 | y_dict = {"paper": labels} 99 | 100 | # num_classes = 3 101 | 102 | float_mask = np.zeros(len(pc_p)) 103 | for conf_id in conf_ids: 104 | pc_c_mask = (pc_c == conf_id) 105 | float_mask[pc_c_mask] = np.random.permutation(np.linspace(0, 1, pc_c_mask.sum())) 106 | train_index = np.where(float_mask <= 0.2)[0] 107 | valid_index = np.where((float_mask > 0.2) & (float_mask <= 0.3))[0] 108 | test_index = np.where(float_mask > 0.3)[0] 109 | 110 | # hg.nodes["paper"].data["feat"] = features 111 | 112 | hetero_graph = HeteroGraph(x_dict=x_dict, edge_index_dict=edge_index_dict, y_dict=y_dict) 113 | 114 | target_node_type = "paper" 115 | 116 | return hetero_graph, target_node_type, (train_index, valid_index, test_index) 117 | 118 | # return hg, labels, num_classes, train_idx, val_idx, test_idx 119 | 120 | 121 | class NARSACMDataset(_NARSAcademicDataset): 122 | def __init__(self, dataset_root_path=None): 123 | super().__init__("nars_academic_acm", dataset_root_path) 124 | 125 | -------------------------------------------------------------------------------- /tf_geometric/datasets/ogb.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | from tf_geometric.data.graph import Graph 4 | from tf_geometric.data.dataset import DownloadableDataset 5 | from tf_geometric.utils.graph_utils import convert_edge_to_directed 6 | from ogb_lite.nodeproppred import NodePropPredDataset 7 | import numpy as np 8 | 9 | 10 | class OGBNodePropPredDataset(DownloadableDataset): 11 | """ 12 | OGB Node Property Prediction Datasets: https://ogb.stanford.edu/docs/nodeprop/ 13 | """ 14 | 15 | def __init__(self, dataset_name, dataset_root_path=None): 16 | """ 17 | OGB Node Property Prediction Datasets: https://ogb.stanford.edu/docs/nodeprop/ 18 | 19 | :param dataset_name: "ogbn-arxiv" | "ogbn-products" | "ogbn-proteins" | "ogbn-papers100M" | "ogbn-mag" 20 | :param dataset_root_path: 21 | """ 22 | 23 | super().__init__(dataset_name=dataset_name, 24 | download_urls=None, 25 | download_file_name=None, 26 | cache_name="cache.p", 27 | dataset_root_path=dataset_root_path, 28 | ) 29 | 30 | # https://github.com/tkipf/gcn/blob/master/gcn/utils.py 31 | def process(self): 32 | dataset = NodePropPredDataset(name=self.dataset_name, root=self.download_root_path) 33 | 34 | graph, label = dataset[0] # graph: library-agnostic graph object 35 | 36 | x = graph["node_feat"] 37 | edge_index = graph["edge_index"] 38 | 39 | # convert edge_index to directed 40 | edge_index, _ = convert_edge_to_directed(edge_index, None) 41 | 42 | label = label.flatten().astype(np.int32) 43 | graph = Graph(x=x, edge_index=edge_index, y=label) 44 | 45 | split_index = dataset.get_idx_split() 46 | train_index, valid_index, test_index = split_index["train"], split_index["valid"], split_index["test"] 47 | 48 | return graph, (train_index, valid_index, test_index) 49 | 50 | 51 | -------------------------------------------------------------------------------- /tf_geometric/datasets/ppi.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | import numpy as np 4 | import os 5 | 6 | import networkx as nx 7 | from tf_geometric.data.graph import Graph 8 | from tf_geometric.data.dataset import DownloadableDataset 9 | import json 10 | 11 | from tf_geometric.utils.data_utils import load_cache 12 | from tf_geometric.utils.graph_utils import convert_edge_to_directed 13 | 14 | 15 | class PPIDataset(DownloadableDataset): 16 | 17 | def __init__(self, dataset_root_path=None): 18 | super().__init__(dataset_name="PPI", 19 | download_urls=[ 20 | "https://data.dgl.ai/dataset/ppi.zip", 21 | "https://github.com/CrawlScript/gnn_datasets/raw/master/PPI/ppi.zip" 22 | ], 23 | download_file_name="ppi.zip", 24 | cache_name="cache.p", 25 | dataset_root_path=dataset_root_path, 26 | ) 27 | 28 | def process(self): 29 | 30 | splits = ["train", "valid", "test"] 31 | 32 | split_data_dict = { 33 | split: [] for split in splits 34 | } 35 | 36 | for split in split_data_dict.keys(): 37 | split_graph_ids = np.load(os.path.join(self.raw_root_path, "{}_graph_id.npy".format(split))) 38 | split_features = np.load(os.path.join(self.raw_root_path, "{}_feats.npy".format(split))).astype(np.float32) 39 | split_labels = np.load(os.path.join(self.raw_root_path, "{}_labels.npy".format(split))).astype(np.int32) 40 | 41 | nx_graph_path = os.path.join(self.raw_root_path, "{}_graph.json".format(split)) 42 | with open(nx_graph_path, "r", encoding="utf-8") as f: 43 | nx_graph = nx.DiGraph(nx.json_graph.node_link_graph(json.load(f))) 44 | 45 | split_unique_graph_ids = sorted(set(split_graph_ids)) 46 | 47 | for graph_id in split_unique_graph_ids: 48 | mask_indices = np.where(split_graph_ids == graph_id)[0] 49 | 50 | min_node_index = np.min(mask_indices) 51 | 52 | edge_index = nx_graph.subgraph(mask_indices).edges 53 | edge_index = np.array(edge_index).T - min_node_index 54 | 55 | edge_index, _ = convert_edge_to_directed(edge_index) 56 | 57 | graph = Graph( 58 | x=split_features[mask_indices], 59 | edge_index=edge_index, 60 | y=split_labels[mask_indices] 61 | ) 62 | split_data_dict[split].append(graph) 63 | # print("split: ", split) 64 | 65 | processed_data = [split_data_dict[split] for split in splits] 66 | return processed_data 67 | -------------------------------------------------------------------------------- /tf_geometric/datasets/reddit.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | import numpy as np 4 | import os 5 | 6 | import networkx as nx 7 | from tf_geometric.data.graph import Graph 8 | from tf_geometric.data.dataset import DownloadableDataset 9 | import json 10 | import scipy.sparse as sp 11 | from tf_geometric.utils.data_utils import load_cache 12 | from tf_geometric.utils.graph_utils import convert_edge_to_directed 13 | 14 | 15 | class _BaseRedditDataset(DownloadableDataset): 16 | 17 | def __init__(self, dataset_root_path=None, cache_name=None): 18 | super().__init__(dataset_name="reddit", 19 | download_urls=[ 20 | "https://data.dgl.ai/dataset/reddit.zip" 21 | ], 22 | download_file_name="reddit.zip", 23 | cache_name=cache_name, 24 | dataset_root_path=dataset_root_path, 25 | ) 26 | 27 | def process(self): 28 | 29 | common_data_path = os.path.join(self.raw_root_path, "reddit_data.npz") 30 | common_data = np.load(common_data_path) 31 | 32 | x = common_data["feature"] 33 | y = common_data["label"] 34 | 35 | mask = common_data["node_types"] 36 | full_index = np.arange(len(x), dtype=np.int32) 37 | train_index = full_index[mask == 1] 38 | valid_index = full_index[mask == 2] 39 | test_index = full_index[mask == 3] 40 | 41 | graph_data_path = os.path.join(self.raw_root_path, "reddit_graph.npz") 42 | 43 | adj = sp.load_npz(graph_data_path) 44 | edge_index = np.stack([adj.row, adj.col], axis=0) 45 | 46 | graph = Graph(x=x, y=y, edge_index=edge_index) 47 | return graph, (train_index, valid_index, test_index) 48 | 49 | 50 | class TransductiveRedditDataset(_BaseRedditDataset): 51 | 52 | def __init__(self, dataset_root_path=None): 53 | super().__init__(dataset_root_path=dataset_root_path, cache_name="transductive_cache.p") 54 | 55 | 56 | class InductiveRedditDataset(_BaseRedditDataset): 57 | 58 | def __init__(self, dataset_root_path=None): 59 | super().__init__(dataset_root_path=dataset_root_path, cache_name="inductive_cache.p") 60 | 61 | def process(self): 62 | graph, (train_index, valid_index, test_index) = super().process() 63 | train_graph = graph.sample_new_graph_by_node_index(train_index) 64 | valid_graph = graph.sample_new_graph_by_node_index(valid_index) 65 | test_graph = graph.sample_new_graph_by_node_index(test_index) 66 | return train_graph, valid_graph, test_graph 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | -------------------------------------------------------------------------------- /tf_geometric/layers/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | from .conv.gcn import GCN 3 | from .conv.gat import GAT 4 | from .conv.gin import GIN 5 | from .conv.graph_sage import MeanGraphSage, SumGraphSage, MeanPoolGraphSage, MaxPoolGraphSage, GCNGraphSage, LSTMGraphSage 6 | from .conv.sgc import SGC 7 | from .conv.tagcn import TAGCN 8 | from .conv.chebynet import ChebyNet 9 | from .conv.appnp import APPNP 10 | from .conv.le_conv import LEConv 11 | from .conv.ssgc import SSGC 12 | 13 | 14 | from .sampling.drop_edge import DropEdge 15 | 16 | from .kernel.map_reduce import MapReduceGNN 17 | 18 | from .pool.common_pool import MeanPool, MinPool, MaxPool, SumPool 19 | from .pool.diff_pool import DiffPool 20 | from .pool.set2set import Set2Set 21 | from .pool.sag_pool import SAGPool 22 | from .pool.asap import ASAP 23 | from .pool.sort_pool import SortPool 24 | from .pool.min_cut_pool import MinCutPool 25 | -------------------------------------------------------------------------------- /tf_geometric/layers/conv/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CrawlScript/tf_geometric/b7c40d9005c60b27a7e18b919d32418c5548252e/tf_geometric/layers/conv/__init__.py -------------------------------------------------------------------------------- /tf_geometric/layers/conv/chebynet.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | import warnings 4 | import tensorflow as tf 5 | from tf_geometric.nn.conv.chebynet import chebynet, chebynet_cache_normed_edge 6 | from tf_geometric.nn.conv.gcn import gcn_build_cache_by_adj 7 | 8 | 9 | class ChebyNet(tf.keras.Model): 10 | """ 11 | The chebyshev spectral graph convolutional operator from the 12 | `"Convolutional Neural Networks on Graphs with Fast Localized Spectral 13 | Filtering" `_ paper 14 | 15 | """ 16 | 17 | def __init__(self, units, k, activation=None, use_bias=True, normalization_type="sym", 18 | use_dynamic_lambda_max=False, 19 | kernel_regularizer=None, bias_regularizer=None, 20 | *args, **kwargs): 21 | """ 22 | 23 | :param units: Positive integer, dimensionality of the output space. 24 | :param k: Chebyshev filter size (default: '3"). 25 | :param lambda_max: 26 | :param use_bias: Boolean, whether the layer uses a bias vector. 27 | :param activation: Activation function to use. 28 | :param normalization_type: The normalization scheme for the graph 29 | Laplacian (default: :obj:`"sym"`) 30 | :param use_dynamic_lambda_max: If true, compute max eigen value for each forward, 31 | otherwise use 2.0 as the max eigen value 32 | :param kernel_regularizer: Regularizer function applied to the `kernel` weights matrix. 33 | :param bias_regularizer: Regularizer function applied to the bias vector. 34 | """ 35 | super().__init__(*args, **kwargs) 36 | self.units = units 37 | 38 | assert k > 0 39 | assert normalization_type in [None, 'sym', 'rw'], 'Invalid normalization' 40 | 41 | self.k = k 42 | 43 | self.use_bias = use_bias 44 | 45 | self.kernels = [] 46 | self.bias = None 47 | 48 | self.activation = activation 49 | self.normalization_type = normalization_type 50 | self.use_dynamic_lambda_max = use_dynamic_lambda_max 51 | 52 | self.kernel_regularizer = kernel_regularizer 53 | self.bias_regularizer = bias_regularizer 54 | 55 | def build(self, input_shapes): 56 | x_shape = input_shapes[0] 57 | num_features = x_shape[-1] 58 | 59 | for k in range(self.k): 60 | kernel = self.add_weight("kernel{}".format(k), shape=[num_features, self.units], 61 | initializer="glorot_uniform", regularizer=self.kernel_regularizer) 62 | self.kernels.append(kernel) 63 | 64 | # self.kernel = self.add_weight("kernel", shape=[self.K, num_features, self.units], 65 | # initializer="glorot_uniform", regularizer=self.kernel_regularizer) 66 | if self.use_bias: 67 | self.bias = self.add_weight("bias", shape=[self.units], 68 | initializer="zeros", regularizer=self.bias_regularizer) 69 | 70 | def build_cache_for_graph(self, graph, override=False): 71 | """ 72 | Manually compute the normed edge based on this layer's GCN normalization configuration (self.renorm and self.improved) and put it in graph.cache. 73 | If the normed edge already exists in graph.cache and the override parameter is False, this method will do nothing. 74 | 75 | :param graph: tfg.Graph, the input graph. 76 | :param override: Whether to override existing cached normed edge. 77 | :return: None 78 | """ 79 | chebynet_cache_normed_edge(graph, self.normalization_type, 80 | use_dynamic_lambda_max=self.use_dynamic_lambda_max, override=override) 81 | 82 | def cache_normed_edge(self, graph, override=False): 83 | """ 84 | Manually compute the normed edge based on this layer's GCN normalization configuration (self.renorm and self.improved) and put it in graph.cache. 85 | If the normed edge already exists in graph.cache and the override parameter is False, this method will do nothing. 86 | 87 | :param graph: tfg.Graph, the input graph. 88 | :param override: Whether to override existing cached normed edge. 89 | :return: None 90 | 91 | .. deprecated:: 0.0.56 92 | Use ``build_cache_for_graph`` instead. 93 | """ 94 | warnings.warn( 95 | "'ChebyNet.cache_normed_edge(graph, override)' is deprecated, use 'ChebyNet.build_cache_for_graph(graph, override)' instead", 96 | DeprecationWarning) 97 | return self.build_cache_for_graph(graph, override=override) 98 | 99 | def call(self, inputs, cache=None, training=None, mask=None): 100 | """ 101 | 102 | :param inputs: List of graph info: [x, edge_index, edge_weight] 103 | :param cache: A dict for caching A' for GCN. Different graph should not share the same cache dict. 104 | :return: Updated node features (x), shape: [num_nodes, units] 105 | """ 106 | 107 | if len(inputs) == 3: 108 | x, edge_index, edge_weight = inputs 109 | else: 110 | x, edge_index = inputs 111 | edge_weight = None 112 | 113 | return chebynet(x, edge_index, edge_weight, self.k, self.kernels, self.bias, self.activation, 114 | self.normalization_type, use_dynamic_lambda_max=self.use_dynamic_lambda_max, cache=cache) 115 | -------------------------------------------------------------------------------- /tf_geometric/layers/conv/gat.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | import tensorflow as tf 3 | 4 | from tf_geometric.nn.conv.gat import gat 5 | 6 | 7 | class GAT(tf.keras.Model): 8 | 9 | def __init__(self, units, 10 | attention_units=None, 11 | activation=None, 12 | use_bias=True, 13 | num_heads=1, 14 | split_value_heads=True, 15 | query_activation=tf.nn.relu, 16 | key_activation=tf.nn.relu, 17 | edge_drop_rate=0.0, 18 | kernel_regularizer=None, 19 | bias_regularizer=None, 20 | *args, **kwargs): 21 | """ 22 | 23 | :param units: Positive integer, dimensionality of the output space. 24 | :param attention_units: Positive integer, dimensionality of the output space for Q and K in attention. 25 | :param activation: Activation function to use. 26 | :param use_bias: Boolean, whether the layer uses a bias vector. 27 | :param num_heads: Number of attention heads. 28 | :param split_value_heads: Boolean. If true, split V as value attention heads, and then concatenate them as output. 29 | Else, num_heads different V are used as value attention heads, and the mean of them are used as output. 30 | :param query_activation: Activation function for Q in attention. 31 | :param key_activation: Activation function for K in attention. 32 | :param edge_drop_rate: Dropout rate of attention weights. 33 | :param kernel_regularizer: Regularizer function applied to the `kernel` weights matrix. 34 | :param bias_regularizer: Regularizer function applied to the bias vector. 35 | """ 36 | super().__init__(*args, **kwargs) 37 | self.units = units 38 | self.attention_units = units if attention_units is None else attention_units 39 | self.edge_drop_rate = edge_drop_rate 40 | 41 | self.query_kernel = None 42 | self.query_bias = None 43 | self.query_activation = query_activation 44 | 45 | self.key_kernel = None 46 | self.key_bias = None 47 | self.key_activation = key_activation 48 | 49 | self.kernel = None 50 | self.bias = None 51 | 52 | self.activation = activation 53 | self.use_bias = use_bias 54 | self.num_heads = num_heads 55 | self.split_value_heads = split_value_heads 56 | 57 | self.kernel_regularizer = kernel_regularizer 58 | self.bias_regularizer = bias_regularizer 59 | 60 | def build(self, input_shapes): 61 | x_shape = input_shapes[0] 62 | num_features = x_shape[-1] 63 | 64 | self.query_kernel = self.add_weight("query_kernel", shape=[num_features, self.attention_units], 65 | initializer="glorot_uniform", regularizer=self.kernel_regularizer) 66 | self.query_bias = self.add_weight("query_bias", shape=[self.attention_units], 67 | initializer="zeros", regularizer=self.bias_regularizer) 68 | 69 | self.key_kernel = self.add_weight("key_kernel", shape=[num_features, self.attention_units], 70 | initializer="glorot_uniform", regularizer=self.kernel_regularizer) 71 | self.key_bias = self.add_weight("key_bias", shape=[self.attention_units], 72 | initializer="zeros", regularizer=self.bias_regularizer) 73 | 74 | if self.split_value_heads: 75 | self.kernel = self.add_weight("kernel", shape=[num_features, self.units], 76 | initializer="glorot_uniform", regularizer=self.kernel_regularizer) 77 | else: 78 | self.kernel = self.add_weight("kernel", shape=[num_features, self.units * self.num_heads], 79 | initializer="glorot_uniform", regularizer=self.kernel_regularizer) 80 | 81 | if self.use_bias: 82 | self.bias = self.add_weight("bias", shape=[self.units], 83 | initializer="zeros", regularizer=self.bias_regularizer) 84 | 85 | def call(self, inputs, training=None, mask=None): 86 | """ 87 | 88 | :param inputs: List of graph info: [x, edge_index] or [x, edge_index, edge_weight]. 89 | Note that the edge_weight will not be used. 90 | :return: Updated node features (x), shape: [num_nodes, units] 91 | """ 92 | x, edge_index = inputs[0], inputs[1] 93 | 94 | return gat(x, edge_index, 95 | self.query_kernel, self.query_bias, self.query_activation, 96 | self.key_kernel, self.key_bias, self.key_activation, 97 | self.kernel, self.bias, self.activation, 98 | num_heads=self.num_heads, 99 | split_value_heads=self.split_value_heads, 100 | edge_drop_rate=self.edge_drop_rate, 101 | training=training) 102 | -------------------------------------------------------------------------------- /tf_geometric/layers/conv/gin.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | import tensorflow as tf 3 | from tf_geometric.nn.conv.gin import gin 4 | 5 | 6 | class GIN(tf.keras.Model): 7 | """ 8 | Graph Isomorphism Network Layer 9 | """ 10 | 11 | def __init__(self, mlp_model, eps=0, train_eps=False, *args, **kwargs): 12 | """ 13 | :param mlp_model: A neural network (multi-layer perceptrons). 14 | :param eps: float, optional, (default: :obj:`0.`). 15 | :param train_eps: Boolean, Whether the eps is trained. 16 | :param activation: Activation function to use. 17 | """ 18 | super().__init__(*args, **kwargs) 19 | self.mlp_model = mlp_model 20 | 21 | self.eps = eps 22 | if train_eps: 23 | self.eps = self.add_weight("eps", shape=[], initializer="zeros") 24 | 25 | 26 | def call(self, inputs, cache=None, training=None, mask=None): 27 | """ 28 | 29 | :param inputs: List of graph info: [x, edge_index, edge_weight] 30 | :param cache: A dict for caching A' for GIN. Different graph should not share the same cache dict. 31 | :return: Updated node features (x), shape: [num_nodes, units] 32 | """ 33 | 34 | if len(inputs) == 3: 35 | x, edge_index, _ = inputs 36 | else: 37 | x, edge_index = inputs 38 | 39 | return gin(x, edge_index, self.mlp_model, self.eps, training=training) 40 | -------------------------------------------------------------------------------- /tf_geometric/layers/conv/le_conv.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | import tensorflow as tf 4 | from tf_geometric.nn.conv.le_conv import le_conv 5 | 6 | 7 | class LEConv(tf.keras.Model): 8 | """ 9 | Graph Convolutional Layer 10 | """ 11 | 12 | def build(self, input_shapes): 13 | x_shape = input_shapes[0] 14 | num_features = x_shape[-1] 15 | 16 | self.self_kernel = self.add_weight("self_kernel", shape=[num_features, self.units], 17 | initializer="glorot_uniform", regularizer=self.kernel_regularizer) 18 | if self.self_use_bias: 19 | self.self_bias = self.add_weight("self_bias", shape=[self.units], 20 | initializer="zeros", regularizer=self.bias_regularizer) 21 | 22 | self.aggr_self_kernel = self.add_weight("aggr_self_kernel", shape=[num_features, self.units], 23 | initializer="glorot_uniform", regularizer=self.kernel_regularizer) 24 | if self.aggr_self_use_bias: 25 | self.aggr_self_bias = self.add_weight("aggr_self_bias", shape=[self.units], 26 | initializer="zeros", regularizer=self.bias_regularizer) 27 | 28 | self.aggr_neighbor_kernel = self.add_weight("aggr_neighbor_kernel", shape=[num_features, self.units], 29 | initializer="glorot_uniform", regularizer=self.kernel_regularizer) 30 | if self.aggr_neighbor_use_bias: 31 | self.aggr_neighbor_bias = self.add_weight("aggr_neighbor_bias", shape=[self.units], 32 | initializer="zeros", regularizer=self.bias_regularizer) 33 | 34 | def __init__(self, units, activation=None, 35 | self_use_bias=True, 36 | aggr_self_use_bias=True, 37 | aggr_neighbor_use_bias=False, 38 | kernel_regularizer=None, bias_regularizer=None, *args, **kwargs): 39 | """ 40 | 41 | :param units: Positive integer, dimensionality of the output space. 42 | :param activation: Activation function to use. 43 | :param use_bias: Boolean, whether the layer uses a bias vector. 44 | :param renorm: Whether use renormalization trick (https://arxiv.org/pdf/1609.02907.pdf). 45 | :param improved: Whether use improved GCN or not. 46 | :param kernel_regularizer: Regularizer function applied to the `kernel` weights matrix. 47 | :param bias_regularizer: Regularizer function applied to the bias vector. 48 | """ 49 | super().__init__(*args, **kwargs) 50 | self.units = units 51 | 52 | self.activation = activation 53 | 54 | self.self_use_bias = self_use_bias 55 | self.aggr_self_use_bias = aggr_self_use_bias 56 | self.aggr_neighbor_use_bias = aggr_neighbor_use_bias 57 | 58 | self.kernel_regularizer = kernel_regularizer 59 | self.bias_regularizer = bias_regularizer 60 | 61 | self.self_kernel = None 62 | self.self_bias = None 63 | self.aggr_self_kernel = None 64 | self.aggr_self_bias = None 65 | self.aggr_neighbor_kernel = None 66 | self.aggr_neighbor_bias = None 67 | 68 | def call(self, inputs, training=None, mask=None): 69 | """ 70 | 71 | :param inputs: List of graph info: [x, edge_index, edge_weight] 72 | :return: Updated node features (x), shape: [num_nodes, units] 73 | """ 74 | 75 | if len(inputs) == 3: 76 | x, edge_index, edge_weight = inputs 77 | else: 78 | x, edge_index = inputs 79 | edge_weight = None 80 | 81 | return le_conv(x, edge_index, edge_weight, 82 | self.self_kernel, self.self_bias, 83 | self.aggr_self_kernel, self.aggr_self_bias, 84 | self.aggr_neighbor_kernel, self.aggr_neighbor_bias, 85 | activation=self.activation) 86 | -------------------------------------------------------------------------------- /tf_geometric/layers/conv/sgc.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | import warnings 3 | from tf_geometric.nn.conv.sgc import sgc 4 | from tf_geometric.nn.conv.gcn import gcn_build_cache_for_graph, gcn_build_cache_by_adj 5 | import tensorflow as tf 6 | 7 | 8 | class SGC(tf.keras.Model): 9 | """ 10 | The simple graph convolutional operator from the `"Simplifying Graph 11 | Convolutional Networks" `_ paper 12 | """ 13 | 14 | def __init__(self, units, k=1, activation=None, use_bias=True, renorm=True, improved=False, 15 | kernel_regularizer=None, bias_regularizer=None, 16 | *args, **kwargs): 17 | super().__init__(*args, **kwargs) 18 | """ 19 | :param units: Size of each output sample.. 20 | :param k: Number of hops.(default: :obj:`1`) 21 | :param activation: Activation function to use. 22 | :param use_bias: Boolean, whether the layer uses a bias vector. 23 | :param renorm: Whether use renormalization trick (https://arxiv.org/pdf/1609.02907.pdf). 24 | :param improved: Whether use improved GCN or not. 25 | :param kernel_regularizer: Regularizer function applied to the `kernel` weights matrix. 26 | :param bias_regularizer: Regularizer function applied to the bias vector. 27 | """ 28 | 29 | self.units = units 30 | self.use_bias = use_bias 31 | self.renorm = renorm 32 | self.improved = improved 33 | self.k = k 34 | self.activation = activation 35 | self.kernel = None 36 | self.bias = None 37 | 38 | self.kernel_regularizer = kernel_regularizer 39 | self.bias_regularizer = bias_regularizer 40 | 41 | def build(self, input_shape): 42 | 43 | x_shape = input_shape[0] 44 | num_features = x_shape[-1] 45 | 46 | self.kernel = self.add_weight("kernel", shape=[num_features, self.units], 47 | initializer="glorot_uniform", regularizer=self.kernel_regularizer) 48 | if self.use_bias: 49 | self.bias = self.add_weight("bias", shape=[self.units], 50 | initializer="zeros", regularizer=self.bias_regularizer) 51 | 52 | 53 | def build_cache_by_adj(self, sparse_adj, override=False, cache=None): 54 | """ 55 | Manually compute the normed edge based on this layer's GCN normalization configuration (self.renorm and self.improved) and put it in graph.cache. 56 | If the normed edge already exists in graph.cache and the override parameter is False, this method will do nothing. 57 | 58 | :param graph: tfg.Graph, the input graph. 59 | :param override: Whether to override existing cached normed edge. 60 | :return: None 61 | """ 62 | return gcn_build_cache_by_adj(sparse_adj, self.renorm, self.improved, override=override, cache=cache) 63 | 64 | def build_cache_for_graph(self, graph, override=False): 65 | """ 66 | Manually compute the normed edge based on this layer's GCN normalization configuration (self.renorm and self.improved) and put it in graph.cache. 67 | If the normed edge already exists in graph.cache and the override parameter is False, this method will do nothing. 68 | 69 | :param graph: tfg.Graph, the input graph. 70 | :param override: Whether to override existing cached normed edge. 71 | :return: None 72 | """ 73 | gcn_build_cache_for_graph(graph, renorm=self.renorm, improved=self.improved, override=override) 74 | 75 | def cache_normed_edge(self, graph, override=False): 76 | """ 77 | Manually compute the normed edge based on this layer's GCN normalization configuration (self.renorm and self.improved) and put it in graph.cache. 78 | If the normed edge already exists in graph.cache and the override parameter is False, this method will do nothing. 79 | 80 | :param graph: tfg.Graph, the input graph. 81 | :param override: Whether to override existing cached normed edge. 82 | :return: None 83 | 84 | .. deprecated:: 0.0.56 85 | Use ``build_cache_for_graph`` instead. 86 | """ 87 | warnings.warn("'SGC.cache_normed_edge(graph, override)' is deprecated, use 'SGC.build_cache_for_graph(graph, override)' instead", DeprecationWarning) 88 | return self.build_cache_for_graph(graph, override=override) 89 | 90 | 91 | def call(self, inputs, cache=None, training=None, mask=None): 92 | """ 93 | :param inputs: List of graph info: [x, edge_index, edge_weight] 94 | :param cache: A dict for caching A' for GCN. Different graph should not share the same cache dict. 95 | :return: Updated node features (x), shape: [num_nodes, num_units] 96 | """ 97 | 98 | if len(inputs) == 3: 99 | x, edge_index, edge_weight = inputs 100 | else: 101 | x, edge_index = inputs 102 | edge_weight = None 103 | 104 | return sgc(x, edge_index, edge_weight, self.k, self.kernel, 105 | bias=self.bias, activation=self.activation, 106 | renorm=self.renorm, improved=self.improved, cache=cache) 107 | -------------------------------------------------------------------------------- /tf_geometric/layers/conv/tagcn.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | import warnings 4 | import tensorflow as tf 5 | from tf_geometric.nn.conv.gcn import gcn_build_cache_for_graph, gcn_build_cache_by_adj 6 | 7 | from tf_geometric.nn.conv.tagcn import tagcn 8 | 9 | 10 | class TAGCN(tf.keras.Model): 11 | """ 12 | The topology adaptive graph convolutional networks operator from the 13 | `"Topology Adaptive Graph Convolutional Networks" 14 | `_ paper 15 | """ 16 | 17 | def __init__(self, units, k=3, activation=None, use_bias=True, 18 | renorm=False, improved=False, 19 | kernel_regularizer=None, bias_regularizer=None, 20 | *args, **kwargs): 21 | """ 22 | 23 | :param units: Positive integer, dimensionality of the output space. 24 | :param k: Number of hops (default: '3"). 25 | :param activation: Activation function to use. 26 | :param use_bias: Boolean, whether the layer uses a bias vector. 27 | :param renorm: Whether use renormalization trick (https://arxiv.org/pdf/1609.02907.pdf). 28 | :param improved: Whether use improved GCN or not. 29 | :param kernel_regularizer: Regularizer function applied to the `kernel` weights matrix. 30 | :param bias_regularizer: Regularizer function applied to the bias vector. 31 | """ 32 | super().__init__(*args, **kwargs) 33 | self.units = units 34 | assert k > 0 35 | 36 | self.k = k 37 | 38 | self.activation = activation 39 | self.use_bias = use_bias 40 | 41 | self.kernel = None 42 | self.bias = None 43 | 44 | self.kernel_regularizer = kernel_regularizer 45 | self.bias_regularizer = bias_regularizer 46 | 47 | self.renorm = renorm 48 | self.improved = improved 49 | 50 | def build(self, input_shapes): 51 | x_shape = input_shapes[0] 52 | num_features = x_shape[-1] 53 | 54 | self.kernel = self.add_weight("kernel", shape=[num_features * (self.k + 1), self.units], 55 | initializer="glorot_uniform", regularizer=self.kernel_regularizer) 56 | if self.use_bias: 57 | self.bias = self.add_weight("bias", shape=[self.units], 58 | initializer="zeros", regularizer=self.bias_regularizer) 59 | 60 | def build_cache_by_adj(self, sparse_adj, override=False, cache=None): 61 | """ 62 | Manually compute the normed edge based on this layer's GCN normalization configuration (self.renorm and self.improved) and put it in graph.cache. 63 | If the normed edge already exists in graph.cache and the override parameter is False, this method will do nothing. 64 | 65 | :param graph: tfg.Graph, the input graph. 66 | :param override: Whether to override existing cached normed edge. 67 | :return: None 68 | """ 69 | return gcn_build_cache_by_adj(sparse_adj, self.renorm, self.improved, override=override, cache=cache) 70 | 71 | def build_cache_for_graph(self, graph, override=False): 72 | """ 73 | Manually compute the normed edge based on this layer's GCN normalization configuration (self.renorm and self.improved) and put it in graph.cache. 74 | If the normed edge already exists in graph.cache and the override parameter is False, this method will do nothing. 75 | 76 | :param graph: tfg.Graph, the input graph. 77 | :param override: Whether to override existing cached normed edge. 78 | :return: None 79 | """ 80 | gcn_build_cache_for_graph(graph, renorm=self.renorm, improved=self.improved, override=override) 81 | 82 | def cache_normed_edge(self, graph, override=False): 83 | """ 84 | Manually compute the normed edge based on this layer's GCN normalization configuration (self.renorm and self.improved) and put it in graph.cache. 85 | If the normed edge already exists in graph.cache and the override parameter is False, this method will do nothing. 86 | 87 | :param graph: tfg.Graph, the input graph. 88 | :param override: Whether to override existing cached normed edge. 89 | :return: None 90 | 91 | .. deprecated:: 0.0.56 92 | Use ``build_cache_for_graph`` instead. 93 | """ 94 | warnings.warn("'TAGCN.cache_normed_edge(graph, override)' is deprecated, use 'TAGCN.build_cache_for_graph(graph, override)' instead", DeprecationWarning) 95 | return self.build_cache_for_graph(graph, override=override) 96 | 97 | 98 | def call(self, inputs, cache=None, training=None, mask=None): 99 | """ 100 | 101 | :param inputs: List of graph info: [x, edge_index, edge_weight] 102 | :param cache: A dict for caching A' for GCN. Different graph should not share the same cache dict. 103 | :return: Updated node features (x), shape: [num_nodes, units] 104 | """ 105 | 106 | if len(inputs) == 3: 107 | x, edge_index, edge_weight = inputs 108 | else: 109 | x, edge_index = inputs 110 | edge_weight = None 111 | 112 | return tagcn(x, edge_index, edge_weight, self.k, self.kernel, 113 | bias=self.bias, activation=self.activation, renorm=self.renorm, 114 | improved=self.improved, cache=cache) 115 | -------------------------------------------------------------------------------- /tf_geometric/layers/kernel/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CrawlScript/tf_geometric/b7c40d9005c60b27a7e18b919d32418c5548252e/tf_geometric/layers/kernel/__init__.py -------------------------------------------------------------------------------- /tf_geometric/layers/kernel/map_reduce.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | import tensorflow as tf 3 | from tf_geometric.nn import aggregate_neighbors 4 | 5 | 6 | class MapReduceGNN(tf.keras.Model): 7 | 8 | def map(self, repeated_x, neighbor_x, edge_weight=None): 9 | pass 10 | 11 | def reduce(self, neighbor_msg, node_index, num_nodes=None): 12 | pass 13 | 14 | def update(self, x, reduced_neighbor_msg): 15 | pass 16 | 17 | def get_mapper(self): 18 | def mapper(repeated_x, neighbor_x, edge_weight=None): 19 | return self.map(repeated_x, neighbor_x, edge_weight) 20 | return mapper 21 | 22 | def get_reducer(self): 23 | def reducer(neighbor_msg, node_index, num_nodes=None): 24 | return self.reduce(neighbor_msg, node_index, num_nodes) 25 | return reducer 26 | 27 | def get_updater(self): 28 | def updater(x, reduced_neighbor_msg): 29 | return self.update(x, reduced_neighbor_msg) 30 | return updater 31 | 32 | def call(self, inputs, training=None, mask=None): 33 | x, edge_index, edge_weight = inputs 34 | return aggregate_neighbors( 35 | x, 36 | edge_index, 37 | edge_weight, 38 | self.get_mapper(), 39 | self.get_reducer(), 40 | self.get_updater() 41 | ) -------------------------------------------------------------------------------- /tf_geometric/layers/pool/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CrawlScript/tf_geometric/b7c40d9005c60b27a7e18b919d32418c5548252e/tf_geometric/layers/pool/__init__.py -------------------------------------------------------------------------------- /tf_geometric/layers/pool/common_pool.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | import tensorflow as tf 3 | from tf_geometric.nn.pool.common_pool import mean_pool, min_pool, max_pool, sum_pool 4 | 5 | 6 | class CommonPool(tf.keras.Model): 7 | 8 | def __init__(self, pool_func, *args, **kwargs): 9 | super().__init__(*args, **kwargs) 10 | self.pool_func = pool_func 11 | 12 | def call(self, inputs, training=None, mask=None): 13 | if len(inputs) == 2: 14 | x, node_graph_index = inputs 15 | num_graphs = None 16 | else: 17 | x, node_graph_index, num_graphs = inputs 18 | 19 | return self.pool_func(x, node_graph_index, num_graphs) 20 | 21 | 22 | class MeanPool(tf.keras.Model): 23 | def __init__(self, *args, **kwargs): 24 | super().__init__(mean_pool, *args, **kwargs) 25 | 26 | 27 | class MinPool(tf.keras.Model): 28 | def __init__(self, *args, **kwargs): 29 | super().__init__(min_pool, *args, **kwargs) 30 | 31 | 32 | class MaxPool(tf.keras.Model): 33 | def __init__(self, *args, **kwargs): 34 | super().__init__(max_pool, *args, **kwargs) 35 | 36 | 37 | class SumPool(tf.keras.Model): 38 | def __init__(self, *args, **kwargs): 39 | super().__init__(sum_pool, *args, **kwargs) -------------------------------------------------------------------------------- /tf_geometric/layers/pool/diff_pool.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | 4 | import tensorflow as tf 5 | 6 | from tf_geometric.nn.pool.diff_pool import diff_pool 7 | 8 | 9 | class DiffPool(tf.keras.Model): 10 | """ 11 | OOP API for DiffPool: "Hierarchical graph representation learning with differentiable pooling" 12 | """ 13 | 14 | def __init__(self, feature_gnn, assign_gnn, units, num_clusters, activation=None, use_bias=True, 15 | bias_regularizer=None, *args, **kwargs): 16 | """ 17 | DiffPool 18 | 19 | :param feature_gnn: A GNN model to learn pooled node features, [x, edge_index, edge_weight] => updated_x, 20 | where updated_x corresponds to high-order node features. 21 | :param assign_gnn: A GNN model to learn cluster assignment for the pooling, [x, edge_index, edge_weight] => updated_x, 22 | where updated_x corresponds to the cluster assignment matrix. 23 | :param units: Positive integer, dimensionality of the output space. It must be provided if you set use_bias=True. 24 | :param num_clusters: Number of clusters for pooling. 25 | :param activation: Activation function to use. 26 | :param use_bias: Boolean, whether the layer uses a bias vector. If true, the "units" parameter must be provided. 27 | :param bias_regularizer: Regularizer function applied to the bias vector. 28 | """ 29 | super().__init__(*args, **kwargs) 30 | self.feature_gnn = feature_gnn 31 | self.assign_gnn = assign_gnn 32 | 33 | self.num_clusters = num_clusters 34 | self.activation = activation 35 | 36 | if use_bias and units is None: 37 | raise Exception("The \"units\" parameter is required when you set use_bias=True.") 38 | 39 | if use_bias: 40 | self.bias = self.add_weight("bias", shape=[units], 41 | initializer="zeros", regularizer=bias_regularizer) 42 | 43 | def call(self, inputs, cache=None, training=None, mask=None): 44 | """ 45 | 46 | :param inputs: List of graph info: [x, edge_index, edge_weight, node_graph_index] 47 | :param cache: A dict for caching A' for GCN. Different graph should not share the same cache dict. 48 | :return: Pooled graph: [pooled_x, pooled_edge_index, pooled_edge_weight, pooled_node_graph_index] 49 | """ 50 | x, edge_index, edge_weight, node_graph_index = inputs 51 | 52 | return diff_pool(x, edge_index, edge_weight, node_graph_index, 53 | self.feature_gnn, self.assign_gnn, self.num_clusters, 54 | bias=self.bias, activation=self.activation, training=training, cache=cache) 55 | -------------------------------------------------------------------------------- /tf_geometric/layers/pool/min_cut_pool.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | 4 | import tensorflow as tf 5 | 6 | from tf_geometric.nn.pool.min_cut_pool import min_cut_pool 7 | 8 | 9 | class MinCutPool(tf.keras.Model): 10 | """ 11 | OOP API for MinCutPool: "Spectral Clustering with Graph Neural Networks for Graph Pooling" 12 | """ 13 | 14 | def __init__(self, feature_gnn, assign_gnn, units, num_clusters, activation=None, use_bias=True, 15 | gnn_use_normed_edge=True, 16 | bias_regularizer=None, *args, **kwargs): 17 | """ 18 | MinCutPool 19 | 20 | :param feature_gnn: A GNN model to learn pooled node features, [x, edge_index, edge_weight] => updated_x, 21 | where updated_x corresponds to high-order node features. 22 | :param assign_gnn: A GNN model to learn cluster assignment for the pooling, [x, edge_index, edge_weight] => updated_x, 23 | where updated_x corresponds to the cluster assignment matrix. 24 | :param units: Positive integer, dimensionality of the output space. It must be provided if you set use_bias=True. 25 | :param num_clusters: Number of clusters for pooling. 26 | :param activation: Activation function to use. 27 | :param use_bias: Boolean, whether the layer uses a bias vector. If true, the "units" parameter must be provided. 28 | :param gnn_use_normed_edge: Boolean. Whether to use normalized edge for feature_gnn and assign_gnn. 29 | :param bias_regularizer: Regularizer function applied to the bias vector. 30 | """ 31 | super().__init__(*args, **kwargs) 32 | self.feature_gnn = feature_gnn 33 | self.assign_gnn = assign_gnn 34 | 35 | self.num_clusters = num_clusters 36 | self.activation = activation 37 | 38 | if use_bias and units is None: 39 | raise Exception("The \"units\" parameter is required when you set use_bias=True.") 40 | 41 | if use_bias: 42 | self.bias = self.add_weight("bias", shape=[units], 43 | initializer="zeros", regularizer=bias_regularizer) 44 | 45 | self.gnn_use_normed_edge = gnn_use_normed_edge 46 | 47 | def call(self, inputs, cache=None, training=None, mask=None, return_loss_func=False, return_losses=False): 48 | """ 49 | 50 | :param inputs: List of graph info: [x, edge_index, edge_weight, node_graph_index] 51 | :param cache: A dict for caching A' for GCN. Different graph should not share the same cache dict. 52 | :param return_loss_func: Boolean. If True, return (outputs, loss_func), where loss_func is a callable function 53 | that returns a list of losses. 54 | :param return_losses: Boolean. If True, return (outputs, losses), where losses is a list of losses. 55 | :return: Pooled graph: [pooled_x, pooled_edge_index, pooled_edge_weight, pooled_node_graph_index] 56 | """ 57 | 58 | if return_loss_func and return_losses: 59 | raise Exception("return_loss_func and return_losses cannot be set to True at the same time") 60 | 61 | x, edge_index, edge_weight, node_graph_index = inputs 62 | 63 | outputs, loss_func = min_cut_pool(x, edge_index, edge_weight, node_graph_index, 64 | self.feature_gnn, self.assign_gnn, self.num_clusters, 65 | bias=self.bias, activation=self.activation, 66 | gnn_use_normed_edge=self.gnn_use_normed_edge, 67 | training=training, cache=cache, 68 | return_loss_func=True) 69 | self.add_loss(loss_func) 70 | 71 | if return_loss_func: 72 | return outputs, loss_func 73 | elif return_losses: 74 | losses = loss_func() 75 | return outputs, losses 76 | else: 77 | return outputs 78 | 79 | -------------------------------------------------------------------------------- /tf_geometric/layers/pool/sag_pool.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | import tensorflow as tf 4 | 5 | from tf_geometric.nn.pool.sag_pool import sag_pool 6 | 7 | 8 | class SAGPool(tf.keras.Model): 9 | """ 10 | OOP API for SAGPool 11 | """ 12 | 13 | def __init__(self, score_gnn, k=None, ratio=None, score_activation=None, *args, **kwargs): 14 | """ 15 | SAGPool 16 | 17 | :param score_gnn: A GNN model to score nodes for the pooling, [x, edge_index, edge_weight] => node_score. 18 | :param k: Keep top k targets for each source 19 | :param ratio: Keep num_targets * ratio targets for each source 20 | :param score_activation: Activation to use for node_score before multiplying node_features with node_score 21 | """ 22 | super().__init__(*args, **kwargs) 23 | self.score_gnn = score_gnn 24 | self.k = k 25 | self.ratio = ratio 26 | self.score_activation = score_activation 27 | 28 | def call(self, inputs, cache=None, training=None, mask=None): 29 | """ 30 | 31 | :param inputs: List of graph info: [x, edge_index, edge_weight, node_graph_index] 32 | :param cache: A dict for caching A' for GCN. Different graph should not share the same cache dict. 33 | :return: Pooled graph: [pooled_x, pooled_edge_index, pooled_edge_weight, pooled_node_graph_index] 34 | """ 35 | x, edge_index, edge_weight, node_graph_index = inputs 36 | 37 | return sag_pool(x, edge_index, edge_weight, node_graph_index, self.score_gnn, 38 | k=self.k, ratio=self.ratio, score_activation=self.score_activation, 39 | training=training, cache=cache) 40 | -------------------------------------------------------------------------------- /tf_geometric/layers/pool/set2set.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | import tensorflow as tf 4 | from tf_geometric.nn.pool.set2set import set2set 5 | 6 | 7 | class Set2Set(tf.keras.Model): 8 | """ 9 | OOP API for Set2Set 10 | """ 11 | 12 | def __init__(self, num_iterations=4, *args, **kwargs): 13 | """ 14 | Set2Set 15 | 16 | :param num_iterations: Number of iterations for attention. 17 | """ 18 | super().__init__(*args, **kwargs) 19 | self.num_iterations = num_iterations 20 | self.lstm = None 21 | 22 | def build(self, input_shapes): 23 | x_shape = input_shapes[0] 24 | num_features = x_shape[-1] 25 | 26 | self.lstm = tf.keras.layers.LSTM(num_features, return_sequences=True, return_state=True) 27 | 28 | def call(self, inputs, cache=None, training=None, mask=None): 29 | """ 30 | 31 | :param inputs: List of graph info: [x, node_graph_index] 32 | :param cache: A dict for caching A' for GCN. Different graph should not share the same cache dict. 33 | :return: Graph features, shape: [num_graphs, num_node_features * 2] 34 | """ 35 | x, node_graph_index = inputs 36 | 37 | return set2set(x, node_graph_index, self.lstm, self.num_iterations, training=training) 38 | -------------------------------------------------------------------------------- /tf_geometric/layers/pool/sort_pool.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | import tensorflow as tf 4 | 5 | from tf_geometric.nn.pool.sort_pool import sort_pool 6 | 7 | 8 | class SortPool(tf.keras.Model): 9 | """ 10 | OOP API for SortPool "An End-to-End Deep Learning Architecture for Graph Classification" 11 | """ 12 | 13 | def __init__(self, k=None, ratio=None, sort_index=-1, *args, **kwargs): 14 | """ 15 | SAGPool 16 | 17 | :param score_gnn: A GNN model to score nodes for the pooling, [x, edge_index, edge_weight] => node_score. 18 | :param k: Keep top k targets for each source 19 | :param ratio: Keep num_targets * ratio targets for each source 20 | :param sort_index: The sort_index_th index of the last axis will used for sort. 21 | """ 22 | super().__init__(*args, **kwargs) 23 | self.k = k 24 | self.ratio = ratio 25 | self.sort_index = sort_index 26 | 27 | def call(self, inputs, training=None, mask=None): 28 | """ 29 | 30 | :param inputs: List of graph info: [x, edge_index, edge_weight, node_graph_index] 31 | :return: Pooled grpah: [pooled_x, pooled_edge_index, pooled_edge_weight, pooled_node_graph_index] 32 | """ 33 | x, edge_index, edge_weight, node_graph_index = inputs 34 | 35 | return sort_pool(x, edge_index, edge_weight, node_graph_index, 36 | k=self.k, ratio=self.ratio, sort_index=self.sort_index, training=training) 37 | -------------------------------------------------------------------------------- /tf_geometric/layers/sampling/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CrawlScript/tf_geometric/b7c40d9005c60b27a7e18b919d32418c5548252e/tf_geometric/layers/sampling/__init__.py -------------------------------------------------------------------------------- /tf_geometric/layers/sampling/drop_edge.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | import tensorflow as tf 4 | from tf_geometric.nn.sampling.drop_edge import drop_edge 5 | 6 | 7 | class DropEdge(tf.keras.Model): 8 | def __init__(self, rate=0.5, force_undirected: bool = False): 9 | """ 10 | DropEdge: Towards Deep Graph Convolutional Networks on Node Classification 11 | https://openreview.net/forum?id=Hkx1qkrKPr 12 | 13 | :param rate: dropout rate 14 | :param force_undirected: If set to `True`, will either 15 | drop or keep both edges of an undirected edge. 16 | """ 17 | super().__init__() 18 | self.rate = rate 19 | self.force_undirected = force_undirected 20 | 21 | if self.rate < 0. or self.rate > 1.: 22 | raise ValueError('Dropout probability has to be between 0 and 1, ' 23 | 'but got {}'.format(self.rate)) 24 | 25 | def call(self, inputs, training=None, mask=None): 26 | return drop_edge(inputs=inputs, rate=self.rate, 27 | force_undirected=self.force_undirected, training=training) 28 | -------------------------------------------------------------------------------- /tf_geometric/nn/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | # nn package contains functional APIs for tf_geometric 4 | 5 | 6 | from .kernel.map_reduce import identity_mapper, neighbor_count_mapper, sum_reducer, sum_updater, identity_updater, mean_reducer, max_reducer, aggregate_neighbors 7 | from .conv.gcn import gcn, gcn_norm_adj, gcn_build_cache_by_adj, gcn_build_cache_for_graph, gcn_norm_edge, gcn_cache_normed_edge 8 | from .conv.gat import gat 9 | from .conv.chebynet import chebynet, chebynet_norm_edge 10 | from .conv.sgc import sgc 11 | from .conv.tagcn import tagcn 12 | from .conv.graph_sage import mean_graph_sage, sum_graph_sage, mean_pool_graph_sage, max_pool_graph_sage, gcn_graph_sage, lstm_graph_sage 13 | from .conv.appnp import appnp 14 | from .conv.gin import gin 15 | from .conv.le_conv import le_conv 16 | from .conv.ssgc import ssgc 17 | 18 | 19 | from .sampling.drop_edge import drop_edge 20 | 21 | from .pool.common_pool import mean_pool, min_pool, max_pool, sum_pool 22 | from .pool.topk_pool import topk_pool 23 | from .pool.diff_pool import diff_pool, diff_pool_coarsen 24 | from .pool.set2set import set2set 25 | from .pool.cluster_pool import cluster_pool 26 | from .pool.sag_pool import sag_pool 27 | from .pool.asap import asap 28 | from .pool.sort_pool import sort_pool 29 | from .pool.min_cut_pool import min_cut_pool, min_cut_pool_coarsen, min_cut_pool_compute_losses 30 | -------------------------------------------------------------------------------- /tf_geometric/nn/conv/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CrawlScript/tf_geometric/b7c40d9005c60b27a7e18b919d32418c5548252e/tf_geometric/nn/conv/__init__.py -------------------------------------------------------------------------------- /tf_geometric/nn/conv/appnp.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | import tensorflow as tf 4 | # from tf_geometric.sparse.sparse_adj import SparseAdj 5 | from tf_sparse import SparseMatrix 6 | 7 | from tf_geometric.nn.conv.gcn import gcn_norm_adj 8 | import tf_sparse as tfs 9 | 10 | 11 | def appnp(x, edge_index, edge_weight, kernels, biases, 12 | dense_activation=tf.nn.relu, activation=None, 13 | k=10, alpha=0.1, 14 | dense_drop_rate=0.0, last_dense_drop_rate=0.0, edge_drop_rate=0.0, 15 | cache=None, training=False): 16 | 17 | """ 18 | Functional API for Approximate Personalized Propagation of Neural Predictions (APPNP). 19 | 20 | :param x: Tensor, shape: [num_nodes, num_features], node features 21 | :param edge_index: Tensor, shape: [2, num_edges], edge information 22 | :param edge_weight: Tensor or None, shape: [num_edges] 23 | :param kernels: List[Tensor], shape of each Tensor: [num_features, num_output_features], weights 24 | :param biases: List[Tensor], shape of each Tensor: [num_output_features], biases 25 | :param dense_activation: Activation function to use for the dense layers, 26 | except for the last dense layer, which will not be activated. 27 | :param activation: Activation function to use for the output. 28 | :param k: Number of propagation power iterations. 29 | :param alpha: Teleport Probability. 30 | :param dense_drop_rate: Dropout rate for the output of every dense layer (except the last one). 31 | :param last_dense_drop_rate: Dropout rate for the output of the last dense layer. 32 | last_dense_drop_rate is usually set to 0.0 for classification tasks. 33 | :param edge_drop_rate: Dropout rate for the edges/adj used for propagation. 34 | :param cache: A dict for caching A' for GCN. Different graph should not share the same cache dict. 35 | To use @tf_utils.function with gcn, you should cache the noremd edge information before the first call of the gcn. 36 | 37 | - (1) If you're using OOP APIs tfg.layers.GCN: 38 | 39 | gcn_layer.build_cache_for_graph(graph) 40 | 41 | - (2) If you're using functional API tfg.nn.gcn: 42 | 43 | from tf_geometric.nn.conv.gcn import gcn_build_cache_for_graph 44 | gcn_build_cache_for_graph(graph) 45 | 46 | :param training: Python boolean indicating whether the layer should behave in 47 | training mode (adding dropout) or in inference mode (doing nothing). 48 | :return: Updated node features (x), shape: [num_nodes, num_output_features] 49 | """ 50 | 51 | num_nodes = tfs.shape(x)[0] 52 | # updated_edge_index, normed_edge_weight = gcn_norm_edge(edge_index, num_nodes, edge_weight, cache=cache) 53 | sparse_adj = SparseMatrix(edge_index, edge_weight, [num_nodes, num_nodes]) 54 | normed_sparse_adj = gcn_norm_adj(sparse_adj, cache=cache)\ 55 | .dropout(edge_drop_rate, training=training) 56 | 57 | num_dense_layers = len(kernels) 58 | 59 | h = x 60 | 61 | # MLP Encoder 62 | if kernels is not None: 63 | 64 | for i, (kernel, bias) in enumerate(zip(kernels, biases)): 65 | # SparseTensor is usually used for one-hot node features (For example, feature-less nodes.) 66 | if isinstance(h, tf.sparse.SparseTensor): 67 | h = tf.sparse.sparse_dense_matmul(h, kernel) 68 | else: 69 | h = h @ kernel 70 | 71 | if bias is not None: 72 | h += bias 73 | 74 | if i < num_dense_layers - 1: 75 | if dense_activation is not None: 76 | h = dense_activation(h) 77 | if training and dense_drop_rate > 0.0: 78 | h = tf.compat.v2.nn.dropout(h, dense_drop_rate) 79 | else: 80 | if training and last_dense_drop_rate > 0.0: 81 | h = tf.compat.v2.nn.dropout(h, last_dense_drop_rate) 82 | 83 | output = h 84 | 85 | for i in range(k): 86 | output = normed_sparse_adj @ output 87 | output = output * (1.0 - alpha) + h * alpha 88 | 89 | if activation is not None: 90 | output = activation(output) 91 | 92 | return output 93 | 94 | -------------------------------------------------------------------------------- /tf_geometric/nn/conv/gat.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | import tensorflow as tf 3 | 4 | # from tf_geometric.sparse.sparse_adj import SparseAdj 5 | from tf_sparse import SparseMatrix 6 | 7 | from tf_geometric.utils.graph_utils import add_self_loop_edge 8 | import tf_sparse as tfs 9 | 10 | 11 | # follow Transformer-Style Attention 12 | # Attention is all you need 13 | def gat(x, edge_index, 14 | query_kernel, query_bias, query_activation, 15 | key_kernel, key_bias, key_activation, 16 | kernel, bias=None, activation=None, num_heads=1, 17 | split_value_heads=True, edge_drop_rate=0.0, training=False): 18 | """ 19 | 20 | :param x: Tensor, shape: [num_nodes, num_features], node features 21 | :param edge_index: Tensor, shape: [2, num_edges], edge information 22 | :param query_kernel: Tensor, shape: [num_features, num_query_features], weight for Q in attention 23 | :param query_bias: Tensor, shape: [num_query_features], bias for Q in attention 24 | :param query_activation: Activation function for Q in attention. 25 | :param key_kernel: Tensor, shape: [num_features, num_key_features], weight for K in attention 26 | :param key_bias: Tensor, shape: [num_key_features], bias for K in attention 27 | :param key_activation: Activation function for K in attention. 28 | :param kernel: Tensor, shape: [num_features, num_output_features], weight 29 | :param bias: Tensor, shape: [num_output_features], bias 30 | :param activation: Activation function to use. 31 | :param num_heads: Number of attention heads. 32 | :param split_value_heads: Boolean. If true, split V as value attention heads, and then concatenate them as output. 33 | Else, num_heads different V are used as value attention heads, and the mean of them are used as output. 34 | :param edge_drop_rate: Dropout rate of attention weights. 35 | :param training: Python boolean indicating whether the layer should behave in 36 | training mode (adding dropout) or in inference mode (doing nothing). 37 | :return: Updated node features (x), shape: [num_nodes, num_output_features] 38 | """ 39 | 40 | num_nodes = tfs.shape(x)[0] 41 | 42 | # self-attention 43 | edge_index, edge_weight = add_self_loop_edge(edge_index, num_nodes) 44 | 45 | row, col = edge_index[0], edge_index[1] 46 | 47 | x_is_sparse = isinstance(x, tf.sparse.SparseTensor) 48 | 49 | if x_is_sparse: 50 | Q = tf.sparse.sparse_dense_matmul(x, query_kernel) 51 | else: 52 | Q = x @ query_kernel 53 | Q += query_bias 54 | if query_activation is not None: 55 | Q = query_activation(Q) 56 | Q = tf.gather(Q, row) 57 | 58 | if x_is_sparse: 59 | K = tf.sparse.sparse_dense_matmul(x, key_kernel) 60 | else: 61 | K = x @ key_kernel 62 | K += key_bias 63 | if key_activation is not None: 64 | K = key_activation(K) 65 | K = tf.gather(K, col) 66 | 67 | if x_is_sparse: 68 | V = tf.sparse.sparse_dense_matmul(x, kernel) 69 | else: 70 | V = x @ kernel 71 | 72 | # xxxxx_ denotes the multi-head style stuff 73 | Q_ = tf.concat(tf.split(Q, num_heads, axis=-1), axis=0) 74 | K_ = tf.concat(tf.split(K, num_heads, axis=-1), axis=0) 75 | # splited queries and keys are modeled as virtual vertices 76 | qk_edge_index_ = tf.concat([edge_index + i * num_nodes for i in range(num_heads)], axis=1) 77 | 78 | scale = tf.math.sqrt(tf.cast(tf.shape(Q_)[-1], tf.float32)) 79 | att_score_ = tf.reduce_sum(Q_ * K_, axis=-1) / scale 80 | 81 | # new implementation based on SparseAdj 82 | num_nodes_ = num_nodes * num_heads 83 | sparse_att_adj = SparseMatrix(qk_edge_index_, att_score_, [num_nodes_, num_nodes_]) \ 84 | .segment_softmax(axis=-1) \ 85 | .dropout(edge_drop_rate, training=training) 86 | 87 | V_ = tf.concat(tf.split(V, num_heads, axis=-1), axis=0) 88 | 89 | h_ = sparse_att_adj @ V_ 90 | 91 | # old implementation 92 | # normed_att_score_ = segment_softmax(att_score_, qk_edge_index_[0], num_nodes * num_heads) 93 | # 94 | # if training and drop_rate > 0.0: 95 | # normed_att_score_ = tf.compat.v2.nn.dropout(normed_att_score_, drop_rate) 96 | # 97 | # if split_value_heads: 98 | # V_ = tf.concat(tf.split(V, num_heads, axis=-1), axis=0) 99 | # edge_index_ = qk_edge_index_ 100 | # else: 101 | # V_ = V 102 | # edge_index_ = tf.tile(edge_index, [1, num_heads]) 103 | # 104 | # h_ = aggregate_neighbors( 105 | # V_, edge_index_, normed_att_score_, 106 | # gcn_mapper, 107 | # sum_reducer, 108 | # identity_updater 109 | # ) 110 | 111 | if split_value_heads: 112 | h = tf.concat(tf.split(h_, num_heads, axis=0), axis=-1) 113 | else: 114 | h = tf.add_n(tf.split(h_, num_heads, axis=0)) / num_heads 115 | 116 | if bias is not None: 117 | h += bias 118 | 119 | if activation is not None: 120 | h = activation(h) 121 | 122 | return h 123 | -------------------------------------------------------------------------------- /tf_geometric/nn/conv/gin.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | import tensorflow as tf 3 | # from tf_geometric.sparse.sparse_adj import SparseAdj 4 | from tf_sparse import SparseMatrix 5 | 6 | 7 | def gin_updater(x, reduced_neighbor_msg, eps): 8 | return x * (1.0 + eps) + reduced_neighbor_msg 9 | 10 | 11 | def gin(x, edge_index, mlp_model, eps=0.0, training=None): 12 | """ 13 | 14 | :param x: Tensor, shape: [num_nodes, num_features], node features 15 | :param edge_index: Tensor, shape: [2, num_edges], edge information 16 | :param mlp_model: A neural network (multi-layer perceptrons). 17 | :param eps: float, optional, (default: :obj:`0.`). 18 | :param training: Whether currently executing in training or inference mode. 19 | :return: Updated node features (x), shape: [num_nodes, num_output_features] 20 | """ 21 | 22 | # h = aggregate_neighbors( 23 | # x, edge_index, None, 24 | # identity_mapper, 25 | # sum_reducer, 26 | # identity_updater 27 | # ) 28 | 29 | # h = gin_updater(x, h, eps) 30 | 31 | num_nodes = tf.shape(x)[0] 32 | sparse_adj = SparseMatrix(edge_index, shape=[num_nodes, num_nodes]) 33 | 34 | neighbor_h = sparse_adj @ x 35 | h = x * (1.0 + eps) + neighbor_h 36 | h = mlp_model(h, training=training) 37 | 38 | return h 39 | -------------------------------------------------------------------------------- /tf_geometric/nn/conv/le_conv.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | import tensorflow as tf 3 | 4 | 5 | def le_conv(x, edge_index, edge_weight, 6 | self_kernel, self_bias, 7 | aggr_self_kernel, aggr_self_bias, 8 | aggr_neighbor_kernel, aggr_neighbor_bias, activation=None): 9 | """ 10 | Functional API for LeConv in ASAP. 11 | 12 | h_i = activation(x_i @ self_kernel + \sum_{j} (x_i @ aggr_self_kernel - x_j @ aggr_neighbor_kernel)) 13 | 14 | :param x: Tensor, shape: [num_nodes, num_features], node features 15 | :param edge_index: Tensor, shape: [2, num_edges], edge information 16 | :param edge_weight: Tensor or None, shape: [num_edges] 17 | :param self_kernel: Please look at the formula above. 18 | :param aggr_self_kernel: Please look at the formula above. 19 | :param aggr_neighbor_kernel: Please look at the formula above. 20 | :param activation: Activation function to use. 21 | :return: Updated node features (x), shape: [num_nodes, num_output_features] 22 | """ 23 | 24 | if edge_weight is None: 25 | edge_weight = tf.ones([tf.shape(edge_index)[1]], dtype=tf.float32) 26 | 27 | num_nodes = tf.shape(x)[0] 28 | self_h = x @ self_kernel 29 | if self_bias is not None: 30 | self_h += self_bias 31 | 32 | aggr_self_h = x @ aggr_self_kernel 33 | if aggr_self_bias is not None: 34 | aggr_self_h += aggr_self_bias 35 | 36 | aggr_neighbor_h = x @ aggr_neighbor_kernel 37 | if aggr_neighbor_bias is not None: 38 | aggr_neighbor_h += aggr_neighbor_bias 39 | 40 | row, col = edge_index[0], edge_index[1] 41 | 42 | repeated_aggr_self_h = tf.gather(aggr_self_h, col) 43 | repeated_aggr_neighbor_h = tf.gather(aggr_neighbor_h, col) 44 | repeated_aggr_h = (repeated_aggr_self_h - repeated_aggr_neighbor_h) * tf.expand_dims(edge_weight, axis=-1) 45 | aggr_h = tf.math.unsorted_segment_sum(repeated_aggr_h, row, num_nodes) 46 | 47 | h = self_h + aggr_h 48 | 49 | if activation is not None: 50 | h = activation(h) 51 | 52 | return h 53 | 54 | 55 | 56 | -------------------------------------------------------------------------------- /tf_geometric/nn/conv/sgc.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | import tensorflow as tf 4 | # from tf_geometric.sparse.sparse_adj import SparseAdj 5 | from tf_sparse import SparseMatrix 6 | 7 | from tf_geometric.nn.conv.gcn import gcn_norm_adj 8 | import tf_sparse as tfs 9 | 10 | def sgc(x, edge_index, edge_weight, k, kernel, bias=None, activation=None, renorm=True, improved=False, cache=None): 11 | """ 12 | Functional API for Simple Graph Convolution (SGC). 13 | 14 | :param x: Tensor, shape: [num_nodes, num_features], node features 15 | :param edge_index: Tensor, shape: [2, num_edges], edge information 16 | :param edge_weight: Tensor or None, shape: [num_edges] 17 | :param k: Number of hops.(default: :obj:`1`) 18 | :param kernel: Tensor, shape: [num_features, num_output_features], weight. 19 | :param bias: Tensor, shape: [num_output_features], bias. 20 | :param activation: Activation function to use. 21 | :param renorm: Whether use renormalization trick (https://arxiv.org/pdf/1609.02907.pdf). 22 | :param improved: Whether use improved GCN or not. 23 | :param cache: A dict for caching A' for GCN. Different graph should not share the same cache dict. 24 | :return: Updated node features (x), shape: [num_nodes, num_features] 25 | """ 26 | num_nodes = tfs.shape(x)[0] 27 | sparse_adj = SparseMatrix(edge_index, edge_weight, [num_nodes, num_nodes]) 28 | normed_sparse_adj = gcn_norm_adj(sparse_adj, renorm=renorm, improved=improved, cache=cache) 29 | 30 | # SparseTensor is usually used for one-hot node features (For example, feature-less nodes.) 31 | if isinstance(x, tf.sparse.SparseTensor): 32 | h = tf.sparse.sparse_dense_matmul(x, kernel) 33 | else: 34 | h = x @ kernel 35 | 36 | for _ in range(k): 37 | h = normed_sparse_adj @ h 38 | 39 | # updated_edge_index, normed_edge_weight = gcn_norm_edge(edge_index, x.shape[0], edge_weight, 40 | # renorm, improved, cache) 41 | # 42 | # h = x 43 | # for _ in range(k): 44 | # h = aggregate_neighbors( 45 | # h, 46 | # updated_edge_index, 47 | # normed_edge_weight, 48 | # gcn_mapper, 49 | # sum_reducer, 50 | # identity_updater 51 | # ) 52 | 53 | # h = h @ kernel 54 | 55 | if bias is not None: 56 | h += bias 57 | 58 | if activation is not None: 59 | h = activation(h) 60 | 61 | return h 62 | 63 | -------------------------------------------------------------------------------- /tf_geometric/nn/conv/ssgc.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | import tensorflow as tf 4 | # from tf_geometric.sparse.sparse_adj import SparseAdj 5 | from tf_sparse import SparseMatrix 6 | 7 | from tf_geometric.nn.conv.gcn import gcn_norm_adj 8 | import tf_sparse as tfs 9 | 10 | 11 | def ssgc(x, edge_index, edge_weight, 12 | kernels=None, biases=None, 13 | k=10, alpha=0.1, 14 | dense_activation=tf.nn.relu, 15 | activation=None, 16 | dense_drop_rate=0.0, 17 | last_dense_drop_rate=0.0, 18 | edge_drop_rate=0.0, 19 | cache=None, training=False): 20 | 21 | """ 22 | Functional API for Simple Spectral Graph Convolution (SSGC / S^2GC). 23 | Paper URL: https://openreview.net/forum?id=CYO5T-YjWZV 24 | 25 | :param x: Tensor, shape: [num_nodes, num_features], node features 26 | :param edge_index: Tensor, shape: [2, num_edges], edge information 27 | :param edge_weight: Tensor or None, shape: [num_edges] 28 | :param kernels: List[Tensor], shape of each Tensor: [num_features, num_output_features], weights 29 | :param biases: List[Tensor], shape of each Tensor: [num_output_features], biases 30 | :param dense_activation: Activation function to use for the dense layers, 31 | except for the last dense layer, which will not be activated. 32 | :param activation: Activation function to use for the output. 33 | :param k: Number of propagation power iterations. 34 | :param alpha: Teleport Probability. 35 | :param dense_drop_rate: Dropout rate for the output of every dense layer (except the last one). 36 | :param last_dense_drop_rate: Dropout rate for the output of the last dense layer. 37 | last_dense_drop_rate is usually set to 0.0 for classification tasks. 38 | :param edge_drop_rate: Dropout rate for the edges/adj used for propagation. 39 | :param cache: A dict for caching A' for GCN. Different graph should not share the same cache dict. 40 | To use @tf_utils.function with gcn, you should cache the noremd edge information before the first call of the gcn. 41 | 42 | - (1) If you're using OOP APIs tfg.layers.GCN: 43 | 44 | gcn_layer.build_cache_for_graph(graph) 45 | 46 | - (2) If you're using functional API tfg.nn.gcn: 47 | 48 | from tf_geometric.nn.conv.gcn import gcn_build_cache_for_graph 49 | gcn_build_cache_for_graph(graph) 50 | 51 | :param training: Python boolean indicating whether the layer should behave in 52 | training mode (adding dropout) or in inference mode (doing nothing). 53 | :return: Updated node features (x), shape: [num_nodes, num_output_features] 54 | """ 55 | 56 | num_nodes = tfs.shape(x)[0] 57 | 58 | # updated_edge_index, normed_edge_weight = gcn_norm_edge(edge_index, num_nodes, edge_weight, cache=cache) 59 | sparse_adj = SparseMatrix(edge_index, edge_weight, [num_nodes, num_nodes]) 60 | normed_sparse_adj = gcn_norm_adj(sparse_adj, cache=cache)\ 61 | .dropout(edge_drop_rate, training=training) 62 | 63 | h = x 64 | 65 | # MLP Encoder 66 | if kernels is not None: 67 | 68 | num_dense_layers = len(kernels) 69 | 70 | for i, (kernel, bias) in enumerate(zip(kernels, biases)): 71 | # SparseTensor is usually used for one-hot node features (For example, feature-less nodes.) 72 | if isinstance(h, tf.sparse.SparseTensor): 73 | h = tf.sparse.sparse_dense_matmul(h, kernel) 74 | else: 75 | h = h @ kernel 76 | 77 | if bias is not None: 78 | h += bias 79 | 80 | if i < num_dense_layers - 1: 81 | if dense_activation is not None: 82 | h = dense_activation(h) 83 | if training and dense_drop_rate > 0.0: 84 | h = tf.compat.v2.nn.dropout(h, dense_drop_rate) 85 | else: 86 | if training and last_dense_drop_rate > 0.0: 87 | h = tf.compat.v2.nn.dropout(h, last_dense_drop_rate) 88 | 89 | # propagation 90 | output = h * alpha 91 | 92 | for _ in range(k): 93 | h = normed_sparse_adj @ h 94 | output += (1 - alpha) * h / k 95 | 96 | if activation is not None: 97 | output = activation(output) 98 | 99 | return output 100 | 101 | -------------------------------------------------------------------------------- /tf_geometric/nn/conv/tagcn.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | import tensorflow as tf 3 | # from tf_geometric.sparse.sparse_adj import SparseAdj 4 | from tf_sparse import SparseMatrix 5 | 6 | from tf_geometric.nn.conv.gcn import gcn_norm_adj 7 | import tf_sparse as tfs 8 | 9 | 10 | def tagcn(x, edge_index, edge_weight, k, kernel, bias=None, activation=None, renorm=False, improved=False, cache=None): 11 | """ 12 | Functional API for Topology Adaptive Graph Convolutional Network (TAGCN). 13 | 14 | :param x: Tensor, shape: [num_nodes, num_features], node features. 15 | :param edge_index: Tensor, shape: [2, num_edges], edge information. 16 | :param edge_weight: Tensor or None, shape: [num_edges]. 17 | :param k: Number of hops.(default: :obj:`3`) 18 | :param kernel: Tensor, shape: [num_features, num_output_features], weight. 19 | :param bias: Tensor, shape: [num_output_features], bias. 20 | :param activation: Activation function to use. 21 | :param renorm: Whether use renormalization trick (https://arxiv.org/pdf/1609.02907.pdf). 22 | :param improved: Whether use improved GCN or not. 23 | :param cache: A dict for caching A' for GCN. Different graph should not share the same cache dict. 24 | :return: Updated node features (x), shape: [num_nodes, num_output_features] 25 | """ 26 | 27 | num_nodes = tfs.shape(x)[0] 28 | 29 | sparse_adj = SparseMatrix(edge_index, edge_weight, [num_nodes, num_nodes]) 30 | normed_sparse_adj = gcn_norm_adj(sparse_adj, renorm=renorm, improved=improved, cache=cache) 31 | 32 | if isinstance(x, tf.sparse.SparseTensor): 33 | x = tf.sparse.to_dense(x) 34 | elif isinstance(x, tfs.SparseMatrix): 35 | x = x.to_dense() 36 | 37 | xs = [x] 38 | for _ in range(k): 39 | h = normed_sparse_adj @ xs[-1] 40 | xs.append(h) 41 | 42 | h = tf.concat(xs, axis=-1) 43 | 44 | out = h @ kernel 45 | if bias is not None: 46 | out += bias 47 | 48 | if activation is not None: 49 | out = activation(out) 50 | 51 | return out 52 | -------------------------------------------------------------------------------- /tf_geometric/nn/kernel/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CrawlScript/tf_geometric/b7c40d9005c60b27a7e18b919d32418c5548252e/tf_geometric/nn/kernel/__init__.py -------------------------------------------------------------------------------- /tf_geometric/nn/kernel/map_reduce.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | import tensorflow as tf 4 | from tf_geometric.nn.kernel.segment import segment_op_with_pad 5 | 6 | 7 | def identity_mapper(repeated_x, neighbor_x, edge_weight=None): 8 | return neighbor_x 9 | 10 | 11 | def neighbor_count_mapper(repeated_x, neighbor_x, edge_weight=None): 12 | return tf.ones([neighbor_x.shape[0], 1]) 13 | 14 | 15 | def sum_reducer(neighbor_msg, node_index, num_nodes=None): 16 | return tf.math.unsorted_segment_sum(neighbor_msg, node_index, num_segments=num_nodes) 17 | 18 | 19 | def sum_updater(x, reduced_neighbor_msg): 20 | return x + reduced_neighbor_msg 21 | 22 | 23 | def identity_updater(x, reduced_neighbor_msg): 24 | return reduced_neighbor_msg 25 | 26 | 27 | def mean_reducer(neighbor_msg, node_index, num_nodes=None): 28 | return tf.math.unsorted_segment_mean(neighbor_msg, node_index, num_segments=num_nodes) 29 | 30 | 31 | if tf.__version__[0] == "1": 32 | def max_reducer(neighbor_msg, node_index, num_nodes=None): 33 | if num_nodes is None: 34 | num_nodes = tf.reduce_max(node_index) + 1 35 | max_neighbor_msg = segment_op_with_pad(tf.math.segment_max, neighbor_msg, node_index, num_segments=num_nodes) 36 | return max_neighbor_msg 37 | else: 38 | def max_reducer(neighbor_msg, node_index, num_nodes=None): 39 | if num_nodes is None: 40 | num_nodes = tf.reduce_max(node_index) + 1 41 | max_neighbor_msg = tf.math.unsorted_segment_max(neighbor_msg, node_index, num_segments=num_nodes) 42 | return max_neighbor_msg 43 | 44 | 45 | def aggregate_neighbors(x, edge_index, edge_weight=None, mapper=identity_mapper, 46 | reducer=sum_reducer, updater=sum_updater, num_nodes=None): 47 | """ 48 | :param x: 49 | :param edge_index: 50 | :param mapper: (features_of_node, features_of_neighbor_node, edge_weight) => neighbor_msg 51 | :param reducer: (neighbor_msg, node_index) => reduced_neighbor_msg 52 | :param updater: (features_of_node, reduced_neighbor_msg, num_nodes) => aggregated_node_features 53 | :param num_nodes: Number of nodes. 54 | :return: 55 | """ 56 | 57 | if tf.shape(edge_index)[0] == 0: 58 | return x 59 | 60 | row, col = edge_index[0], edge_index[1] 61 | 62 | repeated_x = tf.gather(x, row) 63 | neighbor_x = tf.gather(x, col) 64 | 65 | neighbor_msg = mapper(repeated_x, neighbor_x, edge_weight=edge_weight) 66 | 67 | if num_nodes is None: 68 | num_nodes = tf.shape(x)[0] 69 | 70 | reduced_msg = reducer(neighbor_msg, row, num_nodes=num_nodes) 71 | udpated = updater(x, reduced_msg) 72 | 73 | return udpated 74 | 75 | 76 | 77 | 78 | 79 | 80 | -------------------------------------------------------------------------------- /tf_geometric/nn/kernel/segment.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | import tensorflow as tf 3 | 4 | 5 | def segment_op_with_pad(segment_op, data, segment_ids, num_segments): 6 | 7 | sort_index = tf.argsort(segment_ids) 8 | sorted_segment_ids = tf.gather(segment_ids, sort_index) 9 | sorted_data = tf.gather(data, sort_index) 10 | 11 | reduced_data = segment_op(sorted_data, sorted_segment_ids) 12 | num_paddings = num_segments - tf.shape(reduced_data)[0] 13 | 14 | pad_shape = tf.concat([ 15 | [num_paddings], 16 | tf.shape(data)[1:] 17 | ], axis=0) 18 | pads = tf.zeros(pad_shape, dtype=reduced_data.dtype) 19 | outputs = tf.concat( 20 | [reduced_data, pads], 21 | axis=0 22 | ) 23 | return outputs 24 | 25 | 26 | def segment_softmax(data, segment_ids, num_segments): 27 | max_values = tf.math.unsorted_segment_max(data, segment_ids, num_segments=num_segments) 28 | gathered_max_values = tf.gather(max_values, segment_ids) 29 | exp = tf.exp(data - tf.stop_gradient(gathered_max_values)) 30 | denominator = tf.math.unsorted_segment_sum(exp, segment_ids, num_segments=num_segments) + 1e-8 31 | gathered_denominator = tf.gather(denominator, segment_ids) 32 | score = exp / gathered_denominator 33 | return score 34 | 35 | 36 | def segment_count(index, num_segments=None): 37 | data = tf.ones_like(index) 38 | if num_segments is None: 39 | num_segments = tf.reduce_max(index) + 1 40 | return tf.math.unsorted_segment_sum(data, index, num_segments=num_segments) 41 | -------------------------------------------------------------------------------- /tf_geometric/nn/pool/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CrawlScript/tf_geometric/b7c40d9005c60b27a7e18b919d32418c5548252e/tf_geometric/nn/pool/__init__.py -------------------------------------------------------------------------------- /tf_geometric/nn/pool/common_pool.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | import tensorflow as tf 3 | 4 | from tf_geometric.nn.kernel.segment import segment_count, segment_op_with_pad 5 | 6 | 7 | def mean_pool(x, node_graph_index, num_graphs=None): 8 | if num_graphs is None: 9 | num_graphs = tf.reduce_max(node_graph_index) + 1 10 | num_nodes_of_graphs = segment_count(node_graph_index, num_segments=num_graphs) 11 | sum_x = tf.math.unsorted_segment_sum(x, node_graph_index, num_segments=num_graphs) 12 | return sum_x / (tf.cast(tf.expand_dims(num_nodes_of_graphs, -1), tf.float32) + 1e-8) 13 | 14 | 15 | def sum_pool(x, node_graph_index, num_graphs=None): 16 | if num_graphs is None: 17 | num_graphs = tf.reduce_max(node_graph_index) + 1 18 | sum_x = tf.math.unsorted_segment_sum(x, node_graph_index, num_segments=num_graphs) 19 | return sum_x 20 | 21 | 22 | if tf.__version__[0] == "1": 23 | 24 | def max_pool(x, node_graph_index, num_graphs=None): 25 | if num_graphs is None: 26 | num_graphs = tf.reduce_max(node_graph_index) + 1 27 | # max_x = tf.math.unsorted_segment_max(x, node_graph_index, num_segments=num_graphs) 28 | max_x = segment_op_with_pad(tf.math.segment_max, x, node_graph_index, num_segments=num_graphs) 29 | return max_x 30 | 31 | 32 | def min_pool(x, node_graph_index, num_graphs=None): 33 | if num_graphs is None: 34 | num_graphs = tf.reduce_max(node_graph_index) + 1 35 | # min_x = tf.math.unsorted_segment_min(x, node_graph_index, num_segments=num_graphs) 36 | min_x = segment_op_with_pad(tf.math.segment_min, x, node_graph_index, num_segments=num_graphs) 37 | return min_x 38 | 39 | else: 40 | 41 | def max_pool(x, node_graph_index, num_graphs=None): 42 | if num_graphs is None: 43 | num_graphs = tf.reduce_max(node_graph_index) + 1 44 | max_x = tf.math.unsorted_segment_max(x, node_graph_index, num_segments=num_graphs) 45 | return max_x 46 | 47 | 48 | def min_pool(x, node_graph_index, num_graphs=None): 49 | if num_graphs is None: 50 | num_graphs = tf.reduce_max(node_graph_index) + 1 51 | min_x = tf.math.unsorted_segment_min(x, node_graph_index, num_segments=num_graphs) 52 | return min_x 53 | 54 | 55 | 56 | 57 | 58 | 59 | -------------------------------------------------------------------------------- /tf_geometric/nn/pool/sag_pool.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | import tensorflow as tf 3 | from tf_geometric.data.graph import BatchGraph 4 | from tf_geometric.nn.pool.topk_pool import topk_pool 5 | 6 | 7 | def sag_pool(x, edge_index, edge_weight, node_graph_index, 8 | score_gnn, k=None, ratio=None, 9 | score_activation=None, training=None, cache=None): 10 | """ 11 | Functional API for SAGPool 12 | 13 | :param x: Tensor, shape: [num_nodes, num_features], node features 14 | :param edge_index: Tensor, shape: [2, num_edges], edge information 15 | :param edge_weight: Tensor or None, shape: [num_edges] 16 | :param node_graph_index: Tensor/NDArray, shape: [num_nodes], graph index for each node 17 | :param score_gnn: A GNN model to score nodes for the pooling, [x, edge_index, edge_weight] => node_score. 18 | :param k: Keep top k targets for each source 19 | :param ratio: Keep num_targets * ratio targets for each source 20 | :param score_activation: Activation to use for node_score before multiplying node_features with node_score 21 | :param training: Python boolean indicating whether the layer should behave in 22 | training mode (adding dropout) or in inference mode (doing nothing). 23 | :param cache: A dict for caching A' for GCN. Different graph should not share the same cache dict. 24 | :return: [pooled_x, pooled_edge_index, pooled_edge_weight, pooled_node_graph_index] 25 | """ 26 | 27 | if cache is None: 28 | node_score = score_gnn([x, edge_index, edge_weight], training=training) 29 | else: 30 | node_score = score_gnn([x, edge_index, edge_weight], training=training, cache=cache) 31 | 32 | topk_node_index = topk_pool(node_graph_index, node_score, k=k, ratio=ratio) 33 | 34 | if score_activation is not None: 35 | node_score = score_activation(node_score) 36 | 37 | pooled_graph = BatchGraph( 38 | x=x * node_score, 39 | edge_index=edge_index, 40 | node_graph_index=node_graph_index, 41 | edge_graph_index=None, 42 | edge_weight=edge_weight 43 | ).sample_new_graph_by_node_index(topk_node_index) 44 | 45 | return pooled_graph.x, pooled_graph.edge_index, pooled_graph.edge_weight, pooled_graph.node_graph_index 46 | -------------------------------------------------------------------------------- /tf_geometric/nn/pool/set2set.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | import tensorflow as tf 3 | from tf_geometric.nn.kernel.segment import segment_softmax 4 | 5 | from tf_geometric.utils.union_utils import union_len 6 | 7 | 8 | def set2set(x, node_graph_index, lstm, num_iterations, training=None): 9 | """ 10 | Functional API for Set2Set 11 | 12 | :param x: Tensor, shape: [num_nodes, num_features], node features 13 | :param node_graph_index: Tensor/NDArray, shape: [num_nodes], graph index for each node 14 | :param lstm: A lstm model. 15 | :param num_iterations: Number of iterations for attention. 16 | :param training: Python boolean indicating whether the layer should behave in 17 | training mode (adding dropout) or in inference mode (doing nothing). 18 | :return: Graph features, shape: [num_graphs, num_node_features * 2] 19 | """ 20 | 21 | num_graphs = tf.reduce_max(node_graph_index) + 1 22 | 23 | lstm_units = tf.shape(x)[-1] 24 | 25 | h = tf.zeros([num_graphs, lstm_units * 2], dtype=tf.float32) 26 | initial_state = [tf.zeros([1, lstm_units], dtype=tf.float32), tf.zeros([1, lstm_units], dtype=tf.float32)] 27 | 28 | for _ in range(num_iterations): 29 | 30 | h = tf.expand_dims(h, axis=0) 31 | h, state_h, state_c = lstm(h, initial_state=initial_state, training=training) 32 | initial_state = [state_h, state_c] 33 | h = tf.squeeze(h, axis=0) 34 | 35 | repeated_h = tf.gather(h, node_graph_index) 36 | # attention 37 | att_score = tf.reduce_sum(x * repeated_h, axis=-1, keepdims=True) 38 | normed_att_score = segment_softmax(att_score, node_graph_index, num_graphs) 39 | att_h = tf.math.unsorted_segment_sum(x * normed_att_score, node_graph_index, num_graphs) 40 | h = tf.concat([h, att_h], axis=-1) 41 | 42 | return h 43 | -------------------------------------------------------------------------------- /tf_geometric/nn/pool/sort_pool.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | import tensorflow as tf 3 | from tf_geometric.data.graph import BatchGraph 4 | from tf_geometric.nn.pool.topk_pool import topk_pool 5 | 6 | 7 | def sort_pool(x, edge_index, edge_weight, node_graph_index, 8 | k=None, ratio=None, 9 | sort_index=-1, training=None): 10 | """ 11 | Functional API for SortPool "An End-to-End Deep Learning Architecture for Graph Classification" 12 | 13 | :param x: Tensor, shape: [num_nodes, num_features], node features 14 | :param edge_index: Tensor, shape: [2, num_edges], edge information 15 | :param edge_weight: Tensor or None, shape: [num_edges] 16 | :param node_graph_index: Tensor/NDArray, shape: [num_nodes], graph index for each node 17 | :param k: Keep top k targets for each source 18 | :param ratio: Keep num_targets * ratio targets for each source 19 | :param sort_index: The sort_index_th index of the last axis will used for sort. 20 | :param training: Python boolean indicating whether the layer should behave in 21 | training mode (adding dropout) or in inference mode (doing nothing). 22 | :return: [pooled_x, pooled_edge_index, pooled_edge_weight, pooled_node_graph_index] 23 | """ 24 | 25 | score = x[:, sort_index] 26 | topk_node_index = topk_pool(node_graph_index, score, k=k, ratio=ratio) 27 | 28 | pooled_graph = BatchGraph( 29 | x=x, 30 | edge_index=edge_index, 31 | node_graph_index=node_graph_index, 32 | edge_graph_index=None, 33 | edge_weight=edge_weight 34 | ).sample_new_graph_by_node_index(topk_node_index) 35 | 36 | return pooled_graph.x, pooled_graph.edge_index, pooled_graph.edge_weight, pooled_graph.node_graph_index -------------------------------------------------------------------------------- /tf_geometric/nn/pool/topk_pool.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | import tensorflow as tf 3 | from tf_geometric.utils.union_utils import union_len 4 | 5 | 6 | def topk_pool(source_index, score, k=None, ratio=None): 7 | """ 8 | 9 | :param source_index: index of source node (of edge) or source graph (of node) 10 | :param score: 1-D Array 11 | :param k: Keep top k targets for each source 12 | :param ratio: Keep num_targets * ratio targets for each source 13 | :return: sampled_edge_index, sampled_edge_score, sample_index 14 | """ 15 | 16 | if k is None and ratio is None: 17 | raise Exception("you should provide either k or ratio for topk_pool") 18 | elif k is not None and ratio is not None: 19 | raise Exception("you should provide either k or ratio for topk_pool, not both of them") 20 | 21 | # currently, we consider the source_index is not sorted 22 | # the option is preserved for future performance optimization 23 | source_index_sorted = False 24 | 25 | if source_index_sorted: 26 | sorted_source_index = source_index 27 | # sort score by source_index 28 | sorted_score = score 29 | else: 30 | source_index_perm = tf.argsort(source_index) 31 | sorted_source_index = tf.gather(source_index, source_index_perm) 32 | sorted_score = tf.gather(score, source_index_perm) 33 | 34 | sorted_score = tf.reshape(sorted_score, [-1]) 35 | 36 | num_targets = tf.shape(sorted_source_index)[0] 37 | target_ones = tf.ones([num_targets], dtype=tf.int32) 38 | num_targets_for_sources = tf.math.segment_sum(target_ones, sorted_source_index) 39 | # number of columns for score matrix 40 | num_cols = tf.reduce_max(num_targets_for_sources) 41 | 42 | # max index of source + 1 43 | num_seen_sources = tf.shape(num_targets_for_sources)[0] 44 | 45 | min_score = tf.reduce_min(sorted_score) 46 | 47 | num_targets_before = tf.concat([ 48 | tf.zeros([1], dtype=tf.int32), 49 | tf.math.cumsum(num_targets_for_sources)[:-1] 50 | ], axis=0) 51 | 52 | target_index_for_source = tf.range(0, num_targets) - tf.gather(num_targets_before, sorted_source_index) 53 | 54 | score_matrix = tf.cast(tf.fill([num_seen_sources, num_cols], min_score - 1.0), dtype=tf.float32) 55 | score_index = tf.stack([sorted_source_index, target_index_for_source], axis=1) 56 | score_matrix = tf.tensor_scatter_nd_update(score_matrix, score_index, sorted_score) 57 | 58 | sort_index = tf.argsort(score_matrix, axis=-1, direction="DESCENDING") 59 | 60 | if k is not None: 61 | node_k = tf.math.minimum( 62 | tf.cast(tf.fill([num_seen_sources], k), dtype=tf.int32), 63 | num_targets_for_sources 64 | ) 65 | else: 66 | node_k = tf.cast( 67 | tf.math.ceil(tf.cast(num_targets_for_sources, dtype=tf.float32) * tf.cast(ratio, dtype=tf.float32)), 68 | dtype=tf.int32 69 | ) 70 | 71 | row, col = tf.meshgrid(tf.range(num_seen_sources), tf.range(tf.reduce_max(node_k)), indexing="ij") 72 | row = tf.reshape(row, [-1]) 73 | col = tf.reshape(col, [-1]) 74 | repeated_k = tf.gather(node_k, row) 75 | k_mask = tf.less(col, repeated_k) 76 | 77 | row = tf.boolean_mask(row, k_mask) 78 | col = tf.boolean_mask(col, k_mask) 79 | 80 | sample_col_index = tf.gather_nd(sort_index, tf.stack([row, col], axis=1)) 81 | 82 | topk_index = tf.gather(num_targets_before, row) + sample_col_index 83 | 84 | if source_index_sorted: 85 | return topk_index 86 | else: 87 | return tf.gather(source_index_perm, topk_index) 88 | -------------------------------------------------------------------------------- /tf_geometric/nn/sampling/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CrawlScript/tf_geometric/b7c40d9005c60b27a7e18b919d32418c5548252e/tf_geometric/nn/sampling/__init__.py -------------------------------------------------------------------------------- /tf_geometric/nn/sampling/drop_edge.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | import tensorflow as tf 3 | import numpy as np 4 | 5 | 6 | def drop_edge(inputs, rate=0.5, force_undirected=False, training=None): 7 | """ 8 | 9 | :param inputs: List of edge_index and other edge attributes [edge_index, edge_attr, ...] 10 | :param rate: dropout rate 11 | :param force_undirected: If set to `True`, will either 12 | drop or keep both edges of an undirected edge. 13 | :param training: Python boolean indicating whether the layer should behave in 14 | training mode (adding dropout) or in inference mode (doing nothing). 15 | :return: List of dropped edge_index and other dropped edge attributes 16 | """ 17 | 18 | if not training: 19 | return inputs 20 | 21 | if rate < 0.0 or rate > 1.0: 22 | raise ValueError('Dropout probability has to be between 0 and 1, ' 23 | 'but got {}'.format(rate)) 24 | 25 | edge_index, *edge_attrs = inputs 26 | 27 | edge_index_is_tensor = tf.is_tensor(edge_index) 28 | if not edge_index_is_tensor: 29 | edge_index = tf.convert_to_tensor(edge_index) 30 | 31 | row, col = edge_index[0], edge_index[1] 32 | if force_undirected: 33 | index = tf.where(tf.less(row, col)) 34 | index = tf.boolean_mask(index, tf.greater(tf.nn.dropout(tf.ones_like(index, dtype=tf.float32), rate), 0)) 35 | dropped_edge_index = tf.gather(edge_index, index, axis=-1) 36 | dropped_edge_index = tf.concat([dropped_edge_index, tf.gather(dropped_edge_index, [1, 0])], axis=-1) 37 | index = tf.concat([index, index], axis=-1) 38 | else: 39 | index = tf.boolean_mask(tf.range(0, tf.shape(row)[0]), 40 | tf.greater(tf.nn.dropout(tf.ones_like(row, dtype=tf.float32), rate), 0)) 41 | dropped_edge_index = tf.gather(edge_index, index, axis=-1) 42 | 43 | for i in range(len(edge_attrs)): 44 | if tf.is_tensor(edge_attrs[i]): 45 | edge_attrs[i] = tf.gather(edge_attrs[i], index, axis=-1) 46 | else: 47 | edge_attrs[i] = np.take(edge_attrs[i], index, axis=-1) 48 | 49 | if not edge_index_is_tensor: 50 | dropped_edge_index = dropped_edge_index.numpy() 51 | 52 | return [dropped_edge_index] + edge_attrs 53 | -------------------------------------------------------------------------------- /tf_geometric/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | -------------------------------------------------------------------------------- /tf_geometric/utils/data_utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | import tensorflow as tf 3 | import pickle 4 | 5 | 6 | 7 | def download_file(path, url_or_urls): 8 | if not isinstance(url_or_urls, list): 9 | urls = [url_or_urls] 10 | else: 11 | urls = url_or_urls 12 | 13 | last_except = None 14 | for url in urls: 15 | try: 16 | return tf.keras.utils.get_file(path, origin=url) 17 | except Exception as e: 18 | last_except = e 19 | print(e) 20 | 21 | raise last_except 22 | 23 | 24 | def save_cache(obj, path): 25 | with open(path, "wb") as f: 26 | pickle.dump(obj, f, protocol=4) 27 | 28 | 29 | def load_cache(path): 30 | # if not os.path.exists(path): 31 | # return None 32 | 33 | with open(path, "rb") as f: 34 | return pickle.load(f) -------------------------------------------------------------------------------- /tf_geometric/utils/tf_sparse_utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | import tensorflow as tf 4 | import tf_sparse as tfs 5 | import numpy as np 6 | 7 | 8 | def sparse_tensor_gather_sub(sparse_tensor, sub_index, axis=0): 9 | gather_index = sparse_tensor.indices[:, axis] 10 | if axis in [0, -2]: 11 | other_axis = 1 12 | else: 13 | other_axis = 0 14 | 15 | dense_shape = tf.shape(sparse_tensor) 16 | 17 | index_mask = tf.scatter_nd(tf.expand_dims(sub_index, axis=-1), tf.ones_like(sub_index), [dense_shape[axis]]) 18 | index_mask = tf.cast(index_mask, tf.bool) 19 | 20 | gather_mask = tf.gather(index_mask, gather_index) 21 | 22 | masked_values = tf.boolean_mask(sparse_tensor.values, gather_mask) 23 | masked_indices = tf.boolean_mask(sparse_tensor.indices, gather_mask) 24 | 25 | reverse_index = tf.cast(tf.fill([tf.reduce_max(sub_index) + 1], -1), tf.int32) 26 | reverse_index = tf.tensor_scatter_nd_update(reverse_index, tf.expand_dims(sub_index, axis=-1), 27 | tf.range(tf.shape(sub_index)[0])) 28 | 29 | masked_gather_index = masked_indices[:, axis] 30 | masked_gather_index = tf.gather(reverse_index, masked_gather_index) 31 | masked_gather_index = tf.cast(masked_gather_index, dtype=tf.int64) 32 | 33 | masked_other_index = masked_indices[:, other_axis] 34 | 35 | new_indices = [None, None] 36 | new_indices[axis] = masked_gather_index 37 | new_indices[other_axis] = masked_other_index 38 | 39 | new_shape = [None, None] 40 | new_shape[axis] = tf.shape(sub_index)[0] 41 | new_shape[other_axis] = dense_shape[other_axis] 42 | new_shape = tf.cast(new_shape, tf.int64) 43 | 44 | new_indices = tf.stack(new_indices, axis=1) 45 | 46 | new_sparse_tensor = tf.sparse.SparseTensor( 47 | indices=new_indices, 48 | values=masked_values, 49 | dense_shape=new_shape 50 | ) 51 | new_sparse_tensor = tf.sparse.reorder(new_sparse_tensor) 52 | return new_sparse_tensor 53 | 54 | 55 | def sparse_gather_sub(x, sub_index, axis=0): 56 | is_sparse_matrix = isinstance(x, tfs.SparseMatrix) 57 | 58 | if is_sparse_matrix: 59 | sparse_tensor = x.to_sparse_tensor() 60 | else: 61 | sparse_tensor = x 62 | 63 | output_sparse_tensor = sparse_tensor_gather_sub(sparse_tensor, sub_index, axis) 64 | 65 | if is_sparse_matrix: 66 | return x.__class__.from_sparse_tensor(output_sparse_tensor) 67 | else: 68 | return output_sparse_tensor 69 | 70 | 71 | def compute_num_or_size_splits(num_h_features, num_splits): 72 | 73 | if num_splits is None or num_splits == 1: 74 | num_or_size_splits = None 75 | elif num_h_features % num_splits == 0: 76 | num_or_size_splits = num_splits 77 | else: 78 | split_size = np.ceil(num_h_features / num_splits).astype(np.int64) 79 | 80 | num_pre_splits = np.floor(num_h_features / split_size).astype(np.int64) 81 | last_split_size = num_h_features % split_size 82 | 83 | num_or_size_splits = [split_size] * num_pre_splits + ([last_split_size] if last_split_size > 0 else []) 84 | 85 | if len(num_or_size_splits) != num_splits: 86 | raise Exception( 87 | "cannot split H of shape [None, {}] into {} matrices, please provide a valid num_splits" 88 | .format(num_h_features, num_splits)) 89 | 90 | return num_or_size_splits 91 | -------------------------------------------------------------------------------- /tf_geometric/utils/tf_utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | import tensorflow as tf 4 | import warnings 5 | 6 | 7 | def warn_tf1(): 8 | warnings.warn("Using @tf_utils.function to speed-up functions is only available for TensorFlow 2.x. " 9 | "Upgrade your TensorFlow to 2.x and the performance can be improved dramatically with @tf_utils.function") 10 | 11 | 12 | def tf_func_warn(*args, **kwargs): 13 | if len(args) == 1 and len(kwargs) == 0 and callable(args[0]): 14 | func = args[0] 15 | warn_tf1() 16 | return func 17 | else: 18 | def decorate(func): 19 | warn_tf1() 20 | return func 21 | 22 | return decorate 23 | 24 | 25 | # disable tf.function for tf 1 26 | if tf.__version__[0] == "1": 27 | function = tf_func_warn 28 | else: 29 | function = tf.function 30 | -------------------------------------------------------------------------------- /tf_geometric/utils/union_utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | import tensorflow as tf 3 | import numpy as np 4 | import tf_sparse as tfs 5 | 6 | def convert_union_to_numpy(data, dtype=None): 7 | if data is None: 8 | return data 9 | 10 | if tf.is_tensor(data): 11 | np_data = data.numpy() 12 | elif isinstance(data, list): 13 | np_data = np.array(data) 14 | else: 15 | np_data = data 16 | 17 | if dtype is not None: 18 | np_data = np_data.astype(dtype) 19 | 20 | return np_data 21 | 22 | 23 | def union_len(data): 24 | if tf.is_tensor(data): 25 | return tfs.shape(data)[0] 26 | else: 27 | return len(data) --------------------------------------------------------------------------------