├── .github ├── CODEOWNERS ├── CONTRIBUTING.md ├── ISSUE_TEMPLATE │ ├── bug-report.yml │ ├── config.yml │ ├── documentation.yml │ ├── feature-request.yml │ ├── installation.yml │ └── refactor.yml ├── actions │ └── setup │ │ └── action.yml ├── dependabot.yml ├── labeler.yml └── workflows │ ├── auto-merge.yml │ ├── building_nightly.yml │ ├── changelog.yml │ ├── documentation.yml │ ├── examples.yml │ ├── labeler.yml │ ├── linting.yml │ ├── testing_full.yml │ ├── testing_full_gpu.yml │ ├── testing_latest.yml │ ├── testing_minimal.yml │ ├── testing_nightly.yml │ ├── testing_prev.yml │ └── testing_rag.yml ├── .gitignore ├── .pre-commit-config.yaml ├── CHANGELOG.md ├── CITATION.cff ├── LICENSE ├── README.md ├── benchmark ├── README.md ├── citation │ ├── README.md │ ├── __init__.py │ ├── appnp.py │ ├── arma.py │ ├── cheb.py │ ├── datasets.py │ ├── gat.py │ ├── gcn.py │ ├── inference.sh │ ├── run.sh │ ├── sgc.py │ ├── statistics.py │ └── train_eval.py ├── inference │ ├── README.md │ └── inference_benchmark.py ├── kernel │ ├── README.md │ ├── __init__.py │ ├── asap.py │ ├── datasets.py │ ├── diff_pool.py │ ├── edge_pool.py │ ├── gcn.py │ ├── gin.py │ ├── global_attention.py │ ├── graclus.py │ ├── graph_sage.py │ ├── main.py │ ├── main_performance.py │ ├── sag_pool.py │ ├── set2set.py │ ├── sort_pool.py │ ├── statistics.py │ ├── top_k.py │ └── train_eval.py ├── loader │ └── neighbor_loader.py ├── multi_gpu │ └── training │ │ ├── README.md │ │ ├── common.py │ │ ├── training_benchmark_cuda.py │ │ └── training_benchmark_xpu.py ├── points │ ├── README.md │ ├── __init__.py │ ├── datasets.py │ ├── edge_cnn.py │ ├── mpnn.py │ ├── point_cnn.py │ ├── point_net.py │ ├── spline_cnn.py │ ├── statistics.py │ └── train_eval.py ├── runtime │ ├── README.md │ ├── __init__.py │ ├── dgl │ │ ├── gat.py │ │ ├── gcn.py │ │ ├── hidden.py │ │ ├── main.py │ │ ├── rgcn.py │ │ └── train.py │ ├── gat.py │ ├── gcn.py │ ├── main.py │ ├── rgcn.py │ └── train.py ├── setup.py ├── training │ ├── README.md │ └── training_benchmark.py └── utils │ ├── __init__.py │ ├── hetero_gat.py │ ├── hetero_sage.py │ └── utils.py ├── codecov.yml ├── docker ├── Dockerfile ├── Dockerfile.xpu ├── README.md └── singularity ├── docs ├── Makefile ├── README.md ├── requirements.txt └── source │ ├── .gitignore │ ├── _figures │ ├── .gitignore │ ├── architecture.pdf │ ├── architecture.svg │ ├── build.sh │ ├── dist_part.png │ ├── dist_proc.png │ ├── dist_sampling.png │ ├── graph.svg │ ├── graph.tex │ ├── graphgps_layer.png │ ├── graphgym_design_space.png │ ├── graphgym_evaluation.png │ ├── graphgym_results.png │ ├── hg_example.svg │ ├── hg_example.tex │ ├── intel_kumo.png │ ├── meshcnn_edge_adjacency.svg │ ├── point_cloud1.png │ ├── point_cloud2.png │ ├── point_cloud3.png │ ├── point_cloud4.png │ ├── remote_1.png │ ├── remote_2.png │ ├── remote_3.png │ ├── shallow_node_embeddings.png │ ├── to_hetero.svg │ ├── to_hetero.tex │ ├── to_hetero_with_bases.svg │ ├── to_hetero_with_bases.tex │ └── training_affinity.png │ ├── _static │ ├── js │ │ └── version_alert.js │ └── thumbnails │ │ ├── create_dataset.png │ │ ├── create_gnn.png │ │ ├── dataset_splitting.png │ │ ├── distributed_pyg.png │ │ ├── explain.png │ │ ├── graph_transformer.png │ │ ├── heterogeneous.png │ │ ├── load_csv.png │ │ ├── multi_gpu_vanilla.png │ │ ├── neighbor_loader.png │ │ ├── point_cloud.png │ │ └── shallow_node_embeddings.png │ ├── _templates │ └── autosummary │ │ ├── class.rst │ │ ├── inherited_class.rst │ │ ├── metrics.rst │ │ ├── nn.rst │ │ └── only_class.rst │ ├── advanced │ ├── batching.rst │ ├── compile.rst │ ├── cpu_affinity.rst │ ├── graphgym.rst │ ├── hgam.rst │ ├── jit.rst │ ├── remote.rst │ └── sparse_tensor.rst │ ├── cheatsheet │ ├── data_cheatsheet.rst │ └── gnn_cheatsheet.rst │ ├── conf.py │ ├── external │ └── resources.rst │ ├── get_started │ ├── colabs.rst │ └── introduction.rst │ ├── index.rst │ ├── install │ ├── installation.rst │ └── quick-start.html │ ├── modules │ ├── contrib.rst │ ├── data.rst │ ├── datasets.rst │ ├── distributed.rst │ ├── explain.rst │ ├── graphgym.rst │ ├── llm.rst │ ├── loader.rst │ ├── metrics.rst │ ├── nn.rst │ ├── profile.rst │ ├── root.rst │ ├── sampler.rst │ ├── transforms.rst │ └── utils.rst │ ├── notes │ ├── batching.rst │ ├── cheatsheet.rst │ ├── colabs.rst │ ├── create_dataset.rst │ ├── create_gnn.rst │ ├── data_cheatsheet.rst │ ├── explain.rst │ ├── graphgym.rst │ ├── heterogeneous.rst │ ├── installation.rst │ ├── introduction.rst │ ├── jit.rst │ ├── load_csv.rst │ ├── remote.rst │ ├── resources.rst │ └── sparse_tensor.rst │ └── tutorial │ ├── application.rst │ ├── compile.rst │ ├── create_dataset.rst │ ├── create_gnn.rst │ ├── dataset.rst │ ├── dataset_splitting.rst │ ├── distributed.rst │ ├── distributed_pyg.rst │ ├── explain.rst │ ├── gnn_design.rst │ ├── graph_transformer.rst │ ├── heterogeneous.rst │ ├── load_csv.rst │ ├── multi_gpu_vanilla.rst │ ├── multi_node_multi_gpu_vanilla.rst │ ├── neighbor_loader.rst │ ├── point_cloud.rst │ └── shallow_node_embeddings.rst ├── examples ├── README.md ├── agnn.py ├── ar_link_pred.py ├── argva_node_clustering.py ├── arma.py ├── attentive_fp.py ├── autoencoder.py ├── cluster_gcn_ppi.py ├── cluster_gcn_reddit.py ├── colors_topk_pool.py ├── compile │ ├── gcn.py │ └── gin.py ├── contrib │ ├── README.md │ ├── pgm_explainer_graph_classification.py │ ├── pgm_explainer_node_classification.py │ ├── rbcd_attack.py │ └── rbcd_attack_poisoning.py ├── cora.py ├── correct_and_smooth.py ├── cpp │ ├── CMakeLists.txt │ ├── README.md │ ├── main.cpp │ └── save_model.py ├── datapipe.py ├── dgcnn_classification.py ├── dgcnn_segmentation.py ├── dir_gnn.py ├── distributed │ ├── README.md │ ├── graphlearn_for_pytorch │ │ ├── README.md │ │ ├── dist_train_sage_sup_config.yml │ │ ├── dist_train_sage_supervised.py │ │ ├── launch.py │ │ └── partition_ogbn_dataset.py │ ├── kuzu │ │ ├── README.md │ │ └── papers_100M │ │ │ ├── README.md │ │ │ ├── prepare_data.py │ │ │ └── train.py │ └── pyg │ │ ├── README.md │ │ ├── launch.py │ │ ├── node_ogb_cpu.py │ │ ├── partition_graph.py │ │ ├── run_dist.sh │ │ └── temporal_link_movielens_cpu.py ├── dna.py ├── egc.py ├── equilibrium_median.py ├── explain │ ├── README.md │ ├── captum_explainer.py │ ├── captum_explainer_hetero_link.py │ ├── gnn_explainer.py │ ├── gnn_explainer_ba_shapes.py │ ├── gnn_explainer_link_pred.py │ └── graphmask_explainer.py ├── faust.py ├── film.py ├── gat.py ├── gcn.py ├── gcn2_cora.py ├── gcn2_ppi.py ├── geniepath.py ├── glnn.py ├── gpse.py ├── graph_gps.py ├── graph_sage_unsup.py ├── graph_sage_unsup_ppi.py ├── graph_saint.py ├── graph_unet.py ├── hetero │ ├── README.md │ ├── bipartite_sage.py │ ├── bipartite_sage_unsup.py │ ├── dmgi_unsup.py │ ├── han_imdb.py │ ├── hetero_conv_dblp.py │ ├── hetero_link_pred.py │ ├── hgt_dblp.py │ ├── hierarchical_sage.py │ ├── load_csv.py │ ├── metapath2vec.py │ ├── recommender_system.py │ ├── temporal_link_pred.py │ └── to_hetero_mag.py ├── hierarchical_sampling.py ├── infomax_inductive.py ├── infomax_transductive.py ├── jit │ ├── README.md │ ├── film.py │ ├── gat.py │ ├── gcn.py │ └── gin.py ├── kge_fb15k_237.py ├── label_prop.py ├── lcm_aggr_2nd_min.py ├── lightgcn.py ├── link_pred.py ├── linkx.py ├── llm │ ├── README.md │ ├── g_retriever.py │ ├── git_mol.py │ ├── glem.py │ ├── molecule_gpt.py │ ├── nvtx_examples │ │ ├── README.md │ │ ├── nvtx_rag_backend_example.py │ │ ├── nvtx_run.sh │ │ └── nvtx_webqsp_example.py │ ├── protein_mpnn.py │ └── txt2kg_rag.py ├── lpformer.py ├── mem_pool.py ├── mixhop.py ├── mnist_graclus.py ├── mnist_nn_conv.py ├── mnist_voxel_grid.py ├── multi_gpu │ ├── README.md │ ├── data_parallel.py │ ├── distributed_batching.py │ ├── distributed_sampling.py │ ├── distributed_sampling_multinode.py │ ├── distributed_sampling_multinode.sbatch │ ├── distributed_sampling_xpu.py │ ├── mag240m_graphsage.py │ ├── model_parallel.py │ ├── ogbn_train_cugraph.py │ ├── papers100m_gcn.py │ ├── papers100m_gcn_cugraph_multinode.py │ ├── papers100m_gcn_multinode.py │ ├── pcqm4m_ogb.py │ └── taobao.py ├── mutag_gin.py ├── node2vec.py ├── ogbn_proteins_deepgcn.py ├── ogbn_train.py ├── ogbn_train_cugraph.py ├── ogc.py ├── pmlp.py ├── pna.py ├── point_transformer_classification.py ├── point_transformer_segmentation.py ├── pointnet2_classification.py ├── pointnet2_segmentation.py ├── ppi.py ├── proteins_diff_pool.py ├── proteins_dmon_pool.py ├── proteins_gmt.py ├── proteins_mincut_pool.py ├── proteins_topk_pool.py ├── pytorch_ignite │ ├── README.md │ └── gin.py ├── pytorch_lightning │ ├── README.md │ ├── gin.py │ ├── graph_sage.py │ └── relational_gnn.py ├── qm9_nn_conv.py ├── qm9_pretrained_dimenet.py ├── qm9_pretrained_schnet.py ├── quiver │ ├── README.md │ ├── multi_gpu_quiver.py │ └── single_gpu_quiver.py ├── randlanet_classification.py ├── randlanet_segmentation.py ├── rdl.py ├── rect.py ├── reddit.py ├── renet.py ├── rev_gnn.py ├── rgat.py ├── rgcn.py ├── rgcn_link_pred.py ├── seal_link_pred.py ├── sgc.py ├── shadow.py ├── sign.py ├── signed_gcn.py ├── super_gat.py ├── tagcn.py ├── tensorboard_logging.py ├── tgn.py ├── triangles_sag_pool.py ├── unimp_arxiv.py ├── upfd.py └── wl_kernel.py ├── graphgym ├── agg_batch.py ├── configs │ ├── example.yaml │ └── pyg │ │ ├── example_graph.yaml │ │ ├── example_link.yaml │ │ └── example_node.yaml ├── configs_gen.py ├── custom_graphgym │ ├── __init__.py │ ├── act │ │ ├── __init__.py │ │ └── example.py │ ├── config │ │ ├── __init__.py │ │ └── example.py │ ├── encoder │ │ ├── __init__.py │ │ └── example.py │ ├── head │ │ ├── __init__.py │ │ └── example.py │ ├── layer │ │ ├── __init__.py │ │ └── example.py │ ├── loader │ │ ├── __init__.py │ │ └── example.py │ ├── loss │ │ ├── __init__.py │ │ └── example.py │ ├── network │ │ ├── __init__.py │ │ └── example.py │ ├── optimizer │ │ ├── __init__.py │ │ └── example.py │ ├── pooling │ │ ├── __init__.py │ │ └── example.py │ ├── stage │ │ ├── __init__.py │ │ └── example.py │ ├── train │ │ ├── __init__.py │ │ └── example.py │ └── transform │ │ └── __init__.py ├── grids │ ├── example.txt │ └── pyg │ │ └── example.txt ├── main.py ├── parallel.sh ├── results │ └── example_node_grid_example │ │ └── agg │ │ ├── test.csv │ │ ├── test_best.csv │ │ ├── test_bestepoch.csv │ │ ├── train.csv │ │ ├── train_best.csv │ │ ├── train_bestepoch.csv │ │ ├── val.csv │ │ ├── val_best.csv │ │ └── val_bestepoch.csv ├── run_batch.sh ├── run_single.sh └── sample │ ├── dimensions.txt │ └── dimensionsatt.txt ├── pyproject.toml ├── readthedocs.yml ├── test ├── conftest.py ├── contrib │ ├── explain │ │ └── test_pgm_explainer.py │ └── nn │ │ └── models │ │ └── test_rbcd_attack.py ├── data │ ├── lightning │ │ └── test_datamodule.py │ ├── test_batch.py │ ├── test_data.py │ ├── test_database.py │ ├── test_datapipes.py │ ├── test_dataset.py │ ├── test_dataset_summary.py │ ├── test_feature_store.py │ ├── test_graph_store.py │ ├── test_hetero_data.py │ ├── test_hypergraph_data.py │ ├── test_inherit.py │ ├── test_on_disk_dataset.py │ ├── test_remote_backend_utils.py │ ├── test_storage.py │ ├── test_temporal.py │ └── test_view.py ├── datasets │ ├── graph_generator │ │ ├── test_ba_graph.py │ │ ├── test_er_graph.py │ │ ├── test_grid_graph.py │ │ └── test_tree_graph.py │ ├── motif_generator │ │ ├── test_custom_motif.py │ │ ├── test_cycle_motif.py │ │ ├── test_grid_motif.py │ │ └── test_house_motif.py │ ├── test_ba_shapes.py │ ├── test_bzr.py │ ├── test_elliptic.py │ ├── test_enzymes.py │ ├── test_explainer_dataset.py │ ├── test_fake.py │ ├── test_git_mol_dataset.py │ ├── test_imdb_binary.py │ ├── test_infection_dataset.py │ ├── test_karate.py │ ├── test_medshapenet.py │ ├── test_molecule_gpt_dataset.py │ ├── test_mutag.py │ ├── test_planetoid.py │ ├── test_protein_mpnn_dataset.py │ ├── test_snap_dataset.py │ ├── test_suite_sparse.py │ ├── test_tag_dataset.py │ ├── test_teeth3ds.py │ └── test_web_qsp_dataset.py ├── distributed │ ├── test_dist_link_neighbor_loader.py │ ├── test_dist_link_neighbor_sampler.py │ ├── test_dist_neighbor_loader.py │ ├── test_dist_neighbor_sampler.py │ ├── test_dist_utils.py │ ├── test_local_feature_store.py │ ├── test_local_graph_store.py │ ├── test_partition.py │ └── test_rpc.py ├── explain │ ├── algorithm │ │ ├── test_attention_explainer.py │ │ ├── test_captum.py │ │ ├── test_captum_explainer.py │ │ ├── test_captum_hetero.py │ │ ├── test_explain_algorithm_utils.py │ │ ├── test_gnn_explainer.py │ │ ├── test_graphmask_explainer.py │ │ └── test_pg_explainer.py │ ├── conftest.py │ ├── metric │ │ ├── test_basic_metric.py │ │ ├── test_faithfulness.py │ │ └── test_fidelity.py │ ├── test_explain_config.py │ ├── test_explainer.py │ ├── test_explanation.py │ ├── test_hetero_explainer.py │ └── test_hetero_explanation.py ├── graphgym │ ├── example_node.yml │ ├── test_config.py │ ├── test_graphgym.py │ ├── test_logger.py │ └── test_register.py ├── io │ ├── example1.off │ ├── example2.off │ ├── test_fs.py │ └── test_off.py ├── llm │ ├── models │ │ ├── test_g_retriever.py │ │ ├── test_git_mol.py │ │ ├── test_glem.py │ │ ├── test_llm.py │ │ ├── test_molecule_gpt.py │ │ ├── test_protein_mpnn.py │ │ ├── test_sentence_transformer.py │ │ └── test_vision_transformer.py │ ├── test_large_graph_indexer.py │ ├── test_rag_loader.py │ └── utils │ │ ├── test_rag_backend_utils.py │ │ ├── test_rag_feature_store.py │ │ ├── test_rag_graph_store.py │ │ └── test_vectorrag.py ├── loader │ ├── test_cache.py │ ├── test_cluster.py │ ├── test_dataloader.py │ ├── test_dynamic_batch_sampler.py │ ├── test_graph_saint.py │ ├── test_hgt_loader.py │ ├── test_ibmb_loader.py │ ├── test_imbalanced_sampler.py │ ├── test_link_neighbor_loader.py │ ├── test_mixin.py │ ├── test_neighbor_loader.py │ ├── test_neighbor_sampler.py │ ├── test_prefetch.py │ ├── test_random_node_loader.py │ ├── test_shadow.py │ ├── test_temporal_dataloader.py │ ├── test_utils.py │ └── test_zip_loader.py ├── metrics │ └── test_link_pred_metric.py ├── my_config.yaml ├── nn │ ├── aggr │ │ ├── test_aggr_utils.py │ │ ├── test_attention.py │ │ ├── test_basic.py │ │ ├── test_deep_sets.py │ │ ├── test_equilibrium.py │ │ ├── test_fused.py │ │ ├── test_gmt.py │ │ ├── test_gru.py │ │ ├── test_lcm.py │ │ ├── test_lstm.py │ │ ├── test_mlp_aggr.py │ │ ├── test_multi.py │ │ ├── test_patch_transformer.py │ │ ├── test_quantile.py │ │ ├── test_scaler.py │ │ ├── test_set2set.py │ │ ├── test_set_transformer.py │ │ ├── test_sort.py │ │ └── test_variance_preserving.py │ ├── attention │ │ ├── test_performer_attention.py │ │ ├── test_polynormer_attention.py │ │ └── test_qformer.py │ ├── conv │ │ ├── cugraph │ │ │ ├── test_cugraph_gat_conv.py │ │ │ ├── test_cugraph_rgcn_conv.py │ │ │ └── test_cugraph_sage_conv.py │ │ ├── test_agnn_conv.py │ │ ├── test_antisymmetric_conv.py │ │ ├── test_appnp.py │ │ ├── test_arma_conv.py │ │ ├── test_cg_conv.py │ │ ├── test_cheb_conv.py │ │ ├── test_cluster_gcn_conv.py │ │ ├── test_create_gnn.py │ │ ├── test_dir_gnn_conv.py │ │ ├── test_dna_conv.py │ │ ├── test_edge_conv.py │ │ ├── test_eg_conv.py │ │ ├── test_fa_conv.py │ │ ├── test_feast_conv.py │ │ ├── test_film_conv.py │ │ ├── test_fused_gat_conv.py │ │ ├── test_gat_conv.py │ │ ├── test_gated_graph_conv.py │ │ ├── test_gatv2_conv.py │ │ ├── test_gcn2_conv.py │ │ ├── test_gcn_conv.py │ │ ├── test_gen_conv.py │ │ ├── test_general_conv.py │ │ ├── test_gin_conv.py │ │ ├── test_gmm_conv.py │ │ ├── test_gps_conv.py │ │ ├── test_graph_conv.py │ │ ├── test_gravnet_conv.py │ │ ├── test_han_conv.py │ │ ├── test_heat_conv.py │ │ ├── test_hetero_conv.py │ │ ├── test_hgt_conv.py │ │ ├── test_hypergraph_conv.py │ │ ├── test_le_conv.py │ │ ├── test_lg_conv.py │ │ ├── test_meshcnn_conv.py │ │ ├── test_message_passing.py │ │ ├── test_mf_conv.py │ │ ├── test_mixhop_conv.py │ │ ├── test_nn_conv.py │ │ ├── test_pan_conv.py │ │ ├── test_pdn_conv.py │ │ ├── test_pna_conv.py │ │ ├── test_point_conv.py │ │ ├── test_point_gnn_conv.py │ │ ├── test_point_transformer_conv.py │ │ ├── test_ppf_conv.py │ │ ├── test_res_gated_graph_conv.py │ │ ├── test_rgat_conv.py │ │ ├── test_rgcn_conv.py │ │ ├── test_sage_conv.py │ │ ├── test_sg_conv.py │ │ ├── test_signed_conv.py │ │ ├── test_simple_conv.py │ │ ├── test_spline_conv.py │ │ ├── test_ssg_conv.py │ │ ├── test_static_graph.py │ │ ├── test_supergat_conv.py │ │ ├── test_tag_conv.py │ │ ├── test_transformer_conv.py │ │ ├── test_wl_conv.py │ │ ├── test_wl_conv_continuous.py │ │ ├── test_x_conv.py │ │ └── utils │ │ │ └── test_gnn_cheatsheet.py │ ├── dense │ │ ├── test_dense_gat_conv.py │ │ ├── test_dense_gcn_conv.py │ │ ├── test_dense_gin_conv.py │ │ ├── test_dense_graph_conv.py │ │ ├── test_dense_sage_conv.py │ │ ├── test_diff_pool.py │ │ ├── test_dmon_pool.py │ │ ├── test_linear.py │ │ └── test_mincut_pool.py │ ├── functional │ │ ├── test_bro.py │ │ └── test_gini.py │ ├── kge │ │ ├── test_complex.py │ │ ├── test_distmult.py │ │ ├── test_rotate.py │ │ └── test_transe.py │ ├── models │ │ ├── test_attentive_fp.py │ │ ├── test_attract_repel.py │ │ ├── test_autoencoder.py │ │ ├── test_basic_gnn.py │ │ ├── test_correct_and_smooth.py │ │ ├── test_deep_graph_infomax.py │ │ ├── test_deepgcn.py │ │ ├── test_dimenet.py │ │ ├── test_gnnff.py │ │ ├── test_gpse.py │ │ ├── test_graph_mixer.py │ │ ├── test_graph_unet.py │ │ ├── test_jumping_knowledge.py │ │ ├── test_label_prop.py │ │ ├── test_lightgcn.py │ │ ├── test_linkx.py │ │ ├── test_lpformer.py │ │ ├── test_mask_label.py │ │ ├── test_meta.py │ │ ├── test_metapath2vec.py │ │ ├── test_mlp.py │ │ ├── test_neural_fingerprint.py │ │ ├── test_node2vec.py │ │ ├── test_pmlp.py │ │ ├── test_polynormer.py │ │ ├── test_re_net.py │ │ ├── test_rect.py │ │ ├── test_rev_gnn.py │ │ ├── test_schnet.py │ │ ├── test_sgformer.py │ │ ├── test_signed_gcn.py │ │ ├── test_tgn.py │ │ └── test_visnet.py │ ├── norm │ │ ├── test_batch_norm.py │ │ ├── test_diff_group_norm.py │ │ ├── test_graph_norm.py │ │ ├── test_graph_size_norm.py │ │ ├── test_instance_norm.py │ │ ├── test_layer_norm.py │ │ ├── test_mean_subtraction_norm.py │ │ ├── test_msg_norm.py │ │ └── test_pair_norm.py │ ├── pool │ │ ├── connect │ │ │ └── test_filter_edges.py │ │ ├── select │ │ │ └── test_select_topk.py │ │ ├── test_approx_knn.py │ │ ├── test_asap.py │ │ ├── test_avg_pool.py │ │ ├── test_cluster_pool.py │ │ ├── test_consecutive.py │ │ ├── test_decimation.py │ │ ├── test_edge_pool.py │ │ ├── test_glob.py │ │ ├── test_graclus.py │ │ ├── test_knn.py │ │ ├── test_max_pool.py │ │ ├── test_mem_pool.py │ │ ├── test_pan_pool.py │ │ ├── test_pool.py │ │ ├── test_sag_pool.py │ │ ├── test_topk_pool.py │ │ └── test_voxel_grid.py │ ├── test_compile_basic.py │ ├── test_compile_conv.py │ ├── test_compile_dynamic.py │ ├── test_data_parallel.py │ ├── test_encoding.py │ ├── test_fvcore.py │ ├── test_fx.py │ ├── test_inits.py │ ├── test_model_hub.py │ ├── test_model_summary.py │ ├── test_module_dict.py │ ├── test_parameter_dict.py │ ├── test_reshape.py │ ├── test_resolver.py │ ├── test_sequential.py │ ├── test_to_fixed_size_transformer.py │ ├── test_to_hetero_module.py │ ├── test_to_hetero_transformer.py │ ├── test_to_hetero_with_bases_transformer.py │ └── unpool │ │ └── test_knn_interpolate.py ├── profile │ ├── test_benchmark.py │ ├── test_nvtx.py │ ├── test_profile.py │ ├── test_profile_utils.py │ └── test_profiler.py ├── sampler │ ├── test_sampler_base.py │ └── test_sampler_neighbor_sampler.py ├── test_config_mixin.py ├── test_config_store.py ├── test_debug.py ├── test_edge_index.py ├── test_experimental.py ├── test_hash_tensor.py ├── test_home.py ├── test_index.py ├── test_inspector.py ├── test_isinstance.py ├── test_onnx.py ├── test_seed.py ├── test_typing.py ├── test_warnings.py ├── testing │ └── test_decorators.py ├── transforms │ ├── test_add_gpse.py │ ├── test_add_metapaths.py │ ├── test_add_positional_encoding.py │ ├── test_add_remaining_self_loops.py │ ├── test_add_self_loops.py │ ├── test_cartesian.py │ ├── test_center.py │ ├── test_compose.py │ ├── test_constant.py │ ├── test_delaunay.py │ ├── test_distance.py │ ├── test_face_to_edge.py │ ├── test_feature_propagation.py │ ├── test_fixed_points.py │ ├── test_gcn_norm.py │ ├── test_gdc.py │ ├── test_generate_mesh_normals.py │ ├── test_grid_sampling.py │ ├── test_half_hop.py │ ├── test_knn_graph.py │ ├── test_laplacian_lambda_max.py │ ├── test_largest_connected_components.py │ ├── test_line_graph.py │ ├── test_linear_transformation.py │ ├── test_local_cartesian.py │ ├── test_local_degree_profile.py │ ├── test_mask_transform.py │ ├── test_node_property_split.py │ ├── test_normalize_features.py │ ├── test_normalize_rotation.py │ ├── test_normalize_scale.py │ ├── test_one_hot_degree.py │ ├── test_pad.py │ ├── test_point_pair_features.py │ ├── test_polar.py │ ├── test_radius_graph.py │ ├── test_random_flip.py │ ├── test_random_jitter.py │ ├── test_random_link_split.py │ ├── test_random_node_split.py │ ├── test_random_rotate.py │ ├── test_random_scale.py │ ├── test_random_shear.py │ ├── test_remove_duplicated_edges.py │ ├── test_remove_isolated_nodes.py │ ├── test_remove_self_loops.py │ ├── test_remove_training_classes.py │ ├── test_rooted_subgraph.py │ ├── test_sample_points.py │ ├── test_sign.py │ ├── test_spherical.py │ ├── test_svd_feature_reduction.py │ ├── test_target_indegree.py │ ├── test_to_dense.py │ ├── test_to_device.py │ ├── test_to_sparse_tensor.py │ ├── test_to_superpixels.py │ ├── test_to_undirected.py │ ├── test_two_hop.py │ └── test_virtual_node.py ├── utils │ ├── conftest.py │ ├── test_assortativity.py │ ├── test_augmentation.py │ ├── test_coalesce.py │ ├── test_convert.py │ ├── test_cross_entropy.py │ ├── test_degree.py │ ├── test_dropout.py │ ├── test_embedding.py │ ├── test_functions.py │ ├── test_geodesic.py │ ├── test_grid.py │ ├── test_hetero.py │ ├── test_homophily.py │ ├── test_index_sort.py │ ├── test_isolated.py │ ├── test_laplacian.py │ ├── test_lexsort.py │ ├── test_loop.py │ ├── test_map.py │ ├── test_mask.py │ ├── test_mesh_laplacian.py │ ├── test_negative_sampling.py │ ├── test_nested.py │ ├── test_noise_scheduler.py │ ├── test_normalize_edge_index.py │ ├── test_normalized_cut.py │ ├── test_num_nodes.py │ ├── test_one_hot.py │ ├── test_ppr.py │ ├── test_random.py │ ├── test_repeat.py │ ├── test_scatter.py │ ├── test_segment.py │ ├── test_select.py │ ├── test_smiles.py │ ├── test_softmax.py │ ├── test_sort_edge_index.py │ ├── test_sparse.py │ ├── test_spmm.py │ ├── test_subgraph.py │ ├── test_to_dense_adj.py │ ├── test_to_dense_batch.py │ ├── test_total_influence.py │ ├── test_train_test_split_edges.py │ ├── test_tree_decomposition.py │ ├── test_trim_to_layer.py │ ├── test_unbatch.py │ └── test_undirected.py └── visualization │ ├── test_graph_visualization.py │ └── test_influence.py └── torch_geometric ├── __init__.py ├── _compile.py ├── _onnx.py ├── backend.py ├── config_mixin.py ├── config_store.py ├── contrib ├── __init__.py ├── datasets │ └── __init__.py ├── explain │ ├── __init__.py │ └── pgm_explainer.py ├── nn │ ├── __init__.py │ ├── conv │ │ └── __init__.py │ └── models │ │ ├── __init__.py │ │ └── rbcd_attack.py └── transforms │ └── __init__.py ├── data ├── __init__.py ├── batch.py ├── collate.py ├── data.py ├── database.py ├── datapipes.py ├── dataset.py ├── download.py ├── extract.py ├── feature_store.py ├── graph_store.py ├── hetero_data.py ├── hypergraph_data.py ├── in_memory_dataset.py ├── lightning │ ├── __init__.py │ └── datamodule.py ├── makedirs.py ├── on_disk_dataset.py ├── remote_backend_utils.py ├── separate.py ├── storage.py ├── summary.py ├── temporal.py └── view.py ├── datasets ├── __init__.py ├── actor.py ├── airfrans.py ├── airports.py ├── amazon.py ├── amazon_book.py ├── amazon_products.py ├── aminer.py ├── aqsol.py ├── attributed_graph_dataset.py ├── ba2motif_dataset.py ├── ba_multi_shapes.py ├── ba_shapes.py ├── bitcoin_otc.py ├── brca_tgca.py ├── citation_full.py ├── city.py ├── coauthor.py ├── coma.py ├── cornell.py ├── dblp.py ├── dbp15k.py ├── deezer_europe.py ├── dgraph.py ├── dynamic_faust.py ├── elliptic.py ├── elliptic_temporal.py ├── email_eu_core.py ├── entities.py ├── explainer_dataset.py ├── facebook.py ├── fake.py ├── faust.py ├── flickr.py ├── freebase.py ├── gdelt.py ├── gdelt_lite.py ├── ged_dataset.py ├── gemsec.py ├── geometry.py ├── git_mol_dataset.py ├── github.py ├── gnn_benchmark_dataset.py ├── graph_generator │ ├── __init__.py │ ├── ba_graph.py │ ├── base.py │ ├── er_graph.py │ ├── grid_graph.py │ └── tree_graph.py ├── heterophilous_graph_dataset.py ├── hgb_dataset.py ├── hm.py ├── hydro_net.py ├── icews.py ├── igmc_dataset.py ├── imdb.py ├── infection_dataset.py ├── instruct_mol_dataset.py ├── jodie.py ├── karate.py ├── last_fm.py ├── lastfm_asia.py ├── linkx_dataset.py ├── lrgb.py ├── malnet_tiny.py ├── md17.py ├── medshapenet.py ├── mixhop_synthetic_dataset.py ├── mnist_superpixels.py ├── modelnet.py ├── molecule_gpt_dataset.py ├── molecule_net.py ├── motif_generator │ ├── __init__.py │ ├── base.py │ ├── custom.py │ ├── cycle.py │ ├── grid.py │ └── house.py ├── movie_lens.py ├── movie_lens_100k.py ├── movie_lens_1m.py ├── myket.py ├── nell.py ├── neurograph.py ├── ogb_mag.py ├── omdb.py ├── opf.py ├── ose_gvcs.py ├── particle.py ├── pascal.py ├── pascal_pf.py ├── pcpnet_dataset.py ├── pcqm4m.py ├── planetoid.py ├── polblogs.py ├── ppi.py ├── protein_mpnn_dataset.py ├── qm7.py ├── qm9.py ├── rcdd.py ├── reddit.py ├── reddit2.py ├── rel_link_pred_dataset.py ├── s3dis.py ├── sbm_dataset.py ├── shapenet.py ├── shrec2016.py ├── snap_dataset.py ├── suite_sparse.py ├── tag_dataset.py ├── taobao.py ├── teeth3ds.py ├── tosca.py ├── tu_dataset.py ├── twitch.py ├── upfd.py ├── utils │ ├── __init__.py │ └── cheatsheet.py ├── web_qsp_dataset.py ├── webkb.py ├── wikics.py ├── wikidata.py ├── wikipedia_network.py ├── willow_object_class.py ├── word_net.py ├── yelp.py └── zinc.py ├── debug.py ├── deprecation.py ├── device.py ├── distributed ├── __init__.py ├── dist_context.py ├── dist_link_neighbor_loader.py ├── dist_loader.py ├── dist_neighbor_loader.py ├── dist_neighbor_sampler.py ├── event_loop.py ├── local_feature_store.py ├── local_graph_store.py ├── partition.py ├── rpc.py └── utils.py ├── edge_index.py ├── experimental.py ├── explain ├── __init__.py ├── algorithm │ ├── __init__.py │ ├── attention_explainer.py │ ├── base.py │ ├── captum.py │ ├── captum_explainer.py │ ├── dummy_explainer.py │ ├── gnn_explainer.py │ ├── graphmask_explainer.py │ ├── pg_explainer.py │ └── utils.py ├── config.py ├── explainer.py ├── explanation.py └── metric │ ├── __init__.py │ ├── basic.py │ ├── faithfulness.py │ └── fidelity.py ├── graphgym ├── __init__.py ├── benchmark.py ├── checkpoint.py ├── cmd_args.py ├── config.py ├── contrib │ ├── __init__.py │ ├── act │ │ └── __init__.py │ ├── config │ │ └── __init__.py │ ├── encoder │ │ └── __init__.py │ ├── head │ │ └── __init__.py │ ├── layer │ │ ├── __init__.py │ │ └── generalconv.py │ ├── loader │ │ └── __init__.py │ ├── loss │ │ └── __init__.py │ ├── network │ │ └── __init__.py │ ├── optimizer │ │ └── __init__.py │ ├── pooling │ │ └── __init__.py │ ├── stage │ │ └── __init__.py │ ├── train │ │ └── __init__.py │ └── transform │ │ └── __init__.py ├── imports.py ├── init.py ├── loader.py ├── logger.py ├── loss.py ├── model_builder.py ├── models │ ├── __init__.py │ ├── act.py │ ├── encoder.py │ ├── gnn.py │ ├── head.py │ ├── layer.py │ ├── pooling.py │ └── transform.py ├── optim.py ├── register.py ├── train.py └── utils │ ├── LICENSE │ ├── __init__.py │ ├── agg_runs.py │ ├── comp_budget.py │ ├── device.py │ ├── epoch.py │ ├── io.py │ ├── plot.py │ └── tools.py ├── hash_tensor.py ├── home.py ├── index.py ├── inspector.py ├── io ├── __init__.py ├── fs.py ├── npz.py ├── obj.py ├── off.py ├── planetoid.py ├── ply.py ├── sdf.py ├── tu.py └── txt_array.py ├── isinstance.py ├── lazy_loader.py ├── llm ├── __init__.py ├── large_graph_indexer.py ├── models │ ├── __init__.py │ ├── g_retriever.py │ ├── git_mol.py │ ├── glem.py │ ├── llm.py │ ├── llm_judge.py │ ├── molecule_gpt.py │ ├── protein_mpnn.py │ ├── sentence_transformer.py │ ├── txt2kg.py │ └── vision_transformer.py ├── rag_loader.py └── utils │ ├── __init__.py │ ├── backend_utils.py │ ├── feature_store.py │ ├── graph_store.py │ └── vectorrag.py ├── loader ├── __init__.py ├── base.py ├── cache.py ├── cluster.py ├── data_list_loader.py ├── dataloader.py ├── dense_data_loader.py ├── dynamic_batch_sampler.py ├── graph_saint.py ├── hgt_loader.py ├── ibmb_loader.py ├── imbalanced_sampler.py ├── link_loader.py ├── link_neighbor_loader.py ├── mixin.py ├── neighbor_loader.py ├── neighbor_sampler.py ├── node_loader.py ├── prefetch.py ├── random_node_loader.py ├── shadow.py ├── temporal_dataloader.py ├── utils.py └── zip_loader.py ├── logging.py ├── metrics ├── __init__.py └── link_pred.py ├── nn ├── __init__.py ├── aggr │ ├── __init__.py │ ├── attention.py │ ├── base.py │ ├── basic.py │ ├── deep_sets.py │ ├── equilibrium.py │ ├── fused.py │ ├── gmt.py │ ├── gru.py │ ├── lcm.py │ ├── lstm.py │ ├── mlp.py │ ├── multi.py │ ├── patch_transformer.py │ ├── quantile.py │ ├── scaler.py │ ├── set2set.py │ ├── set_transformer.py │ ├── sort.py │ ├── utils.py │ └── variance_preserving.py ├── attention │ ├── __init__.py │ ├── performer.py │ ├── polynormer.py │ ├── qformer.py │ └── sgformer.py ├── conv │ ├── __init__.py │ ├── agnn_conv.py │ ├── antisymmetric_conv.py │ ├── appnp.py │ ├── arma_conv.py │ ├── cg_conv.py │ ├── cheb_conv.py │ ├── cluster_gcn_conv.py │ ├── collect.jinja │ ├── cugraph │ │ ├── __init__.py │ │ ├── base.py │ │ ├── gat_conv.py │ │ ├── rgcn_conv.py │ │ └── sage_conv.py │ ├── dir_gnn_conv.py │ ├── dna_conv.py │ ├── edge_conv.py │ ├── edge_updater.jinja │ ├── eg_conv.py │ ├── fa_conv.py │ ├── feast_conv.py │ ├── film_conv.py │ ├── fused_gat_conv.py │ ├── gat_conv.py │ ├── gated_graph_conv.py │ ├── gatv2_conv.py │ ├── gcn2_conv.py │ ├── gcn_conv.py │ ├── gen_conv.py │ ├── general_conv.py │ ├── gin_conv.py │ ├── gmm_conv.py │ ├── gps_conv.py │ ├── graph_conv.py │ ├── gravnet_conv.py │ ├── han_conv.py │ ├── heat_conv.py │ ├── hetero_conv.py │ ├── hgt_conv.py │ ├── hypergraph_conv.py │ ├── le_conv.py │ ├── lg_conv.py │ ├── meshcnn_conv.py │ ├── message_passing.py │ ├── mf_conv.py │ ├── mixhop_conv.py │ ├── nn_conv.py │ ├── pan_conv.py │ ├── pdn_conv.py │ ├── pna_conv.py │ ├── point_conv.py │ ├── point_gnn_conv.py │ ├── point_transformer_conv.py │ ├── ppf_conv.py │ ├── propagate.jinja │ ├── res_gated_graph_conv.py │ ├── rgat_conv.py │ ├── rgcn_conv.py │ ├── sage_conv.py │ ├── sg_conv.py │ ├── signed_conv.py │ ├── simple_conv.py │ ├── spline_conv.py │ ├── ssg_conv.py │ ├── supergat_conv.py │ ├── tag_conv.py │ ├── transformer_conv.py │ ├── utils │ │ ├── __init__.py │ │ └── cheatsheet.py │ ├── wl_conv.py │ ├── wl_conv_continuous.py │ └── x_conv.py ├── data_parallel.py ├── dense │ ├── __init__.py │ ├── dense_gat_conv.py │ ├── dense_gcn_conv.py │ ├── dense_gin_conv.py │ ├── dense_graph_conv.py │ ├── dense_sage_conv.py │ ├── diff_pool.py │ ├── dmon_pool.py │ ├── linear.py │ └── mincut_pool.py ├── encoding.py ├── functional │ ├── __init__.py │ ├── bro.py │ └── gini.py ├── fx.py ├── glob.py ├── inits.py ├── kge │ ├── __init__.py │ ├── base.py │ ├── complex.py │ ├── distmult.py │ ├── loader.py │ ├── rotate.py │ └── transe.py ├── lr_scheduler.py ├── model_hub.py ├── models │ ├── __init__.py │ ├── attentive_fp.py │ ├── attract_repel.py │ ├── autoencoder.py │ ├── basic_gnn.py │ ├── captum.py │ ├── correct_and_smooth.py │ ├── deep_graph_infomax.py │ ├── deepgcn.py │ ├── dimenet.py │ ├── dimenet_utils.py │ ├── gnnff.py │ ├── gpse.py │ ├── graph_mixer.py │ ├── graph_unet.py │ ├── jumping_knowledge.py │ ├── label_prop.py │ ├── lightgcn.py │ ├── linkx.py │ ├── lpformer.py │ ├── mask_label.py │ ├── meta.py │ ├── metapath2vec.py │ ├── mlp.py │ ├── neural_fingerprint.py │ ├── node2vec.py │ ├── pmlp.py │ ├── polynormer.py │ ├── re_net.py │ ├── rect.py │ ├── rev_gnn.py │ ├── schnet.py │ ├── sgformer.py │ ├── signed_gcn.py │ ├── tgn.py │ └── visnet.py ├── module_dict.py ├── norm │ ├── __init__.py │ ├── batch_norm.py │ ├── diff_group_norm.py │ ├── graph_norm.py │ ├── graph_size_norm.py │ ├── instance_norm.py │ ├── layer_norm.py │ ├── mean_subtraction_norm.py │ ├── msg_norm.py │ └── pair_norm.py ├── parameter_dict.py ├── pool │ ├── __init__.py │ ├── approx_knn.py │ ├── asap.py │ ├── avg_pool.py │ ├── cluster_pool.py │ ├── connect │ │ ├── __init__.py │ │ ├── base.py │ │ └── filter_edges.py │ ├── consecutive.py │ ├── decimation.py │ ├── edge_pool.py │ ├── glob.py │ ├── graclus.py │ ├── knn.py │ ├── max_pool.py │ ├── mem_pool.py │ ├── pan_pool.py │ ├── pool.py │ ├── sag_pool.py │ ├── select │ │ ├── __init__.py │ │ ├── base.py │ │ └── topk.py │ ├── topk_pool.py │ └── voxel_grid.py ├── reshape.py ├── resolver.py ├── sequential.jinja ├── sequential.py ├── summary.py ├── to_fixed_size_transformer.py ├── to_hetero_module.py ├── to_hetero_transformer.py ├── to_hetero_with_bases_transformer.py └── unpool │ ├── __init__.py │ └── knn_interpolate.py ├── profile ├── __init__.py ├── benchmark.py ├── nvtx.py ├── profile.py ├── profiler.py └── utils.py ├── resolver.py ├── sampler ├── __init__.py ├── base.py ├── hgt_sampler.py ├── neighbor_sampler.py └── utils.py ├── seed.py ├── template.py ├── testing ├── __init__.py ├── asserts.py ├── data.py ├── decorators.py ├── distributed.py ├── feature_store.py └── graph_store.py ├── transforms ├── __init__.py ├── add_gpse.py ├── add_metapaths.py ├── add_positional_encoding.py ├── add_remaining_self_loops.py ├── add_self_loops.py ├── base_transform.py ├── cartesian.py ├── center.py ├── compose.py ├── constant.py ├── delaunay.py ├── distance.py ├── face_to_edge.py ├── feature_propagation.py ├── fixed_points.py ├── gcn_norm.py ├── gdc.py ├── generate_mesh_normals.py ├── grid_sampling.py ├── half_hop.py ├── knn_graph.py ├── laplacian_lambda_max.py ├── largest_connected_components.py ├── line_graph.py ├── linear_transformation.py ├── local_cartesian.py ├── local_degree_profile.py ├── mask.py ├── node_property_split.py ├── normalize_features.py ├── normalize_rotation.py ├── normalize_scale.py ├── one_hot_degree.py ├── pad.py ├── point_pair_features.py ├── polar.py ├── radius_graph.py ├── random_flip.py ├── random_jitter.py ├── random_link_split.py ├── random_node_split.py ├── random_rotate.py ├── random_scale.py ├── random_shear.py ├── remove_duplicated_edges.py ├── remove_isolated_nodes.py ├── remove_self_loops.py ├── remove_training_classes.py ├── rooted_subgraph.py ├── sample_points.py ├── sign.py ├── spherical.py ├── svd_feature_reduction.py ├── target_indegree.py ├── to_dense.py ├── to_device.py ├── to_sparse_tensor.py ├── to_superpixels.py ├── to_undirected.py ├── two_hop.py └── virtual_node.py ├── typing.py ├── utils ├── __init__.py ├── _assortativity.py ├── _coalesce.py ├── _degree.py ├── _grid.py ├── _homophily.py ├── _index_sort.py ├── _lexsort.py ├── _negative_sampling.py ├── _normalize_edge_index.py ├── _normalized_cut.py ├── _one_hot.py ├── _scatter.py ├── _segment.py ├── _select.py ├── _softmax.py ├── _sort_edge_index.py ├── _spmm.py ├── _subgraph.py ├── _to_dense_adj.py ├── _to_dense_batch.py ├── _train_test_split_edges.py ├── _tree_decomposition.py ├── _trim_to_layer.py ├── _unbatch.py ├── augmentation.py ├── convert.py ├── cross_entropy.py ├── dropout.py ├── embedding.py ├── functions.py ├── geodesic.py ├── hetero.py ├── influence.py ├── isolated.py ├── laplacian.py ├── loop.py ├── map.py ├── mask.py ├── mesh_laplacian.py ├── mixin.py ├── nested.py ├── noise_scheduler.py ├── num_nodes.py ├── ppr.py ├── random.py ├── repeat.py ├── smiles.py ├── sparse.py └── undirected.py ├── visualization ├── __init__.py ├── graph.py └── influence.py └── warnings.py /.github/CODEOWNERS: -------------------------------------------------------------------------------- 1 | # https://docs.github.com/en/repositories/managing-your-repositorys-settings-and-features/customizing-your-repository/about-code-owners 2 | 3 | * @rusty1s @akihironitta 4 | 5 | *.py @rusty1s @wsad1 @akihironitta 6 | 7 | /.github/ @rusty1s @akihironitta 8 | 9 | /.github/CODEOWNERS @rusty1s 10 | 11 | /torch_geometric/data/ @rusty1s @mananshah99 @akihironitta 12 | 13 | /torch_geometric/loader/ @rusty1s @mananshah99 @akihironitta 14 | 15 | /torch_geometric/sampler/ @rusty1s @mananshah99 @akihironitta 16 | 17 | /docs/ @rusty1s @akihironitta 18 | 19 | /torch_geometric/nn/conv/cugraph @tingyu66 20 | 21 | /examples/llm @puririshi98 22 | 23 | /torch_geometric/llm @puririshi98 24 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/config.yml: -------------------------------------------------------------------------------- 1 | blank_issues_enabled: false 2 | contact_links: 3 | - name: 🙏 Ask a Question 4 | url: https://github.com/pyg-team/pytorch_geometric/discussions/new 5 | about: Ask and answer PyG related questions 6 | - name: 💬 Slack 7 | url: https://join.slack.com/t/torchgeometricco/shared_invite/zt-p6br3yuo-BxRoe36OHHLF6jYU8xHtBA 8 | about: Chat with our community 9 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/documentation.yml: -------------------------------------------------------------------------------- 1 | name: "📚 Typos and Doc Fixes" 2 | description: "Tell us about how we can improve our documentation" 3 | labels: documentation 4 | 5 | body: 6 | - type: textarea 7 | attributes: 8 | label: 📚 Describe the documentation issue 9 | description: | 10 | A clear and concise description of the issue. 11 | validations: 12 | required: true 13 | - type: textarea 14 | attributes: 15 | label: Suggest a potential alternative/fix 16 | description: | 17 | Tell us how we could improve the documentation in this regard. 18 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature-request.yml: -------------------------------------------------------------------------------- 1 | name: "🚀 Feature Request" 2 | description: "Propose a new PyG feature" 3 | labels: feature 4 | 5 | body: 6 | - type: textarea 7 | attributes: 8 | label: 🚀 The feature, motivation and pitch 9 | description: > 10 | A clear and concise description of the feature proposal. Please outline the motivation for the proposal. Is your feature request related to a specific problem? e.g., *"I'm working on X and would like Y to be possible"*. If this is related to another GitHub issue, please link here too. 11 | validations: 12 | required: true 13 | - type: textarea 14 | attributes: 15 | label: Alternatives 16 | description: > 17 | A description of any alternative solutions or features you've considered, if any. 18 | - type: textarea 19 | attributes: 20 | label: Additional context 21 | description: > 22 | Add any other context or screenshots about the feature request. 23 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/refactor.yml: -------------------------------------------------------------------------------- 1 | name: "🛠 Refactor" 2 | description: "Suggest a code refactor or deprecation" 3 | labels: refactor 4 | 5 | body: 6 | - type: textarea 7 | attributes: 8 | label: 🛠 Proposed Refactor 9 | description: | 10 | A clear and concise description of the refactor proposal. Please outline the motivation for the proposal. If this is related to another GitHub issue, please link here too. 11 | validations: 12 | required: true 13 | - type: textarea 14 | attributes: 15 | label: Suggest a potential alternative/fix 16 | description: | 17 | Tell us how we could improve the code in this regard. 18 | validations: 19 | required: true 20 | -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | # https://docs.github.com/en/code-security/dependabot/working-with-dependabot/dependabot-options-reference 2 | version: 2 3 | updates: 4 | - package-ecosystem: "github-actions" 5 | directories: 6 | - "/" 7 | - "/.github/actions/setup" 8 | schedule: 9 | interval: "daily" 10 | time: "00:00" 11 | labels: 12 | - "ci" 13 | - "skip-changelog" 14 | pull-request-branch-name: 15 | separator: "-" 16 | open-pull-requests-limit: 10 17 | -------------------------------------------------------------------------------- /.github/workflows/auto-merge.yml: -------------------------------------------------------------------------------- 1 | name: Auto-merge Bot PRs 2 | 3 | on: # yamllint disable-line rule:truthy 4 | pull_request_target: 5 | types: [opened, reopened] 6 | 7 | permissions: 8 | contents: write 9 | pull-requests: write 10 | 11 | jobs: 12 | auto-merge: 13 | runs-on: ubuntu-latest 14 | if: ${{ github.event.pull_request.user.login == 'dependabot[bot]' || github.event.pull_request.user.login == 'pre-commit-ci[bot]' }} 15 | steps: 16 | - uses: actions/checkout@v5 17 | 18 | - name: Label bot PRs 19 | run: gh pr edit --add-label "ci,skip-changelog" ${{ github.event.pull_request.html_url }} 20 | env: 21 | GITHUB_TOKEN: ${{ secrets.PAT }} 22 | 23 | - name: Auto-approve 24 | uses: hmarr/auto-approve-action@v4 25 | with: 26 | github-token: ${{ secrets.PAT }} 27 | 28 | - name: Enable auto-merge 29 | run: gh pr merge --auto --squash ${{ github.event.pull_request.html_url }} 30 | env: 31 | GITHUB_TOKEN: ${{ secrets.PAT }} 32 | -------------------------------------------------------------------------------- /.github/workflows/changelog.yml: -------------------------------------------------------------------------------- 1 | name: Changelog Enforcer 2 | 3 | on: # yamllint disable-line rule:truthy 4 | pull_request: 5 | types: [opened, synchronize, reopened, ready_for_review, labeled, unlabeled] 6 | 7 | jobs: 8 | 9 | changelog: 10 | runs-on: ubuntu-latest 11 | 12 | steps: 13 | - name: Enforce changelog entry 14 | uses: dangoslen/changelog-enforcer@v3 15 | with: 16 | skipLabels: skip-changelog 17 | -------------------------------------------------------------------------------- /.github/workflows/labeler.yml: -------------------------------------------------------------------------------- 1 | name: PR Labeler 2 | 3 | on: # yamllint disable-line rule:truthy 4 | pull_request: 5 | 6 | jobs: 7 | 8 | assign-labels: 9 | if: github.repository == 'pyg-team/pytorch_geometric' 10 | runs-on: ubuntu-latest 11 | 12 | permissions: 13 | contents: read 14 | pull-requests: write 15 | 16 | steps: 17 | - name: Add PR labels 18 | uses: actions/labeler@v6 19 | continue-on-error: true 20 | with: 21 | repo-token: "${{ secrets.GITHUB_TOKEN }}" 22 | sync-labels: true 23 | 24 | assign-author: 25 | if: github.repository == 'pyg-team/pytorch_geometric' 26 | runs-on: ubuntu-latest 27 | 28 | steps: 29 | - name: Add PR author 30 | uses: samspills/assign-pr-to-author@v1.0 31 | continue-on-error: true 32 | if: github.event_name == 'pull_request' 33 | with: 34 | repo-token: "${{ secrets.GITHUB_TOKEN }}" 35 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | .pytest_cache/ 3 | .DS_Store 4 | data/ 5 | build/ 6 | dist/ 7 | alpha/ 8 | runs/ 9 | wandb/ 10 | .cache/ 11 | .eggs/ 12 | lightning_logs/ 13 | outputs/ 14 | graphgym/datasets/ 15 | graphgym/results/ 16 | *.egg-info/ 17 | .ipynb_checkpoints 18 | .coverage 19 | .coverage.* 20 | coverage.xml 21 | .vscode 22 | .idea 23 | .venv 24 | *.out 25 | *.pt 26 | *.onnx 27 | examples/**/*.png 28 | examples/**/*.pdf 29 | benchmark/results/ 30 | .mypy_cache/ 31 | uv.lock 32 | 33 | !torch_geometric/data/ 34 | !test/data/ 35 | -------------------------------------------------------------------------------- /benchmark/README.md: -------------------------------------------------------------------------------- 1 | # PyG Benchmark Suite 2 | 3 | This benchmark suite provides evaluation scripts for **[semi-supervised node classification](https://github.com/pyg-team/pytorch_geometric/tree/master/benchmark/citation)**, **[graph classification](https://github.com/pyg-team/pytorch_geometric/tree/master/benchmark/kernel)**, and **[point cloud classification](https://github.com/pyg-team/pytorch_geometric/tree/master/benchmark/points)** and **[runtimes](https://github.com/pyg-team/pytorch_geometric/tree/master/benchmark/runtime)** in order to compare various methods in homogeneous evaluation scenarios. 4 | In particular, we take care to avoid to perform hyperparameter and model selection on the test set and instead use an additional validation set. 5 | 6 | ## Installation 7 | 8 | ``` 9 | $ pip install -e . 10 | ``` 11 | -------------------------------------------------------------------------------- /benchmark/citation/__init__.py: -------------------------------------------------------------------------------- 1 | from .datasets import get_planetoid_dataset 2 | from .train_eval import random_planetoid_splits, run 3 | 4 | __all__ = [ 5 | 'get_planetoid_dataset', 6 | 'random_planetoid_splits', 7 | 'run', 8 | ] 9 | -------------------------------------------------------------------------------- /benchmark/citation/datasets.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | 3 | import torch_geometric.transforms as T 4 | from torch_geometric.datasets import Planetoid 5 | 6 | 7 | def get_planetoid_dataset(name, normalize_features=False, transform=None): 8 | path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', name) 9 | dataset = Planetoid(path, name) 10 | 11 | if transform is not None and normalize_features: 12 | dataset.transform = T.Compose([T.NormalizeFeatures(), transform]) 13 | elif normalize_features: 14 | dataset.transform = T.NormalizeFeatures() 15 | elif transform is not None: 16 | dataset.transform = transform 17 | 18 | return dataset 19 | -------------------------------------------------------------------------------- /benchmark/citation/statistics.py: -------------------------------------------------------------------------------- 1 | from citation import get_planetoid_dataset 2 | 3 | 4 | def print_dataset(dataset): 5 | data = dataset[0] 6 | print('Name', dataset) 7 | print('Nodes', data.num_nodes) 8 | print('Edges', data.num_edges // 2) 9 | print('Features', dataset.num_features) 10 | print('Classes', dataset.num_classes) 11 | print('Label rate', data.train_mask.sum().item() / data.num_nodes) 12 | print() 13 | 14 | 15 | for name in ['Cora', 'CiteSeer', 'PubMed']: 16 | print_dataset(get_planetoid_dataset(name)) 17 | -------------------------------------------------------------------------------- /benchmark/kernel/__init__.py: -------------------------------------------------------------------------------- 1 | from .datasets import get_dataset 2 | from .train_eval import cross_validation_with_val_set 3 | 4 | __all__ = [ 5 | 'get_dataset', 6 | 'cross_validation_with_val_set', 7 | ] 8 | -------------------------------------------------------------------------------- /benchmark/kernel/statistics.py: -------------------------------------------------------------------------------- 1 | from kernel.datasets import get_dataset 2 | 3 | 4 | def print_dataset(dataset): 5 | num_nodes = num_edges = 0 6 | for data in dataset: 7 | num_nodes += data.num_nodes 8 | num_edges += data.num_edges 9 | 10 | print('Name', dataset) 11 | print('Graphs', len(dataset)) 12 | print('Nodes', num_nodes / len(dataset)) 13 | print('Edges', (num_edges // 2) / len(dataset)) 14 | print('Features', dataset.num_features) 15 | print('Classes', dataset.num_classes) 16 | print() 17 | 18 | 19 | for name in ['MUTAG', 'PROTEINS', 'COLLAB', 'IMDB-BINARY', 'REDDIT-BINARY']: 20 | print_dataset(get_dataset(name)) 21 | -------------------------------------------------------------------------------- /benchmark/points/README.md: -------------------------------------------------------------------------------- 1 | # Point Cloud classification 2 | 3 | Evaluation scripts for various methods on the ModelNet10 dataset: 4 | 5 | - **[MPNN](https://github.com/pyg-team/pytorch_geometric/blob/master/benchmark/points/mpnn.py)**: `python mpnn.py` 6 | - **[PointNet++](https://github.com/pyg-team/pytorch_geometric/blob/master/benchmark/points/point_net.py)**: `python point_net.py` 7 | - **[EdgeCNN](https://github.com/pyg-team/pytorch_geometric/blob/master/benchmark/points/edge_cnn.py)**: `python edge_cnn.py` 8 | - **[SplineCNN](https://github.com/pyg-team/pytorch_geometric/blob/master/benchmark/points/spline_cnn.py)**: `python spline_cnn.py` 9 | - **[PointCNN](https://github.com/pyg-team/pytorch_geometric/blob/master/benchmark/points/point_cnn.py)**: `python point_cnn.py` 10 | -------------------------------------------------------------------------------- /benchmark/points/__init__.py: -------------------------------------------------------------------------------- 1 | from .datasets import get_dataset 2 | from .train_eval import run 3 | 4 | __all__ = [ 5 | 'get_dataset', 6 | 'run', 7 | ] 8 | -------------------------------------------------------------------------------- /benchmark/points/datasets.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | 3 | import torch_geometric.transforms as T 4 | from torch_geometric.datasets import ModelNet 5 | 6 | 7 | def get_dataset(num_points): 8 | name = 'ModelNet10' 9 | path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', name) 10 | pre_transform = T.NormalizeScale() 11 | transform = T.SamplePoints(num_points) 12 | 13 | train_dataset = ModelNet(path, name='10', train=True, transform=transform, 14 | pre_transform=pre_transform) 15 | test_dataset = ModelNet(path, name='10', train=False, transform=transform, 16 | pre_transform=pre_transform) 17 | 18 | return train_dataset, test_dataset 19 | -------------------------------------------------------------------------------- /benchmark/points/statistics.py: -------------------------------------------------------------------------------- 1 | from points.datasets import get_dataset 2 | 3 | from torch_geometric.transforms import RadiusGraph 4 | 5 | 6 | def print_dataset(train_dataset, test_dataset): 7 | num_nodes = num_edges = 0 8 | for data in train_dataset: 9 | data = RadiusGraph(0.2)(data) 10 | num_nodes += data.num_nodes 11 | num_edges += data.num_edges 12 | for data in test_dataset: 13 | data = RadiusGraph(0.2)(data) 14 | num_nodes += data.num_nodes 15 | num_edges += data.num_edges 16 | 17 | num_graphs = len(train_dataset) + len(test_dataset) 18 | print('Graphs', num_graphs) 19 | print('Nodes', num_nodes / num_graphs) 20 | print('Edges', (num_edges // 2) / num_graphs) 21 | print('Label rate', len(train_dataset) / num_graphs) 22 | print() 23 | 24 | 25 | print_dataset(*get_dataset(num_points=1024)) 26 | -------------------------------------------------------------------------------- /benchmark/runtime/README.md: -------------------------------------------------------------------------------- 1 | # Runtimes 2 | 3 | Run the test suite for PyG via 4 | 5 | ``` 6 | python main.py 7 | ``` 8 | 9 | Install `dgl` and run the test suite for DGL via 10 | 11 | ``` 12 | cd dgl 13 | python main.py 14 | ``` 15 | -------------------------------------------------------------------------------- /benchmark/runtime/__init__.py: -------------------------------------------------------------------------------- 1 | from .train import train_runtime 2 | 3 | __all__ = [ 4 | 'train_runtime', 5 | ] 6 | -------------------------------------------------------------------------------- /benchmark/runtime/dgl/hidden.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import warnings 4 | 5 | warnings.filterwarnings('ignore') 6 | 7 | 8 | class HiddenPrint: 9 | def __enter__(self): 10 | self._original_stdout = sys.stdout 11 | sys.stdout = open(os.devnull, 'w') 12 | 13 | def __exit__(self, exc_type, exc_val, exc_tb): 14 | sys.stdout.close() 15 | sys.stdout = self._original_stdout 16 | -------------------------------------------------------------------------------- /benchmark/runtime/gat.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | from torch_geometric.nn import GATConv 5 | 6 | 7 | class GAT(torch.nn.Module): 8 | def __init__(self, in_channels, out_channels): 9 | super().__init__() 10 | self.conv1 = GATConv(in_channels, 8, heads=8, dropout=0.6) 11 | self.conv2 = GATConv(8 * 8, out_channels, dropout=0.6) 12 | 13 | def forward(self, data): 14 | x, edge_index = data.x, data.edge_index 15 | x = F.dropout(x, p=0.6, training=self.training) 16 | x = F.elu(self.conv1(x, edge_index)) 17 | x = F.dropout(x, p=0.6, training=self.training) 18 | x = self.conv2(x, edge_index) 19 | return F.log_softmax(x, dim=1) 20 | -------------------------------------------------------------------------------- /benchmark/runtime/gcn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | from torch_geometric.nn import GCNConv 5 | 6 | 7 | class GCN(torch.nn.Module): 8 | def __init__(self, in_channels, out_channels): 9 | super().__init__() 10 | self.conv1 = GCNConv(in_channels, 16, cached=True) 11 | self.conv2 = GCNConv(16, out_channels, cached=True) 12 | 13 | def forward(self, data): 14 | x, edge_index = data.x, data.edge_index 15 | x = F.relu(self.conv1(x, edge_index)) 16 | x = F.dropout(x, training=self.training) 17 | x = self.conv2(x, edge_index) 18 | return F.log_softmax(x, dim=1) 19 | -------------------------------------------------------------------------------- /benchmark/runtime/rgcn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | from torch_geometric.nn import FastRGCNConv 5 | 6 | 7 | class RGCN(torch.nn.Module): 8 | def __init__(self, in_channels, out_channels, num_relations): 9 | super().__init__() 10 | self.conv1 = FastRGCNConv(in_channels, 16, num_relations, num_bases=30) 11 | self.conv2 = FastRGCNConv(16, out_channels, num_relations, 12 | num_bases=30) 13 | 14 | def forward(self, data): 15 | edge_index, edge_type = data.edge_index, data.edge_type 16 | x = F.relu(self.conv1(None, edge_index, edge_type)) 17 | x = self.conv2(x, edge_index, edge_type) 18 | return F.log_softmax(x, dim=1) 19 | -------------------------------------------------------------------------------- /benchmark/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import find_packages, setup 2 | 3 | setup( 4 | name='torch_geometric_benchmark', 5 | version='0.1.0', 6 | description='PyG Benchmark Suite', 7 | author='Matthias Fey', 8 | author_email='matthias.fey@tu-dortmund.de', 9 | url='https://github.com/pyg-team/pytorch_geometric_benchmark', 10 | install_requires=['scikit-learn'], 11 | packages=find_packages(), 12 | ) 13 | -------------------------------------------------------------------------------- /benchmark/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .utils import emit_itt 2 | from .utils import get_dataset, get_dataset_with_transformation 3 | from .utils import get_model 4 | from .utils import get_split_masks 5 | from .utils import save_benchmark_data, write_to_csv 6 | from .utils import test 7 | 8 | __all__ = [ 9 | 'emit_itt', 10 | 'get_dataset', 11 | 'get_dataset_with_transformation', 12 | 'get_model', 13 | 'get_split_masks', 14 | 'save_benchmark_data', 15 | 'write_to_csv', 16 | 'test', 17 | ] 18 | -------------------------------------------------------------------------------- /codecov.yml: -------------------------------------------------------------------------------- 1 | # See: https://docs.codecov.io/docs/codecov-yaml 2 | coverage: 3 | range: 80..100 4 | round: down 5 | precision: 2 6 | status: 7 | project: 8 | default: 9 | target: 80% 10 | threshold: 1% 11 | patch: 12 | default: 13 | target: 80% 14 | threshold: 1% 15 | -------------------------------------------------------------------------------- /docker/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvcr.io/nvidia/cuda-dl-base:24.09-cuda12.6-devel-ubuntu22.04 2 | 3 | # Based on NGC PyG 24.09 image: 4 | # https://docs.nvidia.com/deeplearning/frameworks/pyg-release-notes/rel-24-09.html#rel-24-09 5 | 6 | # install pip 7 | RUN apt-get update && apt-get install -y python3-pip 8 | 9 | # install PyTorch - latest stable version 10 | RUN pip install torch torchvision torchaudio 11 | 12 | # install graphviz - latest stable version 13 | RUN apt-get install -y graphviz graphviz-dev 14 | RUN pip install pygraphviz 15 | 16 | # install python packages with NGC PyG 24.09 image versions 17 | RUN pip install torch_geometric==2.6.0 18 | RUN pip install triton==3.0.0 numba==0.59.0 requests==2.32.3 opencv-python==4.7.0.72 scipy==1.14.0 jupyterlab==4.2.5 19 | 20 | # install cugraph 21 | RUN pip install cugraph-cu12 cugraph-pyg-cu12 --extra-index-url=https://pypi.nvidia.com 22 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | SPHINXBUILD = sphinx-build 2 | SPHINXPROJ = pytorch_geometric 3 | SOURCEDIR = source 4 | BUILDDIR = build 5 | 6 | .PHONY: help Makefile 7 | 8 | %: Makefile 9 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(0) 10 | -------------------------------------------------------------------------------- /docs/README.md: -------------------------------------------------------------------------------- 1 | # Building Documentation 2 | 3 | To build the documentation: 4 | 5 | 1. [Build and install](https://github.com/pyg-team/pytorch_geometric/blob/master/.github/CONTRIBUTING.md#developing-pytorch-geometric) PyG from source. 6 | 1. Install [Sphinx](https://www.sphinx-doc.org/en/master/) theme via 7 | ``` 8 | pip install git+https://github.com/pyg-team/pyg_sphinx_theme.git 9 | ``` 10 | 1. Generate the documentation file via: 11 | ``` 12 | cd docs 13 | make html 14 | ``` 15 | 16 | The documentation is now available to view by opening `docs/build/html/index.html`. 17 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | https://download.pytorch.org/whl/cpu/torch-2.8.0%2Bcpu-cp310-cp310-manylinux_2_28_x86_64.whl 2 | numpy>=1.19.5 3 | git+https://github.com/pyg-team/pyg_sphinx_theme.git 4 | -------------------------------------------------------------------------------- /docs/source/.gitignore: -------------------------------------------------------------------------------- 1 | generated/ 2 | -------------------------------------------------------------------------------- /docs/source/_figures/.gitignore: -------------------------------------------------------------------------------- 1 | *.aux 2 | *.log 3 | *.pdf 4 | -------------------------------------------------------------------------------- /docs/source/_figures/architecture.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pyg-team/pytorch_geometric/6c22779538bba0cc6650110f8edcfcd05da5acb4/docs/source/_figures/architecture.pdf -------------------------------------------------------------------------------- /docs/source/_figures/build.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | for filename in *.tex; do 4 | basename=$(basename $filename .tex) 5 | pdflatex "$basename.tex" 6 | pdf2svg "$basename.pdf" "$basename.svg" 7 | done 8 | -------------------------------------------------------------------------------- /docs/source/_figures/dist_part.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pyg-team/pytorch_geometric/6c22779538bba0cc6650110f8edcfcd05da5acb4/docs/source/_figures/dist_part.png -------------------------------------------------------------------------------- /docs/source/_figures/dist_proc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pyg-team/pytorch_geometric/6c22779538bba0cc6650110f8edcfcd05da5acb4/docs/source/_figures/dist_proc.png -------------------------------------------------------------------------------- /docs/source/_figures/dist_sampling.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pyg-team/pytorch_geometric/6c22779538bba0cc6650110f8edcfcd05da5acb4/docs/source/_figures/dist_sampling.png -------------------------------------------------------------------------------- /docs/source/_figures/graph.tex: -------------------------------------------------------------------------------- 1 | \documentclass{standalone} 2 | 3 | \usepackage{tikz} 4 | 5 | \begin{document} 6 | 7 | \begin{tikzpicture} 8 | \node[draw,circle,label= left:{$x_1=-1$}] (0) at (0, 0) {0}; 9 | \node[draw,circle,label=above:{$x_1=0$}] (1) at (1, 1) {1}; 10 | \node[draw,circle,label=right:{$x_1=1$}] (2) at (2, 0) {2}; 11 | 12 | \path[draw] (0) -- (1); 13 | \path[draw] (1) -- (2); 14 | \end{tikzpicture} 15 | 16 | \end{document} 17 | -------------------------------------------------------------------------------- /docs/source/_figures/graphgps_layer.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pyg-team/pytorch_geometric/6c22779538bba0cc6650110f8edcfcd05da5acb4/docs/source/_figures/graphgps_layer.png -------------------------------------------------------------------------------- /docs/source/_figures/graphgym_design_space.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pyg-team/pytorch_geometric/6c22779538bba0cc6650110f8edcfcd05da5acb4/docs/source/_figures/graphgym_design_space.png -------------------------------------------------------------------------------- /docs/source/_figures/graphgym_evaluation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pyg-team/pytorch_geometric/6c22779538bba0cc6650110f8edcfcd05da5acb4/docs/source/_figures/graphgym_evaluation.png -------------------------------------------------------------------------------- /docs/source/_figures/graphgym_results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pyg-team/pytorch_geometric/6c22779538bba0cc6650110f8edcfcd05da5acb4/docs/source/_figures/graphgym_results.png -------------------------------------------------------------------------------- /docs/source/_figures/hg_example.tex: -------------------------------------------------------------------------------- 1 | \documentclass{standalone} 2 | 3 | \usepackage{tikz} 4 | 5 | \begin{document} 6 | 7 | \begin{tikzpicture} 8 | \node[draw,rectangle, align=center] (0) at (0, 0) {\textbf{Author}\\ $1,134,649$ nodes}; 9 | \node[draw,rectangle, align=center] (1) at (4, 2) {\textbf{Paper}\\ $736,389$ nodes}; 10 | \node[draw,rectangle, align=center] (2) at (8, 0) {\textbf{Institution}\\ $8,740$ nodes}; 11 | \node[draw,rectangle, align=center] (3) at (4, 4) {\textbf{Field of Study}\\ $59,965$ nodes}; 12 | 13 | \path[->,>=stealth] (0) edge [above left] node[align=center] {\textbf{writes}\\$7,145,660$ edges} (1.south); 14 | \path[->,>=stealth] (0) edge [below] node[align=center] {\textbf{affiliated with}\\$1,043,998$ edges} (2); 15 | \path[->,>=stealth,every loop/.style={looseness=3}] (1) edge [out=350, in=10, loop, right] node[align=center] {\textbf{cites}\\$5,416,271$ edges} (1); 16 | \path[->,>=stealth] (1) edge [left] node[align=center] {\textbf{has topic}\\$7,505,078$ edges} (3); 17 | \end{tikzpicture} 18 | 19 | \end{document} 20 | -------------------------------------------------------------------------------- /docs/source/_figures/intel_kumo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pyg-team/pytorch_geometric/6c22779538bba0cc6650110f8edcfcd05da5acb4/docs/source/_figures/intel_kumo.png -------------------------------------------------------------------------------- /docs/source/_figures/point_cloud1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pyg-team/pytorch_geometric/6c22779538bba0cc6650110f8edcfcd05da5acb4/docs/source/_figures/point_cloud1.png -------------------------------------------------------------------------------- /docs/source/_figures/point_cloud2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pyg-team/pytorch_geometric/6c22779538bba0cc6650110f8edcfcd05da5acb4/docs/source/_figures/point_cloud2.png -------------------------------------------------------------------------------- /docs/source/_figures/point_cloud3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pyg-team/pytorch_geometric/6c22779538bba0cc6650110f8edcfcd05da5acb4/docs/source/_figures/point_cloud3.png -------------------------------------------------------------------------------- /docs/source/_figures/point_cloud4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pyg-team/pytorch_geometric/6c22779538bba0cc6650110f8edcfcd05da5acb4/docs/source/_figures/point_cloud4.png -------------------------------------------------------------------------------- /docs/source/_figures/remote_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pyg-team/pytorch_geometric/6c22779538bba0cc6650110f8edcfcd05da5acb4/docs/source/_figures/remote_1.png -------------------------------------------------------------------------------- /docs/source/_figures/remote_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pyg-team/pytorch_geometric/6c22779538bba0cc6650110f8edcfcd05da5acb4/docs/source/_figures/remote_2.png -------------------------------------------------------------------------------- /docs/source/_figures/remote_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pyg-team/pytorch_geometric/6c22779538bba0cc6650110f8edcfcd05da5acb4/docs/source/_figures/remote_3.png -------------------------------------------------------------------------------- /docs/source/_figures/shallow_node_embeddings.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pyg-team/pytorch_geometric/6c22779538bba0cc6650110f8edcfcd05da5acb4/docs/source/_figures/shallow_node_embeddings.png -------------------------------------------------------------------------------- /docs/source/_figures/training_affinity.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pyg-team/pytorch_geometric/6c22779538bba0cc6650110f8edcfcd05da5acb4/docs/source/_figures/training_affinity.png -------------------------------------------------------------------------------- /docs/source/_static/js/version_alert.js: -------------------------------------------------------------------------------- 1 | function warnOnLatestVersion() { 2 | if (!window.READTHEDOCS_DATA || window.READTHEDOCS_DATA.version !== "latest") { 3 | return; // not on ReadTheDocs and not latest. 4 | } 5 | 6 | var note = document.createElement('div'); 7 | note.setAttribute('class', 'admonition note'); 8 | note.innerHTML = "

Note

" + 9 | "

" + 10 | "This documentation is for an unreleased development version. " + 11 | "Click here to access the documentation of the current stable release." + 12 | "

"; 13 | 14 | var parent = document.querySelector('#pyg-documentation'); 15 | if (parent) 16 | parent.insertBefore(note, parent.querySelector('h1')); 17 | } 18 | 19 | document.addEventListener('DOMContentLoaded', warnOnLatestVersion); 20 | -------------------------------------------------------------------------------- /docs/source/_static/thumbnails/create_dataset.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pyg-team/pytorch_geometric/6c22779538bba0cc6650110f8edcfcd05da5acb4/docs/source/_static/thumbnails/create_dataset.png -------------------------------------------------------------------------------- /docs/source/_static/thumbnails/create_gnn.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pyg-team/pytorch_geometric/6c22779538bba0cc6650110f8edcfcd05da5acb4/docs/source/_static/thumbnails/create_gnn.png -------------------------------------------------------------------------------- /docs/source/_static/thumbnails/dataset_splitting.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pyg-team/pytorch_geometric/6c22779538bba0cc6650110f8edcfcd05da5acb4/docs/source/_static/thumbnails/dataset_splitting.png -------------------------------------------------------------------------------- /docs/source/_static/thumbnails/distributed_pyg.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pyg-team/pytorch_geometric/6c22779538bba0cc6650110f8edcfcd05da5acb4/docs/source/_static/thumbnails/distributed_pyg.png -------------------------------------------------------------------------------- /docs/source/_static/thumbnails/explain.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pyg-team/pytorch_geometric/6c22779538bba0cc6650110f8edcfcd05da5acb4/docs/source/_static/thumbnails/explain.png -------------------------------------------------------------------------------- /docs/source/_static/thumbnails/graph_transformer.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pyg-team/pytorch_geometric/6c22779538bba0cc6650110f8edcfcd05da5acb4/docs/source/_static/thumbnails/graph_transformer.png -------------------------------------------------------------------------------- /docs/source/_static/thumbnails/heterogeneous.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pyg-team/pytorch_geometric/6c22779538bba0cc6650110f8edcfcd05da5acb4/docs/source/_static/thumbnails/heterogeneous.png -------------------------------------------------------------------------------- /docs/source/_static/thumbnails/load_csv.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pyg-team/pytorch_geometric/6c22779538bba0cc6650110f8edcfcd05da5acb4/docs/source/_static/thumbnails/load_csv.png -------------------------------------------------------------------------------- /docs/source/_static/thumbnails/multi_gpu_vanilla.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pyg-team/pytorch_geometric/6c22779538bba0cc6650110f8edcfcd05da5acb4/docs/source/_static/thumbnails/multi_gpu_vanilla.png -------------------------------------------------------------------------------- /docs/source/_static/thumbnails/neighbor_loader.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pyg-team/pytorch_geometric/6c22779538bba0cc6650110f8edcfcd05da5acb4/docs/source/_static/thumbnails/neighbor_loader.png -------------------------------------------------------------------------------- /docs/source/_static/thumbnails/point_cloud.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pyg-team/pytorch_geometric/6c22779538bba0cc6650110f8edcfcd05da5acb4/docs/source/_static/thumbnails/point_cloud.png -------------------------------------------------------------------------------- /docs/source/_static/thumbnails/shallow_node_embeddings.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pyg-team/pytorch_geometric/6c22779538bba0cc6650110f8edcfcd05da5acb4/docs/source/_static/thumbnails/shallow_node_embeddings.png -------------------------------------------------------------------------------- /docs/source/_templates/autosummary/class.rst: -------------------------------------------------------------------------------- 1 | {{ fullname | escape | underline}} 2 | 3 | .. currentmodule:: {{ module }} 4 | 5 | .. autoclass:: {{ objname }} 6 | :show-inheritance: 7 | :members: 8 | -------------------------------------------------------------------------------- /docs/source/_templates/autosummary/inherited_class.rst: -------------------------------------------------------------------------------- 1 | {{ fullname | escape | underline}} 2 | 3 | .. currentmodule:: {{ module }} 4 | 5 | .. autoclass:: {{ objname }} 6 | :show-inheritance: 7 | :members: 8 | :inherited-members: 9 | :special-members: __cat_dim__, __inc__ 10 | -------------------------------------------------------------------------------- /docs/source/_templates/autosummary/metrics.rst: -------------------------------------------------------------------------------- 1 | {{ fullname | escape | underline}} 2 | 3 | .. currentmodule:: {{ module }} 4 | 5 | .. autoclass:: {{ objname }} 6 | :show-inheritance: 7 | :members: update, compute, reset 8 | -------------------------------------------------------------------------------- /docs/source/_templates/autosummary/nn.rst: -------------------------------------------------------------------------------- 1 | {{ fullname | escape | underline}} 2 | 3 | .. currentmodule:: {{ module }} 4 | 5 | {% if objname != "MessagePassing" %} 6 | .. autoclass:: {{ objname }} 7 | :show-inheritance: 8 | :members: 9 | :exclude-members: forward, reset_parameters, message, message_and_aggregate, edge_update, aggregate, update 10 | 11 | .. automethod:: forward 12 | .. automethod:: reset_parameters 13 | {% else %} 14 | .. autoclass:: {{ objname }} 15 | :show-inheritance: 16 | :members: 17 | {% endif %} 18 | -------------------------------------------------------------------------------- /docs/source/_templates/autosummary/only_class.rst: -------------------------------------------------------------------------------- 1 | {{ fullname | escape | underline}} 2 | 3 | .. currentmodule:: {{ module }} 4 | 5 | .. autoclass:: {{ objname }} 6 | :show-inheritance: 7 | -------------------------------------------------------------------------------- /docs/source/modules/distributed.rst: -------------------------------------------------------------------------------- 1 | torch_geometric.distributed 2 | =========================== 3 | 4 | .. warning:: 5 | ``torch_geometric.distributed`` has been deprecated since 2.7.0 and will 6 | no longer be maintained. For distributed training, refer to :ref:`our 7 | tutorials on distributed training ` or `cuGraph 8 | examples `_. 9 | 10 | .. currentmodule:: torch_geometric.distributed 11 | 12 | .. autosummary:: 13 | :nosignatures: 14 | {% for cls in torch_geometric.distributed.classes %} 15 | {{ cls }} 16 | {% endfor %} 17 | 18 | .. automodule:: torch_geometric.distributed 19 | :members: 20 | -------------------------------------------------------------------------------- /docs/source/modules/llm.rst: -------------------------------------------------------------------------------- 1 | torch_geometric.llm 2 | ======================= 3 | 4 | .. currentmodule:: torch_geometric.llm 5 | 6 | .. autosummary:: 7 | :nosignatures: 8 | {% for cls in torch_geometric.llm.classes %} 9 | {{ cls }} 10 | {% endfor %} 11 | 12 | .. automodule:: torch_geometric.llm 13 | :members: 14 | 15 | 16 | Models 17 | ---------------- 18 | 19 | .. currentmodule:: torch_geometric.llm.models 20 | 21 | .. autosummary:: 22 | :nosignatures: 23 | :toctree: ../generated 24 | 25 | {% for name in torch_geometric.llm.models.classes %} 26 | {{ name }} 27 | {% endfor %} 28 | 29 | Utils 30 | ---------------- 31 | 32 | .. currentmodule:: torch_geometric.llm.utils 33 | 34 | .. autosummary:: 35 | :nosignatures: 36 | :toctree: ../generated 37 | 38 | {% for name in torch_geometric.llm.utils.classes %} 39 | {{ name }} 40 | {% endfor %} 41 | -------------------------------------------------------------------------------- /docs/source/modules/loader.rst: -------------------------------------------------------------------------------- 1 | torch_geometric.loader 2 | ====================== 3 | 4 | .. currentmodule:: torch_geometric.loader 5 | 6 | .. autosummary:: 7 | :nosignatures: 8 | {% for cls in torch_geometric.loader.classes %} 9 | {{ cls }} 10 | {% endfor %} 11 | 12 | .. automodule:: torch_geometric.loader 13 | :members: 14 | -------------------------------------------------------------------------------- /docs/source/modules/metrics.rst: -------------------------------------------------------------------------------- 1 | torch_geometric.metrics 2 | ======================= 3 | 4 | .. contents:: Contents 5 | :local: 6 | 7 | Link Prediction Metrics 8 | ----------------------- 9 | 10 | .. currentmodule:: torch_geometric.metrics 11 | 12 | .. autosummary:: 13 | :nosignatures: 14 | :toctree: ../generated 15 | :template: autosummary/metrics.rst 16 | 17 | {% for name in torch_geometric.metrics.link_pred_metrics %} 18 | {{ name }} 19 | {% endfor %} 20 | -------------------------------------------------------------------------------- /docs/source/modules/profile.rst: -------------------------------------------------------------------------------- 1 | torch_geometric.profile 2 | ======================= 3 | 4 | .. currentmodule:: torch_geometric.profile 5 | 6 | .. autosummary:: 7 | :nosignatures: 8 | {% for cls in torch_geometric.profile.classes %} 9 | {{ cls }} 10 | {% endfor %} 11 | 12 | .. automodule:: torch_geometric.profile 13 | :members: 14 | :undoc-members: 15 | -------------------------------------------------------------------------------- /docs/source/modules/root.rst: -------------------------------------------------------------------------------- 1 | torch_geometric 2 | =============== 3 | 4 | Tensor Objects 5 | -------------- 6 | 7 | .. currentmodule:: torch_geometric 8 | 9 | .. autosummary:: 10 | :nosignatures: 11 | :toctree: ../generated 12 | 13 | Index 14 | EdgeIndex 15 | HashTensor 16 | 17 | Functions 18 | --------- 19 | 20 | .. automodule:: torch_geometric.seed 21 | :members: 22 | 23 | .. automodule:: torch_geometric.home 24 | :members: 25 | 26 | .. automodule:: torch_geometric._compile 27 | :members: 28 | :exclude-members: compile 29 | 30 | .. automodule:: torch_geometric.debug 31 | :members: 32 | 33 | .. automodule:: torch_geometric.experimental 34 | :members: 35 | -------------------------------------------------------------------------------- /docs/source/modules/sampler.rst: -------------------------------------------------------------------------------- 1 | torch_geometric.sampler 2 | ======================= 3 | 4 | .. currentmodule:: torch_geometric.sampler 5 | 6 | .. autosummary:: 7 | :nosignatures: 8 | {% for cls in torch_geometric.sampler.classes %} 9 | {{ cls }} 10 | {% endfor %} 11 | 12 | .. autoclass:: torch_geometric.sampler.BaseSampler 13 | :members: 14 | 15 | .. automodule:: torch_geometric.sampler 16 | :members: 17 | :exclude-members: sample_from_nodes, sample_from_edges, edge_permutation, BaseSampler 18 | -------------------------------------------------------------------------------- /docs/source/modules/utils.rst: -------------------------------------------------------------------------------- 1 | torch_geometric.utils 2 | ===================== 3 | 4 | .. currentmodule:: torch_geometric.utils 5 | 6 | .. autosummary:: 7 | :nosignatures: 8 | {% for cls in torch_geometric.utils.classes %} 9 | {{ cls }} 10 | {% endfor %} 11 | 12 | .. automodule:: torch_geometric.utils 13 | :members: 14 | -------------------------------------------------------------------------------- /docs/source/notes/batching.rst: -------------------------------------------------------------------------------- 1 | :orphan: 2 | 3 | .. include:: ../advanced/batching.rst 4 | -------------------------------------------------------------------------------- /docs/source/notes/colabs.rst: -------------------------------------------------------------------------------- 1 | :orphan: 2 | 3 | .. include:: ../get_started/colabs.rst 4 | -------------------------------------------------------------------------------- /docs/source/notes/create_dataset.rst: -------------------------------------------------------------------------------- 1 | :orphan: 2 | 3 | .. include:: ../tutorial/create_dataset.rst 4 | -------------------------------------------------------------------------------- /docs/source/notes/create_gnn.rst: -------------------------------------------------------------------------------- 1 | :orphan: 2 | 3 | .. include:: ../tutorial/create_gnn.rst 4 | -------------------------------------------------------------------------------- /docs/source/notes/explain.rst: -------------------------------------------------------------------------------- 1 | :orphan: 2 | 3 | .. include:: ../tutorial/explain.rst 4 | -------------------------------------------------------------------------------- /docs/source/notes/graphgym.rst: -------------------------------------------------------------------------------- 1 | :orphan: 2 | 3 | .. include:: ../advanced/graphgym.rst 4 | -------------------------------------------------------------------------------- /docs/source/notes/heterogeneous.rst: -------------------------------------------------------------------------------- 1 | :orphan: 2 | 3 | .. include:: ../tutorial/heterogeneous.rst 4 | -------------------------------------------------------------------------------- /docs/source/notes/installation.rst: -------------------------------------------------------------------------------- 1 | :orphan: 2 | 3 | .. include:: ../install/installation.rst 4 | -------------------------------------------------------------------------------- /docs/source/notes/introduction.rst: -------------------------------------------------------------------------------- 1 | :orphan: 2 | 3 | .. include:: ../get_started/introduction.rst 4 | -------------------------------------------------------------------------------- /docs/source/notes/jit.rst: -------------------------------------------------------------------------------- 1 | :orphan: 2 | 3 | .. include:: ../advanced/jit.rst 4 | -------------------------------------------------------------------------------- /docs/source/notes/load_csv.rst: -------------------------------------------------------------------------------- 1 | :orphan: 2 | 3 | .. include:: ../tutorial/load_csv.rst 4 | -------------------------------------------------------------------------------- /docs/source/notes/remote.rst: -------------------------------------------------------------------------------- 1 | :orphan: 2 | 3 | .. include:: ../advanced/remote.rst 4 | -------------------------------------------------------------------------------- /docs/source/notes/resources.rst: -------------------------------------------------------------------------------- 1 | :orphan: 2 | 3 | .. include:: ../external/resources.rst 4 | -------------------------------------------------------------------------------- /docs/source/notes/sparse_tensor.rst: -------------------------------------------------------------------------------- 1 | :orphan: 2 | 3 | .. include:: ../advanced/sparse_tensor.rst 4 | -------------------------------------------------------------------------------- /docs/source/tutorial/application.rst: -------------------------------------------------------------------------------- 1 | Use-Cases & Applications 2 | ======================== 3 | 4 | .. nbgallery:: 5 | :name: rst-gallery 6 | 7 | neighbor_loader 8 | point_cloud 9 | explain 10 | shallow_node_embeddings 11 | graph_transformer 12 | -------------------------------------------------------------------------------- /docs/source/tutorial/compile.rst: -------------------------------------------------------------------------------- 1 | :orphan: 2 | 3 | .. include:: ../advanced/compile.rst 4 | -------------------------------------------------------------------------------- /docs/source/tutorial/dataset.rst: -------------------------------------------------------------------------------- 1 | Working with Graph Datasets 2 | =========================== 3 | 4 | .. nbgallery:: 5 | :name: rst-gallery 6 | 7 | create_dataset 8 | load_csv 9 | dataset_splitting 10 | -------------------------------------------------------------------------------- /docs/source/tutorial/distributed.rst: -------------------------------------------------------------------------------- 1 | .. _distributed_tutorials: 2 | 3 | Distributed Training 4 | ==================== 5 | 6 | .. nbgallery:: 7 | :name: rst-gallery 8 | 9 | multi_gpu_vanilla 10 | multi_node_multi_gpu_vanilla 11 | distributed_pyg 12 | -------------------------------------------------------------------------------- /docs/source/tutorial/gnn_design.rst: -------------------------------------------------------------------------------- 1 | Design of Graph Neural Networks 2 | =============================== 3 | 4 | .. nbgallery:: 5 | :name: rst-gallery 6 | 7 | create_gnn 8 | heterogeneous 9 | -------------------------------------------------------------------------------- /examples/cpp/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.10) 2 | project(hello-world) 3 | 4 | # The first thing do is to tell cmake to find the TorchScatter 5 | # and TorchSparse libraries. The package pulls in all the necessary 6 | # torch libraries, so there is no need to add `find_package(Torch)`. 7 | find_package(TorchScatter REQUIRED) 8 | find_package(TorchSparse REQUIRED) 9 | 10 | find_package(Python3 COMPONENTS Development) 11 | 12 | add_executable(hello-world main.cpp) 13 | 14 | # We now need to link the TorchScatter and TorchSparse libraries 15 | # to our executable. We can do that by using the 16 | # TorchScatter::TorchScatter and TorchSparse::TorchSparse targets, 17 | # which also adds all the necessary torch dependencies. 18 | target_compile_features(hello-world PUBLIC cxx_range_for) 19 | target_link_libraries(hello-world TorchScatter::TorchScatter) 20 | target_link_libraries(hello-world TorchSparse::TorchSparse) 21 | target_link_libraries(hello-world ${CUDA_cusparse_LIBRARY}) 22 | set_property(TARGET hello-world PROPERTY CXX_STANDARD 14) 23 | -------------------------------------------------------------------------------- /examples/cpp/main.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | #include 6 | 7 | int main(int argc, const char *argv[]) { 8 | if (argc != 2) { 9 | std::cerr << "usage: hello-world \n"; 10 | return -1; 11 | } 12 | 13 | torch::jit::script::Module model; 14 | try { 15 | model = torch::jit::load(argv[1]); 16 | } catch (const c10::Error &e) { 17 | std::cerr << "error loading the model\n"; 18 | return -1; 19 | } 20 | 21 | auto x = torch::randn({5, 32}); 22 | auto edge_index = torch::tensor({ 23 | {0, 1, 1, 2, 2, 3, 3, 4}, 24 | {1, 0, 2, 1, 3, 2, 4, 3}, 25 | }); 26 | 27 | std::vector inputs; 28 | inputs.push_back(x); 29 | inputs.push_back(edge_index); 30 | 31 | auto out = model.forward(inputs).toTensor(); 32 | std::cout << "output tensor shape: " << out.sizes() << std::endl; 33 | } 34 | -------------------------------------------------------------------------------- /examples/distributed/README.md: -------------------------------------------------------------------------------- 1 | # Examples for Distributed Graph Learning 2 | 3 | This directory contains examples for distributed graph learning. 4 | The examples are organized into two subdirectories: 5 | 6 | 1. [`pyg`](./pyg): Distributed training via PyG's own `torch_geometric.distributed` package (deprecated). 7 | 1. [`graphlearn_for_pytorch`](./graphlearn_for_pytorch): Distributed training via the external [GraphLearn-for-PyTorch (GLT)](https://github.com/alibaba/graphlearn-for-pytorch) package. 8 | 1. [`kuzu`](./kuzu): Remote backend via the [Kùzu](https://kuzudb.com/) graph database. 9 | -------------------------------------------------------------------------------- /examples/jit/README.md: -------------------------------------------------------------------------------- 1 | # JIT Examples 2 | 3 | This directory contains examples demonstrating the use of Just-In-Time (JIT) compilation in different GNN models. 4 | 5 | | Example | Description | 6 | | ---------------------- | ----------------------------------------------------------------- | 7 | | [`gcn.py`](./gcn.py) | JIT compilation in `GCN` | 8 | | [`gat.py`](./gat.py) | JIT compilation in `GAT` | 9 | | [`gin.py`](./gin.py) | JIT compilation in `GIN` | 10 | | [`film.py`](./film.py) | JIT compilation in [`GNN-FiLM`](https://arxiv.org/abs/1906.12192) | 11 | -------------------------------------------------------------------------------- /examples/llm/nvtx_examples/nvtx_run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | # Check if the user provided a Python file 4 | if [ -z "$1" ]; then 5 | echo "Usage: $0 " 6 | exit 1 7 | fi 8 | 9 | # Check if the provided file exists 10 | if [[ ! -f "$1" ]]; then 11 | echo "Error: File '$1' does not exist." 12 | exit 1 13 | fi 14 | 15 | # Check if the provided file is a Python file 16 | if [[ ! "$1" == *.py ]]; then 17 | echo "Error: '$1' is not a Python file." 18 | exit 1 19 | fi 20 | 21 | # Get the base name of the Python file 22 | python_file=$(basename "$1") 23 | 24 | # Run nsys profile on the Python file 25 | nsys profile -c cudaProfilerApi --capture-range-end repeat -t cuda,nvtx,osrt,cudnn,cublas --cuda-memory-usage true --cudabacktrace all --force-overwrite true --output=profile_${python_file%.py} python "$1" 26 | 27 | echo "Profile data saved as profile_${python_file%.py}.nsys-rep" 28 | -------------------------------------------------------------------------------- /examples/llm/nvtx_examples/nvtx_webqsp_example.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import torch 4 | 5 | from torch_geometric.datasets import web_qsp_dataset 6 | from torch_geometric.profile import nvtxit 7 | 8 | # Apply Patches 9 | web_qsp_dataset.retrieval_via_pcst = nvtxit()( 10 | web_qsp_dataset.retrieval_via_pcst) 11 | web_qsp_dataset.WebQSPDataset.process = nvtxit()( 12 | web_qsp_dataset.WebQSPDataset.process) 13 | 14 | if __name__ == "__main__": 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument("--capture-torch-kernels", "-k", action="store_true") 17 | args = parser.parse_args() 18 | if args.capture_torch_kernels: 19 | with torch.autograd.profiler.emit_nvtx(): 20 | ds = web_qsp_dataset.WebQSPDataset('baseline', split='val') 21 | else: 22 | ds = web_qsp_dataset.WebQSPDataset('baseline', split='val') 23 | -------------------------------------------------------------------------------- /examples/pytorch_ignite/README.md: -------------------------------------------------------------------------------- 1 | # Examples for PyTorch Ignite 2 | 3 | This directory provides examples showcasing the integration of PyG with [PyTorch Ingite](https://pytorch.org/ignite/index.html). 4 | 5 | | Example | Description | 6 | | -------------------- | ---------------------------------------------------------------- | 7 | | [`gin.py`](./gin.py) | Demonstrates how to implement the GIN model using PyTorch Ignite | 8 | -------------------------------------------------------------------------------- /examples/pytorch_lightning/README.md: -------------------------------------------------------------------------------- 1 | # Examples for PyTorch Lightning 2 | 3 | This directory provides examples showcasing the integration of PyG with [PyTorch Lightning](https://github.com/Lightning-AI/pytorch-lightning). 4 | 5 | | Example | Description | 6 | | ------------------------------------------ | ------------------------------------------------------------------------------------ | 7 | | [`graph_sage.py`](./graph_sage.py) | Combines PyG and PyTorch Lightning for node classification via the `GraphSAGE` model | 8 | | [`gin.py`](./gin.py) | Combines PyG and PyTorch Lightning for graph classification via the `GIN` model | 9 | | [`relational_gnn.py`](./relational_gnn.py) | Combines PyG and PyTorch Lightning for heterogeneous node classification | 10 | -------------------------------------------------------------------------------- /graphgym/agg_batch.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from torch_geometric.graphgym.utils.agg_runs import agg_batch 4 | 5 | 6 | def parse_args(): 7 | """Parses the arguments.""" 8 | parser = argparse.ArgumentParser( 9 | description='Train a classification model') 10 | parser.add_argument('--dir', dest='dir', help='Dir for batch of results', 11 | required=True, type=str) 12 | parser.add_argument('--metric', dest='metric', 13 | help='metric to select best epoch', required=False, 14 | type=str, default='auto') 15 | return parser.parse_args() 16 | 17 | 18 | args = parse_args() 19 | agg_batch(args.dir, args.metric) 20 | -------------------------------------------------------------------------------- /graphgym/configs/example.yaml: -------------------------------------------------------------------------------- 1 | # The recommended basic settings for GNN 2 | out_dir: results 3 | dataset: 4 | format: PyG 5 | name: Cora 6 | task: node 7 | task_type: classification 8 | transductive: true 9 | split: [0.8, 0.2] 10 | transform: none 11 | train: 12 | batch_size: 32 13 | eval_period: 20 14 | ckpt_period: 100 15 | model: 16 | type: gnn 17 | loss_fun: cross_entropy 18 | edge_decoding: dot 19 | graph_pooling: add 20 | gnn: 21 | layers_pre_mp: 1 22 | layers_mp: 2 23 | layers_post_mp: 1 24 | dim_inner: 256 25 | layer_type: generalconv 26 | stage_type: stack 27 | batchnorm: true 28 | act: prelu 29 | dropout: 0.0 30 | agg: add 31 | normalize_adj: false 32 | optim: 33 | optimizer: adam 34 | base_lr: 0.01 35 | max_epoch: 400 36 | -------------------------------------------------------------------------------- /graphgym/configs/pyg/example_graph.yaml: -------------------------------------------------------------------------------- 1 | out_dir: results 2 | dataset: 3 | format: OGB 4 | name: ogbg-molhiv 5 | task: graph 6 | task_type: classification 7 | node_encoder: true 8 | node_encoder_name: Atom 9 | edge_encoder: true 10 | edge_encoder_name: Bond 11 | train: 12 | batch_size: 128 13 | eval_period: 1 14 | ckpt_period: 100 15 | sampler: full_batch 16 | model: 17 | type: gnn 18 | loss_fun: cross_entropy 19 | edge_decoding: dot 20 | graph_pooling: add 21 | gnn: 22 | layers_pre_mp: 1 23 | layers_mp: 2 24 | layers_post_mp: 1 25 | dim_inner: 300 26 | layer_type: generalconv 27 | stage_type: stack 28 | batchnorm: true 29 | act: prelu 30 | dropout: 0.0 31 | agg: mean 32 | normalize_adj: false 33 | optim: 34 | optimizer: adam 35 | base_lr: 0.01 36 | max_epoch: 100 37 | -------------------------------------------------------------------------------- /graphgym/configs/pyg/example_link.yaml: -------------------------------------------------------------------------------- 1 | out_dir: results 2 | dataset: 3 | format: OGB 4 | name: ogbl-collab 5 | task: link_pred 6 | task_type: classification 7 | node_encoder: false 8 | node_encoder_name: Atom 9 | edge_encoder: false 10 | edge_encoder_name: Bond 11 | train: 12 | batch_size: 128 13 | eval_period: 1 14 | ckpt_period: 100 15 | sampler: full_batch 16 | model: 17 | type: gnn 18 | loss_fun: cross_entropy 19 | edge_decoding: dot 20 | graph_pooling: add 21 | gnn: 22 | layers_pre_mp: 1 23 | layers_mp: 2 24 | layers_post_mp: 1 25 | dim_inner: 300 26 | layer_type: gcnconv 27 | stage_type: stack 28 | batchnorm: true 29 | act: prelu 30 | dropout: 0.0 31 | agg: mean 32 | normalize_adj: false 33 | optim: 34 | optimizer: adam 35 | base_lr: 0.01 36 | max_epoch: 100 37 | -------------------------------------------------------------------------------- /graphgym/configs/pyg/example_node.yaml: -------------------------------------------------------------------------------- 1 | out_dir: results 2 | dataset: 3 | format: PyG 4 | name: Cora 5 | task: node 6 | task_type: classification 7 | node_encoder: false 8 | node_encoder_name: Atom 9 | edge_encoder: false 10 | edge_encoder_name: Bond 11 | train: 12 | batch_size: 128 13 | eval_period: 1 14 | ckpt_period: 100 15 | sampler: full_batch 16 | model: 17 | type: gnn 18 | loss_fun: cross_entropy 19 | edge_decoding: dot 20 | graph_pooling: add 21 | gnn: 22 | layers_pre_mp: 0 23 | layers_mp: 2 24 | layers_post_mp: 1 25 | dim_inner: 16 26 | layer_type: gcnconv 27 | stage_type: stack 28 | batchnorm: false 29 | act: prelu 30 | dropout: 0.1 31 | agg: mean 32 | normalize_adj: false 33 | optim: 34 | optimizer: adam 35 | base_lr: 0.01 36 | max_epoch: 200 37 | -------------------------------------------------------------------------------- /graphgym/custom_graphgym/__init__.py: -------------------------------------------------------------------------------- 1 | from .act import * # noqa 2 | from .config import * # noqa 3 | from .encoder import * # noqa 4 | from .head import * # noqa 5 | from .layer import * # noqa 6 | from .loader import * # noqa 7 | from .loss import * # noqa 8 | from .network import * # noqa 9 | from .optimizer import * # noqa 10 | from .pooling import * # noqa 11 | from .stage import * # noqa 12 | from .train import * # noqa 13 | from .transform import * # noqa 14 | -------------------------------------------------------------------------------- /graphgym/custom_graphgym/act/__init__.py: -------------------------------------------------------------------------------- 1 | from os.path import dirname, basename, isfile, join 2 | import glob 3 | 4 | modules = glob.glob(join(dirname(__file__), "*.py")) 5 | __all__ = [ 6 | basename(f)[:-3] for f in modules 7 | if isfile(f) and not f.endswith('__init__.py') 8 | ] 9 | -------------------------------------------------------------------------------- /graphgym/custom_graphgym/act/example.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from torch_geometric.graphgym.config import cfg 7 | from torch_geometric.graphgym.register import register_act 8 | 9 | 10 | class SWISH(nn.Module): 11 | def __init__(self, inplace=False): 12 | super().__init__() 13 | self.inplace = inplace 14 | 15 | def forward(self, x): 16 | if self.inplace: 17 | x.mul_(torch.sigmoid(x)) 18 | return x 19 | else: 20 | return x * torch.sigmoid(x) 21 | 22 | 23 | register_act('swish', partial(SWISH, inplace=cfg.mem.inplace)) 24 | register_act('lrelu_03', partial(nn.LeakyReLU, 0.3, inplace=cfg.mem.inplace)) 25 | -------------------------------------------------------------------------------- /graphgym/custom_graphgym/config/__init__.py: -------------------------------------------------------------------------------- 1 | from os.path import dirname, basename, isfile, join 2 | import glob 3 | 4 | modules = glob.glob(join(dirname(__file__), "*.py")) 5 | __all__ = [ 6 | basename(f)[:-3] for f in modules 7 | if isfile(f) and not f.endswith('__init__.py') 8 | ] 9 | -------------------------------------------------------------------------------- /graphgym/custom_graphgym/config/example.py: -------------------------------------------------------------------------------- 1 | from yacs.config import CfgNode as CN 2 | 3 | from torch_geometric.graphgym.register import register_config 4 | 5 | 6 | @register_config('example') 7 | def set_cfg_example(cfg): 8 | r"""This function sets the default config value for customized options 9 | :return: customized configuration use by the experiment. 10 | """ 11 | # ----------------------------------------------------------------------- # 12 | # Customized options 13 | # ----------------------------------------------------------------------- # 14 | 15 | # example argument 16 | cfg.example_arg = 'example' 17 | 18 | # example argument group 19 | cfg.example_group = CN() 20 | 21 | # then argument can be specified within the group 22 | cfg.example_group.example_arg = 'example' 23 | -------------------------------------------------------------------------------- /graphgym/custom_graphgym/encoder/__init__.py: -------------------------------------------------------------------------------- 1 | from os.path import dirname, basename, isfile, join 2 | import glob 3 | 4 | modules = glob.glob(join(dirname(__file__), "*.py")) 5 | __all__ = [ 6 | basename(f)[:-3] for f in modules 7 | if isfile(f) and not f.endswith('__init__.py') 8 | ] 9 | -------------------------------------------------------------------------------- /graphgym/custom_graphgym/head/__init__.py: -------------------------------------------------------------------------------- 1 | from os.path import dirname, basename, isfile, join 2 | import glob 3 | 4 | modules = glob.glob(join(dirname(__file__), "*.py")) 5 | __all__ = [ 6 | basename(f)[:-3] for f in modules 7 | if isfile(f) and not f.endswith('__init__.py') 8 | ] 9 | -------------------------------------------------------------------------------- /graphgym/custom_graphgym/head/example.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from torch_geometric.graphgym.register import register_head 4 | 5 | 6 | @register_head('head') 7 | class ExampleNodeHead(nn.Module): 8 | r"""Head of GNN for node prediction.""" 9 | def __init__(self, dim_in, dim_out): 10 | super().__init__() 11 | self.layer_post_mp = nn.Linear(dim_in, dim_out, bias=True) 12 | 13 | def _apply_index(self, batch): 14 | if batch.node_label_index.shape[0] == batch.node_label.shape[0]: 15 | return batch.x[batch.node_label_index], batch.node_label 16 | else: 17 | return batch.x[batch.node_label_index], \ 18 | batch.node_label[batch.node_label_index] 19 | 20 | def forward(self, batch): 21 | batch = self.layer_post_mp(batch) 22 | pred, label = self._apply_index(batch) 23 | return pred, label 24 | -------------------------------------------------------------------------------- /graphgym/custom_graphgym/layer/__init__.py: -------------------------------------------------------------------------------- 1 | from os.path import dirname, basename, isfile, join 2 | import glob 3 | 4 | modules = glob.glob(join(dirname(__file__), "*.py")) 5 | __all__ = [ 6 | basename(f)[:-3] for f in modules 7 | if isfile(f) and not f.endswith('__init__.py') 8 | ] 9 | -------------------------------------------------------------------------------- /graphgym/custom_graphgym/loader/__init__.py: -------------------------------------------------------------------------------- 1 | from os.path import dirname, basename, isfile, join 2 | import glob 3 | 4 | modules = glob.glob(join(dirname(__file__), "*.py")) 5 | __all__ = [ 6 | basename(f)[:-3] for f in modules 7 | if isfile(f) and not f.endswith('__init__.py') 8 | ] 9 | -------------------------------------------------------------------------------- /graphgym/custom_graphgym/loader/example.py: -------------------------------------------------------------------------------- 1 | from torch_geometric.datasets import QM7b 2 | from torch_geometric.graphgym.register import register_loader 3 | 4 | 5 | @register_loader('example') 6 | def load_dataset_example(format, name, dataset_dir): 7 | dataset_dir = f'{dataset_dir}/{name}' 8 | if format == 'PyG': 9 | if name == 'QM7b': 10 | dataset_raw = QM7b(dataset_dir) 11 | return dataset_raw 12 | -------------------------------------------------------------------------------- /graphgym/custom_graphgym/loss/__init__.py: -------------------------------------------------------------------------------- 1 | from os.path import dirname, basename, isfile, join 2 | import glob 3 | 4 | modules = glob.glob(join(dirname(__file__), "*.py")) 5 | __all__ = [ 6 | basename(f)[:-3] for f in modules 7 | if isfile(f) and not f.endswith('__init__.py') 8 | ] 9 | -------------------------------------------------------------------------------- /graphgym/custom_graphgym/loss/example.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torch_geometric.graphgym.config import cfg 4 | from torch_geometric.graphgym.register import register_loss 5 | 6 | 7 | @register_loss('smoothl1') 8 | def loss_example(pred, true): 9 | if cfg.model.loss_fun == 'smoothl1': 10 | l1_loss = torch.nn.SmoothL1Loss() 11 | loss = l1_loss(pred, true) 12 | return loss, pred 13 | -------------------------------------------------------------------------------- /graphgym/custom_graphgym/network/__init__.py: -------------------------------------------------------------------------------- 1 | from os.path import dirname, basename, isfile, join 2 | import glob 3 | 4 | modules = glob.glob(join(dirname(__file__), "*.py")) 5 | __all__ = [ 6 | basename(f)[:-3] for f in modules 7 | if isfile(f) and not f.endswith('__init__.py') 8 | ] 9 | -------------------------------------------------------------------------------- /graphgym/custom_graphgym/optimizer/__init__.py: -------------------------------------------------------------------------------- 1 | from os.path import dirname, basename, isfile, join 2 | import glob 3 | 4 | modules = glob.glob(join(dirname(__file__), "*.py")) 5 | __all__ = [ 6 | basename(f)[:-3] for f in modules 7 | if isfile(f) and not f.endswith('__init__.py') 8 | ] 9 | -------------------------------------------------------------------------------- /graphgym/custom_graphgym/optimizer/example.py: -------------------------------------------------------------------------------- 1 | from typing import Iterator 2 | 3 | from torch.nn import Parameter 4 | from torch.optim import Adagrad, Optimizer 5 | from torch.optim.lr_scheduler import ReduceLROnPlateau 6 | 7 | import torch_geometric.graphgym.register as register 8 | 9 | 10 | @register.register_optimizer('adagrad') 11 | def adagrad_optimizer(params: Iterator[Parameter], base_lr: float, 12 | weight_decay: float) -> Adagrad: 13 | return Adagrad(params, lr=base_lr, weight_decay=weight_decay) 14 | 15 | 16 | @register.register_scheduler('pleateau') 17 | def plateau_scheduler(optimizer: Optimizer, patience: int, 18 | lr_decay: float) -> ReduceLROnPlateau: 19 | return ReduceLROnPlateau(optimizer, patience=patience, factor=lr_decay) 20 | -------------------------------------------------------------------------------- /graphgym/custom_graphgym/pooling/__init__.py: -------------------------------------------------------------------------------- 1 | from os.path import dirname, basename, isfile, join 2 | import glob 3 | 4 | modules = glob.glob(join(dirname(__file__), "*.py")) 5 | __all__ = [ 6 | basename(f)[:-3] for f in modules 7 | if isfile(f) and not f.endswith('__init__.py') 8 | ] 9 | -------------------------------------------------------------------------------- /graphgym/custom_graphgym/pooling/example.py: -------------------------------------------------------------------------------- 1 | from torch_geometric.graphgym.register import register_pooling 2 | from torch_geometric.utils import scatter 3 | 4 | 5 | @register_pooling('example') 6 | def global_example_pool(x, batch, size=None): 7 | size = batch.max().item() + 1 if size is None else size 8 | return scatter(x, batch, dim=0, dim_size=size, reduce='sum') 9 | -------------------------------------------------------------------------------- /graphgym/custom_graphgym/stage/__init__.py: -------------------------------------------------------------------------------- 1 | from os.path import dirname, basename, isfile, join 2 | import glob 3 | 4 | modules = glob.glob(join(dirname(__file__), "*.py")) 5 | __all__ = [ 6 | basename(f)[:-3] for f in modules 7 | if isfile(f) and not f.endswith('__init__.py') 8 | ] 9 | -------------------------------------------------------------------------------- /graphgym/custom_graphgym/train/__init__.py: -------------------------------------------------------------------------------- 1 | from os.path import dirname, basename, isfile, join 2 | import glob 3 | 4 | modules = glob.glob(join(dirname(__file__), "*.py")) 5 | __all__ = [ 6 | basename(f)[:-3] for f in modules 7 | if isfile(f) and not f.endswith('__init__.py') 8 | ] 9 | -------------------------------------------------------------------------------- /graphgym/custom_graphgym/transform/__init__.py: -------------------------------------------------------------------------------- 1 | from os.path import dirname, basename, isfile, join 2 | import glob 3 | 4 | modules = glob.glob(join(dirname(__file__), "*.py")) 5 | __all__ = [ 6 | basename(f)[:-3] for f in modules 7 | if isfile(f) and not f.endswith('__init__.py') 8 | ] 9 | -------------------------------------------------------------------------------- /graphgym/grids/example.txt: -------------------------------------------------------------------------------- 1 | # Format for each row: name in config.py; alias; range to search 2 | # No spaces, except between these 3 fields 3 | # Line breaks are used to union different grid search spaces 4 | # Feel free to add '#' to add comments 5 | 6 | 7 | # (1) dataset configurations 8 | dataset.format format ['PyG'] 9 | dataset.name dataset ['TU_ENZYMES','TU_PROTEINS'] 10 | dataset.task task ['graph'] 11 | dataset.transductive trans [False] 12 | # (2) The recommended GNN design space, 96 models in total 13 | gnn.layers_pre_mp l_pre [1,2] 14 | gnn.layers_mp l_mp [2,4,6,8] 15 | gnn.layers_post_mp l_post [2,3] 16 | gnn.stage_type stage ['skipsum','skipconcat'] 17 | gnn.agg agg ['add','mean','max'] 18 | -------------------------------------------------------------------------------- /graphgym/grids/pyg/example.txt: -------------------------------------------------------------------------------- 1 | # Format for each row: name in config.py; alias; range to search 2 | # No spaces, except between these 3 fields 3 | # Line breaks are used to union different grid search spaces 4 | # Feel free to add '#' to add comments 5 | 6 | 7 | gnn.layers_pre_mp l_pre [1,2] 8 | gnn.layers_mp l_mp [2,4,6] 9 | gnn.layers_post_mp l_post [1,2] 10 | gnn.stage_type stage ['stack','skipsum','skipconcat'] 11 | gnn.dim_inner dim [64] 12 | optim.base_lr lr [0.01] 13 | optim.max_epoch epoch [200] 14 | -------------------------------------------------------------------------------- /graphgym/parallel.sh: -------------------------------------------------------------------------------- 1 | CONFIG_DIR=$1 2 | REPEAT=$2 3 | MAX_JOBS=${3:-2} 4 | SLEEP=${4:-1} 5 | MAIN=${5:-main} 6 | 7 | ( 8 | trap 'kill 0' SIGINT 9 | CUR_JOBS=0 10 | for CONFIG in "$CONFIG_DIR"/*.yaml; do 11 | if [ "$CONFIG" != "$CONFIG_DIR/*.yaml" ]; then 12 | ((CUR_JOBS >= MAX_JOBS)) && wait -n 13 | python $MAIN.py --cfg $CONFIG --repeat $REPEAT --mark_done & 14 | echo $CONFIG 15 | sleep $SLEEP 16 | ((++CUR_JOBS)) 17 | fi 18 | done 19 | 20 | wait 21 | ) 22 | -------------------------------------------------------------------------------- /graphgym/run_single.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # Test for running a single experiment. --repeat means run how many different random seeds. 4 | python main.py --cfg configs/pyg/example_node.yaml --repeat 3 # node classification 5 | python main.py --cfg configs/pyg/example_link.yaml --repeat 3 # link prediction 6 | python main.py --cfg configs/pyg/example_graph.yaml --repeat 3 # graph classification 7 | -------------------------------------------------------------------------------- /graphgym/sample/dimensions.txt: -------------------------------------------------------------------------------- 1 | act bn drop agg l_mp l_pre l_post stage batch lr optim epoch 2 | -------------------------------------------------------------------------------- /graphgym/sample/dimensionsatt.txt: -------------------------------------------------------------------------------- 1 | l_tw 2 | -------------------------------------------------------------------------------- /readthedocs.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | 3 | sphinx: 4 | configuration: docs/source/conf.py 5 | 6 | build: 7 | os: ubuntu-24.04 8 | tools: 9 | python: "3.10" 10 | 11 | python: 12 | install: 13 | - requirements: docs/requirements.txt 14 | - method: pip 15 | path: . 16 | 17 | formats: [] 18 | -------------------------------------------------------------------------------- /test/datasets/graph_generator/test_ba_graph.py: -------------------------------------------------------------------------------- 1 | from torch_geometric.datasets.graph_generator import BAGraph 2 | 3 | 4 | def test_ba_graph(): 5 | graph_generator = BAGraph(num_nodes=300, num_edges=5) 6 | assert str(graph_generator) == 'BAGraph(num_nodes=300, num_edges=5)' 7 | 8 | data = graph_generator() 9 | assert len(data) == 2 10 | assert data.num_nodes == 300 11 | assert data.num_edges <= 2 * 300 * 5 12 | -------------------------------------------------------------------------------- /test/datasets/graph_generator/test_er_graph.py: -------------------------------------------------------------------------------- 1 | from torch_geometric.datasets.graph_generator import ERGraph 2 | 3 | 4 | def test_er_graph(): 5 | graph_generator = ERGraph(num_nodes=300, edge_prob=0.1) 6 | assert str(graph_generator) == 'ERGraph(num_nodes=300, edge_prob=0.1)' 7 | 8 | data = graph_generator() 9 | assert len(data) == 2 10 | assert data.num_nodes == 300 11 | assert data.num_edges >= 300 * 300 * 0.05 12 | assert data.num_edges <= 300 * 300 * 0.15 13 | -------------------------------------------------------------------------------- /test/datasets/graph_generator/test_grid_graph.py: -------------------------------------------------------------------------------- 1 | from torch_geometric.datasets.graph_generator import GridGraph 2 | 3 | 4 | def test_grid_graph(): 5 | graph_generator = GridGraph(height=10, width=10) 6 | assert str(graph_generator) == 'GridGraph(height=10, width=10)' 7 | 8 | data = graph_generator() 9 | assert len(data) == 2 10 | assert data.num_nodes == 100 11 | assert data.num_edges == 784 12 | -------------------------------------------------------------------------------- /test/datasets/graph_generator/test_tree_graph.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from torch_geometric.datasets.graph_generator import TreeGraph 4 | 5 | 6 | @pytest.mark.parametrize('undirected', [False, True]) 7 | def test_tree_graph(undirected): 8 | graph_generator = TreeGraph(depth=2, branch=2, undirected=undirected) 9 | assert str(graph_generator) == (f'TreeGraph(depth=2, branch=2, ' 10 | f'undirected={undirected})') 11 | 12 | data = graph_generator() 13 | assert len(data) == 3 14 | assert data.num_nodes == 7 15 | assert data.depth.tolist() == [0, 1, 1, 2, 2, 2, 2] 16 | if not undirected: 17 | assert data.edge_index.tolist() == [ 18 | [0, 0, 1, 1, 2, 2], 19 | [1, 2, 3, 4, 5, 6], 20 | ] 21 | else: 22 | assert data.edge_index.tolist() == [ 23 | [0, 0, 1, 1, 1, 2, 2, 2, 3, 4, 5, 6], 24 | [1, 2, 0, 3, 4, 0, 5, 6, 1, 1, 2, 2], 25 | ] 26 | -------------------------------------------------------------------------------- /test/datasets/motif_generator/test_cycle_motif.py: -------------------------------------------------------------------------------- 1 | from torch_geometric.datasets.motif_generator import CycleMotif 2 | 3 | 4 | def test_cycle_motif(): 5 | motif_generator = CycleMotif(5) 6 | assert str(motif_generator) == 'CycleMotif(5)' 7 | 8 | motif = motif_generator() 9 | assert len(motif) == 2 10 | assert motif.num_nodes == 5 11 | assert motif.num_edges == 10 12 | assert motif.edge_index.tolist() == [ 13 | [0, 0, 1, 1, 2, 2, 3, 3, 4, 4], 14 | [1, 4, 0, 2, 1, 3, 2, 4, 0, 3], 15 | ] 16 | -------------------------------------------------------------------------------- /test/datasets/motif_generator/test_grid_motif.py: -------------------------------------------------------------------------------- 1 | from torch_geometric.datasets.motif_generator import GridMotif 2 | 3 | 4 | def test_grid_motif(): 5 | motif_generator = GridMotif() 6 | assert str(motif_generator) == 'GridMotif()' 7 | 8 | motif = motif_generator() 9 | assert len(motif) == 3 10 | assert motif.num_nodes == 9 11 | assert motif.num_edges == 24 12 | assert motif.edge_index.size() == (2, 24) 13 | assert motif.edge_index.min() == 0 14 | assert motif.edge_index.max() == 8 15 | assert motif.y.size() == (9, ) 16 | assert motif.y.min() == 0 17 | assert motif.y.max() == 2 18 | -------------------------------------------------------------------------------- /test/datasets/motif_generator/test_house_motif.py: -------------------------------------------------------------------------------- 1 | from torch_geometric.datasets.motif_generator import HouseMotif 2 | 3 | 4 | def test_house_motif(): 5 | motif_generator = HouseMotif() 6 | assert str(motif_generator) == 'HouseMotif()' 7 | 8 | motif = motif_generator() 9 | assert len(motif) == 3 10 | assert motif.num_nodes == 5 11 | assert motif.num_edges == 12 12 | assert motif.y.min() == 0 and motif.y.max() == 2 13 | -------------------------------------------------------------------------------- /test/datasets/test_ba_shapes.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | 4 | def test_ba_shapes(get_dataset): 5 | with pytest.warns(UserWarning, match="is deprecated"): 6 | dataset = get_dataset(name='BAShapes') 7 | 8 | assert str(dataset) == 'BAShapes()' 9 | assert len(dataset) == 1 10 | assert dataset.num_features == 10 11 | assert dataset.num_classes == 4 12 | 13 | data = dataset[0] 14 | assert len(data) == 5 15 | assert data.edge_index.size(1) >= 1120 16 | assert data.x.size() == (700, 10) 17 | assert data.y.size() == (700, ) 18 | assert data.expl_mask.sum() == 60 19 | assert data.edge_label.sum() == 960 20 | -------------------------------------------------------------------------------- /test/datasets/test_bzr.py: -------------------------------------------------------------------------------- 1 | from torch_geometric.testing import onlyFullTest, onlyOnline 2 | 3 | 4 | @onlyOnline 5 | @onlyFullTest 6 | def test_bzr(get_dataset): 7 | dataset = get_dataset(name='BZR') 8 | assert len(dataset) == 405 9 | assert dataset.num_features == 53 10 | assert dataset.num_node_labels == 53 11 | assert dataset.num_node_attributes == 3 12 | assert dataset.num_classes == 2 13 | assert str(dataset) == 'BZR(405)' 14 | assert len(dataset[0]) == 3 15 | 16 | 17 | @onlyOnline 18 | @onlyFullTest 19 | def test_bzr_with_node_attr(get_dataset): 20 | dataset = get_dataset(name='BZR', use_node_attr=True) 21 | assert dataset.num_features == 56 22 | assert dataset.num_node_labels == 53 23 | assert dataset.num_node_attributes == 3 24 | -------------------------------------------------------------------------------- /test/datasets/test_git_mol_dataset.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import pytest 4 | 5 | from torch_geometric.datasets import GitMolDataset 6 | from torch_geometric.testing import onlyFullTest, withPackage 7 | 8 | 9 | @onlyFullTest 10 | @withPackage('torchvision', 'rdkit', 'PIL') 11 | @pytest.mark.parametrize('split', [ 12 | (0, 3610), 13 | (1, 451), 14 | (2, 451), 15 | ]) 16 | def test_git_mol_dataset(split: Tuple[int, int]) -> None: 17 | dataset = GitMolDataset(root='./data/GITMol', split=split[0]) 18 | 19 | assert len(dataset) == split[1] 20 | assert dataset[0].image.size() == (1, 3, 224, 224) 21 | assert dataset[0].num_node_features == 9 22 | assert dataset[0].num_edge_features == 3 23 | -------------------------------------------------------------------------------- /test/datasets/test_imdb_binary.py: -------------------------------------------------------------------------------- 1 | from torch_geometric.testing import onlyFullTest, onlyOnline 2 | 3 | 4 | @onlyOnline 5 | @onlyFullTest 6 | def test_imdb_binary(get_dataset): 7 | dataset = get_dataset(name='IMDB-BINARY') 8 | assert len(dataset) == 1000 9 | assert dataset.num_features == 0 10 | assert dataset.num_classes == 2 11 | assert str(dataset) == 'IMDB-BINARY(1000)' 12 | 13 | data = dataset[0] 14 | assert len(data) == 3 15 | assert data.edge_index.size() == (2, 146) 16 | assert data.y.size() == (1, ) 17 | assert data.num_nodes == 20 18 | -------------------------------------------------------------------------------- /test/datasets/test_karate.py: -------------------------------------------------------------------------------- 1 | def test_karate(get_dataset): 2 | dataset = get_dataset(name='KarateClub') 3 | assert str(dataset) == 'KarateClub()' 4 | assert len(dataset) == 1 5 | assert dataset.num_features == 34 6 | assert dataset.num_classes == 4 7 | 8 | assert len(dataset[0]) == 4 9 | assert dataset[0].edge_index.size() == (2, 156) 10 | assert dataset[0].x.size() == (34, 34) 11 | assert dataset[0].y.size() == (34, ) 12 | assert dataset[0].train_mask.size() == (34, ) 13 | assert dataset[0].train_mask.sum().item() == 4 14 | -------------------------------------------------------------------------------- /test/datasets/test_medshapenet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torch_geometric.data import Data 4 | from torch_geometric.datasets import MedShapeNet 5 | from torch_geometric.testing import withPackage 6 | 7 | 8 | @withPackage('MedShapeNet') 9 | def test_medshapenet(): 10 | dataset = MedShapeNet(root="./data/MedShapeNet", size=1) 11 | 12 | assert str(dataset) == f'MedShapeNet({len(dataset)})' 13 | 14 | assert isinstance(dataset[0], Data) 15 | assert dataset.num_classes == 8 16 | 17 | assert isinstance(dataset[0].pos, torch.Tensor) 18 | assert len(dataset[0].pos) > 0 19 | 20 | assert isinstance(dataset[0].face, torch.Tensor) 21 | assert len(dataset[0].face) == 3 22 | 23 | assert isinstance(dataset[0].y, torch.Tensor) 24 | assert len(dataset[0].y) == 1 25 | -------------------------------------------------------------------------------- /test/datasets/test_molecule_gpt_dataset.py: -------------------------------------------------------------------------------- 1 | from torch_geometric.datasets import MoleculeGPTDataset 2 | from torch_geometric.testing import onlyOnline, withPackage 3 | 4 | 5 | @onlyOnline 6 | @withPackage('transformers', 'sentencepiece', 'accelerate', 'rdkit') 7 | def test_molecule_gpt_dataset(): 8 | dataset = MoleculeGPTDataset( 9 | root='./data/MoleculeGPT', 10 | num_units=10, 11 | ) 12 | assert str(dataset) == f'MoleculeGPTDataset({len(dataset)})' 13 | assert dataset.num_edge_features == 4 14 | assert dataset.num_node_features == 6 15 | -------------------------------------------------------------------------------- /test/datasets/test_mutag.py: -------------------------------------------------------------------------------- 1 | from torch_geometric.testing import onlyOnline 2 | 3 | 4 | @onlyOnline 5 | def test_mutag(get_dataset): 6 | dataset = get_dataset(name='MUTAG') 7 | assert len(dataset) == 188 8 | assert dataset.num_features == 7 9 | assert dataset.num_classes == 2 10 | assert str(dataset) == 'MUTAG(188)' 11 | 12 | assert len(dataset[0]) == 4 13 | assert dataset[0].edge_attr.size(1) == 4 14 | 15 | 16 | @onlyOnline 17 | def test_mutag_with_node_attr(get_dataset): 18 | dataset = get_dataset(name='MUTAG', use_node_attr=True) 19 | assert dataset.num_features == 7 20 | -------------------------------------------------------------------------------- /test/datasets/test_protein_mpnn_dataset.py: -------------------------------------------------------------------------------- 1 | from torch_geometric.datasets import ProteinMPNNDataset 2 | from torch_geometric.testing import onlyOnline, withPackage 3 | 4 | 5 | @onlyOnline 6 | @withPackage('pandas') 7 | def test_protein_mpnn_dataset(): 8 | dataset = ProteinMPNNDataset(root='./data/ProteinMPNN') 9 | 10 | assert len(dataset) == 150 11 | assert dataset[0].x.size() == (229, 4, 3) 12 | assert dataset[0].chain_seq_label.size() == (229, ) 13 | assert dataset[0].mask.size() == (229, ) 14 | assert dataset[0].chain_mask_all.size() == (229, ) 15 | assert dataset[0].residue_idx.size() == (229, ) 16 | assert dataset[0].chain_encoding_all.size() == (229, ) 17 | -------------------------------------------------------------------------------- /test/datasets/test_suite_sparse.py: -------------------------------------------------------------------------------- 1 | from torch_geometric.testing import onlyFullTest, onlyOnline 2 | 3 | 4 | @onlyOnline 5 | @onlyFullTest 6 | def test_suite_sparse_dataset(get_dataset): 7 | dataset = get_dataset(group='DIMACS10', name='citationCiteseer') 8 | assert str(dataset) == ('SuiteSparseMatrixCollection(' 9 | 'group=DIMACS10, name=citationCiteseer)') 10 | assert len(dataset) == 1 11 | 12 | 13 | @onlyOnline 14 | @onlyFullTest 15 | def test_illc1850_suite_sparse_dataset(get_dataset): 16 | dataset = get_dataset(group='HB', name='illc1850') 17 | assert str(dataset) == ('SuiteSparseMatrixCollection(' 18 | 'group=HB, name=illc1850)') 19 | assert len(dataset) == 1 20 | -------------------------------------------------------------------------------- /test/datasets/test_tag_dataset.py: -------------------------------------------------------------------------------- 1 | from torch_geometric.datasets import TAGDataset 2 | from torch_geometric.testing import onlyFullTest, withPackage 3 | 4 | 5 | @onlyFullTest 6 | @withPackage('ogb') 7 | def test_tag_dataset() -> None: 8 | from ogb.nodeproppred import PygNodePropPredDataset 9 | 10 | root = './data/ogb' 11 | hf_model = 'prajjwal1/bert-tiny' 12 | token_on_disk = True 13 | 14 | dataset = PygNodePropPredDataset('ogbn-arxiv', root=root) 15 | tag_dataset = TAGDataset(root, dataset, hf_model, 16 | token_on_disk=token_on_disk) 17 | 18 | assert 169343 == tag_dataset[0].num_nodes \ 19 | == len(tag_dataset.text) \ 20 | == len(tag_dataset.llm_explanation) \ 21 | == len(tag_dataset.llm_prediction) 22 | assert 1166243 == tag_dataset[0].num_edges 23 | -------------------------------------------------------------------------------- /test/datasets/test_teeth3ds.py: -------------------------------------------------------------------------------- 1 | from torch_geometric.data import Data 2 | from torch_geometric.datasets import Teeth3DS 3 | from torch_geometric.testing import withPackage 4 | 5 | 6 | @withPackage('trimesh', 'fpsample') 7 | def test_teeth3ds(tmp_path) -> None: 8 | dataset = Teeth3DS(root=tmp_path, split='sample', train=True) 9 | 10 | assert len(dataset) > 0 11 | data = dataset[0] 12 | assert isinstance(data, Data) 13 | assert data.pos.size(1) == 3 14 | assert data.x.size(0) == data.pos.size(0) 15 | assert data.y.size(0) == data.pos.size(0) 16 | assert isinstance(data.jaw, str) 17 | -------------------------------------------------------------------------------- /test/graphgym/example_node.yml: -------------------------------------------------------------------------------- 1 | tensorboard_each_run: false 2 | tensorboard_agg: false 3 | dataset: 4 | format: PyG 5 | name: Cora 6 | task: node 7 | task_type: classification 8 | node_encoder: false 9 | node_encoder_name: Atom 10 | edge_encoder: false 11 | edge_encoder_name: Bond 12 | train: 13 | batch_size: 128 14 | eval_period: 2 15 | ckpt_period: 100 16 | enable_ckpt: false 17 | skip_train_eval: true 18 | sampler: full_batch 19 | model: 20 | type: gnn 21 | loss_fun: cross_entropy 22 | edge_decoding: dot 23 | graph_pooling: add 24 | gnn: 25 | layers_pre_mp: 2 26 | layers_mp: 2 27 | layers_post_mp: 1 28 | dim_inner: 16 29 | layer_type: gcnconv 30 | stage_type: stack 31 | batchnorm: false 32 | act: prelu 33 | dropout: 0.1 34 | agg: mean 35 | normalize_adj: false 36 | optim: 37 | optimizer: adam 38 | base_lr: 0.01 39 | max_epoch: 6 40 | -------------------------------------------------------------------------------- /test/graphgym/test_config.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | from torch_geometric.graphgym.config import from_config 4 | 5 | 6 | @dataclass 7 | class MyConfig: 8 | a: int 9 | b: int = 4 10 | 11 | 12 | def my_func(a: int, b: int = 2) -> str: 13 | return f'a={a},b={b}' 14 | 15 | 16 | def test_from_config(): 17 | assert my_func(a=1) == 'a=1,b=2' 18 | 19 | assert my_func.__name__ == from_config(my_func).__name__ 20 | assert from_config(my_func)(cfg=MyConfig(a=1)) == 'a=1,b=4' 21 | assert from_config(my_func)(cfg=MyConfig(a=1, b=1)) == 'a=1,b=1' 22 | assert from_config(my_func)(2, cfg=MyConfig(a=1, b=3)) == 'a=2,b=3' 23 | assert from_config(my_func)(cfg=MyConfig(a=1), b=3) == 'a=1,b=3' 24 | -------------------------------------------------------------------------------- /test/graphgym/test_logger.py: -------------------------------------------------------------------------------- 1 | from torch_geometric.graphgym.config import set_run_dir 2 | from torch_geometric.graphgym.loader import create_loader 3 | from torch_geometric.graphgym.logger import Logger, LoggerCallback 4 | from torch_geometric.testing import withPackage 5 | 6 | 7 | @withPackage('yacs', 'pytorch_lightning') 8 | def test_logger_callback(): 9 | loaders = create_loader() 10 | assert len(loaders) == 3 11 | 12 | set_run_dir('.') 13 | logger = LoggerCallback() 14 | assert isinstance(logger.train_logger, Logger) 15 | assert isinstance(logger.val_logger, Logger) 16 | assert isinstance(logger.test_logger, Logger) 17 | -------------------------------------------------------------------------------- /test/graphgym/test_register.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import torch_geometric.graphgym.register as register 4 | from torch_geometric.testing import withPackage 5 | 6 | 7 | @register.register_act('identity') 8 | def identity_act(x: torch.Tensor) -> torch.Tensor: 9 | return x 10 | 11 | 12 | @withPackage('yacs') 13 | def test_register(): 14 | assert len(register.act_dict) == 8 15 | assert list(register.act_dict.keys()) == [ 16 | 'relu', 'selu', 'prelu', 'elu', 'lrelu_01', 'lrelu_025', 'lrelu_05', 17 | 'identity' 18 | ] 19 | assert str(register.act_dict['relu']()) == 'ReLU()' 20 | 21 | register.register_act('lrelu_03', torch.nn.LeakyReLU(0.3)) 22 | assert len(register.act_dict) == 9 23 | assert 'lrelu_03' in register.act_dict 24 | -------------------------------------------------------------------------------- /test/io/example1.off: -------------------------------------------------------------------------------- 1 | OFF 2 | 4 2 0 3 | 0.0 0.0 0.0 4 | 0.0 1.0 0.0 5 | 1.0 0.0 0.0 6 | 1.0 1.0 0.0 7 | 3 0 1 2 8 | 3 1 2 3 9 | -------------------------------------------------------------------------------- /test/io/example2.off: -------------------------------------------------------------------------------- 1 | OFF 2 | 4 1 0 3 | 0.0 0.0 0.0 4 | 0.0 1.0 0.0 5 | 1.0 0.0 0.0 6 | 1.0 1.0 0.0 7 | 4 0 1 2 3 8 | -------------------------------------------------------------------------------- /test/llm/models/test_git_mol.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torch_geometric.llm.models import GITMol 4 | from torch_geometric.testing import withPackage 5 | 6 | 7 | @withPackage('transformers', 'sentencepiece', 'accelerate') 8 | def test_git_mol(): 9 | model = GITMol() 10 | 11 | x = torch.ones(10, 16, dtype=torch.long) 12 | edge_index = torch.tensor([ 13 | [0, 1, 2, 3, 4, 5, 6, 7, 8, 9], 14 | [1, 2, 3, 4, 0, 6, 7, 8, 9, 5], 15 | ]) 16 | edge_attr = torch.zeros(edge_index.size(1), 16, dtype=torch.long) 17 | batch = torch.tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0]) 18 | smiles = ['CC(C)([C@H]1CC2=C(O1)C=CC3=C2OC(=O)C=C3)O'] 19 | captions = ['The molecule is the (R)-(-)-enantiomer of columbianetin.'] 20 | images = torch.randn(1, 3, 224, 224) 21 | 22 | # Test train: 23 | loss = model(x, edge_index, batch, edge_attr, smiles, images, captions) 24 | assert loss >= 0 25 | -------------------------------------------------------------------------------- /test/llm/models/test_llm.py: -------------------------------------------------------------------------------- 1 | import gc 2 | 3 | import torch 4 | from torch import Tensor 5 | 6 | from torch_geometric.llm.models import LLM 7 | from torch_geometric.testing import onlyRAG, withPackage 8 | 9 | 10 | @onlyRAG 11 | @withPackage('transformers', 'accelerate') 12 | def test_llm() -> None: 13 | question = ["Is PyG the best open-source GNN library?"] 14 | answer = ["yes!"] 15 | 16 | model = LLM(model_name='Qwen/Qwen3-0.6B', num_params=1, 17 | dtype=torch.float16, 18 | sys_prompt="You're an agent, answer my questions.") 19 | assert str(model) == 'LLM(Qwen/Qwen3-0.6B)' 20 | 21 | loss = model(question, answer) 22 | assert isinstance(loss, Tensor) 23 | assert loss.dim() == 0 24 | assert loss >= 0.0 25 | 26 | pred = model.inference(question) 27 | assert len(pred) == 1 28 | del model 29 | gc.collect() 30 | torch.cuda.empty_cache() 31 | -------------------------------------------------------------------------------- /test/llm/models/test_protein_mpnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torch_geometric.llm.models import ProteinMPNN 4 | from torch_geometric.testing import withPackage 5 | 6 | 7 | @withPackage('torch_cluster') 8 | def test_protein_mpnn(): 9 | num_nodes = 10 10 | vocab_size = 21 11 | 12 | model = ProteinMPNN(vocab_size=vocab_size) 13 | x = torch.randn(num_nodes, 4, 3) 14 | chain_seq_label = torch.randint(0, vocab_size, (num_nodes, )) 15 | mask = torch.ones(num_nodes) 16 | chain_mask_all = torch.ones(num_nodes) 17 | residue_idx = torch.randint(0, 10, (num_nodes, )) 18 | chain_encoding_all = torch.ones(num_nodes) 19 | batch = torch.zeros(num_nodes, dtype=torch.long) 20 | 21 | logits = model(x, chain_seq_label, mask, chain_mask_all, residue_idx, 22 | chain_encoding_all, batch) 23 | assert logits.size() == (num_nodes, vocab_size) 24 | -------------------------------------------------------------------------------- /test/llm/models/test_vision_transformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torch_geometric.llm.models import VisionTransformer 4 | from torch_geometric.testing import onlyFullTest, withCUDA, withPackage 5 | 6 | 7 | @withCUDA 8 | @onlyFullTest 9 | @withPackage('transformers') 10 | def test_vision_transformer(device): 11 | model = VisionTransformer( 12 | model_name='microsoft/swin-base-patch4-window7-224', ).to(device) 13 | assert model.device == device 14 | assert str( 15 | model 16 | ) == 'VisionTransformer(model_name=microsoft/swin-base-patch4-window7-224)' 17 | 18 | images = torch.randn(2, 3, 224, 224).to(device) 19 | 20 | out = model(images) 21 | assert out.device == device 22 | assert out.size() == (2, 49, 1024) 23 | 24 | out = model(images, output_device='cpu') 25 | assert out.is_cpu 26 | assert out.size() == (2, 49, 1024) 27 | -------------------------------------------------------------------------------- /test/loader/test_utils.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | 4 | from torch_geometric.loader.utils import index_select 5 | 6 | 7 | def test_index_select(): 8 | x = torch.randn(3, 5) 9 | index = torch.tensor([0, 2]) 10 | assert torch.equal(index_select(x, index), x[index]) 11 | assert torch.equal(index_select(x, index, dim=-1), x[..., index]) 12 | 13 | 14 | def test_index_select_out_of_range(): 15 | with pytest.raises(IndexError, match="out of range"): 16 | index_select(torch.randn(3, 5), torch.tensor([0, 2, 3])) 17 | -------------------------------------------------------------------------------- /test/my_config.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - dataset: KarateClub 3 | - transform@dataset.transform: 4 | - NormalizeFeatures 5 | - AddSelfLoops 6 | - model: GCN 7 | - optimizer: Adam 8 | - lr_scheduler: ReduceLROnPlateau 9 | - _self_ 10 | 11 | model: 12 | in_channels: 34 13 | out_channels: 4 14 | hidden_channels: 16 15 | num_layers: 2 16 | -------------------------------------------------------------------------------- /test/nn/aggr/test_deep_sets.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torch_geometric.nn import DeepSetsAggregation, Linear 4 | 5 | 6 | def test_deep_sets_aggregation(): 7 | x = torch.randn(6, 16) 8 | index = torch.tensor([0, 0, 1, 1, 1, 2]) 9 | 10 | aggr = DeepSetsAggregation( 11 | local_nn=Linear(16, 32), 12 | global_nn=Linear(32, 64), 13 | ) 14 | aggr.reset_parameters() 15 | assert str(aggr) == ('DeepSetsAggregation(' 16 | 'local_nn=Linear(16, 32, bias=True), ' 17 | 'global_nn=Linear(32, 64, bias=True))') 18 | 19 | out = aggr(x, index) 20 | assert out.size() == (3, 64) 21 | -------------------------------------------------------------------------------- /test/nn/aggr/test_gmt.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torch_geometric.nn.aggr import GraphMultisetTransformer 4 | from torch_geometric.testing import is_full_test 5 | 6 | 7 | def test_graph_multiset_transformer(): 8 | x = torch.randn(6, 16) 9 | index = torch.tensor([0, 0, 1, 1, 1, 2]) 10 | 11 | aggr = GraphMultisetTransformer(16, k=2, heads=2) 12 | aggr.reset_parameters() 13 | assert str(aggr) == ('GraphMultisetTransformer(16, k=2, heads=2, ' 14 | 'layer_norm=False, dropout=0.0)') 15 | 16 | out = aggr(x, index) 17 | assert out.size() == (3, 16) 18 | 19 | if is_full_test(): 20 | jit = torch.jit.script(aggr) 21 | assert torch.allclose(jit(x, index), out) 22 | -------------------------------------------------------------------------------- /test/nn/aggr/test_gru.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torch_geometric.nn import GRUAggregation 4 | 5 | 6 | def test_gru_aggregation(): 7 | x = torch.randn(6, 16) 8 | index = torch.tensor([0, 0, 1, 1, 1, 2]) 9 | 10 | aggr = GRUAggregation(16, 32) 11 | assert str(aggr) == 'GRUAggregation(16, 32)' 12 | 13 | out = aggr(x, index) 14 | assert out.size() == (3, 32) 15 | -------------------------------------------------------------------------------- /test/nn/aggr/test_lstm.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | 4 | from torch_geometric.nn import LSTMAggregation 5 | 6 | 7 | def test_lstm_aggregation(): 8 | x = torch.randn(6, 16) 9 | index = torch.tensor([0, 0, 1, 1, 1, 2]) 10 | 11 | aggr = LSTMAggregation(16, 32) 12 | assert str(aggr) == 'LSTMAggregation(16, 32)' 13 | 14 | with pytest.raises(ValueError, match="is not sorted"): 15 | aggr(x, torch.tensor([0, 1, 0, 1, 2, 1])) 16 | 17 | out = aggr(x, index) 18 | assert out.size() == (3, 32) 19 | -------------------------------------------------------------------------------- /test/nn/aggr/test_mlp_aggr.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torch_geometric.nn import MLPAggregation 4 | 5 | 6 | def test_mlp_aggregation(): 7 | x = torch.randn(6, 16) 8 | index = torch.tensor([0, 0, 1, 1, 1, 2]) 9 | 10 | aggr = MLPAggregation( 11 | in_channels=16, 12 | out_channels=32, 13 | max_num_elements=3, 14 | num_layers=1, 15 | ) 16 | assert str(aggr) == 'MLPAggregation(16, 32, max_num_elements=3)' 17 | 18 | out = aggr(x, index) 19 | assert out.size() == (3, 32) 20 | -------------------------------------------------------------------------------- /test/nn/aggr/test_patch_transformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torch_geometric.nn import PatchTransformerAggregation 4 | from torch_geometric.testing import withCUDA 5 | 6 | 7 | @withCUDA 8 | def test_patch_transformer_aggregation(device: torch.device) -> None: 9 | aggr = PatchTransformerAggregation( 10 | in_channels=16, 11 | out_channels=32, 12 | patch_size=2, 13 | hidden_channels=8, 14 | num_transformer_blocks=1, 15 | heads=2, 16 | dropout=0.2, 17 | aggr=['sum', 'mean', 'min', 'max', 'var', 'std'], 18 | device=device, 19 | ) 20 | aggr.reset_parameters() 21 | assert str(aggr) == 'PatchTransformerAggregation(16, 32, patch_size=2)' 22 | 23 | index = torch.tensor([0, 0, 1, 1, 1, 2], device=device) 24 | x = torch.randn(index.size(0), 16, device=device) 25 | 26 | out = aggr(x, index) 27 | assert out.device == device 28 | assert out.size() == (3, aggr.out_channels) 29 | -------------------------------------------------------------------------------- /test/nn/aggr/test_scaler.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | 4 | from torch_geometric.nn import DegreeScalerAggregation 5 | 6 | 7 | @pytest.mark.parametrize('train_norm', [True, False]) 8 | def test_degree_scaler_aggregation(train_norm): 9 | x = torch.randn(6, 16) 10 | index = torch.tensor([0, 0, 1, 1, 1, 2]) 11 | ptr = torch.tensor([0, 2, 5, 6]) 12 | deg = torch.tensor([0, 3, 0, 1, 1, 0]) 13 | 14 | aggr = ['mean', 'sum', 'max'] 15 | scaler = [ 16 | 'identity', 'amplification', 'attenuation', 'linear', 'inverse_linear' 17 | ] 18 | aggr = DegreeScalerAggregation(aggr, scaler, deg, train_norm=train_norm) 19 | assert str(aggr) == 'DegreeScalerAggregation()' 20 | 21 | out = aggr(x, index) 22 | assert out.size() == (3, 240) 23 | assert torch.allclose(torch.jit.script(aggr)(x, index), out) 24 | 25 | with pytest.raises(NotImplementedError, match="requires 'index'"): 26 | aggr(x, ptr=ptr) 27 | -------------------------------------------------------------------------------- /test/nn/aggr/test_set2set.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torch_geometric.nn.aggr import Set2Set 4 | 5 | 6 | def test_set2set(): 7 | set2set = Set2Set(in_channels=2, processing_steps=1) 8 | assert str(set2set) == 'Set2Set(2, 4)' 9 | 10 | N = 4 11 | x_1, batch_1 = torch.randn(N, 2), torch.zeros(N, dtype=torch.long) 12 | out_1 = set2set(x_1, batch_1).view(-1) 13 | 14 | N = 6 15 | x_2, batch_2 = torch.randn(N, 2), torch.zeros(N, dtype=torch.long) 16 | out_2 = set2set(x_2, batch_2).view(-1) 17 | 18 | x, batch = torch.cat([x_1, x_2]), torch.cat([batch_1, batch_2 + 1]) 19 | out = set2set(x, batch) 20 | assert out.size() == (2, 4) 21 | assert torch.allclose(out_1, out[0]) 22 | assert torch.allclose(out_2, out[1]) 23 | 24 | x, batch = torch.cat([x_2, x_1]), torch.cat([batch_2, batch_1 + 1]) 25 | out = set2set(x, batch) 26 | assert out.size() == (2, 4) 27 | assert torch.allclose(out_1, out[1]) 28 | assert torch.allclose(out_2, out[0]) 29 | -------------------------------------------------------------------------------- /test/nn/attention/test_performer_attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torch_geometric.nn.attention import PerformerAttention 4 | 5 | 6 | def test_performer_attention(): 7 | x = torch.randn(1, 4, 16) 8 | mask = torch.ones([1, 4], dtype=torch.bool) 9 | attn = PerformerAttention(channels=16, heads=4) 10 | out = attn(x, mask) 11 | assert out.shape == (1, 4, 16) 12 | assert str(attn) == ('PerformerAttention(heads=4, ' 13 | 'head_channels=64 kernel=ReLU())') 14 | -------------------------------------------------------------------------------- /test/nn/attention/test_polynormer_attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torch_geometric.nn.attention import PolynormerAttention 4 | 5 | 6 | def test_performer_attention(): 7 | x = torch.randn(1, 4, 16) 8 | mask = torch.ones([1, 4], dtype=torch.bool) 9 | attn = PolynormerAttention(channels=16, heads=4) 10 | out = attn(x, mask) 11 | assert out.shape == (1, 4, 256) 12 | assert str(attn) == 'PolynormerAttention(heads=4, head_channels=64)' 13 | -------------------------------------------------------------------------------- /test/nn/attention/test_qformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torch_geometric.nn.attention import QFormer 4 | 5 | 6 | def test_qformer(): 7 | x = torch.randn(1, 4, 16) 8 | attn = QFormer(input_dim=16, hidden_dim=16, output_dim=32, num_heads=4, 9 | num_layers=2) 10 | out = attn(x) 11 | 12 | assert out.shape == (1, 4, 32) 13 | assert str(attn) == ('QFormer(num_heads=4, num_layers=2)') 14 | -------------------------------------------------------------------------------- /test/nn/conv/test_dir_gnn_conv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torch_geometric.nn import DirGNNConv, SAGEConv 4 | 5 | 6 | def test_dir_gnn_conv(): 7 | x = torch.randn(4, 16) 8 | edge_index = torch.tensor([[0, 1, 2], [1, 2, 3]]) 9 | 10 | conv = DirGNNConv(SAGEConv(16, 32)) 11 | assert str(conv) == 'DirGNNConv(SAGEConv(16, 32, aggr=mean), alpha=0.5)' 12 | 13 | out = conv(x, edge_index) 14 | assert out.size() == (4, 32) 15 | 16 | 17 | def test_static_dir_gnn_conv(): 18 | x = torch.randn(3, 4, 16) 19 | edge_index = torch.tensor([[0, 1, 2], [1, 2, 3]]) 20 | 21 | conv = DirGNNConv(SAGEConv(16, 32)) 22 | 23 | out = conv(x, edge_index) 24 | assert out.size() == (3, 4, 32) 25 | -------------------------------------------------------------------------------- /test/nn/conv/test_static_graph.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torch_geometric.data import Batch, Data 4 | from torch_geometric.nn import ChebConv, GCNConv, MessagePassing 5 | 6 | 7 | class MyConv(MessagePassing): 8 | def forward(self, x, edge_index): 9 | return self.propagate(edge_index, x=x) 10 | 11 | 12 | def test_static_graph(): 13 | edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]]) 14 | x1, x2 = torch.randn(3, 8), torch.randn(3, 8) 15 | 16 | data1 = Data(edge_index=edge_index, x=x1) 17 | data2 = Data(edge_index=edge_index, x=x2) 18 | batch = Batch.from_data_list([data1, data2]) 19 | 20 | x = torch.stack([x1, x2], dim=0) 21 | for conv in [MyConv(), GCNConv(8, 16), ChebConv(8, 16, K=2)]: 22 | out1 = conv(batch.x, batch.edge_index) 23 | assert out1.size(0) == 6 24 | conv.node_dim = 1 25 | out2 = conv(x, edge_index) 26 | assert out2.size()[:2] == (2, 3) 27 | assert torch.allclose(out1, out2.view(-1, out2.size(-1))) 28 | -------------------------------------------------------------------------------- /test/nn/conv/test_x_conv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torch_geometric.nn import XConv 4 | from torch_geometric.testing import is_full_test, withPackage 5 | 6 | 7 | @withPackage('torch_cluster') 8 | def test_x_conv(): 9 | x = torch.randn(8, 16) 10 | pos = torch.rand(8, 3) 11 | batch = torch.tensor([0, 0, 0, 0, 1, 1, 1, 1]) 12 | 13 | conv = XConv(16, 32, dim=3, kernel_size=2, dilation=2) 14 | assert str(conv) == 'XConv(16, 32)' 15 | 16 | torch.manual_seed(12345) 17 | out1 = conv(x, pos) 18 | assert out1.size() == (8, 32) 19 | 20 | torch.manual_seed(12345) 21 | out2 = conv(x, pos, batch) 22 | assert out2.size() == (8, 32) 23 | 24 | if is_full_test(): 25 | jit = torch.jit.script(conv) 26 | 27 | torch.manual_seed(12345) 28 | assert torch.allclose(jit(x, pos), out1, atol=1e-6) 29 | 30 | torch.manual_seed(12345) 31 | assert torch.allclose(jit(x, pos, batch), out2, atol=1e-6) 32 | -------------------------------------------------------------------------------- /test/nn/dense/test_dmon_pool.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | 5 | from torch_geometric.nn import DMoNPooling 6 | 7 | 8 | def test_dmon_pooling(): 9 | batch_size, num_nodes, channels, num_clusters = (2, 20, 16, 10) 10 | x = torch.randn((batch_size, num_nodes, channels)) 11 | adj = torch.ones((batch_size, num_nodes, num_nodes)) 12 | mask = torch.randint(0, 2, (batch_size, num_nodes), dtype=torch.bool) 13 | 14 | pool = DMoNPooling([channels, channels], num_clusters) 15 | assert str(pool) == 'DMoNPooling(16, num_clusters=10)' 16 | 17 | s, x, adj, spectral_loss, ortho_loss, cluster_loss = pool(x, adj, mask) 18 | assert s.size() == (2, 20, 10) 19 | assert x.size() == (2, 10, 16) 20 | assert adj.size() == (2, 10, 10) 21 | assert -1 <= spectral_loss <= 0.5 22 | assert 0 <= ortho_loss <= math.sqrt(2) 23 | assert 0 <= cluster_loss <= math.sqrt(num_clusters) - 1 24 | -------------------------------------------------------------------------------- /test/nn/functional/test_bro.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torch_geometric.nn.functional import bro 4 | 5 | 6 | def test_bro(): 7 | batch = torch.tensor([0, 0, 0, 0, 1, 1, 1, 2, 2]) 8 | 9 | g1 = torch.tensor([ 10 | [0.2, 0.2, 0.2, 0.2], 11 | [0.0, 0.2, 0.2, 0.2], 12 | [0.2, 0.0, 0.2, 0.2], 13 | [0.2, 0.2, 0.0, 0.2], 14 | ]) 15 | 16 | g2 = torch.tensor([ 17 | [0.2, 0.2, 0.2, 0.2], 18 | [0.0, 0.2, 0.2, 0.2], 19 | [0.2, 0.0, 0.2, 0.2], 20 | ]) 21 | 22 | g3 = torch.tensor([ 23 | [0.2, 0.2, 0.2, 0.2], 24 | [0.2, 0.0, 0.2, 0.2], 25 | ]) 26 | 27 | s = 0. 28 | for g in [g1, g2, g3]: 29 | s += torch.norm(g @ g.t() - torch.eye(g.shape[0]), p=2) 30 | 31 | assert torch.isclose(s / 3., bro(torch.cat([g1, g2, g3], dim=0), batch)) 32 | -------------------------------------------------------------------------------- /test/nn/functional/test_gini.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torch_geometric.nn.functional import gini 4 | 5 | 6 | def test_gini(): 7 | w = torch.tensor([[0., 0., 0., 0.], [0., 0., 0., 1000.0]]) 8 | assert torch.isclose(gini(w), torch.tensor(0.5)) 9 | -------------------------------------------------------------------------------- /test/nn/kge/test_distmult.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torch_geometric.nn import DistMult 4 | 5 | 6 | def test_distmult(): 7 | model = DistMult(num_nodes=10, num_relations=5, hidden_channels=32) 8 | assert str(model) == 'DistMult(10, num_relations=5, hidden_channels=32)' 9 | 10 | head_index = torch.tensor([0, 2, 4, 6, 8]) 11 | rel_type = torch.tensor([0, 1, 2, 3, 4]) 12 | tail_index = torch.tensor([1, 3, 5, 7, 9]) 13 | 14 | loader = model.loader(head_index, rel_type, tail_index, batch_size=5) 15 | for h, r, t in loader: 16 | out = model(h, r, t) 17 | assert out.size() == (5, ) 18 | 19 | loss = model.loss(h, r, t) 20 | assert loss >= 0. 21 | 22 | mean_rank, mrr, hits = model.test(h, r, t, batch_size=5, log=False) 23 | assert 0 <= mean_rank <= 10 24 | assert 0 < mrr <= 1 25 | assert hits == 1.0 26 | -------------------------------------------------------------------------------- /test/nn/kge/test_rotate.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torch_geometric.nn import RotatE 4 | 5 | 6 | def test_rotate(): 7 | model = RotatE(num_nodes=10, num_relations=5, hidden_channels=32) 8 | assert str(model) == 'RotatE(10, num_relations=5, hidden_channels=32)' 9 | 10 | head_index = torch.tensor([0, 2, 4, 6, 8]) 11 | rel_type = torch.tensor([0, 1, 2, 3, 4]) 12 | tail_index = torch.tensor([1, 3, 5, 7, 9]) 13 | 14 | loader = model.loader(head_index, rel_type, tail_index, batch_size=5) 15 | for h, r, t in loader: 16 | out = model(h, r, t) 17 | assert out.size() == (5, ) 18 | 19 | loss = model.loss(h, r, t) 20 | assert loss >= 0. 21 | 22 | mean_rank, mrr, hits = model.test(h, r, t, batch_size=5, log=False) 23 | assert 0 <= mean_rank <= 10 24 | assert 0 < mrr <= 1 25 | assert hits == 1.0 26 | -------------------------------------------------------------------------------- /test/nn/kge/test_transe.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torch_geometric.nn import TransE 4 | 5 | 6 | def test_transe(): 7 | model = TransE(num_nodes=10, num_relations=5, hidden_channels=32) 8 | assert str(model) == 'TransE(10, num_relations=5, hidden_channels=32)' 9 | 10 | head_index = torch.tensor([0, 2, 4, 6, 8]) 11 | rel_type = torch.tensor([0, 1, 2, 3, 4]) 12 | tail_index = torch.tensor([1, 3, 5, 7, 9]) 13 | 14 | loader = model.loader(head_index, rel_type, tail_index, batch_size=5) 15 | for h, r, t in loader: 16 | out = model(h, r, t) 17 | assert out.size() == (5, ) 18 | 19 | loss = model.loss(h, r, t) 20 | assert loss >= 0. 21 | 22 | mean_rank, mrr, hits = model.test(h, r, t, batch_size=5, log=False) 23 | assert 0 <= mean_rank <= 10 24 | assert 0 < mrr <= 1 25 | assert hits == 1.0 26 | -------------------------------------------------------------------------------- /test/nn/models/test_attentive_fp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torch_geometric.nn import AttentiveFP 4 | from torch_geometric.testing import is_full_test 5 | 6 | 7 | def test_attentive_fp(): 8 | model = AttentiveFP(8, 16, 32, edge_dim=3, num_layers=2, num_timesteps=2) 9 | assert str(model) == ('AttentiveFP(in_channels=8, hidden_channels=16, ' 10 | 'out_channels=32, edge_dim=3, num_layers=2, ' 11 | 'num_timesteps=2)') 12 | 13 | x = torch.randn(4, 8) 14 | edge_index = torch.tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]]) 15 | edge_attr = torch.randn(edge_index.size(1), 3) 16 | batch = torch.tensor([0, 0, 0, 0]) 17 | 18 | out = model(x, edge_index, edge_attr, batch) 19 | assert out.size() == (1, 32) 20 | 21 | if is_full_test(): 22 | jit = torch.jit.script(model) 23 | assert torch.allclose(jit(x, edge_index, edge_attr, batch), out) 24 | -------------------------------------------------------------------------------- /test/nn/models/test_deepgcn.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | from torch.nn import ReLU 4 | 5 | from torch_geometric.nn import DeepGCNLayer, GENConv, LayerNorm 6 | 7 | 8 | @pytest.mark.parametrize( 9 | 'block_tuple', 10 | [('res+', 1), ('res', 1), ('dense', 2), ('plain', 1)], 11 | ) 12 | @pytest.mark.parametrize('ckpt_grad', [True, False]) 13 | def test_deepgcn(block_tuple, ckpt_grad): 14 | block, expansion = block_tuple 15 | x = torch.randn(3, 8) 16 | edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]]) 17 | conv = GENConv(8, 8) 18 | norm = LayerNorm(8) 19 | act = ReLU() 20 | layer = DeepGCNLayer(conv, norm, act, block=block, ckpt_grad=ckpt_grad) 21 | assert str(layer) == f'DeepGCNLayer(block={block})' 22 | 23 | out = layer(x, edge_index) 24 | assert out.size() == (3, 8 * expansion) 25 | -------------------------------------------------------------------------------- /test/nn/models/test_gnnff.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torch_geometric.nn import GNNFF 4 | from torch_geometric.testing import is_full_test, withPackage 5 | 6 | 7 | @withPackage('torch_sparse') # TODO `triplet` requires `SparseTensor` for now. 8 | @withPackage('torch-cluster') 9 | def test_gnnff(): 10 | z = torch.randint(1, 10, (20, )) 11 | pos = torch.randn(20, 3) 12 | 13 | model = GNNFF( 14 | hidden_node_channels=5, 15 | hidden_edge_channels=5, 16 | num_layers=5, 17 | ) 18 | model.reset_parameters() 19 | 20 | out = model(z, pos) 21 | assert out.size() == (20, 3) 22 | 23 | if is_full_test(): 24 | jit = torch.jit.export(model) 25 | assert torch.allclose(jit(z, pos), out) 26 | -------------------------------------------------------------------------------- /test/nn/models/test_graph_unet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torch_geometric.nn import GraphUNet 4 | from torch_geometric.testing import is_full_test, onlyLinux 5 | 6 | 7 | @onlyLinux # TODO (matthias) Investigate CSR @ CSR support on Windows. 8 | def test_graph_unet(): 9 | model = GraphUNet(16, 32, 8, depth=3) 10 | out = 'GraphUNet(16, 32, 8, depth=3, pool_ratios=[0.5, 0.5, 0.5])' 11 | assert str(model) == out 12 | 13 | x = torch.randn(3, 16) 14 | edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]]) 15 | 16 | out = model(x, edge_index) 17 | assert out.size() == (3, 8) 18 | 19 | if is_full_test(): 20 | jit = torch.jit.export(model) 21 | out = jit(x, edge_index) 22 | assert out.size() == (3, 8) 23 | -------------------------------------------------------------------------------- /test/nn/models/test_mask_label.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torch_geometric.nn import MaskLabel 4 | 5 | 6 | def test_mask_label(): 7 | model = MaskLabel(2, 10) 8 | assert str(model) == 'MaskLabel()' 9 | 10 | x = torch.rand(4, 10) 11 | y = torch.tensor([1, 0, 1, 0]) 12 | mask = torch.tensor([False, False, True, True]) 13 | 14 | out = model(x, y, mask) 15 | assert out.size() == (4, 10) 16 | assert torch.allclose(out[~mask], x[~mask]) 17 | 18 | model = MaskLabel(2, 10, method='concat') 19 | out = model(x, y, mask) 20 | assert out.size() == (4, 20) 21 | assert torch.allclose(out[:, :10], x) 22 | 23 | 24 | def test_ratio_mask(): 25 | mask = torch.tensor([True, True, True, True, False, False, False, False]) 26 | out = MaskLabel.ratio_mask(mask, 0.5) 27 | assert out[:4].sum() <= 4 and out[4:].sum() == 0 28 | -------------------------------------------------------------------------------- /test/nn/models/test_pmlp.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | 4 | from torch_geometric.nn.models import PMLP 5 | 6 | 7 | def test_pmlp(): 8 | x = torch.randn(4, 16) 9 | edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]]) 10 | 11 | pmlp = PMLP(in_channels=16, hidden_channels=32, out_channels=2, 12 | num_layers=4) 13 | assert str(pmlp) == 'PMLP(16, 2, num_layers=4)' 14 | 15 | pmlp.training = True 16 | assert pmlp(x).size() == (4, 2) 17 | 18 | pmlp.training = False 19 | assert pmlp(x, edge_index).size() == (4, 2) 20 | 21 | with pytest.raises(ValueError, match="'edge_index' needs to be present"): 22 | pmlp.training = False 23 | pmlp(x) 24 | -------------------------------------------------------------------------------- /test/nn/models/test_sgformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torch_geometric.nn.models import SGFormer 4 | 5 | 6 | def test_sgformer(): 7 | x = torch.randn(10, 16) 8 | edge_index = torch.tensor([ 9 | [0, 1, 2, 3, 4, 5, 6, 7, 8, 9], 10 | [1, 2, 3, 4, 0, 6, 7, 8, 9, 5], 11 | ]) 12 | batch = torch.tensor([0, 0, 0, 0, 1, 1, 1, 1, 1, 1]) 13 | 14 | model = SGFormer( 15 | in_channels=16, 16 | hidden_channels=128, 17 | out_channels=40, 18 | ) 19 | out = model(x, edge_index, batch) 20 | assert out.size() == (10, 40) 21 | -------------------------------------------------------------------------------- /test/nn/models/test_visnet.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | 4 | from torch_geometric.nn import ViSNet 5 | from torch_geometric.testing import withPackage 6 | 7 | 8 | @withPackage('torch_cluster') 9 | @pytest.mark.parametrize('kwargs', [ 10 | dict(lmax=2, derivative=True, vecnorm_type=None, vertex=False), 11 | dict(lmax=1, derivative=False, vecnorm_type='max_min', vertex=True), 12 | ]) 13 | def test_visnet(kwargs): 14 | z = torch.randint(1, 10, (20, )) 15 | pos = torch.randn(20, 3) 16 | batch = torch.zeros(20, dtype=torch.long) 17 | 18 | model = ViSNet(**kwargs) 19 | 20 | model.reset_parameters() 21 | 22 | energy, forces = model(z, pos, batch) 23 | 24 | assert energy.size() == (1, 1) 25 | 26 | if kwargs['derivative']: 27 | assert forces.size() == (20, 3) 28 | else: 29 | assert forces is None 30 | -------------------------------------------------------------------------------- /test/nn/norm/test_graph_size_norm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torch_geometric.nn import GraphSizeNorm 4 | from torch_geometric.testing import is_full_test 5 | 6 | 7 | def test_graph_size_norm(): 8 | x = torch.randn(100, 16) 9 | batch = torch.repeat_interleave(torch.full((10, ), 10, dtype=torch.long)) 10 | 11 | norm = GraphSizeNorm() 12 | assert str(norm) == 'GraphSizeNorm()' 13 | 14 | out = norm(x, batch) 15 | assert out.size() == (100, 16) 16 | 17 | if is_full_test(): 18 | jit = torch.jit.script(norm) 19 | assert torch.allclose(jit(x, batch), out) 20 | -------------------------------------------------------------------------------- /test/nn/norm/test_mean_subtraction_norm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torch_geometric.nn import MeanSubtractionNorm 4 | from torch_geometric.testing import is_full_test 5 | 6 | 7 | def test_mean_subtraction_norm(): 8 | x = torch.randn(6, 16) 9 | batch = torch.tensor([0, 0, 1, 1, 1, 2]) 10 | 11 | norm = MeanSubtractionNorm() 12 | assert str(norm) == 'MeanSubtractionNorm()' 13 | 14 | if is_full_test(): 15 | torch.jit.script(norm) 16 | 17 | out = norm(x) 18 | assert out.size() == (6, 16) 19 | assert torch.allclose(out.mean(), torch.tensor(0.), atol=1e-6) 20 | 21 | out = norm(x, batch) 22 | assert out.size() == (6, 16) 23 | assert torch.allclose(out[0:2].mean(), torch.tensor(0.), atol=1e-6) 24 | assert torch.allclose(out[0:2].mean(), torch.tensor(0.), atol=1e-6) 25 | -------------------------------------------------------------------------------- /test/nn/norm/test_msg_norm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torch_geometric.nn import MessageNorm 4 | from torch_geometric.testing import is_full_test, withDevice 5 | 6 | 7 | @withDevice 8 | def test_message_norm(device): 9 | norm = MessageNorm(learn_scale=True, device=device) 10 | assert str(norm) == 'MessageNorm(learn_scale=True)' 11 | x = torch.randn(100, 16, device=device) 12 | msg = torch.randn(100, 16, device=device) 13 | out = norm(x, msg) 14 | assert out.size() == (100, 16) 15 | 16 | if is_full_test(): 17 | jit = torch.jit.script(norm) 18 | assert torch.allclose(jit(x, msg), out) 19 | 20 | norm = MessageNorm(learn_scale=False, device=device) 21 | assert str(norm) == 'MessageNorm(learn_scale=False)' 22 | out = norm(x, msg) 23 | assert out.size() == (100, 16) 24 | 25 | if is_full_test(): 26 | jit = torch.jit.script(norm) 27 | assert torch.allclose(jit(x, msg), out) 28 | -------------------------------------------------------------------------------- /test/nn/norm/test_pair_norm.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | 4 | from torch_geometric.nn import PairNorm 5 | from torch_geometric.testing import is_full_test 6 | 7 | 8 | @pytest.mark.parametrize('scale_individually', [False, True]) 9 | def test_pair_norm(scale_individually): 10 | x = torch.randn(100, 16) 11 | batch = torch.zeros(100, dtype=torch.long) 12 | 13 | norm = PairNorm(scale_individually=scale_individually) 14 | assert str(norm) == 'PairNorm()' 15 | 16 | if is_full_test(): 17 | torch.jit.script(norm) 18 | 19 | out1 = norm(x) 20 | assert out1.size() == (100, 16) 21 | 22 | out2 = norm(torch.cat([x, x], dim=0), torch.cat([batch, batch + 1], dim=0)) 23 | assert torch.allclose(out1, out2[:100], atol=1e-6) 24 | assert torch.allclose(out1, out2[100:], atol=1e-6) 25 | -------------------------------------------------------------------------------- /test/nn/pool/test_consecutive.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torch_geometric.nn.pool.consecutive import consecutive_cluster 4 | 5 | 6 | def test_consecutive_cluster(): 7 | src = torch.tensor([8, 2, 10, 15, 100, 1, 100]) 8 | 9 | out, perm = consecutive_cluster(src) 10 | assert out.tolist() == [2, 1, 3, 4, 5, 0, 5] 11 | assert perm.tolist() == [5, 1, 0, 2, 3, 6] 12 | -------------------------------------------------------------------------------- /test/nn/pool/test_graclus.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torch_geometric.nn import graclus 4 | from torch_geometric.testing import withPackage 5 | 6 | 7 | @withPackage('torch_cluster') 8 | def test_graclus(): 9 | edge_index = torch.tensor([[0, 1], [1, 0]]) 10 | assert graclus(edge_index).tolist() == [0, 0] 11 | -------------------------------------------------------------------------------- /test/nn/pool/test_mem_pool.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torch_geometric.nn import MemPooling 4 | from torch_geometric.utils import to_dense_batch 5 | 6 | 7 | def test_mem_pool(): 8 | mpool1 = MemPooling(4, 8, heads=3, num_clusters=2) 9 | assert str(mpool1) == 'MemPooling(4, 8, heads=3, num_clusters=2)' 10 | mpool2 = MemPooling(8, 4, heads=2, num_clusters=1) 11 | 12 | x = torch.randn(17, 4) 13 | batch = torch.tensor([0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 4, 4, 4]) 14 | _, mask = to_dense_batch(x, batch) 15 | 16 | out1, S = mpool1(x, batch) 17 | loss = MemPooling.kl_loss(S) 18 | with torch.autograd.set_detect_anomaly(True): 19 | loss.backward() 20 | out2, _ = mpool2(out1) 21 | 22 | assert out1.size() == (5, 2, 8) 23 | assert out2.size() == (5, 1, 4) 24 | assert S[~mask].sum() == 0 25 | assert round(S[mask].sum().item()) == x.size(0) 26 | assert float(loss) > 0 27 | -------------------------------------------------------------------------------- /test/nn/pool/test_pool.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torch 4 | from torch import Tensor 5 | 6 | from torch_geometric.nn import radius_graph 7 | from torch_geometric.testing import onlyFullTest, withPackage 8 | 9 | 10 | @onlyFullTest 11 | @withPackage('torch_cluster') 12 | def test_radius_graph_jit(): 13 | class Net(torch.nn.Module): 14 | def forward(self, x: Tensor, batch: Optional[Tensor] = None) -> Tensor: 15 | return radius_graph(x, r=2.5, batch=batch, loop=False) 16 | 17 | x = torch.tensor([[-1, -1], [-1, 1], [1, -1], [1, 1]], dtype=torch.float) 18 | batch = torch.tensor([0, 0, 0, 0]) 19 | 20 | model = Net() 21 | jit = torch.jit.script(model) 22 | assert model(x, batch).size() == jit(x, batch).size() 23 | -------------------------------------------------------------------------------- /test/nn/test_data_parallel.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | 4 | from torch_geometric.data import Data 5 | from torch_geometric.nn import DataParallel 6 | from torch_geometric.testing import onlyCUDA 7 | 8 | 9 | @onlyCUDA 10 | def test_data_parallel_single_gpu(): 11 | with pytest.warns(UserWarning, match="much slower"): 12 | module = DataParallel(torch.nn.Identity()) 13 | data_list = [Data(x=torch.randn(x, 1)) for x in [2, 3, 10, 4]] 14 | batches = module.scatter(data_list, device_ids=[0]) 15 | assert len(batches) == 1 16 | 17 | 18 | @onlyCUDA 19 | @pytest.mark.skipif(torch.cuda.device_count() < 2, reason='No multiple GPUs') 20 | def test_data_parallel_multi_gpu(): 21 | with pytest.warns(UserWarning, match="much slower"): 22 | module = DataParallel(torch.nn.Identity()) 23 | data_list = [Data(x=torch.randn(x, 1)) for x in [2, 3, 10, 4]] 24 | batches = module.scatter(data_list, device_ids=[0, 1, 0, 1]) 25 | assert len(batches) == 3 26 | -------------------------------------------------------------------------------- /test/nn/test_encoding.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torch_geometric.nn import PositionalEncoding, TemporalEncoding 4 | from torch_geometric.testing import withDevice 5 | 6 | 7 | @withDevice 8 | def test_positional_encoding(device): 9 | encoder = PositionalEncoding(64, device=device) 10 | assert str(encoder) == 'PositionalEncoding(64)' 11 | 12 | x = torch.tensor([1.0, 2.0, 3.0], device=device) 13 | assert encoder(x).size() == (3, 64) 14 | 15 | 16 | @withDevice 17 | def test_temporal_encoding(device): 18 | encoder = TemporalEncoding(64, device=device) 19 | assert str(encoder) == 'TemporalEncoding(64)' 20 | 21 | x = torch.tensor([1.0, 2.0, 3.0], device=device) 22 | assert encoder(x).size() == (3, 64) 23 | -------------------------------------------------------------------------------- /test/nn/test_fvcore.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torch_geometric.nn import GraphSAGE 4 | from torch_geometric.testing import get_random_edge_index, withPackage 5 | 6 | 7 | @withPackage('fvcore') 8 | def test_fvcore(): 9 | from fvcore.nn import FlopCountAnalysis 10 | 11 | x = torch.randn(10, 16) 12 | edge_index = get_random_edge_index(10, 10, num_edges=100) 13 | 14 | model = GraphSAGE(16, 32, num_layers=2) 15 | 16 | flops = FlopCountAnalysis(model, (x, edge_index)) 17 | 18 | # TODO (matthias) Currently, aggregations are not properly registered. 19 | assert flops.by_module()['convs.0'] == 2 * 10 * 16 * 32 20 | assert flops.by_module()['convs.1'] == 2 * 10 * 32 * 32 21 | assert flops.total() == (flops.by_module()['convs.0'] + 22 | flops.by_module()['convs.1']) 23 | assert flops.by_operator()['linear'] == flops.total() 24 | -------------------------------------------------------------------------------- /test/nn/test_fx.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import Tensor 4 | 5 | 6 | def test_dropout(): 7 | class MyModule(torch.nn.Module): 8 | def forward(self, x: Tensor) -> Tensor: 9 | return F.dropout(x, p=1.0, training=self.training) 10 | 11 | module = MyModule() 12 | graph_module = torch.fx.symbolic_trace(module) 13 | graph_module.recompile() 14 | 15 | x = torch.randn(4) 16 | 17 | graph_module.train() 18 | assert torch.allclose(graph_module(x), torch.zeros_like(x)) 19 | 20 | # This is certainly undesired behavior due to tracing :( 21 | graph_module.eval() 22 | assert torch.allclose(graph_module(x), torch.zeros_like(x)) 23 | -------------------------------------------------------------------------------- /test/nn/test_reshape.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torch_geometric.nn.reshape import Reshape 4 | 5 | 6 | def test_reshape(): 7 | x = torch.randn(10, 4) 8 | 9 | op = Reshape(5, 2, 4) 10 | assert str(op) == 'Reshape(5, 2, 4)' 11 | 12 | assert op(x).size() == (5, 2, 4) 13 | assert torch.equal(op(x).view(10, 4), x) 14 | -------------------------------------------------------------------------------- /test/nn/test_to_fixed_size_transformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torch_geometric.nn import SumAggregation 4 | from torch_geometric.nn.to_fixed_size_transformer import to_fixed_size 5 | 6 | 7 | class Model(torch.nn.Module): 8 | def __init__(self): 9 | super().__init__() 10 | self.aggr = SumAggregation() 11 | 12 | def forward(self, x, batch): 13 | return self.aggr(x, batch, dim=0) 14 | 15 | 16 | def test_to_fixed_size(): 17 | x = torch.randn(10, 16) 18 | batch = torch.zeros(10, dtype=torch.long) 19 | 20 | model = Model() 21 | assert model(x, batch).size() == (1, 16) 22 | 23 | model = to_fixed_size(model, batch_size=10) 24 | assert model(x, batch).size() == (10, 16) 25 | -------------------------------------------------------------------------------- /test/nn/unpool/test_knn_interpolate.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torch_geometric.nn import knn_interpolate 4 | from torch_geometric.testing import withPackage 5 | 6 | 7 | @withPackage('torch_cluster') 8 | def test_knn_interpolate(): 9 | x = torch.tensor([[1.0], [10.0], [100.0], [-1.0], [-10.0], [-100.0]]) 10 | pos_x = torch.tensor([ 11 | [-1.0, 0.0], 12 | [0.0, 0.0], 13 | [1.0, 0.0], 14 | [-2.0, 0.0], 15 | [0.0, 0.0], 16 | [2.0, 0.0], 17 | ]) 18 | pos_y = torch.tensor([ 19 | [-1.0, -1.0], 20 | [1.0, 1.0], 21 | [-2.0, -2.0], 22 | [2.0, 2.0], 23 | ]) 24 | batch_x = torch.tensor([0, 0, 0, 1, 1, 1]) 25 | batch_y = torch.tensor([0, 0, 1, 1]) 26 | 27 | y = knn_interpolate(x, pos_x, pos_y, batch_x, batch_y, k=2) 28 | assert y.tolist() == [[4.0], [70.0], [-4.0], [-70.0]] 29 | -------------------------------------------------------------------------------- /test/profile/test_benchmark.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torch_geometric.profile import benchmark 4 | from torch_geometric.testing import withPackage 5 | 6 | 7 | @withPackage('tabulate') 8 | def test_benchmark(capfd): 9 | def add(x, y): 10 | return x + y 11 | 12 | benchmark( 13 | funcs=[add], 14 | args=(torch.randn(10), torch.randn(10)), 15 | num_steps=1, 16 | num_warmups=1, 17 | backward=True, 18 | ) 19 | 20 | out, _ = capfd.readouterr() 21 | assert '| Name | Forward | Backward | Total |' in out 22 | assert '| add |' in out 23 | -------------------------------------------------------------------------------- /test/test_debug.py: -------------------------------------------------------------------------------- 1 | from torch_geometric import debug, is_debug_enabled, set_debug 2 | 3 | 4 | def test_debug(): 5 | assert is_debug_enabled() is False 6 | set_debug(True) 7 | assert is_debug_enabled() is True 8 | set_debug(False) 9 | assert is_debug_enabled() is False 10 | 11 | assert is_debug_enabled() is False 12 | with set_debug(True): 13 | assert is_debug_enabled() is True 14 | assert is_debug_enabled() is False 15 | 16 | assert is_debug_enabled() is False 17 | set_debug(True) 18 | assert is_debug_enabled() is True 19 | with set_debug(False): 20 | assert is_debug_enabled() is False 21 | assert is_debug_enabled() is True 22 | set_debug(False) 23 | assert is_debug_enabled() is False 24 | 25 | assert is_debug_enabled() is False 26 | with debug(): 27 | assert is_debug_enabled() is True 28 | assert is_debug_enabled() is False 29 | -------------------------------------------------------------------------------- /test/test_home.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | 4 | from torch_geometric import get_home_dir, set_home_dir 5 | from torch_geometric.home import DEFAULT_CACHE_DIR 6 | 7 | 8 | def test_home(): 9 | os.environ.pop('PYG_HOME', None) 10 | home_dir = osp.expanduser(DEFAULT_CACHE_DIR) 11 | assert get_home_dir() == home_dir 12 | 13 | home_dir = '/tmp/test_pyg1' 14 | os.environ['PYG_HOME'] = home_dir 15 | assert get_home_dir() == home_dir 16 | 17 | home_dir = '/tmp/test_pyg2' 18 | set_home_dir(home_dir) 19 | assert get_home_dir() == home_dir 20 | -------------------------------------------------------------------------------- /test/test_isinstance.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torch_geometric import is_torch_instance 4 | from torch_geometric.testing import onlyLinux, withPackage 5 | 6 | 7 | def test_basic(): 8 | assert is_torch_instance(torch.nn.Linear(1, 1), torch.nn.Linear) 9 | 10 | 11 | @onlyLinux 12 | @withPackage('torch>=2.0.0') 13 | def test_compile(): 14 | model = torch.compile(torch.nn.Linear(1, 1)) 15 | assert not isinstance(model, torch.nn.Linear) 16 | assert is_torch_instance(model, torch.nn.Linear) 17 | -------------------------------------------------------------------------------- /test/test_seed.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import numpy as np 4 | import torch 5 | 6 | from torch_geometric import seed_everything 7 | 8 | 9 | def test_seed_everything(): 10 | seed_everything(0) 11 | 12 | assert random.randint(0, 100) == 49 13 | assert random.randint(0, 100) == 97 14 | assert np.random.randint(0, 100) == 44 15 | assert np.random.randint(0, 100) == 47 16 | assert int(torch.randint(0, 100, (1, ))) == 44 17 | assert int(torch.randint(0, 100, (1, ))) == 39 18 | -------------------------------------------------------------------------------- /test/test_warnings.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from unittest.mock import patch 3 | 4 | import pytest 5 | 6 | from torch_geometric.warnings import WarningCache, warn 7 | 8 | 9 | def test_warn(): 10 | with pytest.warns(UserWarning, match='test'): 11 | warn('test') 12 | 13 | 14 | @patch('torch_geometric.is_compiling', return_value=True) 15 | def test_no_warn_if_compiling(_): 16 | """No warning should be raised to avoid graph breaks when compiling.""" 17 | with warnings.catch_warnings(): 18 | warnings.simplefilter('error') 19 | warn('test') 20 | 21 | 22 | def test_warning_cache(): 23 | cache = WarningCache() 24 | assert len(cache) == 0 25 | 26 | cache.warn('test') 27 | assert len(cache) == 1 28 | assert 'test' in cache 29 | 30 | cache.warn('test') 31 | assert len(cache) == 1 32 | 33 | cache.warn('test2') 34 | assert len(cache) == 2 35 | assert 'test2' in cache 36 | -------------------------------------------------------------------------------- /test/testing/test_decorators.py: -------------------------------------------------------------------------------- 1 | import torch_geometric.typing 2 | from torch_geometric.testing import disableExtensions 3 | 4 | 5 | def test_enable_extensions(): 6 | try: 7 | import pyg_lib # noqa 8 | assert torch_geometric.typing.WITH_PYG_LIB 9 | except (ImportError, OSError): 10 | assert not torch_geometric.typing.WITH_PYG_LIB 11 | 12 | try: 13 | import torch_scatter # noqa 14 | assert torch_geometric.typing.WITH_TORCH_SCATTER 15 | except (ImportError, OSError): 16 | assert not torch_geometric.typing.WITH_TORCH_SCATTER 17 | 18 | try: 19 | import torch_sparse # noqa 20 | assert torch_geometric.typing.WITH_TORCH_SPARSE 21 | except (ImportError, OSError): 22 | assert not torch_geometric.typing.WITH_TORCH_SPARSE 23 | 24 | 25 | @disableExtensions 26 | def test_disable_extensions(): 27 | assert not torch_geometric.typing.WITH_PYG_LIB 28 | assert not torch_geometric.typing.WITH_TORCH_SCATTER 29 | assert not torch_geometric.typing.WITH_TORCH_SPARSE 30 | -------------------------------------------------------------------------------- /test/transforms/test_add_gpse.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torch_geometric.data import Data 4 | from torch_geometric.nn import GPSE 5 | from torch_geometric.nn.models.gpse import IdentityHead 6 | from torch_geometric.transforms import AddGPSE 7 | 8 | num_nodes = 6 9 | gpse_inner_dim = 512 10 | 11 | 12 | def test_gpse(): 13 | x = torch.randn(num_nodes, 4) 14 | edge_index = torch.tensor([[0, 1, 0, 4, 1, 4, 2, 3, 3, 5], 15 | [1, 0, 4, 0, 4, 1, 3, 2, 5, 3]]) 16 | data = Data(x=x, edge_index=edge_index) 17 | 18 | model = GPSE() 19 | model.post_mp = IdentityHead() 20 | transform = AddGPSE(model) 21 | 22 | assert str(transform) == 'AddGPSE()' 23 | out = transform(data) 24 | assert out.pestat_GPSE.size() == (num_nodes, gpse_inner_dim) 25 | -------------------------------------------------------------------------------- /test/transforms/test_center.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torch_geometric.data import Data 4 | from torch_geometric.transforms import Center 5 | 6 | 7 | def test_center(): 8 | transform = Center() 9 | assert str(transform) == 'Center()' 10 | 11 | pos = torch.tensor([[0.0, 0.0], [2.0, 0.0], [4.0, 0.0]]) 12 | data = Data(pos=pos) 13 | 14 | data = transform(data) 15 | assert len(data) == 1 16 | assert data.pos.tolist() == [[-2, 0], [0, 0], [2, 0]] 17 | -------------------------------------------------------------------------------- /test/transforms/test_generate_mesh_normals.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torch_geometric.data import Data 4 | from torch_geometric.transforms import GenerateMeshNormals 5 | 6 | 7 | def test_generate_mesh_normals(): 8 | transform = GenerateMeshNormals() 9 | assert str(transform) == 'GenerateMeshNormals()' 10 | 11 | pos = torch.tensor([ 12 | [0.0, 0.0, 0.0], 13 | [-2.0, 1.0, 0.0], 14 | [-1.0, 1.0, 0.0], 15 | [0.0, 1.0, 0.0], 16 | [1.0, 1.0, 0.0], 17 | [2.0, 1.0, 0.0], 18 | ]) 19 | face = torch.tensor([ 20 | [0, 0, 0, 0], 21 | [1, 2, 3, 4], 22 | [2, 3, 4, 5], 23 | ]) 24 | 25 | data = transform(Data(pos=pos, face=face)) 26 | assert len(data) == 3 27 | assert data.pos.tolist() == pos.tolist() 28 | assert data.face.tolist() == face.tolist() 29 | assert data.norm.tolist() == [[0.0, 0.0, -1.0]] * 6 30 | -------------------------------------------------------------------------------- /test/transforms/test_grid_sampling.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torch_geometric.data import Data 4 | from torch_geometric.testing import withPackage 5 | from torch_geometric.transforms import GridSampling 6 | 7 | 8 | @withPackage('torch_cluster') 9 | def test_grid_sampling(): 10 | assert str(GridSampling(5)) == 'GridSampling(size=5)' 11 | 12 | pos = torch.tensor([ 13 | [0.0, 2.0], 14 | [3.0, 2.0], 15 | [3.0, 2.0], 16 | [2.0, 8.0], 17 | [2.0, 6.0], 18 | ]) 19 | y = torch.tensor([0, 1, 1, 2, 2]) 20 | batch = torch.tensor([0, 0, 0, 0, 0]) 21 | 22 | data = Data(pos=pos, y=y, batch=batch) 23 | data = GridSampling(size=5, start=0)(data) 24 | assert len(data) == 3 25 | assert data.pos.tolist() == [[2, 2], [2, 7]] 26 | assert data.y.tolist() == [1, 2] 27 | assert data.batch.tolist() == [0, 0] 28 | -------------------------------------------------------------------------------- /test/transforms/test_knn_graph.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torch_geometric.data import Data 4 | from torch_geometric.testing import withPackage 5 | from torch_geometric.transforms import KNNGraph 6 | 7 | 8 | @withPackage('torch_cluster') 9 | def test_knn_graph(): 10 | assert str(KNNGraph()) == 'KNNGraph(k=6)' 11 | 12 | pos = torch.tensor([ 13 | [0.0, 0.0], 14 | [1.0, 0.0], 15 | [2.0, 0.0], 16 | [0.0, 1.0], 17 | [-2.0, 0.0], 18 | [0.0, -2.0], 19 | ]) 20 | 21 | expected_row = [0, 0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 3, 3, 3, 4, 4, 5, 5] 22 | expected_col = [1, 2, 3, 4, 5, 0, 2, 3, 5, 0, 1, 0, 1, 4, 0, 3, 0, 1] 23 | 24 | data = Data(pos=pos) 25 | data = KNNGraph(k=2, force_undirected=True)(data) 26 | assert len(data) == 2 27 | assert data.pos.tolist() == pos.tolist() 28 | assert data.edge_index[0].tolist() == expected_row 29 | assert data.edge_index[1].tolist() == expected_col 30 | -------------------------------------------------------------------------------- /test/transforms/test_linear_transformation.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | 4 | from torch_geometric.data import Data 5 | from torch_geometric.transforms import LinearTransformation 6 | 7 | 8 | @pytest.mark.parametrize('matrix', [ 9 | [[2.0, 0.0], [0.0, 2.0]], 10 | torch.tensor([[2.0, 0.0], [0.0, 2.0]]), 11 | ]) 12 | def test_linear_transformation(matrix): 13 | pos = torch.tensor([[-1.0, 1.0], [-3.0, 0.0], [2.0, -1.0]]) 14 | 15 | transform = LinearTransformation(matrix) 16 | assert str(transform) == ('LinearTransformation(\n' 17 | '[[2. 0.]\n' 18 | ' [0. 2.]]\n' 19 | ')') 20 | 21 | out = transform(Data(pos=pos)) 22 | assert len(out) == 1 23 | assert torch.allclose(out.pos, 2 * pos) 24 | 25 | out = transform(Data()) 26 | assert len(out) == 0 27 | -------------------------------------------------------------------------------- /test/transforms/test_local_degree_profile.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torch_geometric.data import Data 4 | from torch_geometric.transforms import LocalDegreeProfile 5 | 6 | 7 | def test_target_indegree(): 8 | assert str(LocalDegreeProfile()) == 'LocalDegreeProfile()' 9 | 10 | edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]]) 11 | x = torch.tensor([[1.0], [1.0], [1.0], [1.0]]) # One isolated node. 12 | 13 | expected = torch.tensor([ 14 | [1, 2, 2, 2, 0], 15 | [2, 1, 1, 1, 0], 16 | [1, 2, 2, 2, 0], 17 | [0, 0, 0, 0, 0], 18 | ], dtype=torch.float) 19 | 20 | data = Data(edge_index=edge_index, num_nodes=x.size(0)) 21 | data = LocalDegreeProfile()(data) 22 | assert torch.allclose(data.x, expected, atol=1e-2) 23 | 24 | data = Data(edge_index=edge_index, x=x) 25 | data = LocalDegreeProfile()(data) 26 | assert torch.allclose(data.x[:, :1], x) 27 | assert torch.allclose(data.x[:, 1:], expected, atol=1e-2) 28 | -------------------------------------------------------------------------------- /test/transforms/test_normalize_features.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torch_geometric.data import Data, HeteroData 4 | from torch_geometric.transforms import NormalizeFeatures 5 | 6 | 7 | def test_normalize_scale(): 8 | transform = NormalizeFeatures() 9 | assert str(transform) == 'NormalizeFeatures()' 10 | 11 | x = torch.tensor([[1, 0, 1], [0, 1, 0], [0, 0, 0]], dtype=torch.float) 12 | data = Data(x=x) 13 | 14 | data = transform(data) 15 | assert len(data) == 1 16 | assert data.x.tolist() == [[0.5, 0, 0.5], [0, 1, 0], [0, 0, 0]] 17 | 18 | 19 | def test_hetero_normalize_scale(): 20 | x = torch.tensor([[1, 0, 1], [0, 1, 0], [0, 0, 0]], dtype=torch.float) 21 | 22 | data = HeteroData() 23 | data['v'].x = x 24 | data['w'].x = x 25 | data = NormalizeFeatures()(data) 26 | assert data['v'].x.tolist() == [[0.5, 0, 0.5], [0, 1, 0], [0, 0, 0]] 27 | assert data['w'].x.tolist() == [[0.5, 0, 0.5], [0, 1, 0], [0, 0, 0]] 28 | -------------------------------------------------------------------------------- /test/transforms/test_normalize_scale.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torch_geometric.data import Data 4 | from torch_geometric.transforms import NormalizeScale 5 | 6 | 7 | def test_normalize_scale(): 8 | transform = NormalizeScale() 9 | assert str(transform) == 'NormalizeScale()' 10 | 11 | pos = torch.randn((10, 3)) 12 | data = Data(pos=pos) 13 | 14 | data = transform(data) 15 | assert len(data) == 1 16 | assert data.pos.min().item() > -1 17 | assert data.pos.max().item() < 1 18 | -------------------------------------------------------------------------------- /test/transforms/test_radius_graph.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torch_geometric.data import Data 4 | from torch_geometric.testing import withPackage 5 | from torch_geometric.transforms import RadiusGraph 6 | from torch_geometric.utils import coalesce 7 | 8 | 9 | @withPackage('torch_cluster') 10 | def test_radius_graph(): 11 | assert str(RadiusGraph(r=1)) == 'RadiusGraph(r=1)' 12 | 13 | pos = torch.tensor([ 14 | [0.0, 0.0], 15 | [1.0, 0.0], 16 | [2.0, 0.0], 17 | [0.0, 1.0], 18 | [-2.0, 0.0], 19 | [0.0, -2.0], 20 | ]) 21 | 22 | data = Data(pos=pos) 23 | data = RadiusGraph(r=1.5)(data) 24 | assert len(data) == 2 25 | assert data.pos.tolist() == pos.tolist() 26 | assert coalesce(data.edge_index).tolist() == [[0, 0, 1, 1, 1, 2, 3, 3], 27 | [1, 3, 0, 2, 3, 1, 0, 1]] 28 | -------------------------------------------------------------------------------- /test/transforms/test_random_flip.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torch_geometric.data import Data 4 | from torch_geometric.transforms import RandomFlip 5 | 6 | 7 | def test_random_flip(): 8 | assert str(RandomFlip(axis=0)) == 'RandomFlip(axis=0, p=0.5)' 9 | 10 | pos = torch.tensor([[-1.0, 1.0], [-3.0, 0.0], [2.0, -1.0]]) 11 | 12 | data = Data(pos=pos) 13 | data = RandomFlip(axis=0, p=1)(data) 14 | assert len(data) == 1 15 | assert data.pos.tolist() == [[1.0, 1.0], [3.0, 0.0], [-2.0, -1.0]] 16 | 17 | data = Data(pos=pos) 18 | data = RandomFlip(axis=1, p=1)(data) 19 | assert len(data) == 1 20 | assert data.pos.tolist() == [[-1.0, -1.0], [-3.0, 0.0], [2.0, 1.0]] 21 | -------------------------------------------------------------------------------- /test/transforms/test_random_jitter.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torch_geometric.data import Data 4 | from torch_geometric.transforms import RandomJitter 5 | 6 | 7 | def test_random_jitter(): 8 | assert str(RandomJitter(0.1)) == 'RandomJitter(0.1)' 9 | 10 | pos = torch.tensor([[0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0]]) 11 | 12 | data = Data(pos=pos) 13 | data = RandomJitter(0)(data) 14 | assert len(data) == 1 15 | assert torch.allclose(data.pos, pos) 16 | 17 | data = Data(pos=pos) 18 | data = RandomJitter(0.1)(data) 19 | assert len(data) == 1 20 | assert data.pos.min() >= -0.1 21 | assert data.pos.max() <= 0.1 22 | 23 | data = Data(pos=pos) 24 | data = RandomJitter([0.1, 1])(data) 25 | assert len(data) == 1 26 | assert data.pos[:, 0].min() >= -0.1 27 | assert data.pos[:, 0].max() <= 0.1 28 | assert data.pos[:, 1].min() >= -1 29 | assert data.pos[:, 1].max() <= 1 30 | -------------------------------------------------------------------------------- /test/transforms/test_random_scale.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torch_geometric.data import Data 4 | from torch_geometric.transforms import RandomScale 5 | 6 | 7 | def test_random_scale(): 8 | assert str(RandomScale([1, 2])) == 'RandomScale([1, 2])' 9 | 10 | pos = torch.tensor([[-1.0, -1.0], [-1.0, 1.0], [1.0, -1.0], [1.0, 1.0]]) 11 | 12 | data = Data(pos=pos) 13 | data = RandomScale([1, 1])(data) 14 | assert len(data) == 1 15 | assert data.pos.tolist() == pos.tolist() 16 | 17 | data = Data(pos=pos) 18 | data = RandomScale([2, 2])(data) 19 | assert len(data) == 1 20 | assert data.pos.tolist() == [[-2, -2], [-2, 2], [2, -2], [2, 2]] 21 | -------------------------------------------------------------------------------- /test/transforms/test_random_shear.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torch_geometric.data import Data 4 | from torch_geometric.transforms import RandomShear 5 | 6 | 7 | def test_random_shear(): 8 | assert str(RandomShear(0.1)) == 'RandomShear(0.1)' 9 | 10 | pos = torch.tensor([[-1.0, -1.0], [-1.0, 1.0], [1.0, -1.0], [1.0, 1.0]]) 11 | 12 | data = Data(pos=pos) 13 | data = RandomShear(0)(data) 14 | assert len(data) == 1 15 | assert torch.allclose(data.pos, pos) 16 | 17 | data = Data(pos=pos) 18 | data = RandomShear(0.1)(data) 19 | assert len(data) == 1 20 | assert not torch.allclose(data.pos, pos) 21 | -------------------------------------------------------------------------------- /test/transforms/test_remove_duplicated_edges.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torch_geometric.data import Data 4 | from torch_geometric.transforms import RemoveDuplicatedEdges 5 | 6 | 7 | def test_remove_duplicated_edges(): 8 | edge_index = torch.tensor([[0, 0, 0, 0, 1, 1, 1, 1], 9 | [0, 0, 1, 1, 0, 0, 1, 1]]) 10 | edge_weight = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8]) 11 | data = Data(edge_index=edge_index, edge_weight=edge_weight, num_nodes=2) 12 | 13 | transform = RemoveDuplicatedEdges() 14 | assert str(transform) == 'RemoveDuplicatedEdges()' 15 | 16 | out = transform(data) 17 | assert len(out) == 3 18 | assert out.num_nodes == 2 19 | assert out.edge_index.tolist() == [[0, 0, 1, 1], [0, 1, 0, 1]] 20 | assert out.edge_weight.tolist() == [3, 7, 11, 15] 21 | -------------------------------------------------------------------------------- /test/transforms/test_remove_training_classes.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torch_geometric.data import Data 4 | from torch_geometric.transforms import RemoveTrainingClasses 5 | 6 | 7 | def test_remove_training_classes(): 8 | y = torch.tensor([1, 0, 0, 2, 1, 3]) 9 | train_mask = torch.tensor([False, False, True, True, True, True]) 10 | 11 | data = Data(y=y, train_mask=train_mask) 12 | 13 | transform = RemoveTrainingClasses(classes=[0, 1]) 14 | assert str(transform) == 'RemoveTrainingClasses([0, 1])' 15 | 16 | data = transform(data) 17 | assert len(data) == 2 18 | assert torch.equal(data.y, y) 19 | assert data.train_mask.tolist() == [False, False, False, True, False, True] 20 | -------------------------------------------------------------------------------- /test/transforms/test_sample_points.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torch_geometric.data import Data 4 | from torch_geometric.transforms import SamplePoints 5 | 6 | 7 | def test_sample_points(): 8 | assert str(SamplePoints(1024)) == 'SamplePoints(1024)' 9 | 10 | pos = torch.tensor([ 11 | [0.0, 0.0, 0.0], 12 | [1.0, 0.0, 0.0], 13 | [0.0, 1.0, 0.0], 14 | [1.0, 1.0, 0.0], 15 | ]) 16 | face = torch.tensor([[0, 1], [1, 2], [2, 3]]) 17 | 18 | data = Data(pos=pos) 19 | data.face = face 20 | data = SamplePoints(8)(data) 21 | assert len(data) == 1 22 | assert pos[:, 0].min() >= 0 and pos[:, 0].max() <= 1 23 | assert pos[:, 1].min() >= 0 and pos[:, 1].max() <= 1 24 | assert pos[:, 2].abs().sum() == 0 25 | 26 | data = Data(pos=pos) 27 | data.face = face 28 | data = SamplePoints(8, include_normals=True)(data) 29 | assert len(data) == 2 30 | assert data.normal[:, :2].abs().sum() == 0 31 | assert data.normal[:, 2].abs().sum() == 8 32 | -------------------------------------------------------------------------------- /test/transforms/test_sign.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torch_geometric.data import Data 4 | from torch_geometric.transforms import SIGN 5 | 6 | 7 | def test_sign(): 8 | x = torch.ones(5, 3) 9 | edge_index = torch.tensor([ 10 | [0, 1, 2, 3, 3, 4], 11 | [1, 0, 3, 2, 4, 3], 12 | ]) 13 | data = Data(x=x, edge_index=edge_index) 14 | 15 | transform = SIGN(K=2) 16 | assert str(transform) == 'SIGN(K=2)' 17 | 18 | expected_x1 = torch.tensor([ 19 | [1, 1, 1], 20 | [1, 1, 1], 21 | [0.7071, 0.7071, 0.7071], 22 | [1.4142, 1.4142, 1.4142], 23 | [0.7071, 0.7071, 0.7071], 24 | ]) 25 | expected_x2 = torch.ones(5, 3) 26 | 27 | out = transform(data) 28 | assert len(out) == 4 29 | assert torch.equal(out.edge_index, edge_index) 30 | assert torch.allclose(out.x, x) 31 | assert torch.allclose(out.x1, expected_x1, atol=1e-4) 32 | assert torch.allclose(out.x2, expected_x2) 33 | -------------------------------------------------------------------------------- /test/transforms/test_svd_feature_reduction.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torch_geometric.data import Data 4 | from torch_geometric.transforms import SVDFeatureReduction 5 | 6 | 7 | def test_svd_feature_reduction(): 8 | assert str(SVDFeatureReduction(10)) == 'SVDFeatureReduction(10)' 9 | 10 | x = torch.randn(4, 16) 11 | U, S, _ = torch.linalg.svd(x) 12 | data = Data(x=x) 13 | data = SVDFeatureReduction(10)(data) 14 | assert torch.allclose(data.x, torch.mm(U[:, :10], torch.diag(S[:10]))) 15 | 16 | x = torch.randn(4, 8) 17 | data.x = x 18 | data = SVDFeatureReduction(10)(Data(x=x)) 19 | assert torch.allclose(data.x, x) 20 | -------------------------------------------------------------------------------- /test/transforms/test_target_indegree.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torch_geometric.data import Data 4 | from torch_geometric.transforms import TargetIndegree 5 | 6 | 7 | def test_target_indegree(): 8 | assert str(TargetIndegree()) == 'TargetIndegree(norm=True, max_value=None)' 9 | 10 | edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]]) 11 | edge_attr = torch.tensor([1.0, 1.0, 1.0, 1.0]) 12 | 13 | data = Data(edge_index=edge_index, num_nodes=3) 14 | data = TargetIndegree(norm=False)(data) 15 | assert len(data) == 3 16 | assert data.edge_index.tolist() == edge_index.tolist() 17 | assert data.edge_attr.tolist() == [[2], [1], [1], [2]] 18 | assert data.num_nodes == 3 19 | 20 | data = Data(edge_index=edge_index, edge_attr=edge_attr, num_nodes=3) 21 | data = TargetIndegree(norm=True)(data) 22 | assert len(data) == 3 23 | assert data.edge_index.tolist() == edge_index.tolist() 24 | assert data.edge_attr.tolist() == [[1, 1], [1, 0.5], [1, 0.5], [1, 1]] 25 | assert data.num_nodes == 3 26 | -------------------------------------------------------------------------------- /test/transforms/test_to_device.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torch_geometric.data import Data 4 | from torch_geometric.testing import withDevice 5 | from torch_geometric.transforms import ToDevice 6 | 7 | 8 | @withDevice 9 | def test_to_device(device): 10 | x = torch.randn(3, 4) 11 | edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]]) 12 | edge_weight = torch.randn(edge_index.size(1)) 13 | 14 | data = Data(x=x, edge_index=edge_index, edge_weight=edge_weight) 15 | 16 | transform = ToDevice(device) 17 | assert str(transform) == f'ToDevice({device})' 18 | 19 | data = transform(data) 20 | for _, value in data: 21 | assert value.device == device 22 | -------------------------------------------------------------------------------- /test/utils/test_degree.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torch_geometric.utils import degree 4 | 5 | 6 | def test_degree(): 7 | row = torch.tensor([0, 1, 0, 2, 0]) 8 | deg = degree(row, dtype=torch.long) 9 | assert deg.dtype == torch.long 10 | assert deg.tolist() == [3, 1, 1] 11 | -------------------------------------------------------------------------------- /test/utils/test_functions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torch_geometric.utils import cumsum 4 | 5 | 6 | def test_cumsum(): 7 | x = torch.tensor([2, 4, 1]) 8 | assert cumsum(x).tolist() == [0, 2, 6, 7] 9 | 10 | x = torch.tensor([[2, 4], [3, 6]]) 11 | assert cumsum(x, dim=1).tolist() == [[0, 2, 6], [0, 3, 9]] 12 | -------------------------------------------------------------------------------- /test/utils/test_grid.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torch_geometric.testing import is_full_test 4 | from torch_geometric.utils import grid 5 | 6 | 7 | def test_grid(): 8 | (row, col), pos = grid(height=3, width=2) 9 | 10 | expected_row = [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2] 11 | expected_col = [0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 4, 5] 12 | expected_row += [3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 5, 5, 5, 5] 13 | expected_col += [0, 1, 2, 3, 4, 5, 2, 3, 4, 5, 2, 3, 4, 5] 14 | 15 | expected_pos = [[0, 2], [1, 2], [0, 1], [1, 1], [0, 0], [1, 0]] 16 | 17 | assert row.tolist() == expected_row 18 | assert col.tolist() == expected_col 19 | assert pos.tolist() == expected_pos 20 | 21 | if is_full_test(): 22 | jit = torch.jit.script(grid) 23 | (row, col), pos = jit(height=3, width=2) 24 | assert row.tolist() == expected_row 25 | assert col.tolist() == expected_col 26 | assert pos.tolist() == expected_pos 27 | -------------------------------------------------------------------------------- /test/utils/test_index_sort.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torch_geometric.testing import withDevice 4 | from torch_geometric.utils import index_sort 5 | 6 | 7 | @withDevice 8 | def test_index_sort_stable(device): 9 | for _ in range(100): 10 | inputs = torch.randint(0, 4, size=(10, ), device=device) 11 | 12 | out = index_sort(inputs, stable=True) 13 | expected = torch.sort(inputs, stable=True) 14 | 15 | assert torch.equal(out[0], expected[0]) 16 | assert torch.equal(out[1], expected[1]) 17 | -------------------------------------------------------------------------------- /test/utils/test_lexsort.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | from torch_geometric.utils import lexsort 5 | 6 | 7 | def test_lexsort(): 8 | keys = [torch.randn(100) for _ in range(3)] 9 | 10 | expected = np.lexsort([key.numpy() for key in keys]) 11 | assert torch.equal(lexsort(keys), torch.from_numpy(expected)) 12 | -------------------------------------------------------------------------------- /test/utils/test_mask.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torch_geometric.utils import index_to_mask, mask_select, mask_to_index 4 | 5 | 6 | def test_mask_select(): 7 | src = torch.randn(6, 8) 8 | mask = torch.tensor([False, True, False, True, False, True]) 9 | 10 | out = mask_select(src, 0, mask) 11 | assert out.size() == (3, 8) 12 | assert torch.equal(src[torch.tensor([1, 3, 5])], out) 13 | 14 | jit = torch.jit.script(mask_select) 15 | assert torch.equal(jit(src, 0, mask), out) 16 | 17 | 18 | def test_index_to_mask(): 19 | index = torch.tensor([1, 3, 5]) 20 | 21 | mask = index_to_mask(index) 22 | assert mask.tolist() == [False, True, False, True, False, True] 23 | 24 | mask = index_to_mask(index, size=7) 25 | assert mask.tolist() == [False, True, False, True, False, True, False] 26 | 27 | 28 | def test_mask_to_index(): 29 | mask = torch.tensor([False, True, False, True, False, True]) 30 | 31 | index = mask_to_index(mask) 32 | assert index.tolist() == [1, 3, 5] 33 | -------------------------------------------------------------------------------- /test/utils/test_noise_scheduler.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | 4 | from torch_geometric.utils.noise_scheduler import ( 5 | get_diffusion_beta_schedule, 6 | get_smld_sigma_schedule, 7 | ) 8 | 9 | 10 | def test_get_smld_sigma_schedule(): 11 | expected = torch.tensor([ 12 | 1., 0.59948425, 0.35938137, 0.21544347, 0.12915497, 0.07742637, 13 | 0.04641589, 0.02782559, 0.01668101, 0.01 14 | ]) 15 | out = get_smld_sigma_schedule( 16 | sigma_min=0.01, 17 | sigma_max=1.0, 18 | num_scales=10, 19 | ) 20 | assert torch.allclose(out, expected) 21 | 22 | 23 | @pytest.mark.parametrize( 24 | 'schedule_type', 25 | ['linear', 'quadratic', 'constant', 'sigmoid'], 26 | ) 27 | def test_get_diffusion_beta_schedule(schedule_type): 28 | out = get_diffusion_beta_schedule( 29 | schedule_type, 30 | beta_start=0.1, 31 | beta_end=0.2, 32 | num_diffusion_timesteps=10, 33 | ) 34 | assert out.size() == (10, ) 35 | -------------------------------------------------------------------------------- /test/utils/test_normalize_edge_index.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | 4 | from torch_geometric.utils import normalize_edge_index 5 | 6 | 7 | @pytest.mark.parametrize('add_self_loops', [False, True]) 8 | @pytest.mark.parametrize('symmetric', [False, True]) 9 | def test_normalize_edge_index(add_self_loops: bool, symmetric: bool): 10 | edge_index = torch.tensor([[0, 2, 2, 3], [2, 0, 3, 0]]) 11 | 12 | out = normalize_edge_index( 13 | edge_index, 14 | add_self_loops=add_self_loops, 15 | symmetric=symmetric, 16 | ) 17 | assert isinstance(out, tuple) and len(out) == 2 18 | if not add_self_loops: 19 | assert out[0].equal(edge_index) 20 | else: 21 | assert out[0].tolist() == [ 22 | [0, 2, 2, 3, 0, 1, 2, 3], 23 | [2, 0, 3, 0, 0, 1, 2, 3], 24 | ] 25 | 26 | assert out[1].min() >= 0.0 27 | assert out[1].min() <= 1.0 28 | -------------------------------------------------------------------------------- /test/utils/test_normalized_cut.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torch_geometric.testing import is_full_test 4 | from torch_geometric.utils import normalized_cut 5 | 6 | 7 | def test_normalized_cut(): 8 | row = torch.tensor([0, 1, 1, 1, 2, 2, 3, 3, 4, 4]) 9 | col = torch.tensor([1, 0, 2, 3, 1, 4, 1, 4, 2, 3]) 10 | edge_attr = torch.tensor( 11 | [3.0, 3.0, 6.0, 3.0, 6.0, 1.0, 3.0, 2.0, 1.0, 2.0]) 12 | expected = torch.tensor([4.0, 4.0, 5.0, 2.5, 5.0, 1.0, 2.5, 2.0, 1.0, 2.0]) 13 | 14 | out = normalized_cut(torch.stack([row, col], dim=0), edge_attr) 15 | assert torch.allclose(out, expected) 16 | 17 | if is_full_test(): 18 | jit = torch.jit.script(normalized_cut) 19 | out = jit(torch.stack([row, col], dim=0), edge_attr) 20 | assert torch.allclose(out, expected) 21 | -------------------------------------------------------------------------------- /test/utils/test_num_nodes.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torch_geometric.utils import to_torch_coo_tensor 4 | from torch_geometric.utils.num_nodes import ( 5 | maybe_num_nodes, 6 | maybe_num_nodes_dict, 7 | ) 8 | 9 | 10 | def test_maybe_num_nodes(): 11 | edge_index = torch.tensor([[0, 0, 1, 1, 2, 2, 2], [1, 2, 0, 2, 0, 1, 1]]) 12 | 13 | assert maybe_num_nodes(edge_index, 4) == 4 14 | assert maybe_num_nodes(edge_index) == 3 15 | 16 | adj = to_torch_coo_tensor(edge_index) 17 | assert maybe_num_nodes(adj, 4) == 4 18 | assert maybe_num_nodes(adj) == 3 19 | 20 | 21 | def test_maybe_num_nodes_dict(): 22 | edge_index_dict = { 23 | '1': torch.tensor([[0, 0, 1, 1, 2, 2, 2], [1, 2, 0, 2, 0, 1, 1]]), 24 | '2': torch.tensor([[0, 0, 1, 3], [1, 2, 0, 4]]) 25 | } 26 | num_nodes_dict = {'2': 6} 27 | 28 | assert maybe_num_nodes_dict(edge_index_dict) == {'1': 3, '2': 5} 29 | assert maybe_num_nodes_dict(edge_index_dict, num_nodes_dict) == { 30 | '1': 3, 31 | '2': 6, 32 | } 33 | -------------------------------------------------------------------------------- /test/utils/test_one_hot.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torch_geometric.utils import one_hot 4 | 5 | 6 | def test_one_hot(): 7 | index = torch.tensor([0, 1, 2]) 8 | 9 | out = one_hot(index) 10 | assert out.size() == (3, 3) 11 | assert out.dtype == torch.float 12 | assert out.tolist() == [[1, 0, 0], [0, 1, 0], [0, 0, 1]] 13 | 14 | out = one_hot(index, num_classes=4, dtype=torch.long) 15 | assert out.size() == (3, 4) 16 | assert out.dtype == torch.long 17 | assert out.tolist() == [[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0]] 18 | -------------------------------------------------------------------------------- /test/utils/test_ppr.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | 4 | from torch_geometric.datasets import KarateClub 5 | from torch_geometric.testing import withPackage 6 | from torch_geometric.utils import get_ppr 7 | 8 | 9 | @withPackage('numba') 10 | @pytest.mark.parametrize('target', [None, torch.tensor([0, 4, 5, 6])]) 11 | def test_get_ppr(target): 12 | data = KarateClub()[0] 13 | 14 | edge_index, edge_weight = get_ppr( 15 | data.edge_index, 16 | alpha=0.1, 17 | eps=1e-5, 18 | target=target, 19 | ) 20 | 21 | assert edge_index.size(0) == 2 22 | assert edge_index.size(1) == edge_weight.numel() 23 | 24 | min_row = 0 if target is None else target.min() 25 | max_row = data.num_nodes - 1 if target is None else target.max() 26 | assert edge_index[0].min() == min_row and edge_index[0].max() == max_row 27 | assert edge_index[1].min() >= 0 and edge_index[1].max() < data.num_nodes 28 | assert edge_weight.min() >= 0.0 and edge_weight.max() <= 1.0 29 | -------------------------------------------------------------------------------- /test/utils/test_repeat.py: -------------------------------------------------------------------------------- 1 | from torch_geometric.utils.repeat import repeat 2 | 3 | 4 | def test_repeat(): 5 | assert repeat(None, length=4) is None 6 | assert repeat(4, length=4) == [4, 4, 4, 4] 7 | assert repeat([2, 3, 4], length=4) == [2, 3, 4, 4] 8 | assert repeat([1, 2, 3, 4], length=4) == [1, 2, 3, 4] 9 | assert repeat([1, 2, 3, 4, 5], length=4) == [1, 2, 3, 4] 10 | -------------------------------------------------------------------------------- /test/utils/test_select.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torch_geometric.utils import narrow, select 4 | 5 | 6 | def test_select(): 7 | src = torch.randn(5, 3) 8 | index = torch.tensor([0, 2, 4]) 9 | mask = torch.tensor([True, False, True, False, True]) 10 | 11 | out = select(src, index, dim=0) 12 | assert torch.equal(out, src[index]) 13 | assert torch.equal(out, select(src, mask, dim=0)) 14 | assert torch.equal(out, torch.tensor(select(src.tolist(), index, dim=0))) 15 | assert torch.equal(out, torch.tensor(select(src.tolist(), mask, dim=0))) 16 | 17 | 18 | def test_narrow(): 19 | src = torch.randn(5, 3) 20 | 21 | out = narrow(src, dim=0, start=2, length=2) 22 | assert torch.equal(out, src[2:4]) 23 | assert torch.equal(out, torch.tensor(narrow(src.tolist(), 0, 2, 2))) 24 | -------------------------------------------------------------------------------- /test/utils/test_tree_decomposition.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from torch_geometric.testing import withPackage 4 | from torch_geometric.utils import tree_decomposition 5 | 6 | 7 | @withPackage('rdkit') 8 | @pytest.mark.parametrize('smiles', [ 9 | r'F/C=C/F', 10 | r'C/C(=C\C(=O)c1ccc(C)o1)Nc1ccc2c(c1)OCO2', 11 | ]) 12 | def test_tree_decomposition(smiles): 13 | from rdkit import Chem 14 | mol = Chem.MolFromSmiles(smiles) 15 | tree_decomposition(mol) # TODO Test output 16 | -------------------------------------------------------------------------------- /test/utils/test_unbatch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torch_geometric.utils import unbatch, unbatch_edge_index 4 | 5 | 6 | def test_unbatch(): 7 | src = torch.arange(10) 8 | batch = torch.tensor([0, 0, 0, 1, 1, 2, 2, 3, 4, 4]) 9 | 10 | out = unbatch(src, batch) 11 | assert len(out) == 5 12 | for i in range(len(out)): 13 | assert torch.equal(out[i], src[batch == i]) 14 | 15 | 16 | def test_unbatch_edge_index(): 17 | edge_index = torch.tensor([ 18 | [0, 1, 1, 2, 2, 3, 4, 5, 5, 6], 19 | [1, 0, 2, 1, 3, 2, 5, 4, 6, 5], 20 | ]) 21 | batch = torch.tensor([0, 0, 0, 0, 1, 1, 1]) 22 | 23 | edge_indices = unbatch_edge_index(edge_index, batch) 24 | assert edge_indices[0].tolist() == [[0, 1, 1, 2, 2, 3], [1, 0, 2, 1, 3, 2]] 25 | assert edge_indices[1].tolist() == [[0, 1, 1, 2], [1, 0, 2, 1]] 26 | -------------------------------------------------------------------------------- /test/visualization/test_influence.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torch_geometric.datasets import KarateClub 4 | from torch_geometric.nn import GCNConv 5 | from torch_geometric.visualization import influence 6 | 7 | 8 | class Net(torch.nn.Module): 9 | def __init__(self, in_channels, out_channels): 10 | super().__init__() 11 | self.conv1 = GCNConv(in_channels, out_channels) 12 | self.conv2 = GCNConv(out_channels, out_channels) 13 | 14 | def forward(self, x, edge_index): 15 | x = torch.nn.functional.relu(self.conv1(x, edge_index)) 16 | x = self.conv2(x, edge_index) 17 | return x 18 | 19 | 20 | def test_influence(): 21 | data = KarateClub()[0] 22 | x = torch.randn(data.num_nodes, 8) 23 | 24 | out = influence(Net(x.size(1), 16), x, data.edge_index) 25 | assert out.size() == (data.num_nodes, data.num_nodes) 26 | assert torch.allclose(out.sum(dim=-1), torch.ones(data.num_nodes), 27 | atol=1e-04) 28 | -------------------------------------------------------------------------------- /torch_geometric/contrib/__init__.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | import torch_geometric.contrib.transforms # noqa 4 | import torch_geometric.contrib.datasets # noqa 5 | import torch_geometric.contrib.nn # noqa 6 | import torch_geometric.contrib.explain # noqa 7 | 8 | warnings.warn( 9 | "'torch_geometric.contrib' contains experimental code and is subject to " 10 | "change. Please use with caution.", stacklevel=2) 11 | 12 | __all__ = [] 13 | -------------------------------------------------------------------------------- /torch_geometric/contrib/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | __all__ = classes = [] 2 | -------------------------------------------------------------------------------- /torch_geometric/contrib/explain/__init__.py: -------------------------------------------------------------------------------- 1 | from torch_geometric.deprecation import deprecated 2 | 3 | from .pgm_explainer import PGMExplainer 4 | from torch_geometric.explain.algorithm.graphmask_explainer import ( 5 | GraphMaskExplainer as NewGraphMaskExplainer) 6 | 7 | GraphMaskExplainer = deprecated( 8 | "use 'torch_geometric.explain.algorithm.GraphMaskExplainer' instead", )( 9 | NewGraphMaskExplainer) 10 | 11 | __all__ = classes = [ 12 | 'PGMExplainer', 13 | ] 14 | -------------------------------------------------------------------------------- /torch_geometric/contrib/nn/__init__.py: -------------------------------------------------------------------------------- 1 | from .conv import * # noqa 2 | from .models import * # noqa 3 | 4 | __all__ = [] 5 | -------------------------------------------------------------------------------- /torch_geometric/contrib/nn/conv/__init__.py: -------------------------------------------------------------------------------- 1 | __all__ = classes = [] 2 | -------------------------------------------------------------------------------- /torch_geometric/contrib/nn/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .rbcd_attack import PRBCDAttack, GRBCDAttack 2 | 3 | __all__ = classes = [ 4 | 'PRBCDAttack', 5 | 'GRBCDAttack', 6 | ] 7 | -------------------------------------------------------------------------------- /torch_geometric/contrib/transforms/__init__.py: -------------------------------------------------------------------------------- 1 | __all__ = classes = [] 2 | -------------------------------------------------------------------------------- /torch_geometric/data/lightning/__init__.py: -------------------------------------------------------------------------------- 1 | from .datamodule import LightningDataset, LightningNodeData, LightningLinkData 2 | 3 | __all__ = classes = [ 4 | 'LightningDataset', 5 | 'LightningNodeData', 6 | 'LightningLinkData', 7 | ] 8 | -------------------------------------------------------------------------------- /torch_geometric/data/makedirs.py: -------------------------------------------------------------------------------- 1 | from torch_geometric.deprecation import deprecated 2 | from torch_geometric.io import fs 3 | 4 | 5 | @deprecated("use 'os.makedirs(path, exist_ok=True)' instead") 6 | def makedirs(path: str): 7 | r"""Recursively creates a directory. 8 | 9 | .. warning:: 10 | 11 | :meth:`makedirs` is deprecated and will be removed soon. 12 | Please use :obj:`os.makedirs(path, exist_ok=True)` instead. 13 | 14 | Args: 15 | path (str): The path to create. 16 | """ 17 | fs.makedirs(path, exist_ok=True) 18 | -------------------------------------------------------------------------------- /torch_geometric/datasets/graph_generator/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import GraphGenerator 2 | from .ba_graph import BAGraph 3 | from .er_graph import ERGraph 4 | from .grid_graph import GridGraph 5 | from .tree_graph import TreeGraph 6 | 7 | __all__ = classes = [ 8 | 'GraphGenerator', 9 | 'BAGraph', 10 | 'ERGraph', 11 | 'GridGraph', 12 | 'TreeGraph', 13 | ] 14 | -------------------------------------------------------------------------------- /torch_geometric/datasets/motif_generator/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import MotifGenerator 2 | from .custom import CustomMotif 3 | from .house import HouseMotif 4 | from .cycle import CycleMotif 5 | from .grid import GridMotif 6 | 7 | __all__ = classes = [ 8 | 'MotifGenerator', 9 | 'CustomMotif', 10 | 'HouseMotif', 11 | 'CycleMotif', 12 | 'GridMotif', 13 | ] 14 | -------------------------------------------------------------------------------- /torch_geometric/datasets/motif_generator/base.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Any 3 | 4 | from torch_geometric.data import Data 5 | from torch_geometric.resolver import resolver 6 | 7 | 8 | class MotifGenerator(ABC): 9 | r"""An abstract base class for generating a motif.""" 10 | @abstractmethod 11 | def __call__(self) -> Data: 12 | r"""To be implemented by :class:`Motif` subclasses.""" 13 | 14 | @staticmethod 15 | def resolve(query: Any, *args: Any, **kwargs: Any) -> 'MotifGenerator': 16 | import torch_geometric.datasets.motif_generator as _motif_generators 17 | motif_generators = [ 18 | gen for gen in vars(_motif_generators).values() 19 | if isinstance(gen, type) and issubclass(gen, MotifGenerator) 20 | ] 21 | return resolver(motif_generators, {}, query, MotifGenerator, 'Motif', 22 | *args, **kwargs) 23 | 24 | def __repr__(self) -> str: 25 | return f'{self.__class__.__name__}()' 26 | -------------------------------------------------------------------------------- /torch_geometric/datasets/motif_generator/house.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torch_geometric.data import Data 4 | from torch_geometric.datasets.motif_generator import CustomMotif 5 | 6 | 7 | class HouseMotif(CustomMotif): 8 | r"""Generates the house-structured motif from the `"GNNExplainer: 9 | Generating Explanations for Graph Neural Networks" 10 | `__ paper, containing 5 nodes and 6 11 | undirected edges. Nodes are labeled according to their structural role: 12 | the top, middle and bottom of the house. 13 | """ 14 | def __init__(self) -> None: 15 | structure = Data( 16 | num_nodes=5, 17 | edge_index=torch.tensor([ 18 | [0, 0, 0, 1, 1, 1, 2, 2, 3, 3, 4, 4], 19 | [1, 3, 4, 4, 2, 0, 1, 3, 2, 0, 0, 1], 20 | ]), 21 | y=torch.tensor([0, 0, 1, 1, 2]), 22 | ) 23 | super().__init__(structure) 24 | -------------------------------------------------------------------------------- /torch_geometric/datasets/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .cheatsheet import paper_link, has_stats, get_stat, get_children, get_type 2 | 3 | __all__ = [ 4 | 'paper_link', 5 | 'has_stats', 6 | 'get_stat', 7 | 'get_children', 8 | 'get_type', 9 | ] 10 | -------------------------------------------------------------------------------- /torch_geometric/deprecation.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import inspect 3 | import warnings 4 | from typing import Any, Callable, Optional 5 | 6 | 7 | def deprecated( 8 | details: Optional[str] = None, 9 | func_name: Optional[str] = None, 10 | ) -> Callable: 11 | def decorator(func: Callable) -> Callable: 12 | name = func_name or func.__name__ 13 | 14 | if inspect.isclass(func): 15 | cls = type(func.__name__, (func, ), {}) 16 | cls.__init__ = deprecated(details, name)( # type: ignore 17 | func.__init__) 18 | cls.__doc__ = func.__doc__ 19 | return cls 20 | 21 | @functools.wraps(func) 22 | def wrapper(*args: Any, **kwargs: Any) -> Any: 23 | out = f"'{name}' is deprecated" 24 | if details is not None: 25 | out += f", {details}" 26 | warnings.warn(out, stacklevel=2) 27 | return func(*args, **kwargs) 28 | 29 | return wrapper 30 | 31 | return decorator 32 | -------------------------------------------------------------------------------- /torch_geometric/distributed/dist_context.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from enum import Enum 3 | 4 | 5 | class DistRole(Enum): 6 | WORKER = 1 7 | 8 | 9 | @dataclass 10 | class DistContext: 11 | r"""Context information of the current process.""" 12 | rank: int 13 | global_rank: int 14 | world_size: int 15 | global_world_size: int 16 | group_name: str 17 | role: DistRole = DistRole.WORKER 18 | 19 | @property 20 | def worker_name(self) -> str: 21 | return f'{self.group_name}-{self.rank}' 22 | -------------------------------------------------------------------------------- /torch_geometric/explain/__init__.py: -------------------------------------------------------------------------------- 1 | from .config import ExplainerConfig, ModelConfig, ThresholdConfig 2 | from .explanation import Explanation, HeteroExplanation 3 | from .algorithm import * # noqa 4 | from .explainer import Explainer 5 | from .metric import * # noqa 6 | 7 | __all__ = [ 8 | 'ExplainerConfig', 9 | 'ModelConfig', 10 | 'ThresholdConfig', 11 | 'Explanation', 12 | 'HeteroExplanation', 13 | 'Explainer', 14 | ] 15 | -------------------------------------------------------------------------------- /torch_geometric/explain/algorithm/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import ExplainerAlgorithm 2 | from .dummy_explainer import DummyExplainer 3 | from .gnn_explainer import GNNExplainer 4 | from .captum_explainer import CaptumExplainer 5 | from .pg_explainer import PGExplainer 6 | from .attention_explainer import AttentionExplainer 7 | from .graphmask_explainer import GraphMaskExplainer 8 | 9 | __all__ = classes = [ 10 | 'ExplainerAlgorithm', 11 | 'DummyExplainer', 12 | 'GNNExplainer', 13 | 'CaptumExplainer', 14 | 'PGExplainer', 15 | 'AttentionExplainer', 16 | 'GraphMaskExplainer', 17 | ] 18 | -------------------------------------------------------------------------------- /torch_geometric/explain/metric/__init__.py: -------------------------------------------------------------------------------- 1 | from .basic import groundtruth_metrics 2 | from .fidelity import fidelity, characterization_score, fidelity_curve_auc 3 | from .faithfulness import unfaithfulness 4 | 5 | __all__ = classes = [ 6 | 'groundtruth_metrics', 7 | 'fidelity', 8 | 'characterization_score', 9 | 'fidelity_curve_auc', 10 | 'unfaithfulness', 11 | ] 12 | -------------------------------------------------------------------------------- /torch_geometric/graphgym/benchmark.py: -------------------------------------------------------------------------------- 1 | # Do not change; required for benchmarking 2 | 3 | import torch_geometric_benchmark.torchprof_local as torchprof # noqa 4 | from pytorch_memlab import LineProfiler # noqa 5 | from torch_geometric_benchmark.utils import count_parameters # noqa 6 | from torch_geometric_benchmark.utils import get_gpu_memory_nvdia # noqa 7 | from torch_geometric_benchmark.utils import get_memory_status # noqa 8 | from torch_geometric_benchmark.utils import get_model_size # noqa 9 | 10 | global_line_profiler = LineProfiler() 11 | global_line_profiler.enable() 12 | -------------------------------------------------------------------------------- /torch_geometric/graphgym/cmd_args.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | 4 | def parse_args() -> argparse.Namespace: 5 | r"""Parses the command line arguments.""" 6 | parser = argparse.ArgumentParser(description='GraphGym') 7 | 8 | parser.add_argument('--cfg', dest='cfg_file', type=str, required=True, 9 | help='The configuration file path.') 10 | parser.add_argument('--repeat', type=int, default=1, 11 | help='The number of repeated jobs.') 12 | parser.add_argument('--mark_done', action='store_true', 13 | help='Mark yaml as done after a job has finished.') 14 | parser.add_argument('opts', default=None, nargs=argparse.REMAINDER, 15 | help='See graphgym/config.py for remaining options.') 16 | 17 | return parser.parse_args() 18 | -------------------------------------------------------------------------------- /torch_geometric/graphgym/contrib/__init__.py: -------------------------------------------------------------------------------- 1 | from .act import * # noqa 2 | from .config import * # noqa 3 | from .encoder import * # noqa 4 | from .head import * # noqa 5 | from .layer import * # noqa 6 | from .loader import * # noqa 7 | from .loss import * # noqa 8 | from .network import * # noqa 9 | from .optimizer import * # noqa 10 | from .pooling import * # noqa 11 | from .stage import * # noqa 12 | from .train import * # noqa 13 | from .transform import * # noqa 14 | -------------------------------------------------------------------------------- /torch_geometric/graphgym/contrib/act/__init__.py: -------------------------------------------------------------------------------- 1 | from os.path import dirname, basename, isfile, join 2 | import glob 3 | 4 | modules = glob.glob(join(dirname(__file__), "*.py")) 5 | __all__ = [ 6 | basename(f)[:-3] for f in modules 7 | if isfile(f) and not f.endswith('__init__.py') 8 | ] 9 | -------------------------------------------------------------------------------- /torch_geometric/graphgym/contrib/config/__init__.py: -------------------------------------------------------------------------------- 1 | from os.path import dirname, basename, isfile, join 2 | import glob 3 | 4 | modules = glob.glob(join(dirname(__file__), "*.py")) 5 | __all__ = [ 6 | basename(f)[:-3] for f in modules 7 | if isfile(f) and not f.endswith('__init__.py') 8 | ] 9 | -------------------------------------------------------------------------------- /torch_geometric/graphgym/contrib/encoder/__init__.py: -------------------------------------------------------------------------------- 1 | from os.path import dirname, basename, isfile, join 2 | import glob 3 | 4 | modules = glob.glob(join(dirname(__file__), "*.py")) 5 | __all__ = [ 6 | basename(f)[:-3] for f in modules 7 | if isfile(f) and not f.endswith('__init__.py') 8 | ] 9 | -------------------------------------------------------------------------------- /torch_geometric/graphgym/contrib/head/__init__.py: -------------------------------------------------------------------------------- 1 | from os.path import dirname, basename, isfile, join 2 | import glob 3 | 4 | modules = glob.glob(join(dirname(__file__), "*.py")) 5 | __all__ = [ 6 | basename(f)[:-3] for f in modules 7 | if isfile(f) and not f.endswith('__init__.py') 8 | ] 9 | -------------------------------------------------------------------------------- /torch_geometric/graphgym/contrib/layer/__init__.py: -------------------------------------------------------------------------------- 1 | from os.path import dirname, basename, isfile, join 2 | import glob 3 | 4 | modules = glob.glob(join(dirname(__file__), "*.py")) 5 | __all__ = [ 6 | basename(f)[:-3] for f in modules 7 | if isfile(f) and not f.endswith('__init__.py') 8 | ] 9 | -------------------------------------------------------------------------------- /torch_geometric/graphgym/contrib/loader/__init__.py: -------------------------------------------------------------------------------- 1 | from os.path import dirname, basename, isfile, join 2 | import glob 3 | 4 | modules = glob.glob(join(dirname(__file__), "*.py")) 5 | __all__ = [ 6 | basename(f)[:-3] for f in modules 7 | if isfile(f) and not f.endswith('__init__.py') 8 | ] 9 | -------------------------------------------------------------------------------- /torch_geometric/graphgym/contrib/loss/__init__.py: -------------------------------------------------------------------------------- 1 | from os.path import dirname, basename, isfile, join 2 | import glob 3 | 4 | modules = glob.glob(join(dirname(__file__), "*.py")) 5 | __all__ = [ 6 | basename(f)[:-3] for f in modules 7 | if isfile(f) and not f.endswith('__init__.py') 8 | ] 9 | -------------------------------------------------------------------------------- /torch_geometric/graphgym/contrib/network/__init__.py: -------------------------------------------------------------------------------- 1 | from os.path import dirname, basename, isfile, join 2 | import glob 3 | 4 | modules = glob.glob(join(dirname(__file__), "*.py")) 5 | __all__ = [ 6 | basename(f)[:-3] for f in modules 7 | if isfile(f) and not f.endswith('__init__.py') 8 | ] 9 | -------------------------------------------------------------------------------- /torch_geometric/graphgym/contrib/optimizer/__init__.py: -------------------------------------------------------------------------------- 1 | from os.path import dirname, basename, isfile, join 2 | import glob 3 | 4 | modules = glob.glob(join(dirname(__file__), "*.py")) 5 | __all__ = [ 6 | basename(f)[:-3] for f in modules 7 | if isfile(f) and not f.endswith('__init__.py') 8 | ] 9 | -------------------------------------------------------------------------------- /torch_geometric/graphgym/contrib/pooling/__init__.py: -------------------------------------------------------------------------------- 1 | from os.path import dirname, basename, isfile, join 2 | import glob 3 | 4 | modules = glob.glob(join(dirname(__file__), "*.py")) 5 | __all__ = [ 6 | basename(f)[:-3] for f in modules 7 | if isfile(f) and not f.endswith('__init__.py') 8 | ] 9 | -------------------------------------------------------------------------------- /torch_geometric/graphgym/contrib/stage/__init__.py: -------------------------------------------------------------------------------- 1 | from os.path import dirname, basename, isfile, join 2 | import glob 3 | 4 | modules = glob.glob(join(dirname(__file__), "*.py")) 5 | __all__ = [ 6 | basename(f)[:-3] for f in modules 7 | if isfile(f) and not f.endswith('__init__.py') 8 | ] 9 | -------------------------------------------------------------------------------- /torch_geometric/graphgym/contrib/train/__init__.py: -------------------------------------------------------------------------------- 1 | from os.path import dirname, basename, isfile, join 2 | import glob 3 | 4 | modules = glob.glob(join(dirname(__file__), "*.py")) 5 | __all__ = [ 6 | basename(f)[:-3] for f in modules 7 | if isfile(f) and not f.endswith('__init__.py') 8 | ] 9 | -------------------------------------------------------------------------------- /torch_geometric/graphgym/contrib/transform/__init__.py: -------------------------------------------------------------------------------- 1 | from os.path import dirname, basename, isfile, join 2 | import glob 3 | 4 | modules = glob.glob(join(dirname(__file__), "*.py")) 5 | __all__ = [ 6 | basename(f)[:-3] for f in modules 7 | if isfile(f) and not f.endswith('__init__.py') 8 | ] 9 | -------------------------------------------------------------------------------- /torch_geometric/graphgym/imports.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | import torch 4 | 5 | try: 6 | import lightning.pytorch as pl 7 | _pl_is_available = True 8 | except ImportError: 9 | try: 10 | import pytorch_lightning as pl 11 | _pl_is_available = True 12 | except ImportError: 13 | _pl_is_available = False 14 | 15 | if _pl_is_available: 16 | LightningModule = pl.LightningModule 17 | Callback = pl.Callback 18 | else: 19 | pl = object 20 | LightningModule = torch.nn.Module 21 | Callback = object 22 | 23 | warnings.warn( 24 | "To use GraphGym, install 'pytorch_lightning' or 'lightning' via " 25 | "'pip install pytorch_lightning' or 'pip install lightning'", 26 | stacklevel=2) 27 | -------------------------------------------------------------------------------- /torch_geometric/graphgym/init.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def init_weights(m): 5 | r"""Performs weight initialization. 6 | 7 | Args: 8 | m (nn.Module): PyTorch module 9 | 10 | """ 11 | if (isinstance(m, torch.nn.BatchNorm2d) 12 | or isinstance(m, torch.nn.BatchNorm1d)): 13 | m.weight.data.fill_(1.0) 14 | m.bias.data.zero_() 15 | elif isinstance(m, torch.nn.Linear): 16 | m.weight.data = torch.nn.init.xavier_uniform_( 17 | m.weight.data, gain=torch.nn.init.calculate_gain('relu')) 18 | if m.bias is not None: 19 | m.bias.data.zero_() 20 | -------------------------------------------------------------------------------- /torch_geometric/graphgym/models/pooling.py: -------------------------------------------------------------------------------- 1 | from torch_geometric.graphgym.register import register_pooling 2 | from torch_geometric.nn import ( 3 | global_add_pool, 4 | global_max_pool, 5 | global_mean_pool, 6 | ) 7 | 8 | register_pooling('add', global_add_pool) 9 | register_pooling('mean', global_mean_pool) 10 | register_pooling('max', global_max_pool) 11 | -------------------------------------------------------------------------------- /torch_geometric/graphgym/utils/LICENSE: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pyg-team/pytorch_geometric/6c22779538bba0cc6650110f8edcfcd05da5acb4/torch_geometric/graphgym/utils/LICENSE -------------------------------------------------------------------------------- /torch_geometric/graphgym/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .agg_runs import agg_runs, agg_batch 2 | from .comp_budget import params_count, match_baseline_cfg 3 | from .device import get_current_gpu_usage, auto_select_device 4 | from .epoch import is_eval_epoch, is_ckpt_epoch 5 | from .io import dict_to_json, dict_list_to_json, dict_to_tb, makedirs_rm_exist 6 | from .tools import dummy_context 7 | 8 | __all__ = [ 9 | 'agg_runs', 10 | 'agg_batch', 11 | 'params_count', 12 | 'match_baseline_cfg', 13 | 'get_current_gpu_usage', 14 | 'auto_select_device', 15 | 'is_eval_epoch', 16 | 'is_ckpt_epoch', 17 | 'dict_to_json', 18 | 'dict_list_to_json', 19 | 'dict_to_tb', 20 | 'makedirs_rm_exist', 21 | 'dummy_context', 22 | ] 23 | 24 | classes = __all__ 25 | -------------------------------------------------------------------------------- /torch_geometric/graphgym/utils/epoch.py: -------------------------------------------------------------------------------- 1 | from torch_geometric.graphgym.config import cfg 2 | 3 | 4 | def is_train_eval_epoch(cur_epoch): 5 | """Determines if the model should be evaluated at the training epoch.""" 6 | return is_eval_epoch(cur_epoch) or not cfg.train.skip_train_eval 7 | 8 | 9 | def is_eval_epoch(cur_epoch): 10 | """Determines if the model should be evaluated at the current epoch.""" 11 | return ((cur_epoch + 1) % cfg.train.eval_period == 0 or cur_epoch == 0 12 | or (cur_epoch + 1) == cfg.optim.max_epoch) 13 | 14 | 15 | def is_ckpt_epoch(cur_epoch): 16 | """Determines if the model should be evaluated at the current epoch.""" 17 | return ((cur_epoch + 1) % cfg.train.ckpt_period == 0 18 | or (cur_epoch + 1) == cfg.optim.max_epoch) 19 | -------------------------------------------------------------------------------- /torch_geometric/graphgym/utils/plot.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | 3 | 4 | def view_emb(emb, dir): 5 | """Visualize a embedding matrix. 6 | 7 | Args: 8 | emb (torch.tensor): Embedding matrix with shape (N, D). D is the 9 | feature dimension. 10 | dir (str): Output directory for the embedding figure. 11 | """ 12 | import matplotlib.pyplot as plt 13 | import seaborn as sns 14 | from sklearn.decomposition import PCA 15 | 16 | sns.set_context('poster') 17 | 18 | if emb.shape[1] > 2: 19 | pca = PCA(n_components=2) 20 | emb = pca.fit_transform(emb) 21 | plt.figure(figsize=(10, 10)) 22 | plt.scatter(emb[:, 0], emb[:, 1]) 23 | plt.savefig(osp.join(dir, 'emb_pca.png'), dpi=100) 24 | -------------------------------------------------------------------------------- /torch_geometric/graphgym/utils/tools.py: -------------------------------------------------------------------------------- 1 | class dummy_context(): 2 | """Default context manager that does nothing.""" 3 | def __enter__(self): 4 | return None 5 | 6 | def __exit__(self, exc_type, exc_value, traceback): 7 | return False 8 | -------------------------------------------------------------------------------- /torch_geometric/home.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | from typing import Optional 4 | 5 | ENV_PYG_HOME = 'PYG_HOME' 6 | DEFAULT_CACHE_DIR = osp.join('~', '.cache', 'pyg') 7 | 8 | _home_dir: Optional[str] = None 9 | 10 | 11 | def get_home_dir() -> str: 12 | r"""Get the cache directory used for storing all :pyg:`PyG`-related data. 13 | 14 | If :meth:`set_home_dir` is not called, the path is given by the environment 15 | variable :obj:`$PYG_HOME` which defaults to :obj:`"~/.cache/pyg"`. 16 | """ 17 | if _home_dir is not None: 18 | return _home_dir 19 | 20 | return osp.expanduser(os.getenv(ENV_PYG_HOME, DEFAULT_CACHE_DIR)) 21 | 22 | 23 | def set_home_dir(path: str) -> None: 24 | r"""Set the cache directory used for storing all :pyg:`PyG`-related data. 25 | 26 | Args: 27 | path (str): The path to a local folder. 28 | """ 29 | global _home_dir 30 | _home_dir = path 31 | -------------------------------------------------------------------------------- /torch_geometric/io/__init__.py: -------------------------------------------------------------------------------- 1 | from .txt_array import parse_txt_array, read_txt_array 2 | from .tu import read_tu_data 3 | from .planetoid import read_planetoid_data 4 | from .ply import read_ply 5 | from .obj import read_obj 6 | from .sdf import read_sdf, parse_sdf 7 | from .off import read_off, write_off 8 | from .npz import read_npz, parse_npz 9 | 10 | __all__ = [ 11 | 'read_off', 12 | 'write_off', 13 | 'parse_txt_array', 14 | 'read_txt_array', 15 | 'read_tu_data', 16 | 'read_planetoid_data', 17 | 'read_ply', 18 | 'read_obj', 19 | 'read_sdf', 20 | 'parse_sdf', 21 | 'read_npz', 22 | 'parse_npz', 23 | ] 24 | -------------------------------------------------------------------------------- /torch_geometric/io/ply.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torch_geometric.data import Data 4 | 5 | try: 6 | import openmesh 7 | except ImportError: 8 | openmesh = None 9 | 10 | 11 | def read_ply(path: str) -> Data: 12 | if openmesh is None: 13 | raise ImportError('`read_ply` requires the `openmesh` package.') 14 | 15 | mesh = openmesh.read_trimesh(path) 16 | pos = torch.from_numpy(mesh.points()).to(torch.float) 17 | face = torch.from_numpy(mesh.face_vertex_indices()) 18 | face = face.t().to(torch.long).contiguous() 19 | return Data(pos=pos, face=face) 20 | -------------------------------------------------------------------------------- /torch_geometric/llm/__init__.py: -------------------------------------------------------------------------------- 1 | from .large_graph_indexer import LargeGraphIndexer 2 | from .rag_loader import RAGQueryLoader 3 | from .utils import * # noqa 4 | from .models import * # noqa 5 | 6 | __all__ = classes = [ 7 | 'LargeGraphIndexer', 8 | 'RAGQueryLoader', 9 | ] 10 | -------------------------------------------------------------------------------- /torch_geometric/llm/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .sentence_transformer import SentenceTransformer 2 | from .vision_transformer import VisionTransformer 3 | from .llm import LLM 4 | from .txt2kg import TXT2KG 5 | from .llm_judge import LLMJudge 6 | from .g_retriever import GRetriever 7 | from .molecule_gpt import MoleculeGPT 8 | from .glem import GLEM 9 | from .protein_mpnn import ProteinMPNN 10 | from .git_mol import GITMol 11 | 12 | __all__ = classes = [ 13 | 'SentenceTransformer', 14 | 'VisionTransformer', 15 | 'LLM', 16 | 'LLMJudge', 17 | 'TXT2KG', 18 | 'GRetriever', 19 | 'MoleculeGPT', 20 | 'GLEM', 21 | 'ProteinMPNN', 22 | 'GITMol', 23 | ] 24 | -------------------------------------------------------------------------------- /torch_geometric/llm/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .backend_utils import * # noqa 2 | from .feature_store import KNNRAGFeatureStore 3 | from .graph_store import NeighborSamplingRAGGraphStore 4 | from .vectorrag import DocumentRetriever 5 | 6 | __all__ = classes = [ 7 | 'KNNRAGFeatureStore', 8 | 'NeighborSamplingRAGGraphStore', 9 | 'DocumentRetriever', 10 | ] 11 | -------------------------------------------------------------------------------- /torch_geometric/logging.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from typing import Any 3 | 4 | _wandb_initialized: bool = False 5 | 6 | 7 | def init_wandb(name: str, **kwargs: Any) -> None: 8 | if '--wandb' not in sys.argv: 9 | return 10 | 11 | from datetime import datetime 12 | 13 | import wandb 14 | 15 | wandb.init( 16 | project=name, 17 | entity='pytorch-geometric', 18 | name=datetime.now().strftime('%Y-%m-%d_%H:%M'), 19 | config=kwargs, 20 | ) 21 | 22 | global _wandb_initialized 23 | _wandb_initialized = True 24 | 25 | 26 | def log(**kwargs: Any) -> None: 27 | def _map(value: Any) -> str: 28 | if isinstance(value, int) and not isinstance(value, bool): 29 | return f'{value:03d}' 30 | if isinstance(value, float): 31 | return f'{value:.4f}' 32 | return value 33 | 34 | print(', '.join(f'{key}: {_map(value)}' for key, value in kwargs.items())) 35 | 36 | if _wandb_initialized: 37 | import wandb 38 | wandb.log(kwargs) 39 | -------------------------------------------------------------------------------- /torch_geometric/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | 3 | from .link_pred import ( 4 | LinkPredMetric, 5 | LinkPredMetricCollection, 6 | LinkPredPrecision, 7 | LinkPredRecall, 8 | LinkPredF1, 9 | LinkPredMAP, 10 | LinkPredNDCG, 11 | LinkPredMRR, 12 | LinkPredHitRatio, 13 | LinkPredCoverage, 14 | LinkPredDiversity, 15 | LinkPredPersonalization, 16 | LinkPredAveragePopularity, 17 | ) 18 | 19 | link_pred_metrics = [ 20 | 'LinkPredMetric', 21 | 'LinkPredMetricCollection', 22 | 'LinkPredPrecision', 23 | 'LinkPredRecall', 24 | 'LinkPredF1', 25 | 'LinkPredMAP', 26 | 'LinkPredNDCG', 27 | 'LinkPredMRR', 28 | 'LinkPredHitRatio', 29 | 'LinkPredCoverage', 30 | 'LinkPredDiversity', 31 | 'LinkPredPersonalization', 32 | 'LinkPredAveragePopularity', 33 | ] 34 | 35 | __all__ = link_pred_metrics 36 | -------------------------------------------------------------------------------- /torch_geometric/nn/__init__.py: -------------------------------------------------------------------------------- 1 | from .reshape import Reshape 2 | from .sequential import Sequential 3 | from .data_parallel import DataParallel 4 | from .to_hetero_transformer import to_hetero 5 | from .to_hetero_with_bases_transformer import to_hetero_with_bases 6 | from .to_fixed_size_transformer import to_fixed_size 7 | from .encoding import PositionalEncoding, TemporalEncoding 8 | from .summary import summary 9 | 10 | from .aggr import * # noqa 11 | from .attention import * # noqa 12 | from .conv import * # noqa 13 | from .pool import * # noqa 14 | from .glob import * # noqa 15 | from .norm import * # noqa 16 | from .unpool import * # noqa 17 | from .dense import * # noqa 18 | from .kge import * # noqa 19 | from .models import * # noqa 20 | from .functional import * # noqa 21 | 22 | __all__ = [ 23 | 'Reshape', 24 | 'Sequential', 25 | 'DataParallel', 26 | 'to_hetero', 27 | 'to_hetero_with_bases', 28 | 'to_fixed_size', 29 | 'PositionalEncoding', 30 | 'TemporalEncoding', 31 | 'summary', 32 | ] 33 | -------------------------------------------------------------------------------- /torch_geometric/nn/attention/__init__.py: -------------------------------------------------------------------------------- 1 | from .performer import PerformerAttention 2 | from .qformer import QFormer 3 | from .sgformer import SGFormerAttention 4 | from .polynormer import PolynormerAttention 5 | 6 | __all__ = classes = [ 7 | 'PerformerAttention', 8 | 'QFormer', 9 | 'SGFormerAttention', 10 | 'PolynormerAttention', 11 | ] 12 | -------------------------------------------------------------------------------- /torch_geometric/nn/conv/cugraph/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import CuGraphModule 2 | from .sage_conv import CuGraphSAGEConv 3 | from .gat_conv import CuGraphGATConv 4 | from .rgcn_conv import CuGraphRGCNConv 5 | 6 | __all__ = [ 7 | 'CuGraphModule', 8 | 'CuGraphSAGEConv', 9 | 'CuGraphGATConv', 10 | 'CuGraphRGCNConv', 11 | ] 12 | -------------------------------------------------------------------------------- /torch_geometric/nn/conv/utils/__init__.py: -------------------------------------------------------------------------------- 1 | r"""GNN utility package.""" 2 | 3 | from .cheatsheet import paper_title, paper_link 4 | from .cheatsheet import supports_sparse_tensor 5 | from .cheatsheet import supports_edge_weights 6 | from .cheatsheet import supports_edge_features 7 | from .cheatsheet import supports_bipartite_graphs 8 | from .cheatsheet import supports_static_graphs 9 | from .cheatsheet import supports_lazy_initialization 10 | from .cheatsheet import processes_heterogeneous_graphs 11 | from .cheatsheet import processes_hypergraphs 12 | from .cheatsheet import processes_point_clouds 13 | 14 | __all__ = [ 15 | 'paper_title', 16 | 'paper_link', 17 | 'supports_sparse_tensor', 18 | 'supports_edge_weights', 19 | 'supports_edge_features', 20 | 'supports_bipartite_graphs', 21 | 'supports_static_graphs', 22 | 'supports_lazy_initialization', 23 | 'processes_heterogeneous_graphs', 24 | 'processes_hypergraphs', 25 | 'processes_point_clouds', 26 | ] 27 | -------------------------------------------------------------------------------- /torch_geometric/nn/dense/__init__.py: -------------------------------------------------------------------------------- 1 | r"""Dense neural network module package. 2 | 3 | This package provides modules applicable for operating on dense tensor 4 | representations. 5 | """ 6 | 7 | from .linear import Linear, HeteroLinear, HeteroDictLinear 8 | from .dense_gat_conv import DenseGATConv 9 | from .dense_sage_conv import DenseSAGEConv 10 | from .dense_gcn_conv import DenseGCNConv 11 | from .dense_graph_conv import DenseGraphConv 12 | from .dense_gin_conv import DenseGINConv 13 | from .diff_pool import dense_diff_pool 14 | from .mincut_pool import dense_mincut_pool 15 | from .dmon_pool import DMoNPooling 16 | 17 | __all__ = [ 18 | 'Linear', 19 | 'HeteroLinear', 20 | 'HeteroDictLinear', 21 | 'DenseGCNConv', 22 | 'DenseGINConv', 23 | 'DenseGraphConv', 24 | 'DenseSAGEConv', 25 | 'DenseGATConv', 26 | 'dense_diff_pool', 27 | 'dense_mincut_pool', 28 | 'DMoNPooling', 29 | ] 30 | 31 | lin_classes = __all__[:3] 32 | conv_classes = __all__[3:8] 33 | pool_classes = __all__[8:] 34 | -------------------------------------------------------------------------------- /torch_geometric/nn/functional/__init__.py: -------------------------------------------------------------------------------- 1 | r"""Functional operator package.""" 2 | 3 | from .bro import bro 4 | from .gini import gini 5 | 6 | __all__ = classes = [ 7 | 'bro', 8 | 'gini', 9 | ] 10 | -------------------------------------------------------------------------------- /torch_geometric/nn/functional/gini.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def gini(w: torch.Tensor) -> torch.Tensor: 5 | r"""The Gini coefficient from the `"Improving Molecular Graph Neural 6 | Network Explainability with Orthonormalization and Induced Sparsity" 7 | `_ paper. 8 | 9 | Computes a regularization penalty :math:`\in [0, 1]` for each row of a 10 | matrix according to 11 | 12 | .. math:: 13 | \mathcal{L}_\textrm{Gini}^i = \sum_j^n \sum_{j'}^n \frac{|w_{ij} 14 | - w_{ij'}|}{2 (n^2 - n)\bar{w_i}} 15 | 16 | and returns an average over all rows. 17 | 18 | Args: 19 | w (torch.Tensor): A two-dimensional tensor. 20 | """ 21 | s = 0 22 | for row in w: 23 | t = row.repeat(row.size(0), 1) 24 | u = (t - t.T).abs().sum() / (2 * (row.size(-1)**2 - row.size(-1)) * 25 | row.abs().mean() + torch.finfo().eps) 26 | s += u 27 | s /= w.shape[0] 28 | return s 29 | -------------------------------------------------------------------------------- /torch_geometric/nn/kge/__init__.py: -------------------------------------------------------------------------------- 1 | r"""Knowledge Graph Embedding (KGE) package.""" 2 | 3 | from .base import KGEModel 4 | from .transe import TransE 5 | from .complex import ComplEx 6 | from .distmult import DistMult 7 | from .rotate import RotatE 8 | 9 | __all__ = classes = [ 10 | 'KGEModel', 11 | 'TransE', 12 | 'ComplEx', 13 | 'DistMult', 14 | 'RotatE', 15 | ] 16 | -------------------------------------------------------------------------------- /torch_geometric/nn/kge/loader.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple 2 | 3 | import torch 4 | from torch import Tensor 5 | 6 | 7 | class KGTripletLoader(torch.utils.data.DataLoader): 8 | def __init__(self, head_index: Tensor, rel_type: Tensor, 9 | tail_index: Tensor, **kwargs): 10 | self.head_index = head_index 11 | self.rel_type = rel_type 12 | self.tail_index = tail_index 13 | 14 | super().__init__(range(head_index.numel()), collate_fn=self.sample, 15 | **kwargs) 16 | 17 | def sample(self, index: List[int]) -> Tuple[Tensor, Tensor, Tensor]: 18 | index = torch.tensor(index, device=self.head_index.device) 19 | 20 | head_index = self.head_index[index] 21 | rel_type = self.rel_type[index] 22 | tail_index = self.tail_index[index] 23 | 24 | return head_index, rel_type, tail_index 25 | -------------------------------------------------------------------------------- /torch_geometric/nn/norm/__init__.py: -------------------------------------------------------------------------------- 1 | r"""Normalization package.""" 2 | 3 | from .batch_norm import BatchNorm, HeteroBatchNorm 4 | from .instance_norm import InstanceNorm 5 | from .layer_norm import LayerNorm, HeteroLayerNorm 6 | from .graph_norm import GraphNorm 7 | from .graph_size_norm import GraphSizeNorm 8 | from .pair_norm import PairNorm 9 | from .mean_subtraction_norm import MeanSubtractionNorm 10 | from .msg_norm import MessageNorm 11 | from .diff_group_norm import DiffGroupNorm 12 | 13 | __all__ = [ 14 | 'BatchNorm', 15 | 'HeteroBatchNorm', 16 | 'InstanceNorm', 17 | 'LayerNorm', 18 | 'HeteroLayerNorm', 19 | 'GraphNorm', 20 | 'GraphSizeNorm', 21 | 'PairNorm', 22 | 'MeanSubtractionNorm', 23 | 'MessageNorm', 24 | 'DiffGroupNorm', 25 | ] 26 | 27 | classes = __all__ 28 | -------------------------------------------------------------------------------- /torch_geometric/nn/pool/connect/__init__.py: -------------------------------------------------------------------------------- 1 | r"""Graph connection package. 2 | 3 | This package provides classes for determining coarsened graph connections in 4 | graph pooling scenarios. 5 | """ 6 | 7 | from .base import Connect, ConnectOutput 8 | from .filter_edges import FilterEdges 9 | 10 | __all__ = [ 11 | 'Connect', 12 | 'ConnectOutput', 13 | 'FilterEdges', 14 | ] 15 | -------------------------------------------------------------------------------- /torch_geometric/nn/pool/consecutive.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def consecutive_cluster(src): 5 | unique, inv = torch.unique(src, sorted=True, return_inverse=True) 6 | perm = torch.arange(inv.size(0), dtype=inv.dtype, device=inv.device) 7 | perm = inv.new_empty(unique.size(0)).scatter_(0, inv, perm) 8 | return inv, perm 9 | -------------------------------------------------------------------------------- /torch_geometric/nn/pool/pool.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torch 4 | 5 | from torch_geometric.utils import coalesce, remove_self_loops, scatter 6 | 7 | 8 | def pool_edge( 9 | cluster, 10 | edge_index, 11 | edge_attr: Optional[torch.Tensor] = None, 12 | reduce: Optional[str] = 'sum', 13 | ): 14 | num_nodes = cluster.size(0) 15 | edge_index = cluster[edge_index.view(-1)].view(2, -1) 16 | edge_index, edge_attr = remove_self_loops(edge_index, edge_attr) 17 | if edge_index.numel() > 0: 18 | edge_index, edge_attr = coalesce(edge_index, edge_attr, num_nodes, 19 | reduce=reduce) 20 | return edge_index, edge_attr 21 | 22 | 23 | def pool_batch(perm, batch): 24 | return batch[perm] 25 | 26 | 27 | def pool_pos(cluster, pos): 28 | return scatter(pos, cluster, dim=0, reduce='mean') 29 | -------------------------------------------------------------------------------- /torch_geometric/nn/pool/select/__init__.py: -------------------------------------------------------------------------------- 1 | r"""Node-selection package. 2 | 3 | This package provides classes for node selection methods in graph pooling 4 | scenarios. 5 | """ 6 | 7 | from .base import Select, SelectOutput 8 | from .topk import SelectTopK 9 | 10 | __all__ = [ 11 | 'Select', 12 | 'SelectOutput', 13 | 'SelectTopK', 14 | ] 15 | -------------------------------------------------------------------------------- /torch_geometric/nn/reshape.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | 4 | 5 | class Reshape(torch.nn.Module): 6 | def __init__(self, *shape): 7 | super().__init__() 8 | self.shape = shape 9 | 10 | def forward(self, x: Tensor) -> Tensor: 11 | """""" # noqa: D419 12 | x = x.view(*self.shape) 13 | return x 14 | 15 | def __repr__(self) -> str: 16 | shape = ', '.join([str(dim) for dim in self.shape]) 17 | return f'{self.__class__.__name__}({shape})' 18 | -------------------------------------------------------------------------------- /torch_geometric/nn/sequential.jinja: -------------------------------------------------------------------------------- 1 | import typing 2 | 3 | import torch 4 | from torch import Tensor 5 | 6 | import torch_geometric.typing 7 | {% for module in modules %} 8 | from {{module}} import * 9 | {%- endfor %} 10 | 11 | 12 | def forward( 13 | self, 14 | {%- for param in signature.param_dict.values() %} 15 | {{param.name}}: {{param.type_repr}}, 16 | {%- endfor %} 17 | ) -> {{signature.return_type_repr}}: 18 | 19 | {%- for child in children %} 20 | {{child.return_names|join(', ')}} = self.{{child.name}}({{child.param_names|join(', ')}}) 21 | {%- endfor %} 22 | return {{children[-1].return_names|join(', ')}} 23 | -------------------------------------------------------------------------------- /torch_geometric/nn/unpool/__init__.py: -------------------------------------------------------------------------------- 1 | r"""Unpooling package.""" 2 | 3 | from .knn_interpolate import knn_interpolate 4 | 5 | __all__ = [ 6 | 'knn_interpolate', 7 | ] 8 | 9 | classes = __all__ 10 | -------------------------------------------------------------------------------- /torch_geometric/sampler/__init__.py: -------------------------------------------------------------------------------- 1 | r"""Graph sampler package.""" 2 | 3 | from .base import (BaseSampler, NodeSamplerInput, EdgeSamplerInput, 4 | SamplerOutput, HeteroSamplerOutput, NegativeSampling, 5 | NumNeighbors) 6 | from .neighbor_sampler import NeighborSampler, BidirectionalNeighborSampler 7 | from .hgt_sampler import HGTSampler 8 | 9 | __all__ = classes = [ 10 | 'BaseSampler', 11 | 'NodeSamplerInput', 12 | 'EdgeSamplerInput', 13 | 'SamplerOutput', 14 | 'HeteroSamplerOutput', 15 | 'NumNeighbors', 16 | 'NegativeSampling', 17 | 'NeighborSampler', 18 | 'BidirectionalNeighborSampler', 19 | 'HGTSampler', 20 | ] 21 | -------------------------------------------------------------------------------- /torch_geometric/seed.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import numpy as np 4 | import torch 5 | 6 | 7 | def seed_everything(seed: int) -> None: 8 | r"""Sets the seed for generating random numbers in :pytorch:`PyTorch`, 9 | :obj:`numpy` and :python:`Python`. 10 | 11 | Args: 12 | seed (int): The desired seed. 13 | """ 14 | random.seed(seed) 15 | np.random.seed(seed) 16 | torch.manual_seed(seed) 17 | torch.cuda.manual_seed_all(seed) 18 | -------------------------------------------------------------------------------- /torch_geometric/transforms/center.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | 3 | from torch_geometric.data import Data, HeteroData 4 | from torch_geometric.data.datapipes import functional_transform 5 | from torch_geometric.transforms import BaseTransform 6 | 7 | 8 | @functional_transform('center') 9 | class Center(BaseTransform): 10 | r"""Centers node positions :obj:`data.pos` around the origin 11 | (functional name: :obj:`center`). 12 | """ 13 | def forward( 14 | self, 15 | data: Union[Data, HeteroData], 16 | ) -> Union[Data, HeteroData]: 17 | for store in data.node_stores: 18 | if hasattr(store, 'pos'): 19 | store.pos = store.pos - store.pos.mean(dim=-2, keepdim=True) 20 | return data 21 | -------------------------------------------------------------------------------- /torch_geometric/transforms/normalize_scale.py: -------------------------------------------------------------------------------- 1 | from torch_geometric.data import Data 2 | from torch_geometric.data.datapipes import functional_transform 3 | from torch_geometric.transforms import BaseTransform, Center 4 | 5 | 6 | @functional_transform('normalize_scale') 7 | class NormalizeScale(BaseTransform): 8 | r"""Centers and normalizes node positions to the interval :math:`(-1, 1)` 9 | (functional name: :obj:`normalize_scale`). 10 | """ 11 | def __init__(self) -> None: 12 | self.center = Center() 13 | 14 | def forward(self, data: Data) -> Data: 15 | data = self.center(data) 16 | 17 | assert data.pos is not None 18 | scale = (1.0 / data.pos.abs().max()) * 0.999999 19 | data.pos = data.pos * scale 20 | 21 | return data 22 | -------------------------------------------------------------------------------- /torch_geometric/utils/functions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | 4 | 5 | def cumsum(x: Tensor, dim: int = 0) -> Tensor: 6 | r"""Returns the cumulative sum of elements of :obj:`x`. 7 | In contrast to :meth:`torch.cumsum`, prepends the output with zero. 8 | 9 | Args: 10 | x (torch.Tensor): The input tensor. 11 | dim (int, optional): The dimension to do the operation over. 12 | (default: :obj:`0`) 13 | 14 | Example: 15 | >>> x = torch.tensor([2, 4, 1]) 16 | >>> cumsum(x) 17 | tensor([0, 2, 6, 7]) 18 | 19 | """ 20 | size = x.size()[:dim] + (x.size(dim) + 1, ) + x.size()[dim + 1:] 21 | out = x.new_empty(size) 22 | 23 | out.narrow(dim, 0, 1).zero_() 24 | torch.cumsum(x, dim=dim, out=out.narrow(dim, 1, x.size(dim))) 25 | 26 | return out 27 | -------------------------------------------------------------------------------- /torch_geometric/utils/mixin.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Iterator, TypeVar 2 | 3 | T = TypeVar('T') 4 | 5 | 6 | class CastMixin: 7 | @classmethod 8 | def cast(cls: T, *args: Any, **kwargs: Any) -> T: 9 | if len(args) == 1 and len(kwargs) == 0: 10 | elem = args[0] 11 | if elem is None: 12 | return None # type: ignore 13 | if isinstance(elem, CastMixin): 14 | return elem # type: ignore 15 | if isinstance(elem, tuple): 16 | return cls(*elem) # type: ignore 17 | if isinstance(elem, dict): 18 | return cls(**elem) # type: ignore 19 | return cls(*args, **kwargs) # type: ignore 20 | 21 | def __iter__(self) -> Iterator: 22 | return iter(self.__dict__.values()) 23 | -------------------------------------------------------------------------------- /torch_geometric/utils/repeat.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import numbers 3 | from typing import Any 4 | 5 | import torch 6 | from torch import Tensor 7 | 8 | 9 | def repeat(src: Any, length: int) -> Any: 10 | if src is None: 11 | return None 12 | 13 | if isinstance(src, Tensor): 14 | if src.numel() == 1: 15 | return src.repeat(length) 16 | 17 | if src.numel() > length: 18 | return src[:length] 19 | 20 | if src.numel() < length: 21 | last_elem = src[-1].unsqueeze(0) 22 | padding = last_elem.repeat(length - src.numel()) 23 | return torch.cat([src, padding]) 24 | 25 | return src 26 | 27 | if isinstance(src, numbers.Number): 28 | return list(itertools.repeat(src, length)) 29 | 30 | if (len(src) > length): 31 | return src[:length] 32 | 33 | if (len(src) < length): 34 | return src + list(itertools.repeat(src[-1], length - len(src))) 35 | 36 | return src 37 | -------------------------------------------------------------------------------- /torch_geometric/visualization/__init__.py: -------------------------------------------------------------------------------- 1 | r"""Visualization package.""" 2 | 3 | from .graph import visualize_graph, visualize_hetero_graph 4 | from .influence import influence 5 | 6 | __all__ = [ 7 | 'visualize_graph', 8 | 'visualize_hetero_graph', 9 | 'influence', 10 | ] 11 | -------------------------------------------------------------------------------- /torch_geometric/visualization/influence.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | import torch 4 | from torch import Tensor 5 | from torch.autograd import grad 6 | 7 | 8 | def influence(model: torch.nn.Module, src: Tensor, *args: Any) -> Tensor: 9 | x = src.clone().requires_grad_() 10 | out = model(x, *args).sum(dim=-1) 11 | 12 | influences = [] 13 | for j in range(src.size(0)): 14 | influence = grad([out[j]], [x], retain_graph=True)[0].abs().sum(dim=-1) 15 | influences.append(influence / influence.sum()) 16 | 17 | return torch.stack(influences, dim=0) 18 | -------------------------------------------------------------------------------- /torch_geometric/warnings.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from typing import Literal 3 | 4 | import torch_geometric 5 | 6 | 7 | def warn(message: str, stacklevel: int = 5) -> None: 8 | if torch_geometric.is_compiling(): 9 | return 10 | 11 | warnings.warn(message, stacklevel=stacklevel) 12 | 13 | 14 | def filterwarnings( 15 | action: Literal['default', 'error', 'ignore', 'always', 'module', 'once'], 16 | message: str, 17 | ) -> None: 18 | if torch_geometric.is_compiling(): 19 | return 20 | 21 | warnings.filterwarnings(action, message) 22 | 23 | 24 | class WarningCache(set): 25 | """Cache for warnings.""" 26 | def warn(self, message: str, stacklevel: int = 5) -> None: 27 | """Trigger warning message.""" 28 | if message not in self: 29 | self.add(message) 30 | warn(message, stacklevel=stacklevel) 31 | --------------------------------------------------------------------------------