├── .github
├── CODEOWNERS
├── CONTRIBUTING.md
├── ISSUE_TEMPLATE
│ ├── bug-report.yml
│ ├── config.yml
│ ├── documentation.yml
│ ├── feature-request.yml
│ ├── installation.yml
│ └── refactor.yml
├── actions
│ └── setup
│ │ └── action.yml
├── dependabot.yml
├── labeler.yml
└── workflows
│ ├── auto-merge.yml
│ ├── building_nightly.yml
│ ├── changelog.yml
│ ├── documentation.yml
│ ├── examples.yml
│ ├── labeler.yml
│ ├── linting.yml
│ ├── testing_full.yml
│ ├── testing_full_gpu.yml
│ ├── testing_latest.yml
│ ├── testing_minimal.yml
│ ├── testing_nightly.yml
│ ├── testing_prev.yml
│ └── testing_rag.yml
├── .gitignore
├── .pre-commit-config.yaml
├── CHANGELOG.md
├── CITATION.cff
├── LICENSE
├── README.md
├── benchmark
├── README.md
├── citation
│ ├── README.md
│ ├── __init__.py
│ ├── appnp.py
│ ├── arma.py
│ ├── cheb.py
│ ├── datasets.py
│ ├── gat.py
│ ├── gcn.py
│ ├── inference.sh
│ ├── run.sh
│ ├── sgc.py
│ ├── statistics.py
│ └── train_eval.py
├── inference
│ ├── README.md
│ └── inference_benchmark.py
├── kernel
│ ├── README.md
│ ├── __init__.py
│ ├── asap.py
│ ├── datasets.py
│ ├── diff_pool.py
│ ├── edge_pool.py
│ ├── gcn.py
│ ├── gin.py
│ ├── global_attention.py
│ ├── graclus.py
│ ├── graph_sage.py
│ ├── main.py
│ ├── main_performance.py
│ ├── sag_pool.py
│ ├── set2set.py
│ ├── sort_pool.py
│ ├── statistics.py
│ ├── top_k.py
│ └── train_eval.py
├── loader
│ └── neighbor_loader.py
├── multi_gpu
│ └── training
│ │ ├── README.md
│ │ ├── common.py
│ │ ├── training_benchmark_cuda.py
│ │ └── training_benchmark_xpu.py
├── points
│ ├── README.md
│ ├── __init__.py
│ ├── datasets.py
│ ├── edge_cnn.py
│ ├── mpnn.py
│ ├── point_cnn.py
│ ├── point_net.py
│ ├── spline_cnn.py
│ ├── statistics.py
│ └── train_eval.py
├── runtime
│ ├── README.md
│ ├── __init__.py
│ ├── dgl
│ │ ├── gat.py
│ │ ├── gcn.py
│ │ ├── hidden.py
│ │ ├── main.py
│ │ ├── rgcn.py
│ │ └── train.py
│ ├── gat.py
│ ├── gcn.py
│ ├── main.py
│ ├── rgcn.py
│ └── train.py
├── setup.py
├── training
│ ├── README.md
│ └── training_benchmark.py
└── utils
│ ├── __init__.py
│ ├── hetero_gat.py
│ ├── hetero_sage.py
│ └── utils.py
├── codecov.yml
├── docker
├── Dockerfile
├── Dockerfile.xpu
├── README.md
└── singularity
├── docs
├── Makefile
├── README.md
├── requirements.txt
└── source
│ ├── .gitignore
│ ├── _figures
│ ├── .gitignore
│ ├── architecture.pdf
│ ├── architecture.svg
│ ├── build.sh
│ ├── dist_part.png
│ ├── dist_proc.png
│ ├── dist_sampling.png
│ ├── graph.svg
│ ├── graph.tex
│ ├── graphgps_layer.png
│ ├── graphgym_design_space.png
│ ├── graphgym_evaluation.png
│ ├── graphgym_results.png
│ ├── hg_example.svg
│ ├── hg_example.tex
│ ├── intel_kumo.png
│ ├── meshcnn_edge_adjacency.svg
│ ├── point_cloud1.png
│ ├── point_cloud2.png
│ ├── point_cloud3.png
│ ├── point_cloud4.png
│ ├── remote_1.png
│ ├── remote_2.png
│ ├── remote_3.png
│ ├── shallow_node_embeddings.png
│ ├── to_hetero.svg
│ ├── to_hetero.tex
│ ├── to_hetero_with_bases.svg
│ ├── to_hetero_with_bases.tex
│ └── training_affinity.png
│ ├── _static
│ ├── js
│ │ └── version_alert.js
│ └── thumbnails
│ │ ├── create_dataset.png
│ │ ├── create_gnn.png
│ │ ├── dataset_splitting.png
│ │ ├── distributed_pyg.png
│ │ ├── explain.png
│ │ ├── graph_transformer.png
│ │ ├── heterogeneous.png
│ │ ├── load_csv.png
│ │ ├── multi_gpu_vanilla.png
│ │ ├── neighbor_loader.png
│ │ ├── point_cloud.png
│ │ └── shallow_node_embeddings.png
│ ├── _templates
│ └── autosummary
│ │ ├── class.rst
│ │ ├── inherited_class.rst
│ │ ├── metrics.rst
│ │ ├── nn.rst
│ │ └── only_class.rst
│ ├── advanced
│ ├── batching.rst
│ ├── compile.rst
│ ├── cpu_affinity.rst
│ ├── graphgym.rst
│ ├── hgam.rst
│ ├── jit.rst
│ ├── remote.rst
│ └── sparse_tensor.rst
│ ├── cheatsheet
│ ├── data_cheatsheet.rst
│ └── gnn_cheatsheet.rst
│ ├── conf.py
│ ├── external
│ └── resources.rst
│ ├── get_started
│ ├── colabs.rst
│ └── introduction.rst
│ ├── index.rst
│ ├── install
│ ├── installation.rst
│ └── quick-start.html
│ ├── modules
│ ├── contrib.rst
│ ├── data.rst
│ ├── datasets.rst
│ ├── distributed.rst
│ ├── explain.rst
│ ├── graphgym.rst
│ ├── llm.rst
│ ├── loader.rst
│ ├── metrics.rst
│ ├── nn.rst
│ ├── profile.rst
│ ├── root.rst
│ ├── sampler.rst
│ ├── transforms.rst
│ └── utils.rst
│ ├── notes
│ ├── batching.rst
│ ├── cheatsheet.rst
│ ├── colabs.rst
│ ├── create_dataset.rst
│ ├── create_gnn.rst
│ ├── data_cheatsheet.rst
│ ├── explain.rst
│ ├── graphgym.rst
│ ├── heterogeneous.rst
│ ├── installation.rst
│ ├── introduction.rst
│ ├── jit.rst
│ ├── load_csv.rst
│ ├── remote.rst
│ ├── resources.rst
│ └── sparse_tensor.rst
│ └── tutorial
│ ├── application.rst
│ ├── compile.rst
│ ├── create_dataset.rst
│ ├── create_gnn.rst
│ ├── dataset.rst
│ ├── dataset_splitting.rst
│ ├── distributed.rst
│ ├── distributed_pyg.rst
│ ├── explain.rst
│ ├── gnn_design.rst
│ ├── graph_transformer.rst
│ ├── heterogeneous.rst
│ ├── load_csv.rst
│ ├── multi_gpu_vanilla.rst
│ ├── multi_node_multi_gpu_vanilla.rst
│ ├── neighbor_loader.rst
│ ├── point_cloud.rst
│ └── shallow_node_embeddings.rst
├── examples
├── README.md
├── agnn.py
├── ar_link_pred.py
├── argva_node_clustering.py
├── arma.py
├── attentive_fp.py
├── autoencoder.py
├── cluster_gcn_ppi.py
├── cluster_gcn_reddit.py
├── colors_topk_pool.py
├── compile
│ ├── gcn.py
│ └── gin.py
├── contrib
│ ├── README.md
│ ├── pgm_explainer_graph_classification.py
│ ├── pgm_explainer_node_classification.py
│ ├── rbcd_attack.py
│ └── rbcd_attack_poisoning.py
├── cora.py
├── correct_and_smooth.py
├── cpp
│ ├── CMakeLists.txt
│ ├── README.md
│ ├── main.cpp
│ └── save_model.py
├── datapipe.py
├── dgcnn_classification.py
├── dgcnn_segmentation.py
├── dir_gnn.py
├── distributed
│ ├── README.md
│ ├── graphlearn_for_pytorch
│ │ ├── README.md
│ │ ├── dist_train_sage_sup_config.yml
│ │ ├── dist_train_sage_supervised.py
│ │ ├── launch.py
│ │ └── partition_ogbn_dataset.py
│ ├── kuzu
│ │ ├── README.md
│ │ └── papers_100M
│ │ │ ├── README.md
│ │ │ ├── prepare_data.py
│ │ │ └── train.py
│ └── pyg
│ │ ├── README.md
│ │ ├── launch.py
│ │ ├── node_ogb_cpu.py
│ │ ├── partition_graph.py
│ │ ├── run_dist.sh
│ │ └── temporal_link_movielens_cpu.py
├── dna.py
├── egc.py
├── equilibrium_median.py
├── explain
│ ├── README.md
│ ├── captum_explainer.py
│ ├── captum_explainer_hetero_link.py
│ ├── gnn_explainer.py
│ ├── gnn_explainer_ba_shapes.py
│ ├── gnn_explainer_link_pred.py
│ └── graphmask_explainer.py
├── faust.py
├── film.py
├── gat.py
├── gcn.py
├── gcn2_cora.py
├── gcn2_ppi.py
├── geniepath.py
├── glnn.py
├── gpse.py
├── graph_gps.py
├── graph_sage_unsup.py
├── graph_sage_unsup_ppi.py
├── graph_saint.py
├── graph_unet.py
├── hetero
│ ├── README.md
│ ├── bipartite_sage.py
│ ├── bipartite_sage_unsup.py
│ ├── dmgi_unsup.py
│ ├── han_imdb.py
│ ├── hetero_conv_dblp.py
│ ├── hetero_link_pred.py
│ ├── hgt_dblp.py
│ ├── hierarchical_sage.py
│ ├── load_csv.py
│ ├── metapath2vec.py
│ ├── recommender_system.py
│ ├── temporal_link_pred.py
│ └── to_hetero_mag.py
├── hierarchical_sampling.py
├── infomax_inductive.py
├── infomax_transductive.py
├── jit
│ ├── README.md
│ ├── film.py
│ ├── gat.py
│ ├── gcn.py
│ └── gin.py
├── kge_fb15k_237.py
├── label_prop.py
├── lcm_aggr_2nd_min.py
├── lightgcn.py
├── link_pred.py
├── linkx.py
├── llm
│ ├── README.md
│ ├── g_retriever.py
│ ├── git_mol.py
│ ├── glem.py
│ ├── molecule_gpt.py
│ ├── nvtx_examples
│ │ ├── README.md
│ │ ├── nvtx_rag_backend_example.py
│ │ ├── nvtx_run.sh
│ │ └── nvtx_webqsp_example.py
│ ├── protein_mpnn.py
│ └── txt2kg_rag.py
├── lpformer.py
├── mem_pool.py
├── mixhop.py
├── mnist_graclus.py
├── mnist_nn_conv.py
├── mnist_voxel_grid.py
├── multi_gpu
│ ├── README.md
│ ├── data_parallel.py
│ ├── distributed_batching.py
│ ├── distributed_sampling.py
│ ├── distributed_sampling_multinode.py
│ ├── distributed_sampling_multinode.sbatch
│ ├── distributed_sampling_xpu.py
│ ├── mag240m_graphsage.py
│ ├── model_parallel.py
│ ├── ogbn_train_cugraph.py
│ ├── papers100m_gcn.py
│ ├── papers100m_gcn_cugraph_multinode.py
│ ├── papers100m_gcn_multinode.py
│ ├── pcqm4m_ogb.py
│ └── taobao.py
├── mutag_gin.py
├── node2vec.py
├── ogbn_proteins_deepgcn.py
├── ogbn_train.py
├── ogbn_train_cugraph.py
├── ogc.py
├── pmlp.py
├── pna.py
├── point_transformer_classification.py
├── point_transformer_segmentation.py
├── pointnet2_classification.py
├── pointnet2_segmentation.py
├── ppi.py
├── proteins_diff_pool.py
├── proteins_dmon_pool.py
├── proteins_gmt.py
├── proteins_mincut_pool.py
├── proteins_topk_pool.py
├── pytorch_ignite
│ ├── README.md
│ └── gin.py
├── pytorch_lightning
│ ├── README.md
│ ├── gin.py
│ ├── graph_sage.py
│ └── relational_gnn.py
├── qm9_nn_conv.py
├── qm9_pretrained_dimenet.py
├── qm9_pretrained_schnet.py
├── quiver
│ ├── README.md
│ ├── multi_gpu_quiver.py
│ └── single_gpu_quiver.py
├── randlanet_classification.py
├── randlanet_segmentation.py
├── rdl.py
├── rect.py
├── reddit.py
├── renet.py
├── rev_gnn.py
├── rgat.py
├── rgcn.py
├── rgcn_link_pred.py
├── seal_link_pred.py
├── sgc.py
├── shadow.py
├── sign.py
├── signed_gcn.py
├── super_gat.py
├── tagcn.py
├── tensorboard_logging.py
├── tgn.py
├── triangles_sag_pool.py
├── unimp_arxiv.py
├── upfd.py
└── wl_kernel.py
├── graphgym
├── agg_batch.py
├── configs
│ ├── example.yaml
│ └── pyg
│ │ ├── example_graph.yaml
│ │ ├── example_link.yaml
│ │ └── example_node.yaml
├── configs_gen.py
├── custom_graphgym
│ ├── __init__.py
│ ├── act
│ │ ├── __init__.py
│ │ └── example.py
│ ├── config
│ │ ├── __init__.py
│ │ └── example.py
│ ├── encoder
│ │ ├── __init__.py
│ │ └── example.py
│ ├── head
│ │ ├── __init__.py
│ │ └── example.py
│ ├── layer
│ │ ├── __init__.py
│ │ └── example.py
│ ├── loader
│ │ ├── __init__.py
│ │ └── example.py
│ ├── loss
│ │ ├── __init__.py
│ │ └── example.py
│ ├── network
│ │ ├── __init__.py
│ │ └── example.py
│ ├── optimizer
│ │ ├── __init__.py
│ │ └── example.py
│ ├── pooling
│ │ ├── __init__.py
│ │ └── example.py
│ ├── stage
│ │ ├── __init__.py
│ │ └── example.py
│ ├── train
│ │ ├── __init__.py
│ │ └── example.py
│ └── transform
│ │ └── __init__.py
├── grids
│ ├── example.txt
│ └── pyg
│ │ └── example.txt
├── main.py
├── parallel.sh
├── results
│ └── example_node_grid_example
│ │ └── agg
│ │ ├── test.csv
│ │ ├── test_best.csv
│ │ ├── test_bestepoch.csv
│ │ ├── train.csv
│ │ ├── train_best.csv
│ │ ├── train_bestepoch.csv
│ │ ├── val.csv
│ │ ├── val_best.csv
│ │ └── val_bestepoch.csv
├── run_batch.sh
├── run_single.sh
└── sample
│ ├── dimensions.txt
│ └── dimensionsatt.txt
├── pyproject.toml
├── readthedocs.yml
├── test
├── conftest.py
├── contrib
│ ├── explain
│ │ └── test_pgm_explainer.py
│ └── nn
│ │ └── models
│ │ └── test_rbcd_attack.py
├── data
│ ├── lightning
│ │ └── test_datamodule.py
│ ├── test_batch.py
│ ├── test_data.py
│ ├── test_database.py
│ ├── test_datapipes.py
│ ├── test_dataset.py
│ ├── test_dataset_summary.py
│ ├── test_feature_store.py
│ ├── test_graph_store.py
│ ├── test_hetero_data.py
│ ├── test_hypergraph_data.py
│ ├── test_inherit.py
│ ├── test_on_disk_dataset.py
│ ├── test_remote_backend_utils.py
│ ├── test_storage.py
│ ├── test_temporal.py
│ └── test_view.py
├── datasets
│ ├── graph_generator
│ │ ├── test_ba_graph.py
│ │ ├── test_er_graph.py
│ │ ├── test_grid_graph.py
│ │ └── test_tree_graph.py
│ ├── motif_generator
│ │ ├── test_custom_motif.py
│ │ ├── test_cycle_motif.py
│ │ ├── test_grid_motif.py
│ │ └── test_house_motif.py
│ ├── test_ba_shapes.py
│ ├── test_bzr.py
│ ├── test_elliptic.py
│ ├── test_enzymes.py
│ ├── test_explainer_dataset.py
│ ├── test_fake.py
│ ├── test_git_mol_dataset.py
│ ├── test_imdb_binary.py
│ ├── test_infection_dataset.py
│ ├── test_karate.py
│ ├── test_medshapenet.py
│ ├── test_molecule_gpt_dataset.py
│ ├── test_mutag.py
│ ├── test_planetoid.py
│ ├── test_protein_mpnn_dataset.py
│ ├── test_snap_dataset.py
│ ├── test_suite_sparse.py
│ ├── test_tag_dataset.py
│ ├── test_teeth3ds.py
│ └── test_web_qsp_dataset.py
├── distributed
│ ├── test_dist_link_neighbor_loader.py
│ ├── test_dist_link_neighbor_sampler.py
│ ├── test_dist_neighbor_loader.py
│ ├── test_dist_neighbor_sampler.py
│ ├── test_dist_utils.py
│ ├── test_local_feature_store.py
│ ├── test_local_graph_store.py
│ ├── test_partition.py
│ └── test_rpc.py
├── explain
│ ├── algorithm
│ │ ├── test_attention_explainer.py
│ │ ├── test_captum.py
│ │ ├── test_captum_explainer.py
│ │ ├── test_captum_hetero.py
│ │ ├── test_explain_algorithm_utils.py
│ │ ├── test_gnn_explainer.py
│ │ ├── test_graphmask_explainer.py
│ │ └── test_pg_explainer.py
│ ├── conftest.py
│ ├── metric
│ │ ├── test_basic_metric.py
│ │ ├── test_faithfulness.py
│ │ └── test_fidelity.py
│ ├── test_explain_config.py
│ ├── test_explainer.py
│ ├── test_explanation.py
│ ├── test_hetero_explainer.py
│ └── test_hetero_explanation.py
├── graphgym
│ ├── example_node.yml
│ ├── test_config.py
│ ├── test_graphgym.py
│ ├── test_logger.py
│ └── test_register.py
├── io
│ ├── example1.off
│ ├── example2.off
│ ├── test_fs.py
│ └── test_off.py
├── llm
│ ├── models
│ │ ├── test_g_retriever.py
│ │ ├── test_git_mol.py
│ │ ├── test_glem.py
│ │ ├── test_llm.py
│ │ ├── test_molecule_gpt.py
│ │ ├── test_protein_mpnn.py
│ │ ├── test_sentence_transformer.py
│ │ └── test_vision_transformer.py
│ ├── test_large_graph_indexer.py
│ ├── test_rag_loader.py
│ └── utils
│ │ ├── test_rag_backend_utils.py
│ │ ├── test_rag_feature_store.py
│ │ ├── test_rag_graph_store.py
│ │ └── test_vectorrag.py
├── loader
│ ├── test_cache.py
│ ├── test_cluster.py
│ ├── test_dataloader.py
│ ├── test_dynamic_batch_sampler.py
│ ├── test_graph_saint.py
│ ├── test_hgt_loader.py
│ ├── test_ibmb_loader.py
│ ├── test_imbalanced_sampler.py
│ ├── test_link_neighbor_loader.py
│ ├── test_mixin.py
│ ├── test_neighbor_loader.py
│ ├── test_neighbor_sampler.py
│ ├── test_prefetch.py
│ ├── test_random_node_loader.py
│ ├── test_shadow.py
│ ├── test_temporal_dataloader.py
│ ├── test_utils.py
│ └── test_zip_loader.py
├── metrics
│ └── test_link_pred_metric.py
├── my_config.yaml
├── nn
│ ├── aggr
│ │ ├── test_aggr_utils.py
│ │ ├── test_attention.py
│ │ ├── test_basic.py
│ │ ├── test_deep_sets.py
│ │ ├── test_equilibrium.py
│ │ ├── test_fused.py
│ │ ├── test_gmt.py
│ │ ├── test_gru.py
│ │ ├── test_lcm.py
│ │ ├── test_lstm.py
│ │ ├── test_mlp_aggr.py
│ │ ├── test_multi.py
│ │ ├── test_patch_transformer.py
│ │ ├── test_quantile.py
│ │ ├── test_scaler.py
│ │ ├── test_set2set.py
│ │ ├── test_set_transformer.py
│ │ ├── test_sort.py
│ │ └── test_variance_preserving.py
│ ├── attention
│ │ ├── test_performer_attention.py
│ │ ├── test_polynormer_attention.py
│ │ └── test_qformer.py
│ ├── conv
│ │ ├── cugraph
│ │ │ ├── test_cugraph_gat_conv.py
│ │ │ ├── test_cugraph_rgcn_conv.py
│ │ │ └── test_cugraph_sage_conv.py
│ │ ├── test_agnn_conv.py
│ │ ├── test_antisymmetric_conv.py
│ │ ├── test_appnp.py
│ │ ├── test_arma_conv.py
│ │ ├── test_cg_conv.py
│ │ ├── test_cheb_conv.py
│ │ ├── test_cluster_gcn_conv.py
│ │ ├── test_create_gnn.py
│ │ ├── test_dir_gnn_conv.py
│ │ ├── test_dna_conv.py
│ │ ├── test_edge_conv.py
│ │ ├── test_eg_conv.py
│ │ ├── test_fa_conv.py
│ │ ├── test_feast_conv.py
│ │ ├── test_film_conv.py
│ │ ├── test_fused_gat_conv.py
│ │ ├── test_gat_conv.py
│ │ ├── test_gated_graph_conv.py
│ │ ├── test_gatv2_conv.py
│ │ ├── test_gcn2_conv.py
│ │ ├── test_gcn_conv.py
│ │ ├── test_gen_conv.py
│ │ ├── test_general_conv.py
│ │ ├── test_gin_conv.py
│ │ ├── test_gmm_conv.py
│ │ ├── test_gps_conv.py
│ │ ├── test_graph_conv.py
│ │ ├── test_gravnet_conv.py
│ │ ├── test_han_conv.py
│ │ ├── test_heat_conv.py
│ │ ├── test_hetero_conv.py
│ │ ├── test_hgt_conv.py
│ │ ├── test_hypergraph_conv.py
│ │ ├── test_le_conv.py
│ │ ├── test_lg_conv.py
│ │ ├── test_meshcnn_conv.py
│ │ ├── test_message_passing.py
│ │ ├── test_mf_conv.py
│ │ ├── test_mixhop_conv.py
│ │ ├── test_nn_conv.py
│ │ ├── test_pan_conv.py
│ │ ├── test_pdn_conv.py
│ │ ├── test_pna_conv.py
│ │ ├── test_point_conv.py
│ │ ├── test_point_gnn_conv.py
│ │ ├── test_point_transformer_conv.py
│ │ ├── test_ppf_conv.py
│ │ ├── test_res_gated_graph_conv.py
│ │ ├── test_rgat_conv.py
│ │ ├── test_rgcn_conv.py
│ │ ├── test_sage_conv.py
│ │ ├── test_sg_conv.py
│ │ ├── test_signed_conv.py
│ │ ├── test_simple_conv.py
│ │ ├── test_spline_conv.py
│ │ ├── test_ssg_conv.py
│ │ ├── test_static_graph.py
│ │ ├── test_supergat_conv.py
│ │ ├── test_tag_conv.py
│ │ ├── test_transformer_conv.py
│ │ ├── test_wl_conv.py
│ │ ├── test_wl_conv_continuous.py
│ │ ├── test_x_conv.py
│ │ └── utils
│ │ │ └── test_gnn_cheatsheet.py
│ ├── dense
│ │ ├── test_dense_gat_conv.py
│ │ ├── test_dense_gcn_conv.py
│ │ ├── test_dense_gin_conv.py
│ │ ├── test_dense_graph_conv.py
│ │ ├── test_dense_sage_conv.py
│ │ ├── test_diff_pool.py
│ │ ├── test_dmon_pool.py
│ │ ├── test_linear.py
│ │ └── test_mincut_pool.py
│ ├── functional
│ │ ├── test_bro.py
│ │ └── test_gini.py
│ ├── kge
│ │ ├── test_complex.py
│ │ ├── test_distmult.py
│ │ ├── test_rotate.py
│ │ └── test_transe.py
│ ├── models
│ │ ├── test_attentive_fp.py
│ │ ├── test_attract_repel.py
│ │ ├── test_autoencoder.py
│ │ ├── test_basic_gnn.py
│ │ ├── test_correct_and_smooth.py
│ │ ├── test_deep_graph_infomax.py
│ │ ├── test_deepgcn.py
│ │ ├── test_dimenet.py
│ │ ├── test_gnnff.py
│ │ ├── test_gpse.py
│ │ ├── test_graph_mixer.py
│ │ ├── test_graph_unet.py
│ │ ├── test_jumping_knowledge.py
│ │ ├── test_label_prop.py
│ │ ├── test_lightgcn.py
│ │ ├── test_linkx.py
│ │ ├── test_lpformer.py
│ │ ├── test_mask_label.py
│ │ ├── test_meta.py
│ │ ├── test_metapath2vec.py
│ │ ├── test_mlp.py
│ │ ├── test_neural_fingerprint.py
│ │ ├── test_node2vec.py
│ │ ├── test_pmlp.py
│ │ ├── test_polynormer.py
│ │ ├── test_re_net.py
│ │ ├── test_rect.py
│ │ ├── test_rev_gnn.py
│ │ ├── test_schnet.py
│ │ ├── test_sgformer.py
│ │ ├── test_signed_gcn.py
│ │ ├── test_tgn.py
│ │ └── test_visnet.py
│ ├── norm
│ │ ├── test_batch_norm.py
│ │ ├── test_diff_group_norm.py
│ │ ├── test_graph_norm.py
│ │ ├── test_graph_size_norm.py
│ │ ├── test_instance_norm.py
│ │ ├── test_layer_norm.py
│ │ ├── test_mean_subtraction_norm.py
│ │ ├── test_msg_norm.py
│ │ └── test_pair_norm.py
│ ├── pool
│ │ ├── connect
│ │ │ └── test_filter_edges.py
│ │ ├── select
│ │ │ └── test_select_topk.py
│ │ ├── test_approx_knn.py
│ │ ├── test_asap.py
│ │ ├── test_avg_pool.py
│ │ ├── test_cluster_pool.py
│ │ ├── test_consecutive.py
│ │ ├── test_decimation.py
│ │ ├── test_edge_pool.py
│ │ ├── test_glob.py
│ │ ├── test_graclus.py
│ │ ├── test_knn.py
│ │ ├── test_max_pool.py
│ │ ├── test_mem_pool.py
│ │ ├── test_pan_pool.py
│ │ ├── test_pool.py
│ │ ├── test_sag_pool.py
│ │ ├── test_topk_pool.py
│ │ └── test_voxel_grid.py
│ ├── test_compile_basic.py
│ ├── test_compile_conv.py
│ ├── test_compile_dynamic.py
│ ├── test_data_parallel.py
│ ├── test_encoding.py
│ ├── test_fvcore.py
│ ├── test_fx.py
│ ├── test_inits.py
│ ├── test_model_hub.py
│ ├── test_model_summary.py
│ ├── test_module_dict.py
│ ├── test_parameter_dict.py
│ ├── test_reshape.py
│ ├── test_resolver.py
│ ├── test_sequential.py
│ ├── test_to_fixed_size_transformer.py
│ ├── test_to_hetero_module.py
│ ├── test_to_hetero_transformer.py
│ ├── test_to_hetero_with_bases_transformer.py
│ └── unpool
│ │ └── test_knn_interpolate.py
├── profile
│ ├── test_benchmark.py
│ ├── test_nvtx.py
│ ├── test_profile.py
│ ├── test_profile_utils.py
│ └── test_profiler.py
├── sampler
│ ├── test_sampler_base.py
│ └── test_sampler_neighbor_sampler.py
├── test_config_mixin.py
├── test_config_store.py
├── test_debug.py
├── test_edge_index.py
├── test_experimental.py
├── test_hash_tensor.py
├── test_home.py
├── test_index.py
├── test_inspector.py
├── test_isinstance.py
├── test_onnx.py
├── test_seed.py
├── test_typing.py
├── test_warnings.py
├── testing
│ └── test_decorators.py
├── transforms
│ ├── test_add_gpse.py
│ ├── test_add_metapaths.py
│ ├── test_add_positional_encoding.py
│ ├── test_add_remaining_self_loops.py
│ ├── test_add_self_loops.py
│ ├── test_cartesian.py
│ ├── test_center.py
│ ├── test_compose.py
│ ├── test_constant.py
│ ├── test_delaunay.py
│ ├── test_distance.py
│ ├── test_face_to_edge.py
│ ├── test_feature_propagation.py
│ ├── test_fixed_points.py
│ ├── test_gcn_norm.py
│ ├── test_gdc.py
│ ├── test_generate_mesh_normals.py
│ ├── test_grid_sampling.py
│ ├── test_half_hop.py
│ ├── test_knn_graph.py
│ ├── test_laplacian_lambda_max.py
│ ├── test_largest_connected_components.py
│ ├── test_line_graph.py
│ ├── test_linear_transformation.py
│ ├── test_local_cartesian.py
│ ├── test_local_degree_profile.py
│ ├── test_mask_transform.py
│ ├── test_node_property_split.py
│ ├── test_normalize_features.py
│ ├── test_normalize_rotation.py
│ ├── test_normalize_scale.py
│ ├── test_one_hot_degree.py
│ ├── test_pad.py
│ ├── test_point_pair_features.py
│ ├── test_polar.py
│ ├── test_radius_graph.py
│ ├── test_random_flip.py
│ ├── test_random_jitter.py
│ ├── test_random_link_split.py
│ ├── test_random_node_split.py
│ ├── test_random_rotate.py
│ ├── test_random_scale.py
│ ├── test_random_shear.py
│ ├── test_remove_duplicated_edges.py
│ ├── test_remove_isolated_nodes.py
│ ├── test_remove_self_loops.py
│ ├── test_remove_training_classes.py
│ ├── test_rooted_subgraph.py
│ ├── test_sample_points.py
│ ├── test_sign.py
│ ├── test_spherical.py
│ ├── test_svd_feature_reduction.py
│ ├── test_target_indegree.py
│ ├── test_to_dense.py
│ ├── test_to_device.py
│ ├── test_to_sparse_tensor.py
│ ├── test_to_superpixels.py
│ ├── test_to_undirected.py
│ ├── test_two_hop.py
│ └── test_virtual_node.py
├── utils
│ ├── conftest.py
│ ├── test_assortativity.py
│ ├── test_augmentation.py
│ ├── test_coalesce.py
│ ├── test_convert.py
│ ├── test_cross_entropy.py
│ ├── test_degree.py
│ ├── test_dropout.py
│ ├── test_embedding.py
│ ├── test_functions.py
│ ├── test_geodesic.py
│ ├── test_grid.py
│ ├── test_hetero.py
│ ├── test_homophily.py
│ ├── test_index_sort.py
│ ├── test_isolated.py
│ ├── test_laplacian.py
│ ├── test_lexsort.py
│ ├── test_loop.py
│ ├── test_map.py
│ ├── test_mask.py
│ ├── test_mesh_laplacian.py
│ ├── test_negative_sampling.py
│ ├── test_nested.py
│ ├── test_noise_scheduler.py
│ ├── test_normalize_edge_index.py
│ ├── test_normalized_cut.py
│ ├── test_num_nodes.py
│ ├── test_one_hot.py
│ ├── test_ppr.py
│ ├── test_random.py
│ ├── test_repeat.py
│ ├── test_scatter.py
│ ├── test_segment.py
│ ├── test_select.py
│ ├── test_smiles.py
│ ├── test_softmax.py
│ ├── test_sort_edge_index.py
│ ├── test_sparse.py
│ ├── test_spmm.py
│ ├── test_subgraph.py
│ ├── test_to_dense_adj.py
│ ├── test_to_dense_batch.py
│ ├── test_total_influence.py
│ ├── test_train_test_split_edges.py
│ ├── test_tree_decomposition.py
│ ├── test_trim_to_layer.py
│ ├── test_unbatch.py
│ └── test_undirected.py
└── visualization
│ ├── test_graph_visualization.py
│ └── test_influence.py
└── torch_geometric
├── __init__.py
├── _compile.py
├── _onnx.py
├── backend.py
├── config_mixin.py
├── config_store.py
├── contrib
├── __init__.py
├── datasets
│ └── __init__.py
├── explain
│ ├── __init__.py
│ └── pgm_explainer.py
├── nn
│ ├── __init__.py
│ ├── conv
│ │ └── __init__.py
│ └── models
│ │ ├── __init__.py
│ │ └── rbcd_attack.py
└── transforms
│ └── __init__.py
├── data
├── __init__.py
├── batch.py
├── collate.py
├── data.py
├── database.py
├── datapipes.py
├── dataset.py
├── download.py
├── extract.py
├── feature_store.py
├── graph_store.py
├── hetero_data.py
├── hypergraph_data.py
├── in_memory_dataset.py
├── lightning
│ ├── __init__.py
│ └── datamodule.py
├── makedirs.py
├── on_disk_dataset.py
├── remote_backend_utils.py
├── separate.py
├── storage.py
├── summary.py
├── temporal.py
└── view.py
├── datasets
├── __init__.py
├── actor.py
├── airfrans.py
├── airports.py
├── amazon.py
├── amazon_book.py
├── amazon_products.py
├── aminer.py
├── aqsol.py
├── attributed_graph_dataset.py
├── ba2motif_dataset.py
├── ba_multi_shapes.py
├── ba_shapes.py
├── bitcoin_otc.py
├── brca_tgca.py
├── citation_full.py
├── city.py
├── coauthor.py
├── coma.py
├── cornell.py
├── dblp.py
├── dbp15k.py
├── deezer_europe.py
├── dgraph.py
├── dynamic_faust.py
├── elliptic.py
├── elliptic_temporal.py
├── email_eu_core.py
├── entities.py
├── explainer_dataset.py
├── facebook.py
├── fake.py
├── faust.py
├── flickr.py
├── freebase.py
├── gdelt.py
├── gdelt_lite.py
├── ged_dataset.py
├── gemsec.py
├── geometry.py
├── git_mol_dataset.py
├── github.py
├── gnn_benchmark_dataset.py
├── graph_generator
│ ├── __init__.py
│ ├── ba_graph.py
│ ├── base.py
│ ├── er_graph.py
│ ├── grid_graph.py
│ └── tree_graph.py
├── heterophilous_graph_dataset.py
├── hgb_dataset.py
├── hm.py
├── hydro_net.py
├── icews.py
├── igmc_dataset.py
├── imdb.py
├── infection_dataset.py
├── instruct_mol_dataset.py
├── jodie.py
├── karate.py
├── last_fm.py
├── lastfm_asia.py
├── linkx_dataset.py
├── lrgb.py
├── malnet_tiny.py
├── md17.py
├── medshapenet.py
├── mixhop_synthetic_dataset.py
├── mnist_superpixels.py
├── modelnet.py
├── molecule_gpt_dataset.py
├── molecule_net.py
├── motif_generator
│ ├── __init__.py
│ ├── base.py
│ ├── custom.py
│ ├── cycle.py
│ ├── grid.py
│ └── house.py
├── movie_lens.py
├── movie_lens_100k.py
├── movie_lens_1m.py
├── myket.py
├── nell.py
├── neurograph.py
├── ogb_mag.py
├── omdb.py
├── opf.py
├── ose_gvcs.py
├── particle.py
├── pascal.py
├── pascal_pf.py
├── pcpnet_dataset.py
├── pcqm4m.py
├── planetoid.py
├── polblogs.py
├── ppi.py
├── protein_mpnn_dataset.py
├── qm7.py
├── qm9.py
├── rcdd.py
├── reddit.py
├── reddit2.py
├── rel_link_pred_dataset.py
├── s3dis.py
├── sbm_dataset.py
├── shapenet.py
├── shrec2016.py
├── snap_dataset.py
├── suite_sparse.py
├── tag_dataset.py
├── taobao.py
├── teeth3ds.py
├── tosca.py
├── tu_dataset.py
├── twitch.py
├── upfd.py
├── utils
│ ├── __init__.py
│ └── cheatsheet.py
├── web_qsp_dataset.py
├── webkb.py
├── wikics.py
├── wikidata.py
├── wikipedia_network.py
├── willow_object_class.py
├── word_net.py
├── yelp.py
└── zinc.py
├── debug.py
├── deprecation.py
├── device.py
├── distributed
├── __init__.py
├── dist_context.py
├── dist_link_neighbor_loader.py
├── dist_loader.py
├── dist_neighbor_loader.py
├── dist_neighbor_sampler.py
├── event_loop.py
├── local_feature_store.py
├── local_graph_store.py
├── partition.py
├── rpc.py
└── utils.py
├── edge_index.py
├── experimental.py
├── explain
├── __init__.py
├── algorithm
│ ├── __init__.py
│ ├── attention_explainer.py
│ ├── base.py
│ ├── captum.py
│ ├── captum_explainer.py
│ ├── dummy_explainer.py
│ ├── gnn_explainer.py
│ ├── graphmask_explainer.py
│ ├── pg_explainer.py
│ └── utils.py
├── config.py
├── explainer.py
├── explanation.py
└── metric
│ ├── __init__.py
│ ├── basic.py
│ ├── faithfulness.py
│ └── fidelity.py
├── graphgym
├── __init__.py
├── benchmark.py
├── checkpoint.py
├── cmd_args.py
├── config.py
├── contrib
│ ├── __init__.py
│ ├── act
│ │ └── __init__.py
│ ├── config
│ │ └── __init__.py
│ ├── encoder
│ │ └── __init__.py
│ ├── head
│ │ └── __init__.py
│ ├── layer
│ │ ├── __init__.py
│ │ └── generalconv.py
│ ├── loader
│ │ └── __init__.py
│ ├── loss
│ │ └── __init__.py
│ ├── network
│ │ └── __init__.py
│ ├── optimizer
│ │ └── __init__.py
│ ├── pooling
│ │ └── __init__.py
│ ├── stage
│ │ └── __init__.py
│ ├── train
│ │ └── __init__.py
│ └── transform
│ │ └── __init__.py
├── imports.py
├── init.py
├── loader.py
├── logger.py
├── loss.py
├── model_builder.py
├── models
│ ├── __init__.py
│ ├── act.py
│ ├── encoder.py
│ ├── gnn.py
│ ├── head.py
│ ├── layer.py
│ ├── pooling.py
│ └── transform.py
├── optim.py
├── register.py
├── train.py
└── utils
│ ├── LICENSE
│ ├── __init__.py
│ ├── agg_runs.py
│ ├── comp_budget.py
│ ├── device.py
│ ├── epoch.py
│ ├── io.py
│ ├── plot.py
│ └── tools.py
├── hash_tensor.py
├── home.py
├── index.py
├── inspector.py
├── io
├── __init__.py
├── fs.py
├── npz.py
├── obj.py
├── off.py
├── planetoid.py
├── ply.py
├── sdf.py
├── tu.py
└── txt_array.py
├── isinstance.py
├── lazy_loader.py
├── llm
├── __init__.py
├── large_graph_indexer.py
├── models
│ ├── __init__.py
│ ├── g_retriever.py
│ ├── git_mol.py
│ ├── glem.py
│ ├── llm.py
│ ├── llm_judge.py
│ ├── molecule_gpt.py
│ ├── protein_mpnn.py
│ ├── sentence_transformer.py
│ ├── txt2kg.py
│ └── vision_transformer.py
├── rag_loader.py
└── utils
│ ├── __init__.py
│ ├── backend_utils.py
│ ├── feature_store.py
│ ├── graph_store.py
│ └── vectorrag.py
├── loader
├── __init__.py
├── base.py
├── cache.py
├── cluster.py
├── data_list_loader.py
├── dataloader.py
├── dense_data_loader.py
├── dynamic_batch_sampler.py
├── graph_saint.py
├── hgt_loader.py
├── ibmb_loader.py
├── imbalanced_sampler.py
├── link_loader.py
├── link_neighbor_loader.py
├── mixin.py
├── neighbor_loader.py
├── neighbor_sampler.py
├── node_loader.py
├── prefetch.py
├── random_node_loader.py
├── shadow.py
├── temporal_dataloader.py
├── utils.py
└── zip_loader.py
├── logging.py
├── metrics
├── __init__.py
└── link_pred.py
├── nn
├── __init__.py
├── aggr
│ ├── __init__.py
│ ├── attention.py
│ ├── base.py
│ ├── basic.py
│ ├── deep_sets.py
│ ├── equilibrium.py
│ ├── fused.py
│ ├── gmt.py
│ ├── gru.py
│ ├── lcm.py
│ ├── lstm.py
│ ├── mlp.py
│ ├── multi.py
│ ├── patch_transformer.py
│ ├── quantile.py
│ ├── scaler.py
│ ├── set2set.py
│ ├── set_transformer.py
│ ├── sort.py
│ ├── utils.py
│ └── variance_preserving.py
├── attention
│ ├── __init__.py
│ ├── performer.py
│ ├── polynormer.py
│ ├── qformer.py
│ └── sgformer.py
├── conv
│ ├── __init__.py
│ ├── agnn_conv.py
│ ├── antisymmetric_conv.py
│ ├── appnp.py
│ ├── arma_conv.py
│ ├── cg_conv.py
│ ├── cheb_conv.py
│ ├── cluster_gcn_conv.py
│ ├── collect.jinja
│ ├── cugraph
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── gat_conv.py
│ │ ├── rgcn_conv.py
│ │ └── sage_conv.py
│ ├── dir_gnn_conv.py
│ ├── dna_conv.py
│ ├── edge_conv.py
│ ├── edge_updater.jinja
│ ├── eg_conv.py
│ ├── fa_conv.py
│ ├── feast_conv.py
│ ├── film_conv.py
│ ├── fused_gat_conv.py
│ ├── gat_conv.py
│ ├── gated_graph_conv.py
│ ├── gatv2_conv.py
│ ├── gcn2_conv.py
│ ├── gcn_conv.py
│ ├── gen_conv.py
│ ├── general_conv.py
│ ├── gin_conv.py
│ ├── gmm_conv.py
│ ├── gps_conv.py
│ ├── graph_conv.py
│ ├── gravnet_conv.py
│ ├── han_conv.py
│ ├── heat_conv.py
│ ├── hetero_conv.py
│ ├── hgt_conv.py
│ ├── hypergraph_conv.py
│ ├── le_conv.py
│ ├── lg_conv.py
│ ├── meshcnn_conv.py
│ ├── message_passing.py
│ ├── mf_conv.py
│ ├── mixhop_conv.py
│ ├── nn_conv.py
│ ├── pan_conv.py
│ ├── pdn_conv.py
│ ├── pna_conv.py
│ ├── point_conv.py
│ ├── point_gnn_conv.py
│ ├── point_transformer_conv.py
│ ├── ppf_conv.py
│ ├── propagate.jinja
│ ├── res_gated_graph_conv.py
│ ├── rgat_conv.py
│ ├── rgcn_conv.py
│ ├── sage_conv.py
│ ├── sg_conv.py
│ ├── signed_conv.py
│ ├── simple_conv.py
│ ├── spline_conv.py
│ ├── ssg_conv.py
│ ├── supergat_conv.py
│ ├── tag_conv.py
│ ├── transformer_conv.py
│ ├── utils
│ │ ├── __init__.py
│ │ └── cheatsheet.py
│ ├── wl_conv.py
│ ├── wl_conv_continuous.py
│ └── x_conv.py
├── data_parallel.py
├── dense
│ ├── __init__.py
│ ├── dense_gat_conv.py
│ ├── dense_gcn_conv.py
│ ├── dense_gin_conv.py
│ ├── dense_graph_conv.py
│ ├── dense_sage_conv.py
│ ├── diff_pool.py
│ ├── dmon_pool.py
│ ├── linear.py
│ └── mincut_pool.py
├── encoding.py
├── functional
│ ├── __init__.py
│ ├── bro.py
│ └── gini.py
├── fx.py
├── glob.py
├── inits.py
├── kge
│ ├── __init__.py
│ ├── base.py
│ ├── complex.py
│ ├── distmult.py
│ ├── loader.py
│ ├── rotate.py
│ └── transe.py
├── lr_scheduler.py
├── model_hub.py
├── models
│ ├── __init__.py
│ ├── attentive_fp.py
│ ├── attract_repel.py
│ ├── autoencoder.py
│ ├── basic_gnn.py
│ ├── captum.py
│ ├── correct_and_smooth.py
│ ├── deep_graph_infomax.py
│ ├── deepgcn.py
│ ├── dimenet.py
│ ├── dimenet_utils.py
│ ├── gnnff.py
│ ├── gpse.py
│ ├── graph_mixer.py
│ ├── graph_unet.py
│ ├── jumping_knowledge.py
│ ├── label_prop.py
│ ├── lightgcn.py
│ ├── linkx.py
│ ├── lpformer.py
│ ├── mask_label.py
│ ├── meta.py
│ ├── metapath2vec.py
│ ├── mlp.py
│ ├── neural_fingerprint.py
│ ├── node2vec.py
│ ├── pmlp.py
│ ├── polynormer.py
│ ├── re_net.py
│ ├── rect.py
│ ├── rev_gnn.py
│ ├── schnet.py
│ ├── sgformer.py
│ ├── signed_gcn.py
│ ├── tgn.py
│ └── visnet.py
├── module_dict.py
├── norm
│ ├── __init__.py
│ ├── batch_norm.py
│ ├── diff_group_norm.py
│ ├── graph_norm.py
│ ├── graph_size_norm.py
│ ├── instance_norm.py
│ ├── layer_norm.py
│ ├── mean_subtraction_norm.py
│ ├── msg_norm.py
│ └── pair_norm.py
├── parameter_dict.py
├── pool
│ ├── __init__.py
│ ├── approx_knn.py
│ ├── asap.py
│ ├── avg_pool.py
│ ├── cluster_pool.py
│ ├── connect
│ │ ├── __init__.py
│ │ ├── base.py
│ │ └── filter_edges.py
│ ├── consecutive.py
│ ├── decimation.py
│ ├── edge_pool.py
│ ├── glob.py
│ ├── graclus.py
│ ├── knn.py
│ ├── max_pool.py
│ ├── mem_pool.py
│ ├── pan_pool.py
│ ├── pool.py
│ ├── sag_pool.py
│ ├── select
│ │ ├── __init__.py
│ │ ├── base.py
│ │ └── topk.py
│ ├── topk_pool.py
│ └── voxel_grid.py
├── reshape.py
├── resolver.py
├── sequential.jinja
├── sequential.py
├── summary.py
├── to_fixed_size_transformer.py
├── to_hetero_module.py
├── to_hetero_transformer.py
├── to_hetero_with_bases_transformer.py
└── unpool
│ ├── __init__.py
│ └── knn_interpolate.py
├── profile
├── __init__.py
├── benchmark.py
├── nvtx.py
├── profile.py
├── profiler.py
└── utils.py
├── resolver.py
├── sampler
├── __init__.py
├── base.py
├── hgt_sampler.py
├── neighbor_sampler.py
└── utils.py
├── seed.py
├── template.py
├── testing
├── __init__.py
├── asserts.py
├── data.py
├── decorators.py
├── distributed.py
├── feature_store.py
└── graph_store.py
├── transforms
├── __init__.py
├── add_gpse.py
├── add_metapaths.py
├── add_positional_encoding.py
├── add_remaining_self_loops.py
├── add_self_loops.py
├── base_transform.py
├── cartesian.py
├── center.py
├── compose.py
├── constant.py
├── delaunay.py
├── distance.py
├── face_to_edge.py
├── feature_propagation.py
├── fixed_points.py
├── gcn_norm.py
├── gdc.py
├── generate_mesh_normals.py
├── grid_sampling.py
├── half_hop.py
├── knn_graph.py
├── laplacian_lambda_max.py
├── largest_connected_components.py
├── line_graph.py
├── linear_transformation.py
├── local_cartesian.py
├── local_degree_profile.py
├── mask.py
├── node_property_split.py
├── normalize_features.py
├── normalize_rotation.py
├── normalize_scale.py
├── one_hot_degree.py
├── pad.py
├── point_pair_features.py
├── polar.py
├── radius_graph.py
├── random_flip.py
├── random_jitter.py
├── random_link_split.py
├── random_node_split.py
├── random_rotate.py
├── random_scale.py
├── random_shear.py
├── remove_duplicated_edges.py
├── remove_isolated_nodes.py
├── remove_self_loops.py
├── remove_training_classes.py
├── rooted_subgraph.py
├── sample_points.py
├── sign.py
├── spherical.py
├── svd_feature_reduction.py
├── target_indegree.py
├── to_dense.py
├── to_device.py
├── to_sparse_tensor.py
├── to_superpixels.py
├── to_undirected.py
├── two_hop.py
└── virtual_node.py
├── typing.py
├── utils
├── __init__.py
├── _assortativity.py
├── _coalesce.py
├── _degree.py
├── _grid.py
├── _homophily.py
├── _index_sort.py
├── _lexsort.py
├── _negative_sampling.py
├── _normalize_edge_index.py
├── _normalized_cut.py
├── _one_hot.py
├── _scatter.py
├── _segment.py
├── _select.py
├── _softmax.py
├── _sort_edge_index.py
├── _spmm.py
├── _subgraph.py
├── _to_dense_adj.py
├── _to_dense_batch.py
├── _train_test_split_edges.py
├── _tree_decomposition.py
├── _trim_to_layer.py
├── _unbatch.py
├── augmentation.py
├── convert.py
├── cross_entropy.py
├── dropout.py
├── embedding.py
├── functions.py
├── geodesic.py
├── hetero.py
├── influence.py
├── isolated.py
├── laplacian.py
├── loop.py
├── map.py
├── mask.py
├── mesh_laplacian.py
├── mixin.py
├── nested.py
├── noise_scheduler.py
├── num_nodes.py
├── ppr.py
├── random.py
├── repeat.py
├── smiles.py
├── sparse.py
└── undirected.py
├── visualization
├── __init__.py
├── graph.py
└── influence.py
└── warnings.py
/.github/CODEOWNERS:
--------------------------------------------------------------------------------
1 | # https://docs.github.com/en/repositories/managing-your-repositorys-settings-and-features/customizing-your-repository/about-code-owners
2 |
3 | * @rusty1s @akihironitta
4 |
5 | *.py @rusty1s @wsad1 @akihironitta
6 |
7 | /.github/ @rusty1s @akihironitta
8 |
9 | /.github/CODEOWNERS @rusty1s
10 |
11 | /torch_geometric/data/ @rusty1s @mananshah99 @akihironitta
12 |
13 | /torch_geometric/loader/ @rusty1s @mananshah99 @akihironitta
14 |
15 | /torch_geometric/sampler/ @rusty1s @mananshah99 @akihironitta
16 |
17 | /docs/ @rusty1s @akihironitta
18 |
19 | /torch_geometric/nn/conv/cugraph @tingyu66
20 |
21 | /examples/llm @puririshi98
22 |
23 | /torch_geometric/llm @puririshi98
24 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/config.yml:
--------------------------------------------------------------------------------
1 | blank_issues_enabled: false
2 | contact_links:
3 | - name: 🙏 Ask a Question
4 | url: https://github.com/pyg-team/pytorch_geometric/discussions/new
5 | about: Ask and answer PyG related questions
6 | - name: 💬 Slack
7 | url: https://join.slack.com/t/torchgeometricco/shared_invite/zt-p6br3yuo-BxRoe36OHHLF6jYU8xHtBA
8 | about: Chat with our community
9 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/documentation.yml:
--------------------------------------------------------------------------------
1 | name: "📚 Typos and Doc Fixes"
2 | description: "Tell us about how we can improve our documentation"
3 | labels: documentation
4 |
5 | body:
6 | - type: textarea
7 | attributes:
8 | label: 📚 Describe the documentation issue
9 | description: |
10 | A clear and concise description of the issue.
11 | validations:
12 | required: true
13 | - type: textarea
14 | attributes:
15 | label: Suggest a potential alternative/fix
16 | description: |
17 | Tell us how we could improve the documentation in this regard.
18 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/feature-request.yml:
--------------------------------------------------------------------------------
1 | name: "🚀 Feature Request"
2 | description: "Propose a new PyG feature"
3 | labels: feature
4 |
5 | body:
6 | - type: textarea
7 | attributes:
8 | label: 🚀 The feature, motivation and pitch
9 | description: >
10 | A clear and concise description of the feature proposal. Please outline the motivation for the proposal. Is your feature request related to a specific problem? e.g., *"I'm working on X and would like Y to be possible"*. If this is related to another GitHub issue, please link here too.
11 | validations:
12 | required: true
13 | - type: textarea
14 | attributes:
15 | label: Alternatives
16 | description: >
17 | A description of any alternative solutions or features you've considered, if any.
18 | - type: textarea
19 | attributes:
20 | label: Additional context
21 | description: >
22 | Add any other context or screenshots about the feature request.
23 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/refactor.yml:
--------------------------------------------------------------------------------
1 | name: "🛠 Refactor"
2 | description: "Suggest a code refactor or deprecation"
3 | labels: refactor
4 |
5 | body:
6 | - type: textarea
7 | attributes:
8 | label: 🛠 Proposed Refactor
9 | description: |
10 | A clear and concise description of the refactor proposal. Please outline the motivation for the proposal. If this is related to another GitHub issue, please link here too.
11 | validations:
12 | required: true
13 | - type: textarea
14 | attributes:
15 | label: Suggest a potential alternative/fix
16 | description: |
17 | Tell us how we could improve the code in this regard.
18 | validations:
19 | required: true
20 |
--------------------------------------------------------------------------------
/.github/dependabot.yml:
--------------------------------------------------------------------------------
1 | # https://docs.github.com/en/code-security/dependabot/working-with-dependabot/dependabot-options-reference
2 | version: 2
3 | updates:
4 | - package-ecosystem: "github-actions"
5 | directories:
6 | - "/"
7 | - "/.github/actions/setup"
8 | schedule:
9 | interval: "daily"
10 | time: "00:00"
11 | labels:
12 | - "ci"
13 | - "skip-changelog"
14 | pull-request-branch-name:
15 | separator: "-"
16 | open-pull-requests-limit: 10
17 |
--------------------------------------------------------------------------------
/.github/workflows/auto-merge.yml:
--------------------------------------------------------------------------------
1 | name: Auto-merge Bot PRs
2 |
3 | on: # yamllint disable-line rule:truthy
4 | pull_request_target:
5 | types: [opened, reopened]
6 |
7 | permissions:
8 | contents: write
9 | pull-requests: write
10 |
11 | jobs:
12 | auto-merge:
13 | runs-on: ubuntu-latest
14 | if: ${{ github.event.pull_request.user.login == 'dependabot[bot]' || github.event.pull_request.user.login == 'pre-commit-ci[bot]' }}
15 | steps:
16 | - uses: actions/checkout@v5
17 |
18 | - name: Label bot PRs
19 | run: gh pr edit --add-label "ci,skip-changelog" ${{ github.event.pull_request.html_url }}
20 | env:
21 | GITHUB_TOKEN: ${{ secrets.PAT }}
22 |
23 | - name: Auto-approve
24 | uses: hmarr/auto-approve-action@v4
25 | with:
26 | github-token: ${{ secrets.PAT }}
27 |
28 | - name: Enable auto-merge
29 | run: gh pr merge --auto --squash ${{ github.event.pull_request.html_url }}
30 | env:
31 | GITHUB_TOKEN: ${{ secrets.PAT }}
32 |
--------------------------------------------------------------------------------
/.github/workflows/changelog.yml:
--------------------------------------------------------------------------------
1 | name: Changelog Enforcer
2 |
3 | on: # yamllint disable-line rule:truthy
4 | pull_request:
5 | types: [opened, synchronize, reopened, ready_for_review, labeled, unlabeled]
6 |
7 | jobs:
8 |
9 | changelog:
10 | runs-on: ubuntu-latest
11 |
12 | steps:
13 | - name: Enforce changelog entry
14 | uses: dangoslen/changelog-enforcer@v3
15 | with:
16 | skipLabels: skip-changelog
17 |
--------------------------------------------------------------------------------
/.github/workflows/labeler.yml:
--------------------------------------------------------------------------------
1 | name: PR Labeler
2 |
3 | on: # yamllint disable-line rule:truthy
4 | pull_request:
5 |
6 | jobs:
7 |
8 | assign-labels:
9 | if: github.repository == 'pyg-team/pytorch_geometric'
10 | runs-on: ubuntu-latest
11 |
12 | permissions:
13 | contents: read
14 | pull-requests: write
15 |
16 | steps:
17 | - name: Add PR labels
18 | uses: actions/labeler@v6
19 | continue-on-error: true
20 | with:
21 | repo-token: "${{ secrets.GITHUB_TOKEN }}"
22 | sync-labels: true
23 |
24 | assign-author:
25 | if: github.repository == 'pyg-team/pytorch_geometric'
26 | runs-on: ubuntu-latest
27 |
28 | steps:
29 | - name: Add PR author
30 | uses: samspills/assign-pr-to-author@v1.0
31 | continue-on-error: true
32 | if: github.event_name == 'pull_request'
33 | with:
34 | repo-token: "${{ secrets.GITHUB_TOKEN }}"
35 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__/
2 | .pytest_cache/
3 | .DS_Store
4 | data/
5 | build/
6 | dist/
7 | alpha/
8 | runs/
9 | wandb/
10 | .cache/
11 | .eggs/
12 | lightning_logs/
13 | outputs/
14 | graphgym/datasets/
15 | graphgym/results/
16 | *.egg-info/
17 | .ipynb_checkpoints
18 | .coverage
19 | .coverage.*
20 | coverage.xml
21 | .vscode
22 | .idea
23 | .venv
24 | *.out
25 | *.pt
26 | *.onnx
27 | examples/**/*.png
28 | examples/**/*.pdf
29 | benchmark/results/
30 | .mypy_cache/
31 | uv.lock
32 |
33 | !torch_geometric/data/
34 | !test/data/
35 |
--------------------------------------------------------------------------------
/benchmark/README.md:
--------------------------------------------------------------------------------
1 | # PyG Benchmark Suite
2 |
3 | This benchmark suite provides evaluation scripts for **[semi-supervised node classification](https://github.com/pyg-team/pytorch_geometric/tree/master/benchmark/citation)**, **[graph classification](https://github.com/pyg-team/pytorch_geometric/tree/master/benchmark/kernel)**, and **[point cloud classification](https://github.com/pyg-team/pytorch_geometric/tree/master/benchmark/points)** and **[runtimes](https://github.com/pyg-team/pytorch_geometric/tree/master/benchmark/runtime)** in order to compare various methods in homogeneous evaluation scenarios.
4 | In particular, we take care to avoid to perform hyperparameter and model selection on the test set and instead use an additional validation set.
5 |
6 | ## Installation
7 |
8 | ```
9 | $ pip install -e .
10 | ```
11 |
--------------------------------------------------------------------------------
/benchmark/citation/__init__.py:
--------------------------------------------------------------------------------
1 | from .datasets import get_planetoid_dataset
2 | from .train_eval import random_planetoid_splits, run
3 |
4 | __all__ = [
5 | 'get_planetoid_dataset',
6 | 'random_planetoid_splits',
7 | 'run',
8 | ]
9 |
--------------------------------------------------------------------------------
/benchmark/citation/datasets.py:
--------------------------------------------------------------------------------
1 | import os.path as osp
2 |
3 | import torch_geometric.transforms as T
4 | from torch_geometric.datasets import Planetoid
5 |
6 |
7 | def get_planetoid_dataset(name, normalize_features=False, transform=None):
8 | path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', name)
9 | dataset = Planetoid(path, name)
10 |
11 | if transform is not None and normalize_features:
12 | dataset.transform = T.Compose([T.NormalizeFeatures(), transform])
13 | elif normalize_features:
14 | dataset.transform = T.NormalizeFeatures()
15 | elif transform is not None:
16 | dataset.transform = transform
17 |
18 | return dataset
19 |
--------------------------------------------------------------------------------
/benchmark/citation/statistics.py:
--------------------------------------------------------------------------------
1 | from citation import get_planetoid_dataset
2 |
3 |
4 | def print_dataset(dataset):
5 | data = dataset[0]
6 | print('Name', dataset)
7 | print('Nodes', data.num_nodes)
8 | print('Edges', data.num_edges // 2)
9 | print('Features', dataset.num_features)
10 | print('Classes', dataset.num_classes)
11 | print('Label rate', data.train_mask.sum().item() / data.num_nodes)
12 | print()
13 |
14 |
15 | for name in ['Cora', 'CiteSeer', 'PubMed']:
16 | print_dataset(get_planetoid_dataset(name))
17 |
--------------------------------------------------------------------------------
/benchmark/kernel/__init__.py:
--------------------------------------------------------------------------------
1 | from .datasets import get_dataset
2 | from .train_eval import cross_validation_with_val_set
3 |
4 | __all__ = [
5 | 'get_dataset',
6 | 'cross_validation_with_val_set',
7 | ]
8 |
--------------------------------------------------------------------------------
/benchmark/kernel/statistics.py:
--------------------------------------------------------------------------------
1 | from kernel.datasets import get_dataset
2 |
3 |
4 | def print_dataset(dataset):
5 | num_nodes = num_edges = 0
6 | for data in dataset:
7 | num_nodes += data.num_nodes
8 | num_edges += data.num_edges
9 |
10 | print('Name', dataset)
11 | print('Graphs', len(dataset))
12 | print('Nodes', num_nodes / len(dataset))
13 | print('Edges', (num_edges // 2) / len(dataset))
14 | print('Features', dataset.num_features)
15 | print('Classes', dataset.num_classes)
16 | print()
17 |
18 |
19 | for name in ['MUTAG', 'PROTEINS', 'COLLAB', 'IMDB-BINARY', 'REDDIT-BINARY']:
20 | print_dataset(get_dataset(name))
21 |
--------------------------------------------------------------------------------
/benchmark/points/README.md:
--------------------------------------------------------------------------------
1 | # Point Cloud classification
2 |
3 | Evaluation scripts for various methods on the ModelNet10 dataset:
4 |
5 | - **[MPNN](https://github.com/pyg-team/pytorch_geometric/blob/master/benchmark/points/mpnn.py)**: `python mpnn.py`
6 | - **[PointNet++](https://github.com/pyg-team/pytorch_geometric/blob/master/benchmark/points/point_net.py)**: `python point_net.py`
7 | - **[EdgeCNN](https://github.com/pyg-team/pytorch_geometric/blob/master/benchmark/points/edge_cnn.py)**: `python edge_cnn.py`
8 | - **[SplineCNN](https://github.com/pyg-team/pytorch_geometric/blob/master/benchmark/points/spline_cnn.py)**: `python spline_cnn.py`
9 | - **[PointCNN](https://github.com/pyg-team/pytorch_geometric/blob/master/benchmark/points/point_cnn.py)**: `python point_cnn.py`
10 |
--------------------------------------------------------------------------------
/benchmark/points/__init__.py:
--------------------------------------------------------------------------------
1 | from .datasets import get_dataset
2 | from .train_eval import run
3 |
4 | __all__ = [
5 | 'get_dataset',
6 | 'run',
7 | ]
8 |
--------------------------------------------------------------------------------
/benchmark/points/datasets.py:
--------------------------------------------------------------------------------
1 | import os.path as osp
2 |
3 | import torch_geometric.transforms as T
4 | from torch_geometric.datasets import ModelNet
5 |
6 |
7 | def get_dataset(num_points):
8 | name = 'ModelNet10'
9 | path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', name)
10 | pre_transform = T.NormalizeScale()
11 | transform = T.SamplePoints(num_points)
12 |
13 | train_dataset = ModelNet(path, name='10', train=True, transform=transform,
14 | pre_transform=pre_transform)
15 | test_dataset = ModelNet(path, name='10', train=False, transform=transform,
16 | pre_transform=pre_transform)
17 |
18 | return train_dataset, test_dataset
19 |
--------------------------------------------------------------------------------
/benchmark/points/statistics.py:
--------------------------------------------------------------------------------
1 | from points.datasets import get_dataset
2 |
3 | from torch_geometric.transforms import RadiusGraph
4 |
5 |
6 | def print_dataset(train_dataset, test_dataset):
7 | num_nodes = num_edges = 0
8 | for data in train_dataset:
9 | data = RadiusGraph(0.2)(data)
10 | num_nodes += data.num_nodes
11 | num_edges += data.num_edges
12 | for data in test_dataset:
13 | data = RadiusGraph(0.2)(data)
14 | num_nodes += data.num_nodes
15 | num_edges += data.num_edges
16 |
17 | num_graphs = len(train_dataset) + len(test_dataset)
18 | print('Graphs', num_graphs)
19 | print('Nodes', num_nodes / num_graphs)
20 | print('Edges', (num_edges // 2) / num_graphs)
21 | print('Label rate', len(train_dataset) / num_graphs)
22 | print()
23 |
24 |
25 | print_dataset(*get_dataset(num_points=1024))
26 |
--------------------------------------------------------------------------------
/benchmark/runtime/README.md:
--------------------------------------------------------------------------------
1 | # Runtimes
2 |
3 | Run the test suite for PyG via
4 |
5 | ```
6 | python main.py
7 | ```
8 |
9 | Install `dgl` and run the test suite for DGL via
10 |
11 | ```
12 | cd dgl
13 | python main.py
14 | ```
15 |
--------------------------------------------------------------------------------
/benchmark/runtime/__init__.py:
--------------------------------------------------------------------------------
1 | from .train import train_runtime
2 |
3 | __all__ = [
4 | 'train_runtime',
5 | ]
6 |
--------------------------------------------------------------------------------
/benchmark/runtime/dgl/hidden.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import warnings
4 |
5 | warnings.filterwarnings('ignore')
6 |
7 |
8 | class HiddenPrint:
9 | def __enter__(self):
10 | self._original_stdout = sys.stdout
11 | sys.stdout = open(os.devnull, 'w')
12 |
13 | def __exit__(self, exc_type, exc_val, exc_tb):
14 | sys.stdout.close()
15 | sys.stdout = self._original_stdout
16 |
--------------------------------------------------------------------------------
/benchmark/runtime/gat.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 | from torch_geometric.nn import GATConv
5 |
6 |
7 | class GAT(torch.nn.Module):
8 | def __init__(self, in_channels, out_channels):
9 | super().__init__()
10 | self.conv1 = GATConv(in_channels, 8, heads=8, dropout=0.6)
11 | self.conv2 = GATConv(8 * 8, out_channels, dropout=0.6)
12 |
13 | def forward(self, data):
14 | x, edge_index = data.x, data.edge_index
15 | x = F.dropout(x, p=0.6, training=self.training)
16 | x = F.elu(self.conv1(x, edge_index))
17 | x = F.dropout(x, p=0.6, training=self.training)
18 | x = self.conv2(x, edge_index)
19 | return F.log_softmax(x, dim=1)
20 |
--------------------------------------------------------------------------------
/benchmark/runtime/gcn.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 | from torch_geometric.nn import GCNConv
5 |
6 |
7 | class GCN(torch.nn.Module):
8 | def __init__(self, in_channels, out_channels):
9 | super().__init__()
10 | self.conv1 = GCNConv(in_channels, 16, cached=True)
11 | self.conv2 = GCNConv(16, out_channels, cached=True)
12 |
13 | def forward(self, data):
14 | x, edge_index = data.x, data.edge_index
15 | x = F.relu(self.conv1(x, edge_index))
16 | x = F.dropout(x, training=self.training)
17 | x = self.conv2(x, edge_index)
18 | return F.log_softmax(x, dim=1)
19 |
--------------------------------------------------------------------------------
/benchmark/runtime/rgcn.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 | from torch_geometric.nn import FastRGCNConv
5 |
6 |
7 | class RGCN(torch.nn.Module):
8 | def __init__(self, in_channels, out_channels, num_relations):
9 | super().__init__()
10 | self.conv1 = FastRGCNConv(in_channels, 16, num_relations, num_bases=30)
11 | self.conv2 = FastRGCNConv(16, out_channels, num_relations,
12 | num_bases=30)
13 |
14 | def forward(self, data):
15 | edge_index, edge_type = data.edge_index, data.edge_type
16 | x = F.relu(self.conv1(None, edge_index, edge_type))
17 | x = self.conv2(x, edge_index, edge_type)
18 | return F.log_softmax(x, dim=1)
19 |
--------------------------------------------------------------------------------
/benchmark/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import find_packages, setup
2 |
3 | setup(
4 | name='torch_geometric_benchmark',
5 | version='0.1.0',
6 | description='PyG Benchmark Suite',
7 | author='Matthias Fey',
8 | author_email='matthias.fey@tu-dortmund.de',
9 | url='https://github.com/pyg-team/pytorch_geometric_benchmark',
10 | install_requires=['scikit-learn'],
11 | packages=find_packages(),
12 | )
13 |
--------------------------------------------------------------------------------
/benchmark/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .utils import emit_itt
2 | from .utils import get_dataset, get_dataset_with_transformation
3 | from .utils import get_model
4 | from .utils import get_split_masks
5 | from .utils import save_benchmark_data, write_to_csv
6 | from .utils import test
7 |
8 | __all__ = [
9 | 'emit_itt',
10 | 'get_dataset',
11 | 'get_dataset_with_transformation',
12 | 'get_model',
13 | 'get_split_masks',
14 | 'save_benchmark_data',
15 | 'write_to_csv',
16 | 'test',
17 | ]
18 |
--------------------------------------------------------------------------------
/codecov.yml:
--------------------------------------------------------------------------------
1 | # See: https://docs.codecov.io/docs/codecov-yaml
2 | coverage:
3 | range: 80..100
4 | round: down
5 | precision: 2
6 | status:
7 | project:
8 | default:
9 | target: 80%
10 | threshold: 1%
11 | patch:
12 | default:
13 | target: 80%
14 | threshold: 1%
15 |
--------------------------------------------------------------------------------
/docker/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM nvcr.io/nvidia/cuda-dl-base:24.09-cuda12.6-devel-ubuntu22.04
2 |
3 | # Based on NGC PyG 24.09 image:
4 | # https://docs.nvidia.com/deeplearning/frameworks/pyg-release-notes/rel-24-09.html#rel-24-09
5 |
6 | # install pip
7 | RUN apt-get update && apt-get install -y python3-pip
8 |
9 | # install PyTorch - latest stable version
10 | RUN pip install torch torchvision torchaudio
11 |
12 | # install graphviz - latest stable version
13 | RUN apt-get install -y graphviz graphviz-dev
14 | RUN pip install pygraphviz
15 |
16 | # install python packages with NGC PyG 24.09 image versions
17 | RUN pip install torch_geometric==2.6.0
18 | RUN pip install triton==3.0.0 numba==0.59.0 requests==2.32.3 opencv-python==4.7.0.72 scipy==1.14.0 jupyterlab==4.2.5
19 |
20 | # install cugraph
21 | RUN pip install cugraph-cu12 cugraph-pyg-cu12 --extra-index-url=https://pypi.nvidia.com
22 |
--------------------------------------------------------------------------------
/docs/Makefile:
--------------------------------------------------------------------------------
1 | SPHINXBUILD = sphinx-build
2 | SPHINXPROJ = pytorch_geometric
3 | SOURCEDIR = source
4 | BUILDDIR = build
5 |
6 | .PHONY: help Makefile
7 |
8 | %: Makefile
9 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(0)
10 |
--------------------------------------------------------------------------------
/docs/README.md:
--------------------------------------------------------------------------------
1 | # Building Documentation
2 |
3 | To build the documentation:
4 |
5 | 1. [Build and install](https://github.com/pyg-team/pytorch_geometric/blob/master/.github/CONTRIBUTING.md#developing-pytorch-geometric) PyG from source.
6 | 1. Install [Sphinx](https://www.sphinx-doc.org/en/master/) theme via
7 | ```
8 | pip install git+https://github.com/pyg-team/pyg_sphinx_theme.git
9 | ```
10 | 1. Generate the documentation file via:
11 | ```
12 | cd docs
13 | make html
14 | ```
15 |
16 | The documentation is now available to view by opening `docs/build/html/index.html`.
17 |
--------------------------------------------------------------------------------
/docs/requirements.txt:
--------------------------------------------------------------------------------
1 | https://download.pytorch.org/whl/cpu/torch-2.8.0%2Bcpu-cp310-cp310-manylinux_2_28_x86_64.whl
2 | numpy>=1.19.5
3 | git+https://github.com/pyg-team/pyg_sphinx_theme.git
4 |
--------------------------------------------------------------------------------
/docs/source/.gitignore:
--------------------------------------------------------------------------------
1 | generated/
2 |
--------------------------------------------------------------------------------
/docs/source/_figures/.gitignore:
--------------------------------------------------------------------------------
1 | *.aux
2 | *.log
3 | *.pdf
4 |
--------------------------------------------------------------------------------
/docs/source/_figures/architecture.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pyg-team/pytorch_geometric/6c22779538bba0cc6650110f8edcfcd05da5acb4/docs/source/_figures/architecture.pdf
--------------------------------------------------------------------------------
/docs/source/_figures/build.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 |
3 | for filename in *.tex; do
4 | basename=$(basename $filename .tex)
5 | pdflatex "$basename.tex"
6 | pdf2svg "$basename.pdf" "$basename.svg"
7 | done
8 |
--------------------------------------------------------------------------------
/docs/source/_figures/dist_part.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pyg-team/pytorch_geometric/6c22779538bba0cc6650110f8edcfcd05da5acb4/docs/source/_figures/dist_part.png
--------------------------------------------------------------------------------
/docs/source/_figures/dist_proc.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pyg-team/pytorch_geometric/6c22779538bba0cc6650110f8edcfcd05da5acb4/docs/source/_figures/dist_proc.png
--------------------------------------------------------------------------------
/docs/source/_figures/dist_sampling.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pyg-team/pytorch_geometric/6c22779538bba0cc6650110f8edcfcd05da5acb4/docs/source/_figures/dist_sampling.png
--------------------------------------------------------------------------------
/docs/source/_figures/graph.tex:
--------------------------------------------------------------------------------
1 | \documentclass{standalone}
2 |
3 | \usepackage{tikz}
4 |
5 | \begin{document}
6 |
7 | \begin{tikzpicture}
8 | \node[draw,circle,label= left:{$x_1=-1$}] (0) at (0, 0) {0};
9 | \node[draw,circle,label=above:{$x_1=0$}] (1) at (1, 1) {1};
10 | \node[draw,circle,label=right:{$x_1=1$}] (2) at (2, 0) {2};
11 |
12 | \path[draw] (0) -- (1);
13 | \path[draw] (1) -- (2);
14 | \end{tikzpicture}
15 |
16 | \end{document}
17 |
--------------------------------------------------------------------------------
/docs/source/_figures/graphgps_layer.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pyg-team/pytorch_geometric/6c22779538bba0cc6650110f8edcfcd05da5acb4/docs/source/_figures/graphgps_layer.png
--------------------------------------------------------------------------------
/docs/source/_figures/graphgym_design_space.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pyg-team/pytorch_geometric/6c22779538bba0cc6650110f8edcfcd05da5acb4/docs/source/_figures/graphgym_design_space.png
--------------------------------------------------------------------------------
/docs/source/_figures/graphgym_evaluation.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pyg-team/pytorch_geometric/6c22779538bba0cc6650110f8edcfcd05da5acb4/docs/source/_figures/graphgym_evaluation.png
--------------------------------------------------------------------------------
/docs/source/_figures/graphgym_results.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pyg-team/pytorch_geometric/6c22779538bba0cc6650110f8edcfcd05da5acb4/docs/source/_figures/graphgym_results.png
--------------------------------------------------------------------------------
/docs/source/_figures/hg_example.tex:
--------------------------------------------------------------------------------
1 | \documentclass{standalone}
2 |
3 | \usepackage{tikz}
4 |
5 | \begin{document}
6 |
7 | \begin{tikzpicture}
8 | \node[draw,rectangle, align=center] (0) at (0, 0) {\textbf{Author}\\ $1,134,649$ nodes};
9 | \node[draw,rectangle, align=center] (1) at (4, 2) {\textbf{Paper}\\ $736,389$ nodes};
10 | \node[draw,rectangle, align=center] (2) at (8, 0) {\textbf{Institution}\\ $8,740$ nodes};
11 | \node[draw,rectangle, align=center] (3) at (4, 4) {\textbf{Field of Study}\\ $59,965$ nodes};
12 |
13 | \path[->,>=stealth] (0) edge [above left] node[align=center] {\textbf{writes}\\$7,145,660$ edges} (1.south);
14 | \path[->,>=stealth] (0) edge [below] node[align=center] {\textbf{affiliated with}\\$1,043,998$ edges} (2);
15 | \path[->,>=stealth,every loop/.style={looseness=3}] (1) edge [out=350, in=10, loop, right] node[align=center] {\textbf{cites}\\$5,416,271$ edges} (1);
16 | \path[->,>=stealth] (1) edge [left] node[align=center] {\textbf{has topic}\\$7,505,078$ edges} (3);
17 | \end{tikzpicture}
18 |
19 | \end{document}
20 |
--------------------------------------------------------------------------------
/docs/source/_figures/intel_kumo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pyg-team/pytorch_geometric/6c22779538bba0cc6650110f8edcfcd05da5acb4/docs/source/_figures/intel_kumo.png
--------------------------------------------------------------------------------
/docs/source/_figures/point_cloud1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pyg-team/pytorch_geometric/6c22779538bba0cc6650110f8edcfcd05da5acb4/docs/source/_figures/point_cloud1.png
--------------------------------------------------------------------------------
/docs/source/_figures/point_cloud2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pyg-team/pytorch_geometric/6c22779538bba0cc6650110f8edcfcd05da5acb4/docs/source/_figures/point_cloud2.png
--------------------------------------------------------------------------------
/docs/source/_figures/point_cloud3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pyg-team/pytorch_geometric/6c22779538bba0cc6650110f8edcfcd05da5acb4/docs/source/_figures/point_cloud3.png
--------------------------------------------------------------------------------
/docs/source/_figures/point_cloud4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pyg-team/pytorch_geometric/6c22779538bba0cc6650110f8edcfcd05da5acb4/docs/source/_figures/point_cloud4.png
--------------------------------------------------------------------------------
/docs/source/_figures/remote_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pyg-team/pytorch_geometric/6c22779538bba0cc6650110f8edcfcd05da5acb4/docs/source/_figures/remote_1.png
--------------------------------------------------------------------------------
/docs/source/_figures/remote_2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pyg-team/pytorch_geometric/6c22779538bba0cc6650110f8edcfcd05da5acb4/docs/source/_figures/remote_2.png
--------------------------------------------------------------------------------
/docs/source/_figures/remote_3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pyg-team/pytorch_geometric/6c22779538bba0cc6650110f8edcfcd05da5acb4/docs/source/_figures/remote_3.png
--------------------------------------------------------------------------------
/docs/source/_figures/shallow_node_embeddings.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pyg-team/pytorch_geometric/6c22779538bba0cc6650110f8edcfcd05da5acb4/docs/source/_figures/shallow_node_embeddings.png
--------------------------------------------------------------------------------
/docs/source/_figures/training_affinity.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pyg-team/pytorch_geometric/6c22779538bba0cc6650110f8edcfcd05da5acb4/docs/source/_figures/training_affinity.png
--------------------------------------------------------------------------------
/docs/source/_static/js/version_alert.js:
--------------------------------------------------------------------------------
1 | function warnOnLatestVersion() {
2 | if (!window.READTHEDOCS_DATA || window.READTHEDOCS_DATA.version !== "latest") {
3 | return; // not on ReadTheDocs and not latest.
4 | }
5 |
6 | var note = document.createElement('div');
7 | note.setAttribute('class', 'admonition note');
8 | note.innerHTML = "
Note
" +
9 | " " +
10 | "This documentation is for an unreleased development version. " +
11 | "Click here to access the documentation of the current stable release." +
12 | "
";
13 |
14 | var parent = document.querySelector('#pyg-documentation');
15 | if (parent)
16 | parent.insertBefore(note, parent.querySelector('h1'));
17 | }
18 |
19 | document.addEventListener('DOMContentLoaded', warnOnLatestVersion);
20 |
--------------------------------------------------------------------------------
/docs/source/_static/thumbnails/create_dataset.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pyg-team/pytorch_geometric/6c22779538bba0cc6650110f8edcfcd05da5acb4/docs/source/_static/thumbnails/create_dataset.png
--------------------------------------------------------------------------------
/docs/source/_static/thumbnails/create_gnn.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pyg-team/pytorch_geometric/6c22779538bba0cc6650110f8edcfcd05da5acb4/docs/source/_static/thumbnails/create_gnn.png
--------------------------------------------------------------------------------
/docs/source/_static/thumbnails/dataset_splitting.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pyg-team/pytorch_geometric/6c22779538bba0cc6650110f8edcfcd05da5acb4/docs/source/_static/thumbnails/dataset_splitting.png
--------------------------------------------------------------------------------
/docs/source/_static/thumbnails/distributed_pyg.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pyg-team/pytorch_geometric/6c22779538bba0cc6650110f8edcfcd05da5acb4/docs/source/_static/thumbnails/distributed_pyg.png
--------------------------------------------------------------------------------
/docs/source/_static/thumbnails/explain.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pyg-team/pytorch_geometric/6c22779538bba0cc6650110f8edcfcd05da5acb4/docs/source/_static/thumbnails/explain.png
--------------------------------------------------------------------------------
/docs/source/_static/thumbnails/graph_transformer.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pyg-team/pytorch_geometric/6c22779538bba0cc6650110f8edcfcd05da5acb4/docs/source/_static/thumbnails/graph_transformer.png
--------------------------------------------------------------------------------
/docs/source/_static/thumbnails/heterogeneous.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pyg-team/pytorch_geometric/6c22779538bba0cc6650110f8edcfcd05da5acb4/docs/source/_static/thumbnails/heterogeneous.png
--------------------------------------------------------------------------------
/docs/source/_static/thumbnails/load_csv.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pyg-team/pytorch_geometric/6c22779538bba0cc6650110f8edcfcd05da5acb4/docs/source/_static/thumbnails/load_csv.png
--------------------------------------------------------------------------------
/docs/source/_static/thumbnails/multi_gpu_vanilla.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pyg-team/pytorch_geometric/6c22779538bba0cc6650110f8edcfcd05da5acb4/docs/source/_static/thumbnails/multi_gpu_vanilla.png
--------------------------------------------------------------------------------
/docs/source/_static/thumbnails/neighbor_loader.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pyg-team/pytorch_geometric/6c22779538bba0cc6650110f8edcfcd05da5acb4/docs/source/_static/thumbnails/neighbor_loader.png
--------------------------------------------------------------------------------
/docs/source/_static/thumbnails/point_cloud.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pyg-team/pytorch_geometric/6c22779538bba0cc6650110f8edcfcd05da5acb4/docs/source/_static/thumbnails/point_cloud.png
--------------------------------------------------------------------------------
/docs/source/_static/thumbnails/shallow_node_embeddings.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pyg-team/pytorch_geometric/6c22779538bba0cc6650110f8edcfcd05da5acb4/docs/source/_static/thumbnails/shallow_node_embeddings.png
--------------------------------------------------------------------------------
/docs/source/_templates/autosummary/class.rst:
--------------------------------------------------------------------------------
1 | {{ fullname | escape | underline}}
2 |
3 | .. currentmodule:: {{ module }}
4 |
5 | .. autoclass:: {{ objname }}
6 | :show-inheritance:
7 | :members:
8 |
--------------------------------------------------------------------------------
/docs/source/_templates/autosummary/inherited_class.rst:
--------------------------------------------------------------------------------
1 | {{ fullname | escape | underline}}
2 |
3 | .. currentmodule:: {{ module }}
4 |
5 | .. autoclass:: {{ objname }}
6 | :show-inheritance:
7 | :members:
8 | :inherited-members:
9 | :special-members: __cat_dim__, __inc__
10 |
--------------------------------------------------------------------------------
/docs/source/_templates/autosummary/metrics.rst:
--------------------------------------------------------------------------------
1 | {{ fullname | escape | underline}}
2 |
3 | .. currentmodule:: {{ module }}
4 |
5 | .. autoclass:: {{ objname }}
6 | :show-inheritance:
7 | :members: update, compute, reset
8 |
--------------------------------------------------------------------------------
/docs/source/_templates/autosummary/nn.rst:
--------------------------------------------------------------------------------
1 | {{ fullname | escape | underline}}
2 |
3 | .. currentmodule:: {{ module }}
4 |
5 | {% if objname != "MessagePassing" %}
6 | .. autoclass:: {{ objname }}
7 | :show-inheritance:
8 | :members:
9 | :exclude-members: forward, reset_parameters, message, message_and_aggregate, edge_update, aggregate, update
10 |
11 | .. automethod:: forward
12 | .. automethod:: reset_parameters
13 | {% else %}
14 | .. autoclass:: {{ objname }}
15 | :show-inheritance:
16 | :members:
17 | {% endif %}
18 |
--------------------------------------------------------------------------------
/docs/source/_templates/autosummary/only_class.rst:
--------------------------------------------------------------------------------
1 | {{ fullname | escape | underline}}
2 |
3 | .. currentmodule:: {{ module }}
4 |
5 | .. autoclass:: {{ objname }}
6 | :show-inheritance:
7 |
--------------------------------------------------------------------------------
/docs/source/modules/distributed.rst:
--------------------------------------------------------------------------------
1 | torch_geometric.distributed
2 | ===========================
3 |
4 | .. warning::
5 | ``torch_geometric.distributed`` has been deprecated since 2.7.0 and will
6 | no longer be maintained. For distributed training, refer to :ref:`our
7 | tutorials on distributed training ` or `cuGraph
8 | examples `_.
9 |
10 | .. currentmodule:: torch_geometric.distributed
11 |
12 | .. autosummary::
13 | :nosignatures:
14 | {% for cls in torch_geometric.distributed.classes %}
15 | {{ cls }}
16 | {% endfor %}
17 |
18 | .. automodule:: torch_geometric.distributed
19 | :members:
20 |
--------------------------------------------------------------------------------
/docs/source/modules/llm.rst:
--------------------------------------------------------------------------------
1 | torch_geometric.llm
2 | =======================
3 |
4 | .. currentmodule:: torch_geometric.llm
5 |
6 | .. autosummary::
7 | :nosignatures:
8 | {% for cls in torch_geometric.llm.classes %}
9 | {{ cls }}
10 | {% endfor %}
11 |
12 | .. automodule:: torch_geometric.llm
13 | :members:
14 |
15 |
16 | Models
17 | ----------------
18 |
19 | .. currentmodule:: torch_geometric.llm.models
20 |
21 | .. autosummary::
22 | :nosignatures:
23 | :toctree: ../generated
24 |
25 | {% for name in torch_geometric.llm.models.classes %}
26 | {{ name }}
27 | {% endfor %}
28 |
29 | Utils
30 | ----------------
31 |
32 | .. currentmodule:: torch_geometric.llm.utils
33 |
34 | .. autosummary::
35 | :nosignatures:
36 | :toctree: ../generated
37 |
38 | {% for name in torch_geometric.llm.utils.classes %}
39 | {{ name }}
40 | {% endfor %}
41 |
--------------------------------------------------------------------------------
/docs/source/modules/loader.rst:
--------------------------------------------------------------------------------
1 | torch_geometric.loader
2 | ======================
3 |
4 | .. currentmodule:: torch_geometric.loader
5 |
6 | .. autosummary::
7 | :nosignatures:
8 | {% for cls in torch_geometric.loader.classes %}
9 | {{ cls }}
10 | {% endfor %}
11 |
12 | .. automodule:: torch_geometric.loader
13 | :members:
14 |
--------------------------------------------------------------------------------
/docs/source/modules/metrics.rst:
--------------------------------------------------------------------------------
1 | torch_geometric.metrics
2 | =======================
3 |
4 | .. contents:: Contents
5 | :local:
6 |
7 | Link Prediction Metrics
8 | -----------------------
9 |
10 | .. currentmodule:: torch_geometric.metrics
11 |
12 | .. autosummary::
13 | :nosignatures:
14 | :toctree: ../generated
15 | :template: autosummary/metrics.rst
16 |
17 | {% for name in torch_geometric.metrics.link_pred_metrics %}
18 | {{ name }}
19 | {% endfor %}
20 |
--------------------------------------------------------------------------------
/docs/source/modules/profile.rst:
--------------------------------------------------------------------------------
1 | torch_geometric.profile
2 | =======================
3 |
4 | .. currentmodule:: torch_geometric.profile
5 |
6 | .. autosummary::
7 | :nosignatures:
8 | {% for cls in torch_geometric.profile.classes %}
9 | {{ cls }}
10 | {% endfor %}
11 |
12 | .. automodule:: torch_geometric.profile
13 | :members:
14 | :undoc-members:
15 |
--------------------------------------------------------------------------------
/docs/source/modules/root.rst:
--------------------------------------------------------------------------------
1 | torch_geometric
2 | ===============
3 |
4 | Tensor Objects
5 | --------------
6 |
7 | .. currentmodule:: torch_geometric
8 |
9 | .. autosummary::
10 | :nosignatures:
11 | :toctree: ../generated
12 |
13 | Index
14 | EdgeIndex
15 | HashTensor
16 |
17 | Functions
18 | ---------
19 |
20 | .. automodule:: torch_geometric.seed
21 | :members:
22 |
23 | .. automodule:: torch_geometric.home
24 | :members:
25 |
26 | .. automodule:: torch_geometric._compile
27 | :members:
28 | :exclude-members: compile
29 |
30 | .. automodule:: torch_geometric.debug
31 | :members:
32 |
33 | .. automodule:: torch_geometric.experimental
34 | :members:
35 |
--------------------------------------------------------------------------------
/docs/source/modules/sampler.rst:
--------------------------------------------------------------------------------
1 | torch_geometric.sampler
2 | =======================
3 |
4 | .. currentmodule:: torch_geometric.sampler
5 |
6 | .. autosummary::
7 | :nosignatures:
8 | {% for cls in torch_geometric.sampler.classes %}
9 | {{ cls }}
10 | {% endfor %}
11 |
12 | .. autoclass:: torch_geometric.sampler.BaseSampler
13 | :members:
14 |
15 | .. automodule:: torch_geometric.sampler
16 | :members:
17 | :exclude-members: sample_from_nodes, sample_from_edges, edge_permutation, BaseSampler
18 |
--------------------------------------------------------------------------------
/docs/source/modules/utils.rst:
--------------------------------------------------------------------------------
1 | torch_geometric.utils
2 | =====================
3 |
4 | .. currentmodule:: torch_geometric.utils
5 |
6 | .. autosummary::
7 | :nosignatures:
8 | {% for cls in torch_geometric.utils.classes %}
9 | {{ cls }}
10 | {% endfor %}
11 |
12 | .. automodule:: torch_geometric.utils
13 | :members:
14 |
--------------------------------------------------------------------------------
/docs/source/notes/batching.rst:
--------------------------------------------------------------------------------
1 | :orphan:
2 |
3 | .. include:: ../advanced/batching.rst
4 |
--------------------------------------------------------------------------------
/docs/source/notes/colabs.rst:
--------------------------------------------------------------------------------
1 | :orphan:
2 |
3 | .. include:: ../get_started/colabs.rst
4 |
--------------------------------------------------------------------------------
/docs/source/notes/create_dataset.rst:
--------------------------------------------------------------------------------
1 | :orphan:
2 |
3 | .. include:: ../tutorial/create_dataset.rst
4 |
--------------------------------------------------------------------------------
/docs/source/notes/create_gnn.rst:
--------------------------------------------------------------------------------
1 | :orphan:
2 |
3 | .. include:: ../tutorial/create_gnn.rst
4 |
--------------------------------------------------------------------------------
/docs/source/notes/explain.rst:
--------------------------------------------------------------------------------
1 | :orphan:
2 |
3 | .. include:: ../tutorial/explain.rst
4 |
--------------------------------------------------------------------------------
/docs/source/notes/graphgym.rst:
--------------------------------------------------------------------------------
1 | :orphan:
2 |
3 | .. include:: ../advanced/graphgym.rst
4 |
--------------------------------------------------------------------------------
/docs/source/notes/heterogeneous.rst:
--------------------------------------------------------------------------------
1 | :orphan:
2 |
3 | .. include:: ../tutorial/heterogeneous.rst
4 |
--------------------------------------------------------------------------------
/docs/source/notes/installation.rst:
--------------------------------------------------------------------------------
1 | :orphan:
2 |
3 | .. include:: ../install/installation.rst
4 |
--------------------------------------------------------------------------------
/docs/source/notes/introduction.rst:
--------------------------------------------------------------------------------
1 | :orphan:
2 |
3 | .. include:: ../get_started/introduction.rst
4 |
--------------------------------------------------------------------------------
/docs/source/notes/jit.rst:
--------------------------------------------------------------------------------
1 | :orphan:
2 |
3 | .. include:: ../advanced/jit.rst
4 |
--------------------------------------------------------------------------------
/docs/source/notes/load_csv.rst:
--------------------------------------------------------------------------------
1 | :orphan:
2 |
3 | .. include:: ../tutorial/load_csv.rst
4 |
--------------------------------------------------------------------------------
/docs/source/notes/remote.rst:
--------------------------------------------------------------------------------
1 | :orphan:
2 |
3 | .. include:: ../advanced/remote.rst
4 |
--------------------------------------------------------------------------------
/docs/source/notes/resources.rst:
--------------------------------------------------------------------------------
1 | :orphan:
2 |
3 | .. include:: ../external/resources.rst
4 |
--------------------------------------------------------------------------------
/docs/source/notes/sparse_tensor.rst:
--------------------------------------------------------------------------------
1 | :orphan:
2 |
3 | .. include:: ../advanced/sparse_tensor.rst
4 |
--------------------------------------------------------------------------------
/docs/source/tutorial/application.rst:
--------------------------------------------------------------------------------
1 | Use-Cases & Applications
2 | ========================
3 |
4 | .. nbgallery::
5 | :name: rst-gallery
6 |
7 | neighbor_loader
8 | point_cloud
9 | explain
10 | shallow_node_embeddings
11 | graph_transformer
12 |
--------------------------------------------------------------------------------
/docs/source/tutorial/compile.rst:
--------------------------------------------------------------------------------
1 | :orphan:
2 |
3 | .. include:: ../advanced/compile.rst
4 |
--------------------------------------------------------------------------------
/docs/source/tutorial/dataset.rst:
--------------------------------------------------------------------------------
1 | Working with Graph Datasets
2 | ===========================
3 |
4 | .. nbgallery::
5 | :name: rst-gallery
6 |
7 | create_dataset
8 | load_csv
9 | dataset_splitting
10 |
--------------------------------------------------------------------------------
/docs/source/tutorial/distributed.rst:
--------------------------------------------------------------------------------
1 | .. _distributed_tutorials:
2 |
3 | Distributed Training
4 | ====================
5 |
6 | .. nbgallery::
7 | :name: rst-gallery
8 |
9 | multi_gpu_vanilla
10 | multi_node_multi_gpu_vanilla
11 | distributed_pyg
12 |
--------------------------------------------------------------------------------
/docs/source/tutorial/gnn_design.rst:
--------------------------------------------------------------------------------
1 | Design of Graph Neural Networks
2 | ===============================
3 |
4 | .. nbgallery::
5 | :name: rst-gallery
6 |
7 | create_gnn
8 | heterogeneous
9 |
--------------------------------------------------------------------------------
/examples/cpp/CMakeLists.txt:
--------------------------------------------------------------------------------
1 | cmake_minimum_required(VERSION 3.10)
2 | project(hello-world)
3 |
4 | # The first thing do is to tell cmake to find the TorchScatter
5 | # and TorchSparse libraries. The package pulls in all the necessary
6 | # torch libraries, so there is no need to add `find_package(Torch)`.
7 | find_package(TorchScatter REQUIRED)
8 | find_package(TorchSparse REQUIRED)
9 |
10 | find_package(Python3 COMPONENTS Development)
11 |
12 | add_executable(hello-world main.cpp)
13 |
14 | # We now need to link the TorchScatter and TorchSparse libraries
15 | # to our executable. We can do that by using the
16 | # TorchScatter::TorchScatter and TorchSparse::TorchSparse targets,
17 | # which also adds all the necessary torch dependencies.
18 | target_compile_features(hello-world PUBLIC cxx_range_for)
19 | target_link_libraries(hello-world TorchScatter::TorchScatter)
20 | target_link_libraries(hello-world TorchSparse::TorchSparse)
21 | target_link_libraries(hello-world ${CUDA_cusparse_LIBRARY})
22 | set_property(TARGET hello-world PROPERTY CXX_STANDARD 14)
23 |
--------------------------------------------------------------------------------
/examples/cpp/main.cpp:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | #include
4 |
5 | #include
6 |
7 | int main(int argc, const char *argv[]) {
8 | if (argc != 2) {
9 | std::cerr << "usage: hello-world \n";
10 | return -1;
11 | }
12 |
13 | torch::jit::script::Module model;
14 | try {
15 | model = torch::jit::load(argv[1]);
16 | } catch (const c10::Error &e) {
17 | std::cerr << "error loading the model\n";
18 | return -1;
19 | }
20 |
21 | auto x = torch::randn({5, 32});
22 | auto edge_index = torch::tensor({
23 | {0, 1, 1, 2, 2, 3, 3, 4},
24 | {1, 0, 2, 1, 3, 2, 4, 3},
25 | });
26 |
27 | std::vector inputs;
28 | inputs.push_back(x);
29 | inputs.push_back(edge_index);
30 |
31 | auto out = model.forward(inputs).toTensor();
32 | std::cout << "output tensor shape: " << out.sizes() << std::endl;
33 | }
34 |
--------------------------------------------------------------------------------
/examples/distributed/README.md:
--------------------------------------------------------------------------------
1 | # Examples for Distributed Graph Learning
2 |
3 | This directory contains examples for distributed graph learning.
4 | The examples are organized into two subdirectories:
5 |
6 | 1. [`pyg`](./pyg): Distributed training via PyG's own `torch_geometric.distributed` package (deprecated).
7 | 1. [`graphlearn_for_pytorch`](./graphlearn_for_pytorch): Distributed training via the external [GraphLearn-for-PyTorch (GLT)](https://github.com/alibaba/graphlearn-for-pytorch) package.
8 | 1. [`kuzu`](./kuzu): Remote backend via the [Kùzu](https://kuzudb.com/) graph database.
9 |
--------------------------------------------------------------------------------
/examples/jit/README.md:
--------------------------------------------------------------------------------
1 | # JIT Examples
2 |
3 | This directory contains examples demonstrating the use of Just-In-Time (JIT) compilation in different GNN models.
4 |
5 | | Example | Description |
6 | | ---------------------- | ----------------------------------------------------------------- |
7 | | [`gcn.py`](./gcn.py) | JIT compilation in `GCN` |
8 | | [`gat.py`](./gat.py) | JIT compilation in `GAT` |
9 | | [`gin.py`](./gin.py) | JIT compilation in `GIN` |
10 | | [`film.py`](./film.py) | JIT compilation in [`GNN-FiLM`](https://arxiv.org/abs/1906.12192) |
11 |
--------------------------------------------------------------------------------
/examples/llm/nvtx_examples/nvtx_run.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 |
3 | # Check if the user provided a Python file
4 | if [ -z "$1" ]; then
5 | echo "Usage: $0 "
6 | exit 1
7 | fi
8 |
9 | # Check if the provided file exists
10 | if [[ ! -f "$1" ]]; then
11 | echo "Error: File '$1' does not exist."
12 | exit 1
13 | fi
14 |
15 | # Check if the provided file is a Python file
16 | if [[ ! "$1" == *.py ]]; then
17 | echo "Error: '$1' is not a Python file."
18 | exit 1
19 | fi
20 |
21 | # Get the base name of the Python file
22 | python_file=$(basename "$1")
23 |
24 | # Run nsys profile on the Python file
25 | nsys profile -c cudaProfilerApi --capture-range-end repeat -t cuda,nvtx,osrt,cudnn,cublas --cuda-memory-usage true --cudabacktrace all --force-overwrite true --output=profile_${python_file%.py} python "$1"
26 |
27 | echo "Profile data saved as profile_${python_file%.py}.nsys-rep"
28 |
--------------------------------------------------------------------------------
/examples/llm/nvtx_examples/nvtx_webqsp_example.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import torch
4 |
5 | from torch_geometric.datasets import web_qsp_dataset
6 | from torch_geometric.profile import nvtxit
7 |
8 | # Apply Patches
9 | web_qsp_dataset.retrieval_via_pcst = nvtxit()(
10 | web_qsp_dataset.retrieval_via_pcst)
11 | web_qsp_dataset.WebQSPDataset.process = nvtxit()(
12 | web_qsp_dataset.WebQSPDataset.process)
13 |
14 | if __name__ == "__main__":
15 | parser = argparse.ArgumentParser()
16 | parser.add_argument("--capture-torch-kernels", "-k", action="store_true")
17 | args = parser.parse_args()
18 | if args.capture_torch_kernels:
19 | with torch.autograd.profiler.emit_nvtx():
20 | ds = web_qsp_dataset.WebQSPDataset('baseline', split='val')
21 | else:
22 | ds = web_qsp_dataset.WebQSPDataset('baseline', split='val')
23 |
--------------------------------------------------------------------------------
/examples/pytorch_ignite/README.md:
--------------------------------------------------------------------------------
1 | # Examples for PyTorch Ignite
2 |
3 | This directory provides examples showcasing the integration of PyG with [PyTorch Ingite](https://pytorch.org/ignite/index.html).
4 |
5 | | Example | Description |
6 | | -------------------- | ---------------------------------------------------------------- |
7 | | [`gin.py`](./gin.py) | Demonstrates how to implement the GIN model using PyTorch Ignite |
8 |
--------------------------------------------------------------------------------
/examples/pytorch_lightning/README.md:
--------------------------------------------------------------------------------
1 | # Examples for PyTorch Lightning
2 |
3 | This directory provides examples showcasing the integration of PyG with [PyTorch Lightning](https://github.com/Lightning-AI/pytorch-lightning).
4 |
5 | | Example | Description |
6 | | ------------------------------------------ | ------------------------------------------------------------------------------------ |
7 | | [`graph_sage.py`](./graph_sage.py) | Combines PyG and PyTorch Lightning for node classification via the `GraphSAGE` model |
8 | | [`gin.py`](./gin.py) | Combines PyG and PyTorch Lightning for graph classification via the `GIN` model |
9 | | [`relational_gnn.py`](./relational_gnn.py) | Combines PyG and PyTorch Lightning for heterogeneous node classification |
10 |
--------------------------------------------------------------------------------
/graphgym/agg_batch.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | from torch_geometric.graphgym.utils.agg_runs import agg_batch
4 |
5 |
6 | def parse_args():
7 | """Parses the arguments."""
8 | parser = argparse.ArgumentParser(
9 | description='Train a classification model')
10 | parser.add_argument('--dir', dest='dir', help='Dir for batch of results',
11 | required=True, type=str)
12 | parser.add_argument('--metric', dest='metric',
13 | help='metric to select best epoch', required=False,
14 | type=str, default='auto')
15 | return parser.parse_args()
16 |
17 |
18 | args = parse_args()
19 | agg_batch(args.dir, args.metric)
20 |
--------------------------------------------------------------------------------
/graphgym/configs/example.yaml:
--------------------------------------------------------------------------------
1 | # The recommended basic settings for GNN
2 | out_dir: results
3 | dataset:
4 | format: PyG
5 | name: Cora
6 | task: node
7 | task_type: classification
8 | transductive: true
9 | split: [0.8, 0.2]
10 | transform: none
11 | train:
12 | batch_size: 32
13 | eval_period: 20
14 | ckpt_period: 100
15 | model:
16 | type: gnn
17 | loss_fun: cross_entropy
18 | edge_decoding: dot
19 | graph_pooling: add
20 | gnn:
21 | layers_pre_mp: 1
22 | layers_mp: 2
23 | layers_post_mp: 1
24 | dim_inner: 256
25 | layer_type: generalconv
26 | stage_type: stack
27 | batchnorm: true
28 | act: prelu
29 | dropout: 0.0
30 | agg: add
31 | normalize_adj: false
32 | optim:
33 | optimizer: adam
34 | base_lr: 0.01
35 | max_epoch: 400
36 |
--------------------------------------------------------------------------------
/graphgym/configs/pyg/example_graph.yaml:
--------------------------------------------------------------------------------
1 | out_dir: results
2 | dataset:
3 | format: OGB
4 | name: ogbg-molhiv
5 | task: graph
6 | task_type: classification
7 | node_encoder: true
8 | node_encoder_name: Atom
9 | edge_encoder: true
10 | edge_encoder_name: Bond
11 | train:
12 | batch_size: 128
13 | eval_period: 1
14 | ckpt_period: 100
15 | sampler: full_batch
16 | model:
17 | type: gnn
18 | loss_fun: cross_entropy
19 | edge_decoding: dot
20 | graph_pooling: add
21 | gnn:
22 | layers_pre_mp: 1
23 | layers_mp: 2
24 | layers_post_mp: 1
25 | dim_inner: 300
26 | layer_type: generalconv
27 | stage_type: stack
28 | batchnorm: true
29 | act: prelu
30 | dropout: 0.0
31 | agg: mean
32 | normalize_adj: false
33 | optim:
34 | optimizer: adam
35 | base_lr: 0.01
36 | max_epoch: 100
37 |
--------------------------------------------------------------------------------
/graphgym/configs/pyg/example_link.yaml:
--------------------------------------------------------------------------------
1 | out_dir: results
2 | dataset:
3 | format: OGB
4 | name: ogbl-collab
5 | task: link_pred
6 | task_type: classification
7 | node_encoder: false
8 | node_encoder_name: Atom
9 | edge_encoder: false
10 | edge_encoder_name: Bond
11 | train:
12 | batch_size: 128
13 | eval_period: 1
14 | ckpt_period: 100
15 | sampler: full_batch
16 | model:
17 | type: gnn
18 | loss_fun: cross_entropy
19 | edge_decoding: dot
20 | graph_pooling: add
21 | gnn:
22 | layers_pre_mp: 1
23 | layers_mp: 2
24 | layers_post_mp: 1
25 | dim_inner: 300
26 | layer_type: gcnconv
27 | stage_type: stack
28 | batchnorm: true
29 | act: prelu
30 | dropout: 0.0
31 | agg: mean
32 | normalize_adj: false
33 | optim:
34 | optimizer: adam
35 | base_lr: 0.01
36 | max_epoch: 100
37 |
--------------------------------------------------------------------------------
/graphgym/configs/pyg/example_node.yaml:
--------------------------------------------------------------------------------
1 | out_dir: results
2 | dataset:
3 | format: PyG
4 | name: Cora
5 | task: node
6 | task_type: classification
7 | node_encoder: false
8 | node_encoder_name: Atom
9 | edge_encoder: false
10 | edge_encoder_name: Bond
11 | train:
12 | batch_size: 128
13 | eval_period: 1
14 | ckpt_period: 100
15 | sampler: full_batch
16 | model:
17 | type: gnn
18 | loss_fun: cross_entropy
19 | edge_decoding: dot
20 | graph_pooling: add
21 | gnn:
22 | layers_pre_mp: 0
23 | layers_mp: 2
24 | layers_post_mp: 1
25 | dim_inner: 16
26 | layer_type: gcnconv
27 | stage_type: stack
28 | batchnorm: false
29 | act: prelu
30 | dropout: 0.1
31 | agg: mean
32 | normalize_adj: false
33 | optim:
34 | optimizer: adam
35 | base_lr: 0.01
36 | max_epoch: 200
37 |
--------------------------------------------------------------------------------
/graphgym/custom_graphgym/__init__.py:
--------------------------------------------------------------------------------
1 | from .act import * # noqa
2 | from .config import * # noqa
3 | from .encoder import * # noqa
4 | from .head import * # noqa
5 | from .layer import * # noqa
6 | from .loader import * # noqa
7 | from .loss import * # noqa
8 | from .network import * # noqa
9 | from .optimizer import * # noqa
10 | from .pooling import * # noqa
11 | from .stage import * # noqa
12 | from .train import * # noqa
13 | from .transform import * # noqa
14 |
--------------------------------------------------------------------------------
/graphgym/custom_graphgym/act/__init__.py:
--------------------------------------------------------------------------------
1 | from os.path import dirname, basename, isfile, join
2 | import glob
3 |
4 | modules = glob.glob(join(dirname(__file__), "*.py"))
5 | __all__ = [
6 | basename(f)[:-3] for f in modules
7 | if isfile(f) and not f.endswith('__init__.py')
8 | ]
9 |
--------------------------------------------------------------------------------
/graphgym/custom_graphgym/act/example.py:
--------------------------------------------------------------------------------
1 | from functools import partial
2 |
3 | import torch
4 | import torch.nn as nn
5 |
6 | from torch_geometric.graphgym.config import cfg
7 | from torch_geometric.graphgym.register import register_act
8 |
9 |
10 | class SWISH(nn.Module):
11 | def __init__(self, inplace=False):
12 | super().__init__()
13 | self.inplace = inplace
14 |
15 | def forward(self, x):
16 | if self.inplace:
17 | x.mul_(torch.sigmoid(x))
18 | return x
19 | else:
20 | return x * torch.sigmoid(x)
21 |
22 |
23 | register_act('swish', partial(SWISH, inplace=cfg.mem.inplace))
24 | register_act('lrelu_03', partial(nn.LeakyReLU, 0.3, inplace=cfg.mem.inplace))
25 |
--------------------------------------------------------------------------------
/graphgym/custom_graphgym/config/__init__.py:
--------------------------------------------------------------------------------
1 | from os.path import dirname, basename, isfile, join
2 | import glob
3 |
4 | modules = glob.glob(join(dirname(__file__), "*.py"))
5 | __all__ = [
6 | basename(f)[:-3] for f in modules
7 | if isfile(f) and not f.endswith('__init__.py')
8 | ]
9 |
--------------------------------------------------------------------------------
/graphgym/custom_graphgym/config/example.py:
--------------------------------------------------------------------------------
1 | from yacs.config import CfgNode as CN
2 |
3 | from torch_geometric.graphgym.register import register_config
4 |
5 |
6 | @register_config('example')
7 | def set_cfg_example(cfg):
8 | r"""This function sets the default config value for customized options
9 | :return: customized configuration use by the experiment.
10 | """
11 | # ----------------------------------------------------------------------- #
12 | # Customized options
13 | # ----------------------------------------------------------------------- #
14 |
15 | # example argument
16 | cfg.example_arg = 'example'
17 |
18 | # example argument group
19 | cfg.example_group = CN()
20 |
21 | # then argument can be specified within the group
22 | cfg.example_group.example_arg = 'example'
23 |
--------------------------------------------------------------------------------
/graphgym/custom_graphgym/encoder/__init__.py:
--------------------------------------------------------------------------------
1 | from os.path import dirname, basename, isfile, join
2 | import glob
3 |
4 | modules = glob.glob(join(dirname(__file__), "*.py"))
5 | __all__ = [
6 | basename(f)[:-3] for f in modules
7 | if isfile(f) and not f.endswith('__init__.py')
8 | ]
9 |
--------------------------------------------------------------------------------
/graphgym/custom_graphgym/head/__init__.py:
--------------------------------------------------------------------------------
1 | from os.path import dirname, basename, isfile, join
2 | import glob
3 |
4 | modules = glob.glob(join(dirname(__file__), "*.py"))
5 | __all__ = [
6 | basename(f)[:-3] for f in modules
7 | if isfile(f) and not f.endswith('__init__.py')
8 | ]
9 |
--------------------------------------------------------------------------------
/graphgym/custom_graphgym/head/example.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 | from torch_geometric.graphgym.register import register_head
4 |
5 |
6 | @register_head('head')
7 | class ExampleNodeHead(nn.Module):
8 | r"""Head of GNN for node prediction."""
9 | def __init__(self, dim_in, dim_out):
10 | super().__init__()
11 | self.layer_post_mp = nn.Linear(dim_in, dim_out, bias=True)
12 |
13 | def _apply_index(self, batch):
14 | if batch.node_label_index.shape[0] == batch.node_label.shape[0]:
15 | return batch.x[batch.node_label_index], batch.node_label
16 | else:
17 | return batch.x[batch.node_label_index], \
18 | batch.node_label[batch.node_label_index]
19 |
20 | def forward(self, batch):
21 | batch = self.layer_post_mp(batch)
22 | pred, label = self._apply_index(batch)
23 | return pred, label
24 |
--------------------------------------------------------------------------------
/graphgym/custom_graphgym/layer/__init__.py:
--------------------------------------------------------------------------------
1 | from os.path import dirname, basename, isfile, join
2 | import glob
3 |
4 | modules = glob.glob(join(dirname(__file__), "*.py"))
5 | __all__ = [
6 | basename(f)[:-3] for f in modules
7 | if isfile(f) and not f.endswith('__init__.py')
8 | ]
9 |
--------------------------------------------------------------------------------
/graphgym/custom_graphgym/loader/__init__.py:
--------------------------------------------------------------------------------
1 | from os.path import dirname, basename, isfile, join
2 | import glob
3 |
4 | modules = glob.glob(join(dirname(__file__), "*.py"))
5 | __all__ = [
6 | basename(f)[:-3] for f in modules
7 | if isfile(f) and not f.endswith('__init__.py')
8 | ]
9 |
--------------------------------------------------------------------------------
/graphgym/custom_graphgym/loader/example.py:
--------------------------------------------------------------------------------
1 | from torch_geometric.datasets import QM7b
2 | from torch_geometric.graphgym.register import register_loader
3 |
4 |
5 | @register_loader('example')
6 | def load_dataset_example(format, name, dataset_dir):
7 | dataset_dir = f'{dataset_dir}/{name}'
8 | if format == 'PyG':
9 | if name == 'QM7b':
10 | dataset_raw = QM7b(dataset_dir)
11 | return dataset_raw
12 |
--------------------------------------------------------------------------------
/graphgym/custom_graphgym/loss/__init__.py:
--------------------------------------------------------------------------------
1 | from os.path import dirname, basename, isfile, join
2 | import glob
3 |
4 | modules = glob.glob(join(dirname(__file__), "*.py"))
5 | __all__ = [
6 | basename(f)[:-3] for f in modules
7 | if isfile(f) and not f.endswith('__init__.py')
8 | ]
9 |
--------------------------------------------------------------------------------
/graphgym/custom_graphgym/loss/example.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from torch_geometric.graphgym.config import cfg
4 | from torch_geometric.graphgym.register import register_loss
5 |
6 |
7 | @register_loss('smoothl1')
8 | def loss_example(pred, true):
9 | if cfg.model.loss_fun == 'smoothl1':
10 | l1_loss = torch.nn.SmoothL1Loss()
11 | loss = l1_loss(pred, true)
12 | return loss, pred
13 |
--------------------------------------------------------------------------------
/graphgym/custom_graphgym/network/__init__.py:
--------------------------------------------------------------------------------
1 | from os.path import dirname, basename, isfile, join
2 | import glob
3 |
4 | modules = glob.glob(join(dirname(__file__), "*.py"))
5 | __all__ = [
6 | basename(f)[:-3] for f in modules
7 | if isfile(f) and not f.endswith('__init__.py')
8 | ]
9 |
--------------------------------------------------------------------------------
/graphgym/custom_graphgym/optimizer/__init__.py:
--------------------------------------------------------------------------------
1 | from os.path import dirname, basename, isfile, join
2 | import glob
3 |
4 | modules = glob.glob(join(dirname(__file__), "*.py"))
5 | __all__ = [
6 | basename(f)[:-3] for f in modules
7 | if isfile(f) and not f.endswith('__init__.py')
8 | ]
9 |
--------------------------------------------------------------------------------
/graphgym/custom_graphgym/optimizer/example.py:
--------------------------------------------------------------------------------
1 | from typing import Iterator
2 |
3 | from torch.nn import Parameter
4 | from torch.optim import Adagrad, Optimizer
5 | from torch.optim.lr_scheduler import ReduceLROnPlateau
6 |
7 | import torch_geometric.graphgym.register as register
8 |
9 |
10 | @register.register_optimizer('adagrad')
11 | def adagrad_optimizer(params: Iterator[Parameter], base_lr: float,
12 | weight_decay: float) -> Adagrad:
13 | return Adagrad(params, lr=base_lr, weight_decay=weight_decay)
14 |
15 |
16 | @register.register_scheduler('pleateau')
17 | def plateau_scheduler(optimizer: Optimizer, patience: int,
18 | lr_decay: float) -> ReduceLROnPlateau:
19 | return ReduceLROnPlateau(optimizer, patience=patience, factor=lr_decay)
20 |
--------------------------------------------------------------------------------
/graphgym/custom_graphgym/pooling/__init__.py:
--------------------------------------------------------------------------------
1 | from os.path import dirname, basename, isfile, join
2 | import glob
3 |
4 | modules = glob.glob(join(dirname(__file__), "*.py"))
5 | __all__ = [
6 | basename(f)[:-3] for f in modules
7 | if isfile(f) and not f.endswith('__init__.py')
8 | ]
9 |
--------------------------------------------------------------------------------
/graphgym/custom_graphgym/pooling/example.py:
--------------------------------------------------------------------------------
1 | from torch_geometric.graphgym.register import register_pooling
2 | from torch_geometric.utils import scatter
3 |
4 |
5 | @register_pooling('example')
6 | def global_example_pool(x, batch, size=None):
7 | size = batch.max().item() + 1 if size is None else size
8 | return scatter(x, batch, dim=0, dim_size=size, reduce='sum')
9 |
--------------------------------------------------------------------------------
/graphgym/custom_graphgym/stage/__init__.py:
--------------------------------------------------------------------------------
1 | from os.path import dirname, basename, isfile, join
2 | import glob
3 |
4 | modules = glob.glob(join(dirname(__file__), "*.py"))
5 | __all__ = [
6 | basename(f)[:-3] for f in modules
7 | if isfile(f) and not f.endswith('__init__.py')
8 | ]
9 |
--------------------------------------------------------------------------------
/graphgym/custom_graphgym/train/__init__.py:
--------------------------------------------------------------------------------
1 | from os.path import dirname, basename, isfile, join
2 | import glob
3 |
4 | modules = glob.glob(join(dirname(__file__), "*.py"))
5 | __all__ = [
6 | basename(f)[:-3] for f in modules
7 | if isfile(f) and not f.endswith('__init__.py')
8 | ]
9 |
--------------------------------------------------------------------------------
/graphgym/custom_graphgym/transform/__init__.py:
--------------------------------------------------------------------------------
1 | from os.path import dirname, basename, isfile, join
2 | import glob
3 |
4 | modules = glob.glob(join(dirname(__file__), "*.py"))
5 | __all__ = [
6 | basename(f)[:-3] for f in modules
7 | if isfile(f) and not f.endswith('__init__.py')
8 | ]
9 |
--------------------------------------------------------------------------------
/graphgym/grids/example.txt:
--------------------------------------------------------------------------------
1 | # Format for each row: name in config.py; alias; range to search
2 | # No spaces, except between these 3 fields
3 | # Line breaks are used to union different grid search spaces
4 | # Feel free to add '#' to add comments
5 |
6 |
7 | # (1) dataset configurations
8 | dataset.format format ['PyG']
9 | dataset.name dataset ['TU_ENZYMES','TU_PROTEINS']
10 | dataset.task task ['graph']
11 | dataset.transductive trans [False]
12 | # (2) The recommended GNN design space, 96 models in total
13 | gnn.layers_pre_mp l_pre [1,2]
14 | gnn.layers_mp l_mp [2,4,6,8]
15 | gnn.layers_post_mp l_post [2,3]
16 | gnn.stage_type stage ['skipsum','skipconcat']
17 | gnn.agg agg ['add','mean','max']
18 |
--------------------------------------------------------------------------------
/graphgym/grids/pyg/example.txt:
--------------------------------------------------------------------------------
1 | # Format for each row: name in config.py; alias; range to search
2 | # No spaces, except between these 3 fields
3 | # Line breaks are used to union different grid search spaces
4 | # Feel free to add '#' to add comments
5 |
6 |
7 | gnn.layers_pre_mp l_pre [1,2]
8 | gnn.layers_mp l_mp [2,4,6]
9 | gnn.layers_post_mp l_post [1,2]
10 | gnn.stage_type stage ['stack','skipsum','skipconcat']
11 | gnn.dim_inner dim [64]
12 | optim.base_lr lr [0.01]
13 | optim.max_epoch epoch [200]
14 |
--------------------------------------------------------------------------------
/graphgym/parallel.sh:
--------------------------------------------------------------------------------
1 | CONFIG_DIR=$1
2 | REPEAT=$2
3 | MAX_JOBS=${3:-2}
4 | SLEEP=${4:-1}
5 | MAIN=${5:-main}
6 |
7 | (
8 | trap 'kill 0' SIGINT
9 | CUR_JOBS=0
10 | for CONFIG in "$CONFIG_DIR"/*.yaml; do
11 | if [ "$CONFIG" != "$CONFIG_DIR/*.yaml" ]; then
12 | ((CUR_JOBS >= MAX_JOBS)) && wait -n
13 | python $MAIN.py --cfg $CONFIG --repeat $REPEAT --mark_done &
14 | echo $CONFIG
15 | sleep $SLEEP
16 | ((++CUR_JOBS))
17 | fi
18 | done
19 |
20 | wait
21 | )
22 |
--------------------------------------------------------------------------------
/graphgym/run_single.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | # Test for running a single experiment. --repeat means run how many different random seeds.
4 | python main.py --cfg configs/pyg/example_node.yaml --repeat 3 # node classification
5 | python main.py --cfg configs/pyg/example_link.yaml --repeat 3 # link prediction
6 | python main.py --cfg configs/pyg/example_graph.yaml --repeat 3 # graph classification
7 |
--------------------------------------------------------------------------------
/graphgym/sample/dimensions.txt:
--------------------------------------------------------------------------------
1 | act bn drop agg l_mp l_pre l_post stage batch lr optim epoch
2 |
--------------------------------------------------------------------------------
/graphgym/sample/dimensionsatt.txt:
--------------------------------------------------------------------------------
1 | l_tw
2 |
--------------------------------------------------------------------------------
/readthedocs.yml:
--------------------------------------------------------------------------------
1 | version: 2
2 |
3 | sphinx:
4 | configuration: docs/source/conf.py
5 |
6 | build:
7 | os: ubuntu-24.04
8 | tools:
9 | python: "3.10"
10 |
11 | python:
12 | install:
13 | - requirements: docs/requirements.txt
14 | - method: pip
15 | path: .
16 |
17 | formats: []
18 |
--------------------------------------------------------------------------------
/test/datasets/graph_generator/test_ba_graph.py:
--------------------------------------------------------------------------------
1 | from torch_geometric.datasets.graph_generator import BAGraph
2 |
3 |
4 | def test_ba_graph():
5 | graph_generator = BAGraph(num_nodes=300, num_edges=5)
6 | assert str(graph_generator) == 'BAGraph(num_nodes=300, num_edges=5)'
7 |
8 | data = graph_generator()
9 | assert len(data) == 2
10 | assert data.num_nodes == 300
11 | assert data.num_edges <= 2 * 300 * 5
12 |
--------------------------------------------------------------------------------
/test/datasets/graph_generator/test_er_graph.py:
--------------------------------------------------------------------------------
1 | from torch_geometric.datasets.graph_generator import ERGraph
2 |
3 |
4 | def test_er_graph():
5 | graph_generator = ERGraph(num_nodes=300, edge_prob=0.1)
6 | assert str(graph_generator) == 'ERGraph(num_nodes=300, edge_prob=0.1)'
7 |
8 | data = graph_generator()
9 | assert len(data) == 2
10 | assert data.num_nodes == 300
11 | assert data.num_edges >= 300 * 300 * 0.05
12 | assert data.num_edges <= 300 * 300 * 0.15
13 |
--------------------------------------------------------------------------------
/test/datasets/graph_generator/test_grid_graph.py:
--------------------------------------------------------------------------------
1 | from torch_geometric.datasets.graph_generator import GridGraph
2 |
3 |
4 | def test_grid_graph():
5 | graph_generator = GridGraph(height=10, width=10)
6 | assert str(graph_generator) == 'GridGraph(height=10, width=10)'
7 |
8 | data = graph_generator()
9 | assert len(data) == 2
10 | assert data.num_nodes == 100
11 | assert data.num_edges == 784
12 |
--------------------------------------------------------------------------------
/test/datasets/graph_generator/test_tree_graph.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from torch_geometric.datasets.graph_generator import TreeGraph
4 |
5 |
6 | @pytest.mark.parametrize('undirected', [False, True])
7 | def test_tree_graph(undirected):
8 | graph_generator = TreeGraph(depth=2, branch=2, undirected=undirected)
9 | assert str(graph_generator) == (f'TreeGraph(depth=2, branch=2, '
10 | f'undirected={undirected})')
11 |
12 | data = graph_generator()
13 | assert len(data) == 3
14 | assert data.num_nodes == 7
15 | assert data.depth.tolist() == [0, 1, 1, 2, 2, 2, 2]
16 | if not undirected:
17 | assert data.edge_index.tolist() == [
18 | [0, 0, 1, 1, 2, 2],
19 | [1, 2, 3, 4, 5, 6],
20 | ]
21 | else:
22 | assert data.edge_index.tolist() == [
23 | [0, 0, 1, 1, 1, 2, 2, 2, 3, 4, 5, 6],
24 | [1, 2, 0, 3, 4, 0, 5, 6, 1, 1, 2, 2],
25 | ]
26 |
--------------------------------------------------------------------------------
/test/datasets/motif_generator/test_cycle_motif.py:
--------------------------------------------------------------------------------
1 | from torch_geometric.datasets.motif_generator import CycleMotif
2 |
3 |
4 | def test_cycle_motif():
5 | motif_generator = CycleMotif(5)
6 | assert str(motif_generator) == 'CycleMotif(5)'
7 |
8 | motif = motif_generator()
9 | assert len(motif) == 2
10 | assert motif.num_nodes == 5
11 | assert motif.num_edges == 10
12 | assert motif.edge_index.tolist() == [
13 | [0, 0, 1, 1, 2, 2, 3, 3, 4, 4],
14 | [1, 4, 0, 2, 1, 3, 2, 4, 0, 3],
15 | ]
16 |
--------------------------------------------------------------------------------
/test/datasets/motif_generator/test_grid_motif.py:
--------------------------------------------------------------------------------
1 | from torch_geometric.datasets.motif_generator import GridMotif
2 |
3 |
4 | def test_grid_motif():
5 | motif_generator = GridMotif()
6 | assert str(motif_generator) == 'GridMotif()'
7 |
8 | motif = motif_generator()
9 | assert len(motif) == 3
10 | assert motif.num_nodes == 9
11 | assert motif.num_edges == 24
12 | assert motif.edge_index.size() == (2, 24)
13 | assert motif.edge_index.min() == 0
14 | assert motif.edge_index.max() == 8
15 | assert motif.y.size() == (9, )
16 | assert motif.y.min() == 0
17 | assert motif.y.max() == 2
18 |
--------------------------------------------------------------------------------
/test/datasets/motif_generator/test_house_motif.py:
--------------------------------------------------------------------------------
1 | from torch_geometric.datasets.motif_generator import HouseMotif
2 |
3 |
4 | def test_house_motif():
5 | motif_generator = HouseMotif()
6 | assert str(motif_generator) == 'HouseMotif()'
7 |
8 | motif = motif_generator()
9 | assert len(motif) == 3
10 | assert motif.num_nodes == 5
11 | assert motif.num_edges == 12
12 | assert motif.y.min() == 0 and motif.y.max() == 2
13 |
--------------------------------------------------------------------------------
/test/datasets/test_ba_shapes.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 |
4 | def test_ba_shapes(get_dataset):
5 | with pytest.warns(UserWarning, match="is deprecated"):
6 | dataset = get_dataset(name='BAShapes')
7 |
8 | assert str(dataset) == 'BAShapes()'
9 | assert len(dataset) == 1
10 | assert dataset.num_features == 10
11 | assert dataset.num_classes == 4
12 |
13 | data = dataset[0]
14 | assert len(data) == 5
15 | assert data.edge_index.size(1) >= 1120
16 | assert data.x.size() == (700, 10)
17 | assert data.y.size() == (700, )
18 | assert data.expl_mask.sum() == 60
19 | assert data.edge_label.sum() == 960
20 |
--------------------------------------------------------------------------------
/test/datasets/test_bzr.py:
--------------------------------------------------------------------------------
1 | from torch_geometric.testing import onlyFullTest, onlyOnline
2 |
3 |
4 | @onlyOnline
5 | @onlyFullTest
6 | def test_bzr(get_dataset):
7 | dataset = get_dataset(name='BZR')
8 | assert len(dataset) == 405
9 | assert dataset.num_features == 53
10 | assert dataset.num_node_labels == 53
11 | assert dataset.num_node_attributes == 3
12 | assert dataset.num_classes == 2
13 | assert str(dataset) == 'BZR(405)'
14 | assert len(dataset[0]) == 3
15 |
16 |
17 | @onlyOnline
18 | @onlyFullTest
19 | def test_bzr_with_node_attr(get_dataset):
20 | dataset = get_dataset(name='BZR', use_node_attr=True)
21 | assert dataset.num_features == 56
22 | assert dataset.num_node_labels == 53
23 | assert dataset.num_node_attributes == 3
24 |
--------------------------------------------------------------------------------
/test/datasets/test_git_mol_dataset.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple
2 |
3 | import pytest
4 |
5 | from torch_geometric.datasets import GitMolDataset
6 | from torch_geometric.testing import onlyFullTest, withPackage
7 |
8 |
9 | @onlyFullTest
10 | @withPackage('torchvision', 'rdkit', 'PIL')
11 | @pytest.mark.parametrize('split', [
12 | (0, 3610),
13 | (1, 451),
14 | (2, 451),
15 | ])
16 | def test_git_mol_dataset(split: Tuple[int, int]) -> None:
17 | dataset = GitMolDataset(root='./data/GITMol', split=split[0])
18 |
19 | assert len(dataset) == split[1]
20 | assert dataset[0].image.size() == (1, 3, 224, 224)
21 | assert dataset[0].num_node_features == 9
22 | assert dataset[0].num_edge_features == 3
23 |
--------------------------------------------------------------------------------
/test/datasets/test_imdb_binary.py:
--------------------------------------------------------------------------------
1 | from torch_geometric.testing import onlyFullTest, onlyOnline
2 |
3 |
4 | @onlyOnline
5 | @onlyFullTest
6 | def test_imdb_binary(get_dataset):
7 | dataset = get_dataset(name='IMDB-BINARY')
8 | assert len(dataset) == 1000
9 | assert dataset.num_features == 0
10 | assert dataset.num_classes == 2
11 | assert str(dataset) == 'IMDB-BINARY(1000)'
12 |
13 | data = dataset[0]
14 | assert len(data) == 3
15 | assert data.edge_index.size() == (2, 146)
16 | assert data.y.size() == (1, )
17 | assert data.num_nodes == 20
18 |
--------------------------------------------------------------------------------
/test/datasets/test_karate.py:
--------------------------------------------------------------------------------
1 | def test_karate(get_dataset):
2 | dataset = get_dataset(name='KarateClub')
3 | assert str(dataset) == 'KarateClub()'
4 | assert len(dataset) == 1
5 | assert dataset.num_features == 34
6 | assert dataset.num_classes == 4
7 |
8 | assert len(dataset[0]) == 4
9 | assert dataset[0].edge_index.size() == (2, 156)
10 | assert dataset[0].x.size() == (34, 34)
11 | assert dataset[0].y.size() == (34, )
12 | assert dataset[0].train_mask.size() == (34, )
13 | assert dataset[0].train_mask.sum().item() == 4
14 |
--------------------------------------------------------------------------------
/test/datasets/test_medshapenet.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from torch_geometric.data import Data
4 | from torch_geometric.datasets import MedShapeNet
5 | from torch_geometric.testing import withPackage
6 |
7 |
8 | @withPackage('MedShapeNet')
9 | def test_medshapenet():
10 | dataset = MedShapeNet(root="./data/MedShapeNet", size=1)
11 |
12 | assert str(dataset) == f'MedShapeNet({len(dataset)})'
13 |
14 | assert isinstance(dataset[0], Data)
15 | assert dataset.num_classes == 8
16 |
17 | assert isinstance(dataset[0].pos, torch.Tensor)
18 | assert len(dataset[0].pos) > 0
19 |
20 | assert isinstance(dataset[0].face, torch.Tensor)
21 | assert len(dataset[0].face) == 3
22 |
23 | assert isinstance(dataset[0].y, torch.Tensor)
24 | assert len(dataset[0].y) == 1
25 |
--------------------------------------------------------------------------------
/test/datasets/test_molecule_gpt_dataset.py:
--------------------------------------------------------------------------------
1 | from torch_geometric.datasets import MoleculeGPTDataset
2 | from torch_geometric.testing import onlyOnline, withPackage
3 |
4 |
5 | @onlyOnline
6 | @withPackage('transformers', 'sentencepiece', 'accelerate', 'rdkit')
7 | def test_molecule_gpt_dataset():
8 | dataset = MoleculeGPTDataset(
9 | root='./data/MoleculeGPT',
10 | num_units=10,
11 | )
12 | assert str(dataset) == f'MoleculeGPTDataset({len(dataset)})'
13 | assert dataset.num_edge_features == 4
14 | assert dataset.num_node_features == 6
15 |
--------------------------------------------------------------------------------
/test/datasets/test_mutag.py:
--------------------------------------------------------------------------------
1 | from torch_geometric.testing import onlyOnline
2 |
3 |
4 | @onlyOnline
5 | def test_mutag(get_dataset):
6 | dataset = get_dataset(name='MUTAG')
7 | assert len(dataset) == 188
8 | assert dataset.num_features == 7
9 | assert dataset.num_classes == 2
10 | assert str(dataset) == 'MUTAG(188)'
11 |
12 | assert len(dataset[0]) == 4
13 | assert dataset[0].edge_attr.size(1) == 4
14 |
15 |
16 | @onlyOnline
17 | def test_mutag_with_node_attr(get_dataset):
18 | dataset = get_dataset(name='MUTAG', use_node_attr=True)
19 | assert dataset.num_features == 7
20 |
--------------------------------------------------------------------------------
/test/datasets/test_protein_mpnn_dataset.py:
--------------------------------------------------------------------------------
1 | from torch_geometric.datasets import ProteinMPNNDataset
2 | from torch_geometric.testing import onlyOnline, withPackage
3 |
4 |
5 | @onlyOnline
6 | @withPackage('pandas')
7 | def test_protein_mpnn_dataset():
8 | dataset = ProteinMPNNDataset(root='./data/ProteinMPNN')
9 |
10 | assert len(dataset) == 150
11 | assert dataset[0].x.size() == (229, 4, 3)
12 | assert dataset[0].chain_seq_label.size() == (229, )
13 | assert dataset[0].mask.size() == (229, )
14 | assert dataset[0].chain_mask_all.size() == (229, )
15 | assert dataset[0].residue_idx.size() == (229, )
16 | assert dataset[0].chain_encoding_all.size() == (229, )
17 |
--------------------------------------------------------------------------------
/test/datasets/test_suite_sparse.py:
--------------------------------------------------------------------------------
1 | from torch_geometric.testing import onlyFullTest, onlyOnline
2 |
3 |
4 | @onlyOnline
5 | @onlyFullTest
6 | def test_suite_sparse_dataset(get_dataset):
7 | dataset = get_dataset(group='DIMACS10', name='citationCiteseer')
8 | assert str(dataset) == ('SuiteSparseMatrixCollection('
9 | 'group=DIMACS10, name=citationCiteseer)')
10 | assert len(dataset) == 1
11 |
12 |
13 | @onlyOnline
14 | @onlyFullTest
15 | def test_illc1850_suite_sparse_dataset(get_dataset):
16 | dataset = get_dataset(group='HB', name='illc1850')
17 | assert str(dataset) == ('SuiteSparseMatrixCollection('
18 | 'group=HB, name=illc1850)')
19 | assert len(dataset) == 1
20 |
--------------------------------------------------------------------------------
/test/datasets/test_tag_dataset.py:
--------------------------------------------------------------------------------
1 | from torch_geometric.datasets import TAGDataset
2 | from torch_geometric.testing import onlyFullTest, withPackage
3 |
4 |
5 | @onlyFullTest
6 | @withPackage('ogb')
7 | def test_tag_dataset() -> None:
8 | from ogb.nodeproppred import PygNodePropPredDataset
9 |
10 | root = './data/ogb'
11 | hf_model = 'prajjwal1/bert-tiny'
12 | token_on_disk = True
13 |
14 | dataset = PygNodePropPredDataset('ogbn-arxiv', root=root)
15 | tag_dataset = TAGDataset(root, dataset, hf_model,
16 | token_on_disk=token_on_disk)
17 |
18 | assert 169343 == tag_dataset[0].num_nodes \
19 | == len(tag_dataset.text) \
20 | == len(tag_dataset.llm_explanation) \
21 | == len(tag_dataset.llm_prediction)
22 | assert 1166243 == tag_dataset[0].num_edges
23 |
--------------------------------------------------------------------------------
/test/datasets/test_teeth3ds.py:
--------------------------------------------------------------------------------
1 | from torch_geometric.data import Data
2 | from torch_geometric.datasets import Teeth3DS
3 | from torch_geometric.testing import withPackage
4 |
5 |
6 | @withPackage('trimesh', 'fpsample')
7 | def test_teeth3ds(tmp_path) -> None:
8 | dataset = Teeth3DS(root=tmp_path, split='sample', train=True)
9 |
10 | assert len(dataset) > 0
11 | data = dataset[0]
12 | assert isinstance(data, Data)
13 | assert data.pos.size(1) == 3
14 | assert data.x.size(0) == data.pos.size(0)
15 | assert data.y.size(0) == data.pos.size(0)
16 | assert isinstance(data.jaw, str)
17 |
--------------------------------------------------------------------------------
/test/graphgym/example_node.yml:
--------------------------------------------------------------------------------
1 | tensorboard_each_run: false
2 | tensorboard_agg: false
3 | dataset:
4 | format: PyG
5 | name: Cora
6 | task: node
7 | task_type: classification
8 | node_encoder: false
9 | node_encoder_name: Atom
10 | edge_encoder: false
11 | edge_encoder_name: Bond
12 | train:
13 | batch_size: 128
14 | eval_period: 2
15 | ckpt_period: 100
16 | enable_ckpt: false
17 | skip_train_eval: true
18 | sampler: full_batch
19 | model:
20 | type: gnn
21 | loss_fun: cross_entropy
22 | edge_decoding: dot
23 | graph_pooling: add
24 | gnn:
25 | layers_pre_mp: 2
26 | layers_mp: 2
27 | layers_post_mp: 1
28 | dim_inner: 16
29 | layer_type: gcnconv
30 | stage_type: stack
31 | batchnorm: false
32 | act: prelu
33 | dropout: 0.1
34 | agg: mean
35 | normalize_adj: false
36 | optim:
37 | optimizer: adam
38 | base_lr: 0.01
39 | max_epoch: 6
40 |
--------------------------------------------------------------------------------
/test/graphgym/test_config.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 |
3 | from torch_geometric.graphgym.config import from_config
4 |
5 |
6 | @dataclass
7 | class MyConfig:
8 | a: int
9 | b: int = 4
10 |
11 |
12 | def my_func(a: int, b: int = 2) -> str:
13 | return f'a={a},b={b}'
14 |
15 |
16 | def test_from_config():
17 | assert my_func(a=1) == 'a=1,b=2'
18 |
19 | assert my_func.__name__ == from_config(my_func).__name__
20 | assert from_config(my_func)(cfg=MyConfig(a=1)) == 'a=1,b=4'
21 | assert from_config(my_func)(cfg=MyConfig(a=1, b=1)) == 'a=1,b=1'
22 | assert from_config(my_func)(2, cfg=MyConfig(a=1, b=3)) == 'a=2,b=3'
23 | assert from_config(my_func)(cfg=MyConfig(a=1), b=3) == 'a=1,b=3'
24 |
--------------------------------------------------------------------------------
/test/graphgym/test_logger.py:
--------------------------------------------------------------------------------
1 | from torch_geometric.graphgym.config import set_run_dir
2 | from torch_geometric.graphgym.loader import create_loader
3 | from torch_geometric.graphgym.logger import Logger, LoggerCallback
4 | from torch_geometric.testing import withPackage
5 |
6 |
7 | @withPackage('yacs', 'pytorch_lightning')
8 | def test_logger_callback():
9 | loaders = create_loader()
10 | assert len(loaders) == 3
11 |
12 | set_run_dir('.')
13 | logger = LoggerCallback()
14 | assert isinstance(logger.train_logger, Logger)
15 | assert isinstance(logger.val_logger, Logger)
16 | assert isinstance(logger.test_logger, Logger)
17 |
--------------------------------------------------------------------------------
/test/graphgym/test_register.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | import torch_geometric.graphgym.register as register
4 | from torch_geometric.testing import withPackage
5 |
6 |
7 | @register.register_act('identity')
8 | def identity_act(x: torch.Tensor) -> torch.Tensor:
9 | return x
10 |
11 |
12 | @withPackage('yacs')
13 | def test_register():
14 | assert len(register.act_dict) == 8
15 | assert list(register.act_dict.keys()) == [
16 | 'relu', 'selu', 'prelu', 'elu', 'lrelu_01', 'lrelu_025', 'lrelu_05',
17 | 'identity'
18 | ]
19 | assert str(register.act_dict['relu']()) == 'ReLU()'
20 |
21 | register.register_act('lrelu_03', torch.nn.LeakyReLU(0.3))
22 | assert len(register.act_dict) == 9
23 | assert 'lrelu_03' in register.act_dict
24 |
--------------------------------------------------------------------------------
/test/io/example1.off:
--------------------------------------------------------------------------------
1 | OFF
2 | 4 2 0
3 | 0.0 0.0 0.0
4 | 0.0 1.0 0.0
5 | 1.0 0.0 0.0
6 | 1.0 1.0 0.0
7 | 3 0 1 2
8 | 3 1 2 3
9 |
--------------------------------------------------------------------------------
/test/io/example2.off:
--------------------------------------------------------------------------------
1 | OFF
2 | 4 1 0
3 | 0.0 0.0 0.0
4 | 0.0 1.0 0.0
5 | 1.0 0.0 0.0
6 | 1.0 1.0 0.0
7 | 4 0 1 2 3
8 |
--------------------------------------------------------------------------------
/test/llm/models/test_git_mol.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from torch_geometric.llm.models import GITMol
4 | from torch_geometric.testing import withPackage
5 |
6 |
7 | @withPackage('transformers', 'sentencepiece', 'accelerate')
8 | def test_git_mol():
9 | model = GITMol()
10 |
11 | x = torch.ones(10, 16, dtype=torch.long)
12 | edge_index = torch.tensor([
13 | [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
14 | [1, 2, 3, 4, 0, 6, 7, 8, 9, 5],
15 | ])
16 | edge_attr = torch.zeros(edge_index.size(1), 16, dtype=torch.long)
17 | batch = torch.tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
18 | smiles = ['CC(C)([C@H]1CC2=C(O1)C=CC3=C2OC(=O)C=C3)O']
19 | captions = ['The molecule is the (R)-(-)-enantiomer of columbianetin.']
20 | images = torch.randn(1, 3, 224, 224)
21 |
22 | # Test train:
23 | loss = model(x, edge_index, batch, edge_attr, smiles, images, captions)
24 | assert loss >= 0
25 |
--------------------------------------------------------------------------------
/test/llm/models/test_llm.py:
--------------------------------------------------------------------------------
1 | import gc
2 |
3 | import torch
4 | from torch import Tensor
5 |
6 | from torch_geometric.llm.models import LLM
7 | from torch_geometric.testing import onlyRAG, withPackage
8 |
9 |
10 | @onlyRAG
11 | @withPackage('transformers', 'accelerate')
12 | def test_llm() -> None:
13 | question = ["Is PyG the best open-source GNN library?"]
14 | answer = ["yes!"]
15 |
16 | model = LLM(model_name='Qwen/Qwen3-0.6B', num_params=1,
17 | dtype=torch.float16,
18 | sys_prompt="You're an agent, answer my questions.")
19 | assert str(model) == 'LLM(Qwen/Qwen3-0.6B)'
20 |
21 | loss = model(question, answer)
22 | assert isinstance(loss, Tensor)
23 | assert loss.dim() == 0
24 | assert loss >= 0.0
25 |
26 | pred = model.inference(question)
27 | assert len(pred) == 1
28 | del model
29 | gc.collect()
30 | torch.cuda.empty_cache()
31 |
--------------------------------------------------------------------------------
/test/llm/models/test_protein_mpnn.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from torch_geometric.llm.models import ProteinMPNN
4 | from torch_geometric.testing import withPackage
5 |
6 |
7 | @withPackage('torch_cluster')
8 | def test_protein_mpnn():
9 | num_nodes = 10
10 | vocab_size = 21
11 |
12 | model = ProteinMPNN(vocab_size=vocab_size)
13 | x = torch.randn(num_nodes, 4, 3)
14 | chain_seq_label = torch.randint(0, vocab_size, (num_nodes, ))
15 | mask = torch.ones(num_nodes)
16 | chain_mask_all = torch.ones(num_nodes)
17 | residue_idx = torch.randint(0, 10, (num_nodes, ))
18 | chain_encoding_all = torch.ones(num_nodes)
19 | batch = torch.zeros(num_nodes, dtype=torch.long)
20 |
21 | logits = model(x, chain_seq_label, mask, chain_mask_all, residue_idx,
22 | chain_encoding_all, batch)
23 | assert logits.size() == (num_nodes, vocab_size)
24 |
--------------------------------------------------------------------------------
/test/llm/models/test_vision_transformer.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from torch_geometric.llm.models import VisionTransformer
4 | from torch_geometric.testing import onlyFullTest, withCUDA, withPackage
5 |
6 |
7 | @withCUDA
8 | @onlyFullTest
9 | @withPackage('transformers')
10 | def test_vision_transformer(device):
11 | model = VisionTransformer(
12 | model_name='microsoft/swin-base-patch4-window7-224', ).to(device)
13 | assert model.device == device
14 | assert str(
15 | model
16 | ) == 'VisionTransformer(model_name=microsoft/swin-base-patch4-window7-224)'
17 |
18 | images = torch.randn(2, 3, 224, 224).to(device)
19 |
20 | out = model(images)
21 | assert out.device == device
22 | assert out.size() == (2, 49, 1024)
23 |
24 | out = model(images, output_device='cpu')
25 | assert out.is_cpu
26 | assert out.size() == (2, 49, 1024)
27 |
--------------------------------------------------------------------------------
/test/loader/test_utils.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import torch
3 |
4 | from torch_geometric.loader.utils import index_select
5 |
6 |
7 | def test_index_select():
8 | x = torch.randn(3, 5)
9 | index = torch.tensor([0, 2])
10 | assert torch.equal(index_select(x, index), x[index])
11 | assert torch.equal(index_select(x, index, dim=-1), x[..., index])
12 |
13 |
14 | def test_index_select_out_of_range():
15 | with pytest.raises(IndexError, match="out of range"):
16 | index_select(torch.randn(3, 5), torch.tensor([0, 2, 3]))
17 |
--------------------------------------------------------------------------------
/test/my_config.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - dataset: KarateClub
3 | - transform@dataset.transform:
4 | - NormalizeFeatures
5 | - AddSelfLoops
6 | - model: GCN
7 | - optimizer: Adam
8 | - lr_scheduler: ReduceLROnPlateau
9 | - _self_
10 |
11 | model:
12 | in_channels: 34
13 | out_channels: 4
14 | hidden_channels: 16
15 | num_layers: 2
16 |
--------------------------------------------------------------------------------
/test/nn/aggr/test_deep_sets.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from torch_geometric.nn import DeepSetsAggregation, Linear
4 |
5 |
6 | def test_deep_sets_aggregation():
7 | x = torch.randn(6, 16)
8 | index = torch.tensor([0, 0, 1, 1, 1, 2])
9 |
10 | aggr = DeepSetsAggregation(
11 | local_nn=Linear(16, 32),
12 | global_nn=Linear(32, 64),
13 | )
14 | aggr.reset_parameters()
15 | assert str(aggr) == ('DeepSetsAggregation('
16 | 'local_nn=Linear(16, 32, bias=True), '
17 | 'global_nn=Linear(32, 64, bias=True))')
18 |
19 | out = aggr(x, index)
20 | assert out.size() == (3, 64)
21 |
--------------------------------------------------------------------------------
/test/nn/aggr/test_gmt.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from torch_geometric.nn.aggr import GraphMultisetTransformer
4 | from torch_geometric.testing import is_full_test
5 |
6 |
7 | def test_graph_multiset_transformer():
8 | x = torch.randn(6, 16)
9 | index = torch.tensor([0, 0, 1, 1, 1, 2])
10 |
11 | aggr = GraphMultisetTransformer(16, k=2, heads=2)
12 | aggr.reset_parameters()
13 | assert str(aggr) == ('GraphMultisetTransformer(16, k=2, heads=2, '
14 | 'layer_norm=False, dropout=0.0)')
15 |
16 | out = aggr(x, index)
17 | assert out.size() == (3, 16)
18 |
19 | if is_full_test():
20 | jit = torch.jit.script(aggr)
21 | assert torch.allclose(jit(x, index), out)
22 |
--------------------------------------------------------------------------------
/test/nn/aggr/test_gru.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from torch_geometric.nn import GRUAggregation
4 |
5 |
6 | def test_gru_aggregation():
7 | x = torch.randn(6, 16)
8 | index = torch.tensor([0, 0, 1, 1, 1, 2])
9 |
10 | aggr = GRUAggregation(16, 32)
11 | assert str(aggr) == 'GRUAggregation(16, 32)'
12 |
13 | out = aggr(x, index)
14 | assert out.size() == (3, 32)
15 |
--------------------------------------------------------------------------------
/test/nn/aggr/test_lstm.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import torch
3 |
4 | from torch_geometric.nn import LSTMAggregation
5 |
6 |
7 | def test_lstm_aggregation():
8 | x = torch.randn(6, 16)
9 | index = torch.tensor([0, 0, 1, 1, 1, 2])
10 |
11 | aggr = LSTMAggregation(16, 32)
12 | assert str(aggr) == 'LSTMAggregation(16, 32)'
13 |
14 | with pytest.raises(ValueError, match="is not sorted"):
15 | aggr(x, torch.tensor([0, 1, 0, 1, 2, 1]))
16 |
17 | out = aggr(x, index)
18 | assert out.size() == (3, 32)
19 |
--------------------------------------------------------------------------------
/test/nn/aggr/test_mlp_aggr.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from torch_geometric.nn import MLPAggregation
4 |
5 |
6 | def test_mlp_aggregation():
7 | x = torch.randn(6, 16)
8 | index = torch.tensor([0, 0, 1, 1, 1, 2])
9 |
10 | aggr = MLPAggregation(
11 | in_channels=16,
12 | out_channels=32,
13 | max_num_elements=3,
14 | num_layers=1,
15 | )
16 | assert str(aggr) == 'MLPAggregation(16, 32, max_num_elements=3)'
17 |
18 | out = aggr(x, index)
19 | assert out.size() == (3, 32)
20 |
--------------------------------------------------------------------------------
/test/nn/aggr/test_patch_transformer.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from torch_geometric.nn import PatchTransformerAggregation
4 | from torch_geometric.testing import withCUDA
5 |
6 |
7 | @withCUDA
8 | def test_patch_transformer_aggregation(device: torch.device) -> None:
9 | aggr = PatchTransformerAggregation(
10 | in_channels=16,
11 | out_channels=32,
12 | patch_size=2,
13 | hidden_channels=8,
14 | num_transformer_blocks=1,
15 | heads=2,
16 | dropout=0.2,
17 | aggr=['sum', 'mean', 'min', 'max', 'var', 'std'],
18 | device=device,
19 | )
20 | aggr.reset_parameters()
21 | assert str(aggr) == 'PatchTransformerAggregation(16, 32, patch_size=2)'
22 |
23 | index = torch.tensor([0, 0, 1, 1, 1, 2], device=device)
24 | x = torch.randn(index.size(0), 16, device=device)
25 |
26 | out = aggr(x, index)
27 | assert out.device == device
28 | assert out.size() == (3, aggr.out_channels)
29 |
--------------------------------------------------------------------------------
/test/nn/aggr/test_scaler.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import torch
3 |
4 | from torch_geometric.nn import DegreeScalerAggregation
5 |
6 |
7 | @pytest.mark.parametrize('train_norm', [True, False])
8 | def test_degree_scaler_aggregation(train_norm):
9 | x = torch.randn(6, 16)
10 | index = torch.tensor([0, 0, 1, 1, 1, 2])
11 | ptr = torch.tensor([0, 2, 5, 6])
12 | deg = torch.tensor([0, 3, 0, 1, 1, 0])
13 |
14 | aggr = ['mean', 'sum', 'max']
15 | scaler = [
16 | 'identity', 'amplification', 'attenuation', 'linear', 'inverse_linear'
17 | ]
18 | aggr = DegreeScalerAggregation(aggr, scaler, deg, train_norm=train_norm)
19 | assert str(aggr) == 'DegreeScalerAggregation()'
20 |
21 | out = aggr(x, index)
22 | assert out.size() == (3, 240)
23 | assert torch.allclose(torch.jit.script(aggr)(x, index), out)
24 |
25 | with pytest.raises(NotImplementedError, match="requires 'index'"):
26 | aggr(x, ptr=ptr)
27 |
--------------------------------------------------------------------------------
/test/nn/aggr/test_set2set.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from torch_geometric.nn.aggr import Set2Set
4 |
5 |
6 | def test_set2set():
7 | set2set = Set2Set(in_channels=2, processing_steps=1)
8 | assert str(set2set) == 'Set2Set(2, 4)'
9 |
10 | N = 4
11 | x_1, batch_1 = torch.randn(N, 2), torch.zeros(N, dtype=torch.long)
12 | out_1 = set2set(x_1, batch_1).view(-1)
13 |
14 | N = 6
15 | x_2, batch_2 = torch.randn(N, 2), torch.zeros(N, dtype=torch.long)
16 | out_2 = set2set(x_2, batch_2).view(-1)
17 |
18 | x, batch = torch.cat([x_1, x_2]), torch.cat([batch_1, batch_2 + 1])
19 | out = set2set(x, batch)
20 | assert out.size() == (2, 4)
21 | assert torch.allclose(out_1, out[0])
22 | assert torch.allclose(out_2, out[1])
23 |
24 | x, batch = torch.cat([x_2, x_1]), torch.cat([batch_2, batch_1 + 1])
25 | out = set2set(x, batch)
26 | assert out.size() == (2, 4)
27 | assert torch.allclose(out_1, out[1])
28 | assert torch.allclose(out_2, out[0])
29 |
--------------------------------------------------------------------------------
/test/nn/attention/test_performer_attention.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from torch_geometric.nn.attention import PerformerAttention
4 |
5 |
6 | def test_performer_attention():
7 | x = torch.randn(1, 4, 16)
8 | mask = torch.ones([1, 4], dtype=torch.bool)
9 | attn = PerformerAttention(channels=16, heads=4)
10 | out = attn(x, mask)
11 | assert out.shape == (1, 4, 16)
12 | assert str(attn) == ('PerformerAttention(heads=4, '
13 | 'head_channels=64 kernel=ReLU())')
14 |
--------------------------------------------------------------------------------
/test/nn/attention/test_polynormer_attention.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from torch_geometric.nn.attention import PolynormerAttention
4 |
5 |
6 | def test_performer_attention():
7 | x = torch.randn(1, 4, 16)
8 | mask = torch.ones([1, 4], dtype=torch.bool)
9 | attn = PolynormerAttention(channels=16, heads=4)
10 | out = attn(x, mask)
11 | assert out.shape == (1, 4, 256)
12 | assert str(attn) == 'PolynormerAttention(heads=4, head_channels=64)'
13 |
--------------------------------------------------------------------------------
/test/nn/attention/test_qformer.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from torch_geometric.nn.attention import QFormer
4 |
5 |
6 | def test_qformer():
7 | x = torch.randn(1, 4, 16)
8 | attn = QFormer(input_dim=16, hidden_dim=16, output_dim=32, num_heads=4,
9 | num_layers=2)
10 | out = attn(x)
11 |
12 | assert out.shape == (1, 4, 32)
13 | assert str(attn) == ('QFormer(num_heads=4, num_layers=2)')
14 |
--------------------------------------------------------------------------------
/test/nn/conv/test_dir_gnn_conv.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from torch_geometric.nn import DirGNNConv, SAGEConv
4 |
5 |
6 | def test_dir_gnn_conv():
7 | x = torch.randn(4, 16)
8 | edge_index = torch.tensor([[0, 1, 2], [1, 2, 3]])
9 |
10 | conv = DirGNNConv(SAGEConv(16, 32))
11 | assert str(conv) == 'DirGNNConv(SAGEConv(16, 32, aggr=mean), alpha=0.5)'
12 |
13 | out = conv(x, edge_index)
14 | assert out.size() == (4, 32)
15 |
16 |
17 | def test_static_dir_gnn_conv():
18 | x = torch.randn(3, 4, 16)
19 | edge_index = torch.tensor([[0, 1, 2], [1, 2, 3]])
20 |
21 | conv = DirGNNConv(SAGEConv(16, 32))
22 |
23 | out = conv(x, edge_index)
24 | assert out.size() == (3, 4, 32)
25 |
--------------------------------------------------------------------------------
/test/nn/conv/test_static_graph.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from torch_geometric.data import Batch, Data
4 | from torch_geometric.nn import ChebConv, GCNConv, MessagePassing
5 |
6 |
7 | class MyConv(MessagePassing):
8 | def forward(self, x, edge_index):
9 | return self.propagate(edge_index, x=x)
10 |
11 |
12 | def test_static_graph():
13 | edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]])
14 | x1, x2 = torch.randn(3, 8), torch.randn(3, 8)
15 |
16 | data1 = Data(edge_index=edge_index, x=x1)
17 | data2 = Data(edge_index=edge_index, x=x2)
18 | batch = Batch.from_data_list([data1, data2])
19 |
20 | x = torch.stack([x1, x2], dim=0)
21 | for conv in [MyConv(), GCNConv(8, 16), ChebConv(8, 16, K=2)]:
22 | out1 = conv(batch.x, batch.edge_index)
23 | assert out1.size(0) == 6
24 | conv.node_dim = 1
25 | out2 = conv(x, edge_index)
26 | assert out2.size()[:2] == (2, 3)
27 | assert torch.allclose(out1, out2.view(-1, out2.size(-1)))
28 |
--------------------------------------------------------------------------------
/test/nn/conv/test_x_conv.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from torch_geometric.nn import XConv
4 | from torch_geometric.testing import is_full_test, withPackage
5 |
6 |
7 | @withPackage('torch_cluster')
8 | def test_x_conv():
9 | x = torch.randn(8, 16)
10 | pos = torch.rand(8, 3)
11 | batch = torch.tensor([0, 0, 0, 0, 1, 1, 1, 1])
12 |
13 | conv = XConv(16, 32, dim=3, kernel_size=2, dilation=2)
14 | assert str(conv) == 'XConv(16, 32)'
15 |
16 | torch.manual_seed(12345)
17 | out1 = conv(x, pos)
18 | assert out1.size() == (8, 32)
19 |
20 | torch.manual_seed(12345)
21 | out2 = conv(x, pos, batch)
22 | assert out2.size() == (8, 32)
23 |
24 | if is_full_test():
25 | jit = torch.jit.script(conv)
26 |
27 | torch.manual_seed(12345)
28 | assert torch.allclose(jit(x, pos), out1, atol=1e-6)
29 |
30 | torch.manual_seed(12345)
31 | assert torch.allclose(jit(x, pos, batch), out2, atol=1e-6)
32 |
--------------------------------------------------------------------------------
/test/nn/dense/test_dmon_pool.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import torch
4 |
5 | from torch_geometric.nn import DMoNPooling
6 |
7 |
8 | def test_dmon_pooling():
9 | batch_size, num_nodes, channels, num_clusters = (2, 20, 16, 10)
10 | x = torch.randn((batch_size, num_nodes, channels))
11 | adj = torch.ones((batch_size, num_nodes, num_nodes))
12 | mask = torch.randint(0, 2, (batch_size, num_nodes), dtype=torch.bool)
13 |
14 | pool = DMoNPooling([channels, channels], num_clusters)
15 | assert str(pool) == 'DMoNPooling(16, num_clusters=10)'
16 |
17 | s, x, adj, spectral_loss, ortho_loss, cluster_loss = pool(x, adj, mask)
18 | assert s.size() == (2, 20, 10)
19 | assert x.size() == (2, 10, 16)
20 | assert adj.size() == (2, 10, 10)
21 | assert -1 <= spectral_loss <= 0.5
22 | assert 0 <= ortho_loss <= math.sqrt(2)
23 | assert 0 <= cluster_loss <= math.sqrt(num_clusters) - 1
24 |
--------------------------------------------------------------------------------
/test/nn/functional/test_bro.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from torch_geometric.nn.functional import bro
4 |
5 |
6 | def test_bro():
7 | batch = torch.tensor([0, 0, 0, 0, 1, 1, 1, 2, 2])
8 |
9 | g1 = torch.tensor([
10 | [0.2, 0.2, 0.2, 0.2],
11 | [0.0, 0.2, 0.2, 0.2],
12 | [0.2, 0.0, 0.2, 0.2],
13 | [0.2, 0.2, 0.0, 0.2],
14 | ])
15 |
16 | g2 = torch.tensor([
17 | [0.2, 0.2, 0.2, 0.2],
18 | [0.0, 0.2, 0.2, 0.2],
19 | [0.2, 0.0, 0.2, 0.2],
20 | ])
21 |
22 | g3 = torch.tensor([
23 | [0.2, 0.2, 0.2, 0.2],
24 | [0.2, 0.0, 0.2, 0.2],
25 | ])
26 |
27 | s = 0.
28 | for g in [g1, g2, g3]:
29 | s += torch.norm(g @ g.t() - torch.eye(g.shape[0]), p=2)
30 |
31 | assert torch.isclose(s / 3., bro(torch.cat([g1, g2, g3], dim=0), batch))
32 |
--------------------------------------------------------------------------------
/test/nn/functional/test_gini.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from torch_geometric.nn.functional import gini
4 |
5 |
6 | def test_gini():
7 | w = torch.tensor([[0., 0., 0., 0.], [0., 0., 0., 1000.0]])
8 | assert torch.isclose(gini(w), torch.tensor(0.5))
9 |
--------------------------------------------------------------------------------
/test/nn/kge/test_distmult.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from torch_geometric.nn import DistMult
4 |
5 |
6 | def test_distmult():
7 | model = DistMult(num_nodes=10, num_relations=5, hidden_channels=32)
8 | assert str(model) == 'DistMult(10, num_relations=5, hidden_channels=32)'
9 |
10 | head_index = torch.tensor([0, 2, 4, 6, 8])
11 | rel_type = torch.tensor([0, 1, 2, 3, 4])
12 | tail_index = torch.tensor([1, 3, 5, 7, 9])
13 |
14 | loader = model.loader(head_index, rel_type, tail_index, batch_size=5)
15 | for h, r, t in loader:
16 | out = model(h, r, t)
17 | assert out.size() == (5, )
18 |
19 | loss = model.loss(h, r, t)
20 | assert loss >= 0.
21 |
22 | mean_rank, mrr, hits = model.test(h, r, t, batch_size=5, log=False)
23 | assert 0 <= mean_rank <= 10
24 | assert 0 < mrr <= 1
25 | assert hits == 1.0
26 |
--------------------------------------------------------------------------------
/test/nn/kge/test_rotate.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from torch_geometric.nn import RotatE
4 |
5 |
6 | def test_rotate():
7 | model = RotatE(num_nodes=10, num_relations=5, hidden_channels=32)
8 | assert str(model) == 'RotatE(10, num_relations=5, hidden_channels=32)'
9 |
10 | head_index = torch.tensor([0, 2, 4, 6, 8])
11 | rel_type = torch.tensor([0, 1, 2, 3, 4])
12 | tail_index = torch.tensor([1, 3, 5, 7, 9])
13 |
14 | loader = model.loader(head_index, rel_type, tail_index, batch_size=5)
15 | for h, r, t in loader:
16 | out = model(h, r, t)
17 | assert out.size() == (5, )
18 |
19 | loss = model.loss(h, r, t)
20 | assert loss >= 0.
21 |
22 | mean_rank, mrr, hits = model.test(h, r, t, batch_size=5, log=False)
23 | assert 0 <= mean_rank <= 10
24 | assert 0 < mrr <= 1
25 | assert hits == 1.0
26 |
--------------------------------------------------------------------------------
/test/nn/kge/test_transe.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from torch_geometric.nn import TransE
4 |
5 |
6 | def test_transe():
7 | model = TransE(num_nodes=10, num_relations=5, hidden_channels=32)
8 | assert str(model) == 'TransE(10, num_relations=5, hidden_channels=32)'
9 |
10 | head_index = torch.tensor([0, 2, 4, 6, 8])
11 | rel_type = torch.tensor([0, 1, 2, 3, 4])
12 | tail_index = torch.tensor([1, 3, 5, 7, 9])
13 |
14 | loader = model.loader(head_index, rel_type, tail_index, batch_size=5)
15 | for h, r, t in loader:
16 | out = model(h, r, t)
17 | assert out.size() == (5, )
18 |
19 | loss = model.loss(h, r, t)
20 | assert loss >= 0.
21 |
22 | mean_rank, mrr, hits = model.test(h, r, t, batch_size=5, log=False)
23 | assert 0 <= mean_rank <= 10
24 | assert 0 < mrr <= 1
25 | assert hits == 1.0
26 |
--------------------------------------------------------------------------------
/test/nn/models/test_attentive_fp.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from torch_geometric.nn import AttentiveFP
4 | from torch_geometric.testing import is_full_test
5 |
6 |
7 | def test_attentive_fp():
8 | model = AttentiveFP(8, 16, 32, edge_dim=3, num_layers=2, num_timesteps=2)
9 | assert str(model) == ('AttentiveFP(in_channels=8, hidden_channels=16, '
10 | 'out_channels=32, edge_dim=3, num_layers=2, '
11 | 'num_timesteps=2)')
12 |
13 | x = torch.randn(4, 8)
14 | edge_index = torch.tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]])
15 | edge_attr = torch.randn(edge_index.size(1), 3)
16 | batch = torch.tensor([0, 0, 0, 0])
17 |
18 | out = model(x, edge_index, edge_attr, batch)
19 | assert out.size() == (1, 32)
20 |
21 | if is_full_test():
22 | jit = torch.jit.script(model)
23 | assert torch.allclose(jit(x, edge_index, edge_attr, batch), out)
24 |
--------------------------------------------------------------------------------
/test/nn/models/test_deepgcn.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import torch
3 | from torch.nn import ReLU
4 |
5 | from torch_geometric.nn import DeepGCNLayer, GENConv, LayerNorm
6 |
7 |
8 | @pytest.mark.parametrize(
9 | 'block_tuple',
10 | [('res+', 1), ('res', 1), ('dense', 2), ('plain', 1)],
11 | )
12 | @pytest.mark.parametrize('ckpt_grad', [True, False])
13 | def test_deepgcn(block_tuple, ckpt_grad):
14 | block, expansion = block_tuple
15 | x = torch.randn(3, 8)
16 | edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]])
17 | conv = GENConv(8, 8)
18 | norm = LayerNorm(8)
19 | act = ReLU()
20 | layer = DeepGCNLayer(conv, norm, act, block=block, ckpt_grad=ckpt_grad)
21 | assert str(layer) == f'DeepGCNLayer(block={block})'
22 |
23 | out = layer(x, edge_index)
24 | assert out.size() == (3, 8 * expansion)
25 |
--------------------------------------------------------------------------------
/test/nn/models/test_gnnff.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from torch_geometric.nn import GNNFF
4 | from torch_geometric.testing import is_full_test, withPackage
5 |
6 |
7 | @withPackage('torch_sparse') # TODO `triplet` requires `SparseTensor` for now.
8 | @withPackage('torch-cluster')
9 | def test_gnnff():
10 | z = torch.randint(1, 10, (20, ))
11 | pos = torch.randn(20, 3)
12 |
13 | model = GNNFF(
14 | hidden_node_channels=5,
15 | hidden_edge_channels=5,
16 | num_layers=5,
17 | )
18 | model.reset_parameters()
19 |
20 | out = model(z, pos)
21 | assert out.size() == (20, 3)
22 |
23 | if is_full_test():
24 | jit = torch.jit.export(model)
25 | assert torch.allclose(jit(z, pos), out)
26 |
--------------------------------------------------------------------------------
/test/nn/models/test_graph_unet.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from torch_geometric.nn import GraphUNet
4 | from torch_geometric.testing import is_full_test, onlyLinux
5 |
6 |
7 | @onlyLinux # TODO (matthias) Investigate CSR @ CSR support on Windows.
8 | def test_graph_unet():
9 | model = GraphUNet(16, 32, 8, depth=3)
10 | out = 'GraphUNet(16, 32, 8, depth=3, pool_ratios=[0.5, 0.5, 0.5])'
11 | assert str(model) == out
12 |
13 | x = torch.randn(3, 16)
14 | edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]])
15 |
16 | out = model(x, edge_index)
17 | assert out.size() == (3, 8)
18 |
19 | if is_full_test():
20 | jit = torch.jit.export(model)
21 | out = jit(x, edge_index)
22 | assert out.size() == (3, 8)
23 |
--------------------------------------------------------------------------------
/test/nn/models/test_mask_label.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from torch_geometric.nn import MaskLabel
4 |
5 |
6 | def test_mask_label():
7 | model = MaskLabel(2, 10)
8 | assert str(model) == 'MaskLabel()'
9 |
10 | x = torch.rand(4, 10)
11 | y = torch.tensor([1, 0, 1, 0])
12 | mask = torch.tensor([False, False, True, True])
13 |
14 | out = model(x, y, mask)
15 | assert out.size() == (4, 10)
16 | assert torch.allclose(out[~mask], x[~mask])
17 |
18 | model = MaskLabel(2, 10, method='concat')
19 | out = model(x, y, mask)
20 | assert out.size() == (4, 20)
21 | assert torch.allclose(out[:, :10], x)
22 |
23 |
24 | def test_ratio_mask():
25 | mask = torch.tensor([True, True, True, True, False, False, False, False])
26 | out = MaskLabel.ratio_mask(mask, 0.5)
27 | assert out[:4].sum() <= 4 and out[4:].sum() == 0
28 |
--------------------------------------------------------------------------------
/test/nn/models/test_pmlp.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import torch
3 |
4 | from torch_geometric.nn.models import PMLP
5 |
6 |
7 | def test_pmlp():
8 | x = torch.randn(4, 16)
9 | edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]])
10 |
11 | pmlp = PMLP(in_channels=16, hidden_channels=32, out_channels=2,
12 | num_layers=4)
13 | assert str(pmlp) == 'PMLP(16, 2, num_layers=4)'
14 |
15 | pmlp.training = True
16 | assert pmlp(x).size() == (4, 2)
17 |
18 | pmlp.training = False
19 | assert pmlp(x, edge_index).size() == (4, 2)
20 |
21 | with pytest.raises(ValueError, match="'edge_index' needs to be present"):
22 | pmlp.training = False
23 | pmlp(x)
24 |
--------------------------------------------------------------------------------
/test/nn/models/test_sgformer.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from torch_geometric.nn.models import SGFormer
4 |
5 |
6 | def test_sgformer():
7 | x = torch.randn(10, 16)
8 | edge_index = torch.tensor([
9 | [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
10 | [1, 2, 3, 4, 0, 6, 7, 8, 9, 5],
11 | ])
12 | batch = torch.tensor([0, 0, 0, 0, 1, 1, 1, 1, 1, 1])
13 |
14 | model = SGFormer(
15 | in_channels=16,
16 | hidden_channels=128,
17 | out_channels=40,
18 | )
19 | out = model(x, edge_index, batch)
20 | assert out.size() == (10, 40)
21 |
--------------------------------------------------------------------------------
/test/nn/models/test_visnet.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import torch
3 |
4 | from torch_geometric.nn import ViSNet
5 | from torch_geometric.testing import withPackage
6 |
7 |
8 | @withPackage('torch_cluster')
9 | @pytest.mark.parametrize('kwargs', [
10 | dict(lmax=2, derivative=True, vecnorm_type=None, vertex=False),
11 | dict(lmax=1, derivative=False, vecnorm_type='max_min', vertex=True),
12 | ])
13 | def test_visnet(kwargs):
14 | z = torch.randint(1, 10, (20, ))
15 | pos = torch.randn(20, 3)
16 | batch = torch.zeros(20, dtype=torch.long)
17 |
18 | model = ViSNet(**kwargs)
19 |
20 | model.reset_parameters()
21 |
22 | energy, forces = model(z, pos, batch)
23 |
24 | assert energy.size() == (1, 1)
25 |
26 | if kwargs['derivative']:
27 | assert forces.size() == (20, 3)
28 | else:
29 | assert forces is None
30 |
--------------------------------------------------------------------------------
/test/nn/norm/test_graph_size_norm.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from torch_geometric.nn import GraphSizeNorm
4 | from torch_geometric.testing import is_full_test
5 |
6 |
7 | def test_graph_size_norm():
8 | x = torch.randn(100, 16)
9 | batch = torch.repeat_interleave(torch.full((10, ), 10, dtype=torch.long))
10 |
11 | norm = GraphSizeNorm()
12 | assert str(norm) == 'GraphSizeNorm()'
13 |
14 | out = norm(x, batch)
15 | assert out.size() == (100, 16)
16 |
17 | if is_full_test():
18 | jit = torch.jit.script(norm)
19 | assert torch.allclose(jit(x, batch), out)
20 |
--------------------------------------------------------------------------------
/test/nn/norm/test_mean_subtraction_norm.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from torch_geometric.nn import MeanSubtractionNorm
4 | from torch_geometric.testing import is_full_test
5 |
6 |
7 | def test_mean_subtraction_norm():
8 | x = torch.randn(6, 16)
9 | batch = torch.tensor([0, 0, 1, 1, 1, 2])
10 |
11 | norm = MeanSubtractionNorm()
12 | assert str(norm) == 'MeanSubtractionNorm()'
13 |
14 | if is_full_test():
15 | torch.jit.script(norm)
16 |
17 | out = norm(x)
18 | assert out.size() == (6, 16)
19 | assert torch.allclose(out.mean(), torch.tensor(0.), atol=1e-6)
20 |
21 | out = norm(x, batch)
22 | assert out.size() == (6, 16)
23 | assert torch.allclose(out[0:2].mean(), torch.tensor(0.), atol=1e-6)
24 | assert torch.allclose(out[0:2].mean(), torch.tensor(0.), atol=1e-6)
25 |
--------------------------------------------------------------------------------
/test/nn/norm/test_msg_norm.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from torch_geometric.nn import MessageNorm
4 | from torch_geometric.testing import is_full_test, withDevice
5 |
6 |
7 | @withDevice
8 | def test_message_norm(device):
9 | norm = MessageNorm(learn_scale=True, device=device)
10 | assert str(norm) == 'MessageNorm(learn_scale=True)'
11 | x = torch.randn(100, 16, device=device)
12 | msg = torch.randn(100, 16, device=device)
13 | out = norm(x, msg)
14 | assert out.size() == (100, 16)
15 |
16 | if is_full_test():
17 | jit = torch.jit.script(norm)
18 | assert torch.allclose(jit(x, msg), out)
19 |
20 | norm = MessageNorm(learn_scale=False, device=device)
21 | assert str(norm) == 'MessageNorm(learn_scale=False)'
22 | out = norm(x, msg)
23 | assert out.size() == (100, 16)
24 |
25 | if is_full_test():
26 | jit = torch.jit.script(norm)
27 | assert torch.allclose(jit(x, msg), out)
28 |
--------------------------------------------------------------------------------
/test/nn/norm/test_pair_norm.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import torch
3 |
4 | from torch_geometric.nn import PairNorm
5 | from torch_geometric.testing import is_full_test
6 |
7 |
8 | @pytest.mark.parametrize('scale_individually', [False, True])
9 | def test_pair_norm(scale_individually):
10 | x = torch.randn(100, 16)
11 | batch = torch.zeros(100, dtype=torch.long)
12 |
13 | norm = PairNorm(scale_individually=scale_individually)
14 | assert str(norm) == 'PairNorm()'
15 |
16 | if is_full_test():
17 | torch.jit.script(norm)
18 |
19 | out1 = norm(x)
20 | assert out1.size() == (100, 16)
21 |
22 | out2 = norm(torch.cat([x, x], dim=0), torch.cat([batch, batch + 1], dim=0))
23 | assert torch.allclose(out1, out2[:100], atol=1e-6)
24 | assert torch.allclose(out1, out2[100:], atol=1e-6)
25 |
--------------------------------------------------------------------------------
/test/nn/pool/test_consecutive.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from torch_geometric.nn.pool.consecutive import consecutive_cluster
4 |
5 |
6 | def test_consecutive_cluster():
7 | src = torch.tensor([8, 2, 10, 15, 100, 1, 100])
8 |
9 | out, perm = consecutive_cluster(src)
10 | assert out.tolist() == [2, 1, 3, 4, 5, 0, 5]
11 | assert perm.tolist() == [5, 1, 0, 2, 3, 6]
12 |
--------------------------------------------------------------------------------
/test/nn/pool/test_graclus.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from torch_geometric.nn import graclus
4 | from torch_geometric.testing import withPackage
5 |
6 |
7 | @withPackage('torch_cluster')
8 | def test_graclus():
9 | edge_index = torch.tensor([[0, 1], [1, 0]])
10 | assert graclus(edge_index).tolist() == [0, 0]
11 |
--------------------------------------------------------------------------------
/test/nn/pool/test_mem_pool.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from torch_geometric.nn import MemPooling
4 | from torch_geometric.utils import to_dense_batch
5 |
6 |
7 | def test_mem_pool():
8 | mpool1 = MemPooling(4, 8, heads=3, num_clusters=2)
9 | assert str(mpool1) == 'MemPooling(4, 8, heads=3, num_clusters=2)'
10 | mpool2 = MemPooling(8, 4, heads=2, num_clusters=1)
11 |
12 | x = torch.randn(17, 4)
13 | batch = torch.tensor([0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 4, 4, 4])
14 | _, mask = to_dense_batch(x, batch)
15 |
16 | out1, S = mpool1(x, batch)
17 | loss = MemPooling.kl_loss(S)
18 | with torch.autograd.set_detect_anomaly(True):
19 | loss.backward()
20 | out2, _ = mpool2(out1)
21 |
22 | assert out1.size() == (5, 2, 8)
23 | assert out2.size() == (5, 1, 4)
24 | assert S[~mask].sum() == 0
25 | assert round(S[mask].sum().item()) == x.size(0)
26 | assert float(loss) > 0
27 |
--------------------------------------------------------------------------------
/test/nn/pool/test_pool.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | import torch
4 | from torch import Tensor
5 |
6 | from torch_geometric.nn import radius_graph
7 | from torch_geometric.testing import onlyFullTest, withPackage
8 |
9 |
10 | @onlyFullTest
11 | @withPackage('torch_cluster')
12 | def test_radius_graph_jit():
13 | class Net(torch.nn.Module):
14 | def forward(self, x: Tensor, batch: Optional[Tensor] = None) -> Tensor:
15 | return radius_graph(x, r=2.5, batch=batch, loop=False)
16 |
17 | x = torch.tensor([[-1, -1], [-1, 1], [1, -1], [1, 1]], dtype=torch.float)
18 | batch = torch.tensor([0, 0, 0, 0])
19 |
20 | model = Net()
21 | jit = torch.jit.script(model)
22 | assert model(x, batch).size() == jit(x, batch).size()
23 |
--------------------------------------------------------------------------------
/test/nn/test_data_parallel.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import torch
3 |
4 | from torch_geometric.data import Data
5 | from torch_geometric.nn import DataParallel
6 | from torch_geometric.testing import onlyCUDA
7 |
8 |
9 | @onlyCUDA
10 | def test_data_parallel_single_gpu():
11 | with pytest.warns(UserWarning, match="much slower"):
12 | module = DataParallel(torch.nn.Identity())
13 | data_list = [Data(x=torch.randn(x, 1)) for x in [2, 3, 10, 4]]
14 | batches = module.scatter(data_list, device_ids=[0])
15 | assert len(batches) == 1
16 |
17 |
18 | @onlyCUDA
19 | @pytest.mark.skipif(torch.cuda.device_count() < 2, reason='No multiple GPUs')
20 | def test_data_parallel_multi_gpu():
21 | with pytest.warns(UserWarning, match="much slower"):
22 | module = DataParallel(torch.nn.Identity())
23 | data_list = [Data(x=torch.randn(x, 1)) for x in [2, 3, 10, 4]]
24 | batches = module.scatter(data_list, device_ids=[0, 1, 0, 1])
25 | assert len(batches) == 3
26 |
--------------------------------------------------------------------------------
/test/nn/test_encoding.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from torch_geometric.nn import PositionalEncoding, TemporalEncoding
4 | from torch_geometric.testing import withDevice
5 |
6 |
7 | @withDevice
8 | def test_positional_encoding(device):
9 | encoder = PositionalEncoding(64, device=device)
10 | assert str(encoder) == 'PositionalEncoding(64)'
11 |
12 | x = torch.tensor([1.0, 2.0, 3.0], device=device)
13 | assert encoder(x).size() == (3, 64)
14 |
15 |
16 | @withDevice
17 | def test_temporal_encoding(device):
18 | encoder = TemporalEncoding(64, device=device)
19 | assert str(encoder) == 'TemporalEncoding(64)'
20 |
21 | x = torch.tensor([1.0, 2.0, 3.0], device=device)
22 | assert encoder(x).size() == (3, 64)
23 |
--------------------------------------------------------------------------------
/test/nn/test_fvcore.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from torch_geometric.nn import GraphSAGE
4 | from torch_geometric.testing import get_random_edge_index, withPackage
5 |
6 |
7 | @withPackage('fvcore')
8 | def test_fvcore():
9 | from fvcore.nn import FlopCountAnalysis
10 |
11 | x = torch.randn(10, 16)
12 | edge_index = get_random_edge_index(10, 10, num_edges=100)
13 |
14 | model = GraphSAGE(16, 32, num_layers=2)
15 |
16 | flops = FlopCountAnalysis(model, (x, edge_index))
17 |
18 | # TODO (matthias) Currently, aggregations are not properly registered.
19 | assert flops.by_module()['convs.0'] == 2 * 10 * 16 * 32
20 | assert flops.by_module()['convs.1'] == 2 * 10 * 32 * 32
21 | assert flops.total() == (flops.by_module()['convs.0'] +
22 | flops.by_module()['convs.1'])
23 | assert flops.by_operator()['linear'] == flops.total()
24 |
--------------------------------------------------------------------------------
/test/nn/test_fx.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from torch import Tensor
4 |
5 |
6 | def test_dropout():
7 | class MyModule(torch.nn.Module):
8 | def forward(self, x: Tensor) -> Tensor:
9 | return F.dropout(x, p=1.0, training=self.training)
10 |
11 | module = MyModule()
12 | graph_module = torch.fx.symbolic_trace(module)
13 | graph_module.recompile()
14 |
15 | x = torch.randn(4)
16 |
17 | graph_module.train()
18 | assert torch.allclose(graph_module(x), torch.zeros_like(x))
19 |
20 | # This is certainly undesired behavior due to tracing :(
21 | graph_module.eval()
22 | assert torch.allclose(graph_module(x), torch.zeros_like(x))
23 |
--------------------------------------------------------------------------------
/test/nn/test_reshape.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from torch_geometric.nn.reshape import Reshape
4 |
5 |
6 | def test_reshape():
7 | x = torch.randn(10, 4)
8 |
9 | op = Reshape(5, 2, 4)
10 | assert str(op) == 'Reshape(5, 2, 4)'
11 |
12 | assert op(x).size() == (5, 2, 4)
13 | assert torch.equal(op(x).view(10, 4), x)
14 |
--------------------------------------------------------------------------------
/test/nn/test_to_fixed_size_transformer.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from torch_geometric.nn import SumAggregation
4 | from torch_geometric.nn.to_fixed_size_transformer import to_fixed_size
5 |
6 |
7 | class Model(torch.nn.Module):
8 | def __init__(self):
9 | super().__init__()
10 | self.aggr = SumAggregation()
11 |
12 | def forward(self, x, batch):
13 | return self.aggr(x, batch, dim=0)
14 |
15 |
16 | def test_to_fixed_size():
17 | x = torch.randn(10, 16)
18 | batch = torch.zeros(10, dtype=torch.long)
19 |
20 | model = Model()
21 | assert model(x, batch).size() == (1, 16)
22 |
23 | model = to_fixed_size(model, batch_size=10)
24 | assert model(x, batch).size() == (10, 16)
25 |
--------------------------------------------------------------------------------
/test/nn/unpool/test_knn_interpolate.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from torch_geometric.nn import knn_interpolate
4 | from torch_geometric.testing import withPackage
5 |
6 |
7 | @withPackage('torch_cluster')
8 | def test_knn_interpolate():
9 | x = torch.tensor([[1.0], [10.0], [100.0], [-1.0], [-10.0], [-100.0]])
10 | pos_x = torch.tensor([
11 | [-1.0, 0.0],
12 | [0.0, 0.0],
13 | [1.0, 0.0],
14 | [-2.0, 0.0],
15 | [0.0, 0.0],
16 | [2.0, 0.0],
17 | ])
18 | pos_y = torch.tensor([
19 | [-1.0, -1.0],
20 | [1.0, 1.0],
21 | [-2.0, -2.0],
22 | [2.0, 2.0],
23 | ])
24 | batch_x = torch.tensor([0, 0, 0, 1, 1, 1])
25 | batch_y = torch.tensor([0, 0, 1, 1])
26 |
27 | y = knn_interpolate(x, pos_x, pos_y, batch_x, batch_y, k=2)
28 | assert y.tolist() == [[4.0], [70.0], [-4.0], [-70.0]]
29 |
--------------------------------------------------------------------------------
/test/profile/test_benchmark.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from torch_geometric.profile import benchmark
4 | from torch_geometric.testing import withPackage
5 |
6 |
7 | @withPackage('tabulate')
8 | def test_benchmark(capfd):
9 | def add(x, y):
10 | return x + y
11 |
12 | benchmark(
13 | funcs=[add],
14 | args=(torch.randn(10), torch.randn(10)),
15 | num_steps=1,
16 | num_warmups=1,
17 | backward=True,
18 | )
19 |
20 | out, _ = capfd.readouterr()
21 | assert '| Name | Forward | Backward | Total |' in out
22 | assert '| add |' in out
23 |
--------------------------------------------------------------------------------
/test/test_debug.py:
--------------------------------------------------------------------------------
1 | from torch_geometric import debug, is_debug_enabled, set_debug
2 |
3 |
4 | def test_debug():
5 | assert is_debug_enabled() is False
6 | set_debug(True)
7 | assert is_debug_enabled() is True
8 | set_debug(False)
9 | assert is_debug_enabled() is False
10 |
11 | assert is_debug_enabled() is False
12 | with set_debug(True):
13 | assert is_debug_enabled() is True
14 | assert is_debug_enabled() is False
15 |
16 | assert is_debug_enabled() is False
17 | set_debug(True)
18 | assert is_debug_enabled() is True
19 | with set_debug(False):
20 | assert is_debug_enabled() is False
21 | assert is_debug_enabled() is True
22 | set_debug(False)
23 | assert is_debug_enabled() is False
24 |
25 | assert is_debug_enabled() is False
26 | with debug():
27 | assert is_debug_enabled() is True
28 | assert is_debug_enabled() is False
29 |
--------------------------------------------------------------------------------
/test/test_home.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path as osp
3 |
4 | from torch_geometric import get_home_dir, set_home_dir
5 | from torch_geometric.home import DEFAULT_CACHE_DIR
6 |
7 |
8 | def test_home():
9 | os.environ.pop('PYG_HOME', None)
10 | home_dir = osp.expanduser(DEFAULT_CACHE_DIR)
11 | assert get_home_dir() == home_dir
12 |
13 | home_dir = '/tmp/test_pyg1'
14 | os.environ['PYG_HOME'] = home_dir
15 | assert get_home_dir() == home_dir
16 |
17 | home_dir = '/tmp/test_pyg2'
18 | set_home_dir(home_dir)
19 | assert get_home_dir() == home_dir
20 |
--------------------------------------------------------------------------------
/test/test_isinstance.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from torch_geometric import is_torch_instance
4 | from torch_geometric.testing import onlyLinux, withPackage
5 |
6 |
7 | def test_basic():
8 | assert is_torch_instance(torch.nn.Linear(1, 1), torch.nn.Linear)
9 |
10 |
11 | @onlyLinux
12 | @withPackage('torch>=2.0.0')
13 | def test_compile():
14 | model = torch.compile(torch.nn.Linear(1, 1))
15 | assert not isinstance(model, torch.nn.Linear)
16 | assert is_torch_instance(model, torch.nn.Linear)
17 |
--------------------------------------------------------------------------------
/test/test_seed.py:
--------------------------------------------------------------------------------
1 | import random
2 |
3 | import numpy as np
4 | import torch
5 |
6 | from torch_geometric import seed_everything
7 |
8 |
9 | def test_seed_everything():
10 | seed_everything(0)
11 |
12 | assert random.randint(0, 100) == 49
13 | assert random.randint(0, 100) == 97
14 | assert np.random.randint(0, 100) == 44
15 | assert np.random.randint(0, 100) == 47
16 | assert int(torch.randint(0, 100, (1, ))) == 44
17 | assert int(torch.randint(0, 100, (1, ))) == 39
18 |
--------------------------------------------------------------------------------
/test/test_warnings.py:
--------------------------------------------------------------------------------
1 | import warnings
2 | from unittest.mock import patch
3 |
4 | import pytest
5 |
6 | from torch_geometric.warnings import WarningCache, warn
7 |
8 |
9 | def test_warn():
10 | with pytest.warns(UserWarning, match='test'):
11 | warn('test')
12 |
13 |
14 | @patch('torch_geometric.is_compiling', return_value=True)
15 | def test_no_warn_if_compiling(_):
16 | """No warning should be raised to avoid graph breaks when compiling."""
17 | with warnings.catch_warnings():
18 | warnings.simplefilter('error')
19 | warn('test')
20 |
21 |
22 | def test_warning_cache():
23 | cache = WarningCache()
24 | assert len(cache) == 0
25 |
26 | cache.warn('test')
27 | assert len(cache) == 1
28 | assert 'test' in cache
29 |
30 | cache.warn('test')
31 | assert len(cache) == 1
32 |
33 | cache.warn('test2')
34 | assert len(cache) == 2
35 | assert 'test2' in cache
36 |
--------------------------------------------------------------------------------
/test/testing/test_decorators.py:
--------------------------------------------------------------------------------
1 | import torch_geometric.typing
2 | from torch_geometric.testing import disableExtensions
3 |
4 |
5 | def test_enable_extensions():
6 | try:
7 | import pyg_lib # noqa
8 | assert torch_geometric.typing.WITH_PYG_LIB
9 | except (ImportError, OSError):
10 | assert not torch_geometric.typing.WITH_PYG_LIB
11 |
12 | try:
13 | import torch_scatter # noqa
14 | assert torch_geometric.typing.WITH_TORCH_SCATTER
15 | except (ImportError, OSError):
16 | assert not torch_geometric.typing.WITH_TORCH_SCATTER
17 |
18 | try:
19 | import torch_sparse # noqa
20 | assert torch_geometric.typing.WITH_TORCH_SPARSE
21 | except (ImportError, OSError):
22 | assert not torch_geometric.typing.WITH_TORCH_SPARSE
23 |
24 |
25 | @disableExtensions
26 | def test_disable_extensions():
27 | assert not torch_geometric.typing.WITH_PYG_LIB
28 | assert not torch_geometric.typing.WITH_TORCH_SCATTER
29 | assert not torch_geometric.typing.WITH_TORCH_SPARSE
30 |
--------------------------------------------------------------------------------
/test/transforms/test_add_gpse.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from torch_geometric.data import Data
4 | from torch_geometric.nn import GPSE
5 | from torch_geometric.nn.models.gpse import IdentityHead
6 | from torch_geometric.transforms import AddGPSE
7 |
8 | num_nodes = 6
9 | gpse_inner_dim = 512
10 |
11 |
12 | def test_gpse():
13 | x = torch.randn(num_nodes, 4)
14 | edge_index = torch.tensor([[0, 1, 0, 4, 1, 4, 2, 3, 3, 5],
15 | [1, 0, 4, 0, 4, 1, 3, 2, 5, 3]])
16 | data = Data(x=x, edge_index=edge_index)
17 |
18 | model = GPSE()
19 | model.post_mp = IdentityHead()
20 | transform = AddGPSE(model)
21 |
22 | assert str(transform) == 'AddGPSE()'
23 | out = transform(data)
24 | assert out.pestat_GPSE.size() == (num_nodes, gpse_inner_dim)
25 |
--------------------------------------------------------------------------------
/test/transforms/test_center.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from torch_geometric.data import Data
4 | from torch_geometric.transforms import Center
5 |
6 |
7 | def test_center():
8 | transform = Center()
9 | assert str(transform) == 'Center()'
10 |
11 | pos = torch.tensor([[0.0, 0.0], [2.0, 0.0], [4.0, 0.0]])
12 | data = Data(pos=pos)
13 |
14 | data = transform(data)
15 | assert len(data) == 1
16 | assert data.pos.tolist() == [[-2, 0], [0, 0], [2, 0]]
17 |
--------------------------------------------------------------------------------
/test/transforms/test_generate_mesh_normals.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from torch_geometric.data import Data
4 | from torch_geometric.transforms import GenerateMeshNormals
5 |
6 |
7 | def test_generate_mesh_normals():
8 | transform = GenerateMeshNormals()
9 | assert str(transform) == 'GenerateMeshNormals()'
10 |
11 | pos = torch.tensor([
12 | [0.0, 0.0, 0.0],
13 | [-2.0, 1.0, 0.0],
14 | [-1.0, 1.0, 0.0],
15 | [0.0, 1.0, 0.0],
16 | [1.0, 1.0, 0.0],
17 | [2.0, 1.0, 0.0],
18 | ])
19 | face = torch.tensor([
20 | [0, 0, 0, 0],
21 | [1, 2, 3, 4],
22 | [2, 3, 4, 5],
23 | ])
24 |
25 | data = transform(Data(pos=pos, face=face))
26 | assert len(data) == 3
27 | assert data.pos.tolist() == pos.tolist()
28 | assert data.face.tolist() == face.tolist()
29 | assert data.norm.tolist() == [[0.0, 0.0, -1.0]] * 6
30 |
--------------------------------------------------------------------------------
/test/transforms/test_grid_sampling.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from torch_geometric.data import Data
4 | from torch_geometric.testing import withPackage
5 | from torch_geometric.transforms import GridSampling
6 |
7 |
8 | @withPackage('torch_cluster')
9 | def test_grid_sampling():
10 | assert str(GridSampling(5)) == 'GridSampling(size=5)'
11 |
12 | pos = torch.tensor([
13 | [0.0, 2.0],
14 | [3.0, 2.0],
15 | [3.0, 2.0],
16 | [2.0, 8.0],
17 | [2.0, 6.0],
18 | ])
19 | y = torch.tensor([0, 1, 1, 2, 2])
20 | batch = torch.tensor([0, 0, 0, 0, 0])
21 |
22 | data = Data(pos=pos, y=y, batch=batch)
23 | data = GridSampling(size=5, start=0)(data)
24 | assert len(data) == 3
25 | assert data.pos.tolist() == [[2, 2], [2, 7]]
26 | assert data.y.tolist() == [1, 2]
27 | assert data.batch.tolist() == [0, 0]
28 |
--------------------------------------------------------------------------------
/test/transforms/test_knn_graph.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from torch_geometric.data import Data
4 | from torch_geometric.testing import withPackage
5 | from torch_geometric.transforms import KNNGraph
6 |
7 |
8 | @withPackage('torch_cluster')
9 | def test_knn_graph():
10 | assert str(KNNGraph()) == 'KNNGraph(k=6)'
11 |
12 | pos = torch.tensor([
13 | [0.0, 0.0],
14 | [1.0, 0.0],
15 | [2.0, 0.0],
16 | [0.0, 1.0],
17 | [-2.0, 0.0],
18 | [0.0, -2.0],
19 | ])
20 |
21 | expected_row = [0, 0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 3, 3, 3, 4, 4, 5, 5]
22 | expected_col = [1, 2, 3, 4, 5, 0, 2, 3, 5, 0, 1, 0, 1, 4, 0, 3, 0, 1]
23 |
24 | data = Data(pos=pos)
25 | data = KNNGraph(k=2, force_undirected=True)(data)
26 | assert len(data) == 2
27 | assert data.pos.tolist() == pos.tolist()
28 | assert data.edge_index[0].tolist() == expected_row
29 | assert data.edge_index[1].tolist() == expected_col
30 |
--------------------------------------------------------------------------------
/test/transforms/test_linear_transformation.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import torch
3 |
4 | from torch_geometric.data import Data
5 | from torch_geometric.transforms import LinearTransformation
6 |
7 |
8 | @pytest.mark.parametrize('matrix', [
9 | [[2.0, 0.0], [0.0, 2.0]],
10 | torch.tensor([[2.0, 0.0], [0.0, 2.0]]),
11 | ])
12 | def test_linear_transformation(matrix):
13 | pos = torch.tensor([[-1.0, 1.0], [-3.0, 0.0], [2.0, -1.0]])
14 |
15 | transform = LinearTransformation(matrix)
16 | assert str(transform) == ('LinearTransformation(\n'
17 | '[[2. 0.]\n'
18 | ' [0. 2.]]\n'
19 | ')')
20 |
21 | out = transform(Data(pos=pos))
22 | assert len(out) == 1
23 | assert torch.allclose(out.pos, 2 * pos)
24 |
25 | out = transform(Data())
26 | assert len(out) == 0
27 |
--------------------------------------------------------------------------------
/test/transforms/test_local_degree_profile.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from torch_geometric.data import Data
4 | from torch_geometric.transforms import LocalDegreeProfile
5 |
6 |
7 | def test_target_indegree():
8 | assert str(LocalDegreeProfile()) == 'LocalDegreeProfile()'
9 |
10 | edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]])
11 | x = torch.tensor([[1.0], [1.0], [1.0], [1.0]]) # One isolated node.
12 |
13 | expected = torch.tensor([
14 | [1, 2, 2, 2, 0],
15 | [2, 1, 1, 1, 0],
16 | [1, 2, 2, 2, 0],
17 | [0, 0, 0, 0, 0],
18 | ], dtype=torch.float)
19 |
20 | data = Data(edge_index=edge_index, num_nodes=x.size(0))
21 | data = LocalDegreeProfile()(data)
22 | assert torch.allclose(data.x, expected, atol=1e-2)
23 |
24 | data = Data(edge_index=edge_index, x=x)
25 | data = LocalDegreeProfile()(data)
26 | assert torch.allclose(data.x[:, :1], x)
27 | assert torch.allclose(data.x[:, 1:], expected, atol=1e-2)
28 |
--------------------------------------------------------------------------------
/test/transforms/test_normalize_features.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from torch_geometric.data import Data, HeteroData
4 | from torch_geometric.transforms import NormalizeFeatures
5 |
6 |
7 | def test_normalize_scale():
8 | transform = NormalizeFeatures()
9 | assert str(transform) == 'NormalizeFeatures()'
10 |
11 | x = torch.tensor([[1, 0, 1], [0, 1, 0], [0, 0, 0]], dtype=torch.float)
12 | data = Data(x=x)
13 |
14 | data = transform(data)
15 | assert len(data) == 1
16 | assert data.x.tolist() == [[0.5, 0, 0.5], [0, 1, 0], [0, 0, 0]]
17 |
18 |
19 | def test_hetero_normalize_scale():
20 | x = torch.tensor([[1, 0, 1], [0, 1, 0], [0, 0, 0]], dtype=torch.float)
21 |
22 | data = HeteroData()
23 | data['v'].x = x
24 | data['w'].x = x
25 | data = NormalizeFeatures()(data)
26 | assert data['v'].x.tolist() == [[0.5, 0, 0.5], [0, 1, 0], [0, 0, 0]]
27 | assert data['w'].x.tolist() == [[0.5, 0, 0.5], [0, 1, 0], [0, 0, 0]]
28 |
--------------------------------------------------------------------------------
/test/transforms/test_normalize_scale.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from torch_geometric.data import Data
4 | from torch_geometric.transforms import NormalizeScale
5 |
6 |
7 | def test_normalize_scale():
8 | transform = NormalizeScale()
9 | assert str(transform) == 'NormalizeScale()'
10 |
11 | pos = torch.randn((10, 3))
12 | data = Data(pos=pos)
13 |
14 | data = transform(data)
15 | assert len(data) == 1
16 | assert data.pos.min().item() > -1
17 | assert data.pos.max().item() < 1
18 |
--------------------------------------------------------------------------------
/test/transforms/test_radius_graph.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from torch_geometric.data import Data
4 | from torch_geometric.testing import withPackage
5 | from torch_geometric.transforms import RadiusGraph
6 | from torch_geometric.utils import coalesce
7 |
8 |
9 | @withPackage('torch_cluster')
10 | def test_radius_graph():
11 | assert str(RadiusGraph(r=1)) == 'RadiusGraph(r=1)'
12 |
13 | pos = torch.tensor([
14 | [0.0, 0.0],
15 | [1.0, 0.0],
16 | [2.0, 0.0],
17 | [0.0, 1.0],
18 | [-2.0, 0.0],
19 | [0.0, -2.0],
20 | ])
21 |
22 | data = Data(pos=pos)
23 | data = RadiusGraph(r=1.5)(data)
24 | assert len(data) == 2
25 | assert data.pos.tolist() == pos.tolist()
26 | assert coalesce(data.edge_index).tolist() == [[0, 0, 1, 1, 1, 2, 3, 3],
27 | [1, 3, 0, 2, 3, 1, 0, 1]]
28 |
--------------------------------------------------------------------------------
/test/transforms/test_random_flip.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from torch_geometric.data import Data
4 | from torch_geometric.transforms import RandomFlip
5 |
6 |
7 | def test_random_flip():
8 | assert str(RandomFlip(axis=0)) == 'RandomFlip(axis=0, p=0.5)'
9 |
10 | pos = torch.tensor([[-1.0, 1.0], [-3.0, 0.0], [2.0, -1.0]])
11 |
12 | data = Data(pos=pos)
13 | data = RandomFlip(axis=0, p=1)(data)
14 | assert len(data) == 1
15 | assert data.pos.tolist() == [[1.0, 1.0], [3.0, 0.0], [-2.0, -1.0]]
16 |
17 | data = Data(pos=pos)
18 | data = RandomFlip(axis=1, p=1)(data)
19 | assert len(data) == 1
20 | assert data.pos.tolist() == [[-1.0, -1.0], [-3.0, 0.0], [2.0, 1.0]]
21 |
--------------------------------------------------------------------------------
/test/transforms/test_random_jitter.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from torch_geometric.data import Data
4 | from torch_geometric.transforms import RandomJitter
5 |
6 |
7 | def test_random_jitter():
8 | assert str(RandomJitter(0.1)) == 'RandomJitter(0.1)'
9 |
10 | pos = torch.tensor([[0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0]])
11 |
12 | data = Data(pos=pos)
13 | data = RandomJitter(0)(data)
14 | assert len(data) == 1
15 | assert torch.allclose(data.pos, pos)
16 |
17 | data = Data(pos=pos)
18 | data = RandomJitter(0.1)(data)
19 | assert len(data) == 1
20 | assert data.pos.min() >= -0.1
21 | assert data.pos.max() <= 0.1
22 |
23 | data = Data(pos=pos)
24 | data = RandomJitter([0.1, 1])(data)
25 | assert len(data) == 1
26 | assert data.pos[:, 0].min() >= -0.1
27 | assert data.pos[:, 0].max() <= 0.1
28 | assert data.pos[:, 1].min() >= -1
29 | assert data.pos[:, 1].max() <= 1
30 |
--------------------------------------------------------------------------------
/test/transforms/test_random_scale.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from torch_geometric.data import Data
4 | from torch_geometric.transforms import RandomScale
5 |
6 |
7 | def test_random_scale():
8 | assert str(RandomScale([1, 2])) == 'RandomScale([1, 2])'
9 |
10 | pos = torch.tensor([[-1.0, -1.0], [-1.0, 1.0], [1.0, -1.0], [1.0, 1.0]])
11 |
12 | data = Data(pos=pos)
13 | data = RandomScale([1, 1])(data)
14 | assert len(data) == 1
15 | assert data.pos.tolist() == pos.tolist()
16 |
17 | data = Data(pos=pos)
18 | data = RandomScale([2, 2])(data)
19 | assert len(data) == 1
20 | assert data.pos.tolist() == [[-2, -2], [-2, 2], [2, -2], [2, 2]]
21 |
--------------------------------------------------------------------------------
/test/transforms/test_random_shear.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from torch_geometric.data import Data
4 | from torch_geometric.transforms import RandomShear
5 |
6 |
7 | def test_random_shear():
8 | assert str(RandomShear(0.1)) == 'RandomShear(0.1)'
9 |
10 | pos = torch.tensor([[-1.0, -1.0], [-1.0, 1.0], [1.0, -1.0], [1.0, 1.0]])
11 |
12 | data = Data(pos=pos)
13 | data = RandomShear(0)(data)
14 | assert len(data) == 1
15 | assert torch.allclose(data.pos, pos)
16 |
17 | data = Data(pos=pos)
18 | data = RandomShear(0.1)(data)
19 | assert len(data) == 1
20 | assert not torch.allclose(data.pos, pos)
21 |
--------------------------------------------------------------------------------
/test/transforms/test_remove_duplicated_edges.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from torch_geometric.data import Data
4 | from torch_geometric.transforms import RemoveDuplicatedEdges
5 |
6 |
7 | def test_remove_duplicated_edges():
8 | edge_index = torch.tensor([[0, 0, 0, 0, 1, 1, 1, 1],
9 | [0, 0, 1, 1, 0, 0, 1, 1]])
10 | edge_weight = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8])
11 | data = Data(edge_index=edge_index, edge_weight=edge_weight, num_nodes=2)
12 |
13 | transform = RemoveDuplicatedEdges()
14 | assert str(transform) == 'RemoveDuplicatedEdges()'
15 |
16 | out = transform(data)
17 | assert len(out) == 3
18 | assert out.num_nodes == 2
19 | assert out.edge_index.tolist() == [[0, 0, 1, 1], [0, 1, 0, 1]]
20 | assert out.edge_weight.tolist() == [3, 7, 11, 15]
21 |
--------------------------------------------------------------------------------
/test/transforms/test_remove_training_classes.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from torch_geometric.data import Data
4 | from torch_geometric.transforms import RemoveTrainingClasses
5 |
6 |
7 | def test_remove_training_classes():
8 | y = torch.tensor([1, 0, 0, 2, 1, 3])
9 | train_mask = torch.tensor([False, False, True, True, True, True])
10 |
11 | data = Data(y=y, train_mask=train_mask)
12 |
13 | transform = RemoveTrainingClasses(classes=[0, 1])
14 | assert str(transform) == 'RemoveTrainingClasses([0, 1])'
15 |
16 | data = transform(data)
17 | assert len(data) == 2
18 | assert torch.equal(data.y, y)
19 | assert data.train_mask.tolist() == [False, False, False, True, False, True]
20 |
--------------------------------------------------------------------------------
/test/transforms/test_sample_points.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from torch_geometric.data import Data
4 | from torch_geometric.transforms import SamplePoints
5 |
6 |
7 | def test_sample_points():
8 | assert str(SamplePoints(1024)) == 'SamplePoints(1024)'
9 |
10 | pos = torch.tensor([
11 | [0.0, 0.0, 0.0],
12 | [1.0, 0.0, 0.0],
13 | [0.0, 1.0, 0.0],
14 | [1.0, 1.0, 0.0],
15 | ])
16 | face = torch.tensor([[0, 1], [1, 2], [2, 3]])
17 |
18 | data = Data(pos=pos)
19 | data.face = face
20 | data = SamplePoints(8)(data)
21 | assert len(data) == 1
22 | assert pos[:, 0].min() >= 0 and pos[:, 0].max() <= 1
23 | assert pos[:, 1].min() >= 0 and pos[:, 1].max() <= 1
24 | assert pos[:, 2].abs().sum() == 0
25 |
26 | data = Data(pos=pos)
27 | data.face = face
28 | data = SamplePoints(8, include_normals=True)(data)
29 | assert len(data) == 2
30 | assert data.normal[:, :2].abs().sum() == 0
31 | assert data.normal[:, 2].abs().sum() == 8
32 |
--------------------------------------------------------------------------------
/test/transforms/test_sign.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from torch_geometric.data import Data
4 | from torch_geometric.transforms import SIGN
5 |
6 |
7 | def test_sign():
8 | x = torch.ones(5, 3)
9 | edge_index = torch.tensor([
10 | [0, 1, 2, 3, 3, 4],
11 | [1, 0, 3, 2, 4, 3],
12 | ])
13 | data = Data(x=x, edge_index=edge_index)
14 |
15 | transform = SIGN(K=2)
16 | assert str(transform) == 'SIGN(K=2)'
17 |
18 | expected_x1 = torch.tensor([
19 | [1, 1, 1],
20 | [1, 1, 1],
21 | [0.7071, 0.7071, 0.7071],
22 | [1.4142, 1.4142, 1.4142],
23 | [0.7071, 0.7071, 0.7071],
24 | ])
25 | expected_x2 = torch.ones(5, 3)
26 |
27 | out = transform(data)
28 | assert len(out) == 4
29 | assert torch.equal(out.edge_index, edge_index)
30 | assert torch.allclose(out.x, x)
31 | assert torch.allclose(out.x1, expected_x1, atol=1e-4)
32 | assert torch.allclose(out.x2, expected_x2)
33 |
--------------------------------------------------------------------------------
/test/transforms/test_svd_feature_reduction.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from torch_geometric.data import Data
4 | from torch_geometric.transforms import SVDFeatureReduction
5 |
6 |
7 | def test_svd_feature_reduction():
8 | assert str(SVDFeatureReduction(10)) == 'SVDFeatureReduction(10)'
9 |
10 | x = torch.randn(4, 16)
11 | U, S, _ = torch.linalg.svd(x)
12 | data = Data(x=x)
13 | data = SVDFeatureReduction(10)(data)
14 | assert torch.allclose(data.x, torch.mm(U[:, :10], torch.diag(S[:10])))
15 |
16 | x = torch.randn(4, 8)
17 | data.x = x
18 | data = SVDFeatureReduction(10)(Data(x=x))
19 | assert torch.allclose(data.x, x)
20 |
--------------------------------------------------------------------------------
/test/transforms/test_target_indegree.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from torch_geometric.data import Data
4 | from torch_geometric.transforms import TargetIndegree
5 |
6 |
7 | def test_target_indegree():
8 | assert str(TargetIndegree()) == 'TargetIndegree(norm=True, max_value=None)'
9 |
10 | edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]])
11 | edge_attr = torch.tensor([1.0, 1.0, 1.0, 1.0])
12 |
13 | data = Data(edge_index=edge_index, num_nodes=3)
14 | data = TargetIndegree(norm=False)(data)
15 | assert len(data) == 3
16 | assert data.edge_index.tolist() == edge_index.tolist()
17 | assert data.edge_attr.tolist() == [[2], [1], [1], [2]]
18 | assert data.num_nodes == 3
19 |
20 | data = Data(edge_index=edge_index, edge_attr=edge_attr, num_nodes=3)
21 | data = TargetIndegree(norm=True)(data)
22 | assert len(data) == 3
23 | assert data.edge_index.tolist() == edge_index.tolist()
24 | assert data.edge_attr.tolist() == [[1, 1], [1, 0.5], [1, 0.5], [1, 1]]
25 | assert data.num_nodes == 3
26 |
--------------------------------------------------------------------------------
/test/transforms/test_to_device.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from torch_geometric.data import Data
4 | from torch_geometric.testing import withDevice
5 | from torch_geometric.transforms import ToDevice
6 |
7 |
8 | @withDevice
9 | def test_to_device(device):
10 | x = torch.randn(3, 4)
11 | edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]])
12 | edge_weight = torch.randn(edge_index.size(1))
13 |
14 | data = Data(x=x, edge_index=edge_index, edge_weight=edge_weight)
15 |
16 | transform = ToDevice(device)
17 | assert str(transform) == f'ToDevice({device})'
18 |
19 | data = transform(data)
20 | for _, value in data:
21 | assert value.device == device
22 |
--------------------------------------------------------------------------------
/test/utils/test_degree.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from torch_geometric.utils import degree
4 |
5 |
6 | def test_degree():
7 | row = torch.tensor([0, 1, 0, 2, 0])
8 | deg = degree(row, dtype=torch.long)
9 | assert deg.dtype == torch.long
10 | assert deg.tolist() == [3, 1, 1]
11 |
--------------------------------------------------------------------------------
/test/utils/test_functions.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from torch_geometric.utils import cumsum
4 |
5 |
6 | def test_cumsum():
7 | x = torch.tensor([2, 4, 1])
8 | assert cumsum(x).tolist() == [0, 2, 6, 7]
9 |
10 | x = torch.tensor([[2, 4], [3, 6]])
11 | assert cumsum(x, dim=1).tolist() == [[0, 2, 6], [0, 3, 9]]
12 |
--------------------------------------------------------------------------------
/test/utils/test_grid.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from torch_geometric.testing import is_full_test
4 | from torch_geometric.utils import grid
5 |
6 |
7 | def test_grid():
8 | (row, col), pos = grid(height=3, width=2)
9 |
10 | expected_row = [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2]
11 | expected_col = [0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 4, 5]
12 | expected_row += [3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 5, 5, 5, 5]
13 | expected_col += [0, 1, 2, 3, 4, 5, 2, 3, 4, 5, 2, 3, 4, 5]
14 |
15 | expected_pos = [[0, 2], [1, 2], [0, 1], [1, 1], [0, 0], [1, 0]]
16 |
17 | assert row.tolist() == expected_row
18 | assert col.tolist() == expected_col
19 | assert pos.tolist() == expected_pos
20 |
21 | if is_full_test():
22 | jit = torch.jit.script(grid)
23 | (row, col), pos = jit(height=3, width=2)
24 | assert row.tolist() == expected_row
25 | assert col.tolist() == expected_col
26 | assert pos.tolist() == expected_pos
27 |
--------------------------------------------------------------------------------
/test/utils/test_index_sort.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from torch_geometric.testing import withDevice
4 | from torch_geometric.utils import index_sort
5 |
6 |
7 | @withDevice
8 | def test_index_sort_stable(device):
9 | for _ in range(100):
10 | inputs = torch.randint(0, 4, size=(10, ), device=device)
11 |
12 | out = index_sort(inputs, stable=True)
13 | expected = torch.sort(inputs, stable=True)
14 |
15 | assert torch.equal(out[0], expected[0])
16 | assert torch.equal(out[1], expected[1])
17 |
--------------------------------------------------------------------------------
/test/utils/test_lexsort.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 |
4 | from torch_geometric.utils import lexsort
5 |
6 |
7 | def test_lexsort():
8 | keys = [torch.randn(100) for _ in range(3)]
9 |
10 | expected = np.lexsort([key.numpy() for key in keys])
11 | assert torch.equal(lexsort(keys), torch.from_numpy(expected))
12 |
--------------------------------------------------------------------------------
/test/utils/test_mask.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from torch_geometric.utils import index_to_mask, mask_select, mask_to_index
4 |
5 |
6 | def test_mask_select():
7 | src = torch.randn(6, 8)
8 | mask = torch.tensor([False, True, False, True, False, True])
9 |
10 | out = mask_select(src, 0, mask)
11 | assert out.size() == (3, 8)
12 | assert torch.equal(src[torch.tensor([1, 3, 5])], out)
13 |
14 | jit = torch.jit.script(mask_select)
15 | assert torch.equal(jit(src, 0, mask), out)
16 |
17 |
18 | def test_index_to_mask():
19 | index = torch.tensor([1, 3, 5])
20 |
21 | mask = index_to_mask(index)
22 | assert mask.tolist() == [False, True, False, True, False, True]
23 |
24 | mask = index_to_mask(index, size=7)
25 | assert mask.tolist() == [False, True, False, True, False, True, False]
26 |
27 |
28 | def test_mask_to_index():
29 | mask = torch.tensor([False, True, False, True, False, True])
30 |
31 | index = mask_to_index(mask)
32 | assert index.tolist() == [1, 3, 5]
33 |
--------------------------------------------------------------------------------
/test/utils/test_noise_scheduler.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import torch
3 |
4 | from torch_geometric.utils.noise_scheduler import (
5 | get_diffusion_beta_schedule,
6 | get_smld_sigma_schedule,
7 | )
8 |
9 |
10 | def test_get_smld_sigma_schedule():
11 | expected = torch.tensor([
12 | 1., 0.59948425, 0.35938137, 0.21544347, 0.12915497, 0.07742637,
13 | 0.04641589, 0.02782559, 0.01668101, 0.01
14 | ])
15 | out = get_smld_sigma_schedule(
16 | sigma_min=0.01,
17 | sigma_max=1.0,
18 | num_scales=10,
19 | )
20 | assert torch.allclose(out, expected)
21 |
22 |
23 | @pytest.mark.parametrize(
24 | 'schedule_type',
25 | ['linear', 'quadratic', 'constant', 'sigmoid'],
26 | )
27 | def test_get_diffusion_beta_schedule(schedule_type):
28 | out = get_diffusion_beta_schedule(
29 | schedule_type,
30 | beta_start=0.1,
31 | beta_end=0.2,
32 | num_diffusion_timesteps=10,
33 | )
34 | assert out.size() == (10, )
35 |
--------------------------------------------------------------------------------
/test/utils/test_normalize_edge_index.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import torch
3 |
4 | from torch_geometric.utils import normalize_edge_index
5 |
6 |
7 | @pytest.mark.parametrize('add_self_loops', [False, True])
8 | @pytest.mark.parametrize('symmetric', [False, True])
9 | def test_normalize_edge_index(add_self_loops: bool, symmetric: bool):
10 | edge_index = torch.tensor([[0, 2, 2, 3], [2, 0, 3, 0]])
11 |
12 | out = normalize_edge_index(
13 | edge_index,
14 | add_self_loops=add_self_loops,
15 | symmetric=symmetric,
16 | )
17 | assert isinstance(out, tuple) and len(out) == 2
18 | if not add_self_loops:
19 | assert out[0].equal(edge_index)
20 | else:
21 | assert out[0].tolist() == [
22 | [0, 2, 2, 3, 0, 1, 2, 3],
23 | [2, 0, 3, 0, 0, 1, 2, 3],
24 | ]
25 |
26 | assert out[1].min() >= 0.0
27 | assert out[1].min() <= 1.0
28 |
--------------------------------------------------------------------------------
/test/utils/test_normalized_cut.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from torch_geometric.testing import is_full_test
4 | from torch_geometric.utils import normalized_cut
5 |
6 |
7 | def test_normalized_cut():
8 | row = torch.tensor([0, 1, 1, 1, 2, 2, 3, 3, 4, 4])
9 | col = torch.tensor([1, 0, 2, 3, 1, 4, 1, 4, 2, 3])
10 | edge_attr = torch.tensor(
11 | [3.0, 3.0, 6.0, 3.0, 6.0, 1.0, 3.0, 2.0, 1.0, 2.0])
12 | expected = torch.tensor([4.0, 4.0, 5.0, 2.5, 5.0, 1.0, 2.5, 2.0, 1.0, 2.0])
13 |
14 | out = normalized_cut(torch.stack([row, col], dim=0), edge_attr)
15 | assert torch.allclose(out, expected)
16 |
17 | if is_full_test():
18 | jit = torch.jit.script(normalized_cut)
19 | out = jit(torch.stack([row, col], dim=0), edge_attr)
20 | assert torch.allclose(out, expected)
21 |
--------------------------------------------------------------------------------
/test/utils/test_num_nodes.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from torch_geometric.utils import to_torch_coo_tensor
4 | from torch_geometric.utils.num_nodes import (
5 | maybe_num_nodes,
6 | maybe_num_nodes_dict,
7 | )
8 |
9 |
10 | def test_maybe_num_nodes():
11 | edge_index = torch.tensor([[0, 0, 1, 1, 2, 2, 2], [1, 2, 0, 2, 0, 1, 1]])
12 |
13 | assert maybe_num_nodes(edge_index, 4) == 4
14 | assert maybe_num_nodes(edge_index) == 3
15 |
16 | adj = to_torch_coo_tensor(edge_index)
17 | assert maybe_num_nodes(adj, 4) == 4
18 | assert maybe_num_nodes(adj) == 3
19 |
20 |
21 | def test_maybe_num_nodes_dict():
22 | edge_index_dict = {
23 | '1': torch.tensor([[0, 0, 1, 1, 2, 2, 2], [1, 2, 0, 2, 0, 1, 1]]),
24 | '2': torch.tensor([[0, 0, 1, 3], [1, 2, 0, 4]])
25 | }
26 | num_nodes_dict = {'2': 6}
27 |
28 | assert maybe_num_nodes_dict(edge_index_dict) == {'1': 3, '2': 5}
29 | assert maybe_num_nodes_dict(edge_index_dict, num_nodes_dict) == {
30 | '1': 3,
31 | '2': 6,
32 | }
33 |
--------------------------------------------------------------------------------
/test/utils/test_one_hot.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from torch_geometric.utils import one_hot
4 |
5 |
6 | def test_one_hot():
7 | index = torch.tensor([0, 1, 2])
8 |
9 | out = one_hot(index)
10 | assert out.size() == (3, 3)
11 | assert out.dtype == torch.float
12 | assert out.tolist() == [[1, 0, 0], [0, 1, 0], [0, 0, 1]]
13 |
14 | out = one_hot(index, num_classes=4, dtype=torch.long)
15 | assert out.size() == (3, 4)
16 | assert out.dtype == torch.long
17 | assert out.tolist() == [[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0]]
18 |
--------------------------------------------------------------------------------
/test/utils/test_ppr.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import torch
3 |
4 | from torch_geometric.datasets import KarateClub
5 | from torch_geometric.testing import withPackage
6 | from torch_geometric.utils import get_ppr
7 |
8 |
9 | @withPackage('numba')
10 | @pytest.mark.parametrize('target', [None, torch.tensor([0, 4, 5, 6])])
11 | def test_get_ppr(target):
12 | data = KarateClub()[0]
13 |
14 | edge_index, edge_weight = get_ppr(
15 | data.edge_index,
16 | alpha=0.1,
17 | eps=1e-5,
18 | target=target,
19 | )
20 |
21 | assert edge_index.size(0) == 2
22 | assert edge_index.size(1) == edge_weight.numel()
23 |
24 | min_row = 0 if target is None else target.min()
25 | max_row = data.num_nodes - 1 if target is None else target.max()
26 | assert edge_index[0].min() == min_row and edge_index[0].max() == max_row
27 | assert edge_index[1].min() >= 0 and edge_index[1].max() < data.num_nodes
28 | assert edge_weight.min() >= 0.0 and edge_weight.max() <= 1.0
29 |
--------------------------------------------------------------------------------
/test/utils/test_repeat.py:
--------------------------------------------------------------------------------
1 | from torch_geometric.utils.repeat import repeat
2 |
3 |
4 | def test_repeat():
5 | assert repeat(None, length=4) is None
6 | assert repeat(4, length=4) == [4, 4, 4, 4]
7 | assert repeat([2, 3, 4], length=4) == [2, 3, 4, 4]
8 | assert repeat([1, 2, 3, 4], length=4) == [1, 2, 3, 4]
9 | assert repeat([1, 2, 3, 4, 5], length=4) == [1, 2, 3, 4]
10 |
--------------------------------------------------------------------------------
/test/utils/test_select.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from torch_geometric.utils import narrow, select
4 |
5 |
6 | def test_select():
7 | src = torch.randn(5, 3)
8 | index = torch.tensor([0, 2, 4])
9 | mask = torch.tensor([True, False, True, False, True])
10 |
11 | out = select(src, index, dim=0)
12 | assert torch.equal(out, src[index])
13 | assert torch.equal(out, select(src, mask, dim=0))
14 | assert torch.equal(out, torch.tensor(select(src.tolist(), index, dim=0)))
15 | assert torch.equal(out, torch.tensor(select(src.tolist(), mask, dim=0)))
16 |
17 |
18 | def test_narrow():
19 | src = torch.randn(5, 3)
20 |
21 | out = narrow(src, dim=0, start=2, length=2)
22 | assert torch.equal(out, src[2:4])
23 | assert torch.equal(out, torch.tensor(narrow(src.tolist(), 0, 2, 2)))
24 |
--------------------------------------------------------------------------------
/test/utils/test_tree_decomposition.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from torch_geometric.testing import withPackage
4 | from torch_geometric.utils import tree_decomposition
5 |
6 |
7 | @withPackage('rdkit')
8 | @pytest.mark.parametrize('smiles', [
9 | r'F/C=C/F',
10 | r'C/C(=C\C(=O)c1ccc(C)o1)Nc1ccc2c(c1)OCO2',
11 | ])
12 | def test_tree_decomposition(smiles):
13 | from rdkit import Chem
14 | mol = Chem.MolFromSmiles(smiles)
15 | tree_decomposition(mol) # TODO Test output
16 |
--------------------------------------------------------------------------------
/test/utils/test_unbatch.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from torch_geometric.utils import unbatch, unbatch_edge_index
4 |
5 |
6 | def test_unbatch():
7 | src = torch.arange(10)
8 | batch = torch.tensor([0, 0, 0, 1, 1, 2, 2, 3, 4, 4])
9 |
10 | out = unbatch(src, batch)
11 | assert len(out) == 5
12 | for i in range(len(out)):
13 | assert torch.equal(out[i], src[batch == i])
14 |
15 |
16 | def test_unbatch_edge_index():
17 | edge_index = torch.tensor([
18 | [0, 1, 1, 2, 2, 3, 4, 5, 5, 6],
19 | [1, 0, 2, 1, 3, 2, 5, 4, 6, 5],
20 | ])
21 | batch = torch.tensor([0, 0, 0, 0, 1, 1, 1])
22 |
23 | edge_indices = unbatch_edge_index(edge_index, batch)
24 | assert edge_indices[0].tolist() == [[0, 1, 1, 2, 2, 3], [1, 0, 2, 1, 3, 2]]
25 | assert edge_indices[1].tolist() == [[0, 1, 1, 2], [1, 0, 2, 1]]
26 |
--------------------------------------------------------------------------------
/test/visualization/test_influence.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from torch_geometric.datasets import KarateClub
4 | from torch_geometric.nn import GCNConv
5 | from torch_geometric.visualization import influence
6 |
7 |
8 | class Net(torch.nn.Module):
9 | def __init__(self, in_channels, out_channels):
10 | super().__init__()
11 | self.conv1 = GCNConv(in_channels, out_channels)
12 | self.conv2 = GCNConv(out_channels, out_channels)
13 |
14 | def forward(self, x, edge_index):
15 | x = torch.nn.functional.relu(self.conv1(x, edge_index))
16 | x = self.conv2(x, edge_index)
17 | return x
18 |
19 |
20 | def test_influence():
21 | data = KarateClub()[0]
22 | x = torch.randn(data.num_nodes, 8)
23 |
24 | out = influence(Net(x.size(1), 16), x, data.edge_index)
25 | assert out.size() == (data.num_nodes, data.num_nodes)
26 | assert torch.allclose(out.sum(dim=-1), torch.ones(data.num_nodes),
27 | atol=1e-04)
28 |
--------------------------------------------------------------------------------
/torch_geometric/contrib/__init__.py:
--------------------------------------------------------------------------------
1 | import warnings
2 |
3 | import torch_geometric.contrib.transforms # noqa
4 | import torch_geometric.contrib.datasets # noqa
5 | import torch_geometric.contrib.nn # noqa
6 | import torch_geometric.contrib.explain # noqa
7 |
8 | warnings.warn(
9 | "'torch_geometric.contrib' contains experimental code and is subject to "
10 | "change. Please use with caution.", stacklevel=2)
11 |
12 | __all__ = []
13 |
--------------------------------------------------------------------------------
/torch_geometric/contrib/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | __all__ = classes = []
2 |
--------------------------------------------------------------------------------
/torch_geometric/contrib/explain/__init__.py:
--------------------------------------------------------------------------------
1 | from torch_geometric.deprecation import deprecated
2 |
3 | from .pgm_explainer import PGMExplainer
4 | from torch_geometric.explain.algorithm.graphmask_explainer import (
5 | GraphMaskExplainer as NewGraphMaskExplainer)
6 |
7 | GraphMaskExplainer = deprecated(
8 | "use 'torch_geometric.explain.algorithm.GraphMaskExplainer' instead", )(
9 | NewGraphMaskExplainer)
10 |
11 | __all__ = classes = [
12 | 'PGMExplainer',
13 | ]
14 |
--------------------------------------------------------------------------------
/torch_geometric/contrib/nn/__init__.py:
--------------------------------------------------------------------------------
1 | from .conv import * # noqa
2 | from .models import * # noqa
3 |
4 | __all__ = []
5 |
--------------------------------------------------------------------------------
/torch_geometric/contrib/nn/conv/__init__.py:
--------------------------------------------------------------------------------
1 | __all__ = classes = []
2 |
--------------------------------------------------------------------------------
/torch_geometric/contrib/nn/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .rbcd_attack import PRBCDAttack, GRBCDAttack
2 |
3 | __all__ = classes = [
4 | 'PRBCDAttack',
5 | 'GRBCDAttack',
6 | ]
7 |
--------------------------------------------------------------------------------
/torch_geometric/contrib/transforms/__init__.py:
--------------------------------------------------------------------------------
1 | __all__ = classes = []
2 |
--------------------------------------------------------------------------------
/torch_geometric/data/lightning/__init__.py:
--------------------------------------------------------------------------------
1 | from .datamodule import LightningDataset, LightningNodeData, LightningLinkData
2 |
3 | __all__ = classes = [
4 | 'LightningDataset',
5 | 'LightningNodeData',
6 | 'LightningLinkData',
7 | ]
8 |
--------------------------------------------------------------------------------
/torch_geometric/data/makedirs.py:
--------------------------------------------------------------------------------
1 | from torch_geometric.deprecation import deprecated
2 | from torch_geometric.io import fs
3 |
4 |
5 | @deprecated("use 'os.makedirs(path, exist_ok=True)' instead")
6 | def makedirs(path: str):
7 | r"""Recursively creates a directory.
8 |
9 | .. warning::
10 |
11 | :meth:`makedirs` is deprecated and will be removed soon.
12 | Please use :obj:`os.makedirs(path, exist_ok=True)` instead.
13 |
14 | Args:
15 | path (str): The path to create.
16 | """
17 | fs.makedirs(path, exist_ok=True)
18 |
--------------------------------------------------------------------------------
/torch_geometric/datasets/graph_generator/__init__.py:
--------------------------------------------------------------------------------
1 | from .base import GraphGenerator
2 | from .ba_graph import BAGraph
3 | from .er_graph import ERGraph
4 | from .grid_graph import GridGraph
5 | from .tree_graph import TreeGraph
6 |
7 | __all__ = classes = [
8 | 'GraphGenerator',
9 | 'BAGraph',
10 | 'ERGraph',
11 | 'GridGraph',
12 | 'TreeGraph',
13 | ]
14 |
--------------------------------------------------------------------------------
/torch_geometric/datasets/motif_generator/__init__.py:
--------------------------------------------------------------------------------
1 | from .base import MotifGenerator
2 | from .custom import CustomMotif
3 | from .house import HouseMotif
4 | from .cycle import CycleMotif
5 | from .grid import GridMotif
6 |
7 | __all__ = classes = [
8 | 'MotifGenerator',
9 | 'CustomMotif',
10 | 'HouseMotif',
11 | 'CycleMotif',
12 | 'GridMotif',
13 | ]
14 |
--------------------------------------------------------------------------------
/torch_geometric/datasets/motif_generator/base.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 | from typing import Any
3 |
4 | from torch_geometric.data import Data
5 | from torch_geometric.resolver import resolver
6 |
7 |
8 | class MotifGenerator(ABC):
9 | r"""An abstract base class for generating a motif."""
10 | @abstractmethod
11 | def __call__(self) -> Data:
12 | r"""To be implemented by :class:`Motif` subclasses."""
13 |
14 | @staticmethod
15 | def resolve(query: Any, *args: Any, **kwargs: Any) -> 'MotifGenerator':
16 | import torch_geometric.datasets.motif_generator as _motif_generators
17 | motif_generators = [
18 | gen for gen in vars(_motif_generators).values()
19 | if isinstance(gen, type) and issubclass(gen, MotifGenerator)
20 | ]
21 | return resolver(motif_generators, {}, query, MotifGenerator, 'Motif',
22 | *args, **kwargs)
23 |
24 | def __repr__(self) -> str:
25 | return f'{self.__class__.__name__}()'
26 |
--------------------------------------------------------------------------------
/torch_geometric/datasets/motif_generator/house.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from torch_geometric.data import Data
4 | from torch_geometric.datasets.motif_generator import CustomMotif
5 |
6 |
7 | class HouseMotif(CustomMotif):
8 | r"""Generates the house-structured motif from the `"GNNExplainer:
9 | Generating Explanations for Graph Neural Networks"
10 | `__ paper, containing 5 nodes and 6
11 | undirected edges. Nodes are labeled according to their structural role:
12 | the top, middle and bottom of the house.
13 | """
14 | def __init__(self) -> None:
15 | structure = Data(
16 | num_nodes=5,
17 | edge_index=torch.tensor([
18 | [0, 0, 0, 1, 1, 1, 2, 2, 3, 3, 4, 4],
19 | [1, 3, 4, 4, 2, 0, 1, 3, 2, 0, 0, 1],
20 | ]),
21 | y=torch.tensor([0, 0, 1, 1, 2]),
22 | )
23 | super().__init__(structure)
24 |
--------------------------------------------------------------------------------
/torch_geometric/datasets/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .cheatsheet import paper_link, has_stats, get_stat, get_children, get_type
2 |
3 | __all__ = [
4 | 'paper_link',
5 | 'has_stats',
6 | 'get_stat',
7 | 'get_children',
8 | 'get_type',
9 | ]
10 |
--------------------------------------------------------------------------------
/torch_geometric/deprecation.py:
--------------------------------------------------------------------------------
1 | import functools
2 | import inspect
3 | import warnings
4 | from typing import Any, Callable, Optional
5 |
6 |
7 | def deprecated(
8 | details: Optional[str] = None,
9 | func_name: Optional[str] = None,
10 | ) -> Callable:
11 | def decorator(func: Callable) -> Callable:
12 | name = func_name or func.__name__
13 |
14 | if inspect.isclass(func):
15 | cls = type(func.__name__, (func, ), {})
16 | cls.__init__ = deprecated(details, name)( # type: ignore
17 | func.__init__)
18 | cls.__doc__ = func.__doc__
19 | return cls
20 |
21 | @functools.wraps(func)
22 | def wrapper(*args: Any, **kwargs: Any) -> Any:
23 | out = f"'{name}' is deprecated"
24 | if details is not None:
25 | out += f", {details}"
26 | warnings.warn(out, stacklevel=2)
27 | return func(*args, **kwargs)
28 |
29 | return wrapper
30 |
31 | return decorator
32 |
--------------------------------------------------------------------------------
/torch_geometric/distributed/dist_context.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from enum import Enum
3 |
4 |
5 | class DistRole(Enum):
6 | WORKER = 1
7 |
8 |
9 | @dataclass
10 | class DistContext:
11 | r"""Context information of the current process."""
12 | rank: int
13 | global_rank: int
14 | world_size: int
15 | global_world_size: int
16 | group_name: str
17 | role: DistRole = DistRole.WORKER
18 |
19 | @property
20 | def worker_name(self) -> str:
21 | return f'{self.group_name}-{self.rank}'
22 |
--------------------------------------------------------------------------------
/torch_geometric/explain/__init__.py:
--------------------------------------------------------------------------------
1 | from .config import ExplainerConfig, ModelConfig, ThresholdConfig
2 | from .explanation import Explanation, HeteroExplanation
3 | from .algorithm import * # noqa
4 | from .explainer import Explainer
5 | from .metric import * # noqa
6 |
7 | __all__ = [
8 | 'ExplainerConfig',
9 | 'ModelConfig',
10 | 'ThresholdConfig',
11 | 'Explanation',
12 | 'HeteroExplanation',
13 | 'Explainer',
14 | ]
15 |
--------------------------------------------------------------------------------
/torch_geometric/explain/algorithm/__init__.py:
--------------------------------------------------------------------------------
1 | from .base import ExplainerAlgorithm
2 | from .dummy_explainer import DummyExplainer
3 | from .gnn_explainer import GNNExplainer
4 | from .captum_explainer import CaptumExplainer
5 | from .pg_explainer import PGExplainer
6 | from .attention_explainer import AttentionExplainer
7 | from .graphmask_explainer import GraphMaskExplainer
8 |
9 | __all__ = classes = [
10 | 'ExplainerAlgorithm',
11 | 'DummyExplainer',
12 | 'GNNExplainer',
13 | 'CaptumExplainer',
14 | 'PGExplainer',
15 | 'AttentionExplainer',
16 | 'GraphMaskExplainer',
17 | ]
18 |
--------------------------------------------------------------------------------
/torch_geometric/explain/metric/__init__.py:
--------------------------------------------------------------------------------
1 | from .basic import groundtruth_metrics
2 | from .fidelity import fidelity, characterization_score, fidelity_curve_auc
3 | from .faithfulness import unfaithfulness
4 |
5 | __all__ = classes = [
6 | 'groundtruth_metrics',
7 | 'fidelity',
8 | 'characterization_score',
9 | 'fidelity_curve_auc',
10 | 'unfaithfulness',
11 | ]
12 |
--------------------------------------------------------------------------------
/torch_geometric/graphgym/benchmark.py:
--------------------------------------------------------------------------------
1 | # Do not change; required for benchmarking
2 |
3 | import torch_geometric_benchmark.torchprof_local as torchprof # noqa
4 | from pytorch_memlab import LineProfiler # noqa
5 | from torch_geometric_benchmark.utils import count_parameters # noqa
6 | from torch_geometric_benchmark.utils import get_gpu_memory_nvdia # noqa
7 | from torch_geometric_benchmark.utils import get_memory_status # noqa
8 | from torch_geometric_benchmark.utils import get_model_size # noqa
9 |
10 | global_line_profiler = LineProfiler()
11 | global_line_profiler.enable()
12 |
--------------------------------------------------------------------------------
/torch_geometric/graphgym/cmd_args.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 |
4 | def parse_args() -> argparse.Namespace:
5 | r"""Parses the command line arguments."""
6 | parser = argparse.ArgumentParser(description='GraphGym')
7 |
8 | parser.add_argument('--cfg', dest='cfg_file', type=str, required=True,
9 | help='The configuration file path.')
10 | parser.add_argument('--repeat', type=int, default=1,
11 | help='The number of repeated jobs.')
12 | parser.add_argument('--mark_done', action='store_true',
13 | help='Mark yaml as done after a job has finished.')
14 | parser.add_argument('opts', default=None, nargs=argparse.REMAINDER,
15 | help='See graphgym/config.py for remaining options.')
16 |
17 | return parser.parse_args()
18 |
--------------------------------------------------------------------------------
/torch_geometric/graphgym/contrib/__init__.py:
--------------------------------------------------------------------------------
1 | from .act import * # noqa
2 | from .config import * # noqa
3 | from .encoder import * # noqa
4 | from .head import * # noqa
5 | from .layer import * # noqa
6 | from .loader import * # noqa
7 | from .loss import * # noqa
8 | from .network import * # noqa
9 | from .optimizer import * # noqa
10 | from .pooling import * # noqa
11 | from .stage import * # noqa
12 | from .train import * # noqa
13 | from .transform import * # noqa
14 |
--------------------------------------------------------------------------------
/torch_geometric/graphgym/contrib/act/__init__.py:
--------------------------------------------------------------------------------
1 | from os.path import dirname, basename, isfile, join
2 | import glob
3 |
4 | modules = glob.glob(join(dirname(__file__), "*.py"))
5 | __all__ = [
6 | basename(f)[:-3] for f in modules
7 | if isfile(f) and not f.endswith('__init__.py')
8 | ]
9 |
--------------------------------------------------------------------------------
/torch_geometric/graphgym/contrib/config/__init__.py:
--------------------------------------------------------------------------------
1 | from os.path import dirname, basename, isfile, join
2 | import glob
3 |
4 | modules = glob.glob(join(dirname(__file__), "*.py"))
5 | __all__ = [
6 | basename(f)[:-3] for f in modules
7 | if isfile(f) and not f.endswith('__init__.py')
8 | ]
9 |
--------------------------------------------------------------------------------
/torch_geometric/graphgym/contrib/encoder/__init__.py:
--------------------------------------------------------------------------------
1 | from os.path import dirname, basename, isfile, join
2 | import glob
3 |
4 | modules = glob.glob(join(dirname(__file__), "*.py"))
5 | __all__ = [
6 | basename(f)[:-3] for f in modules
7 | if isfile(f) and not f.endswith('__init__.py')
8 | ]
9 |
--------------------------------------------------------------------------------
/torch_geometric/graphgym/contrib/head/__init__.py:
--------------------------------------------------------------------------------
1 | from os.path import dirname, basename, isfile, join
2 | import glob
3 |
4 | modules = glob.glob(join(dirname(__file__), "*.py"))
5 | __all__ = [
6 | basename(f)[:-3] for f in modules
7 | if isfile(f) and not f.endswith('__init__.py')
8 | ]
9 |
--------------------------------------------------------------------------------
/torch_geometric/graphgym/contrib/layer/__init__.py:
--------------------------------------------------------------------------------
1 | from os.path import dirname, basename, isfile, join
2 | import glob
3 |
4 | modules = glob.glob(join(dirname(__file__), "*.py"))
5 | __all__ = [
6 | basename(f)[:-3] for f in modules
7 | if isfile(f) and not f.endswith('__init__.py')
8 | ]
9 |
--------------------------------------------------------------------------------
/torch_geometric/graphgym/contrib/loader/__init__.py:
--------------------------------------------------------------------------------
1 | from os.path import dirname, basename, isfile, join
2 | import glob
3 |
4 | modules = glob.glob(join(dirname(__file__), "*.py"))
5 | __all__ = [
6 | basename(f)[:-3] for f in modules
7 | if isfile(f) and not f.endswith('__init__.py')
8 | ]
9 |
--------------------------------------------------------------------------------
/torch_geometric/graphgym/contrib/loss/__init__.py:
--------------------------------------------------------------------------------
1 | from os.path import dirname, basename, isfile, join
2 | import glob
3 |
4 | modules = glob.glob(join(dirname(__file__), "*.py"))
5 | __all__ = [
6 | basename(f)[:-3] for f in modules
7 | if isfile(f) and not f.endswith('__init__.py')
8 | ]
9 |
--------------------------------------------------------------------------------
/torch_geometric/graphgym/contrib/network/__init__.py:
--------------------------------------------------------------------------------
1 | from os.path import dirname, basename, isfile, join
2 | import glob
3 |
4 | modules = glob.glob(join(dirname(__file__), "*.py"))
5 | __all__ = [
6 | basename(f)[:-3] for f in modules
7 | if isfile(f) and not f.endswith('__init__.py')
8 | ]
9 |
--------------------------------------------------------------------------------
/torch_geometric/graphgym/contrib/optimizer/__init__.py:
--------------------------------------------------------------------------------
1 | from os.path import dirname, basename, isfile, join
2 | import glob
3 |
4 | modules = glob.glob(join(dirname(__file__), "*.py"))
5 | __all__ = [
6 | basename(f)[:-3] for f in modules
7 | if isfile(f) and not f.endswith('__init__.py')
8 | ]
9 |
--------------------------------------------------------------------------------
/torch_geometric/graphgym/contrib/pooling/__init__.py:
--------------------------------------------------------------------------------
1 | from os.path import dirname, basename, isfile, join
2 | import glob
3 |
4 | modules = glob.glob(join(dirname(__file__), "*.py"))
5 | __all__ = [
6 | basename(f)[:-3] for f in modules
7 | if isfile(f) and not f.endswith('__init__.py')
8 | ]
9 |
--------------------------------------------------------------------------------
/torch_geometric/graphgym/contrib/stage/__init__.py:
--------------------------------------------------------------------------------
1 | from os.path import dirname, basename, isfile, join
2 | import glob
3 |
4 | modules = glob.glob(join(dirname(__file__), "*.py"))
5 | __all__ = [
6 | basename(f)[:-3] for f in modules
7 | if isfile(f) and not f.endswith('__init__.py')
8 | ]
9 |
--------------------------------------------------------------------------------
/torch_geometric/graphgym/contrib/train/__init__.py:
--------------------------------------------------------------------------------
1 | from os.path import dirname, basename, isfile, join
2 | import glob
3 |
4 | modules = glob.glob(join(dirname(__file__), "*.py"))
5 | __all__ = [
6 | basename(f)[:-3] for f in modules
7 | if isfile(f) and not f.endswith('__init__.py')
8 | ]
9 |
--------------------------------------------------------------------------------
/torch_geometric/graphgym/contrib/transform/__init__.py:
--------------------------------------------------------------------------------
1 | from os.path import dirname, basename, isfile, join
2 | import glob
3 |
4 | modules = glob.glob(join(dirname(__file__), "*.py"))
5 | __all__ = [
6 | basename(f)[:-3] for f in modules
7 | if isfile(f) and not f.endswith('__init__.py')
8 | ]
9 |
--------------------------------------------------------------------------------
/torch_geometric/graphgym/imports.py:
--------------------------------------------------------------------------------
1 | import warnings
2 |
3 | import torch
4 |
5 | try:
6 | import lightning.pytorch as pl
7 | _pl_is_available = True
8 | except ImportError:
9 | try:
10 | import pytorch_lightning as pl
11 | _pl_is_available = True
12 | except ImportError:
13 | _pl_is_available = False
14 |
15 | if _pl_is_available:
16 | LightningModule = pl.LightningModule
17 | Callback = pl.Callback
18 | else:
19 | pl = object
20 | LightningModule = torch.nn.Module
21 | Callback = object
22 |
23 | warnings.warn(
24 | "To use GraphGym, install 'pytorch_lightning' or 'lightning' via "
25 | "'pip install pytorch_lightning' or 'pip install lightning'",
26 | stacklevel=2)
27 |
--------------------------------------------------------------------------------
/torch_geometric/graphgym/init.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def init_weights(m):
5 | r"""Performs weight initialization.
6 |
7 | Args:
8 | m (nn.Module): PyTorch module
9 |
10 | """
11 | if (isinstance(m, torch.nn.BatchNorm2d)
12 | or isinstance(m, torch.nn.BatchNorm1d)):
13 | m.weight.data.fill_(1.0)
14 | m.bias.data.zero_()
15 | elif isinstance(m, torch.nn.Linear):
16 | m.weight.data = torch.nn.init.xavier_uniform_(
17 | m.weight.data, gain=torch.nn.init.calculate_gain('relu'))
18 | if m.bias is not None:
19 | m.bias.data.zero_()
20 |
--------------------------------------------------------------------------------
/torch_geometric/graphgym/models/pooling.py:
--------------------------------------------------------------------------------
1 | from torch_geometric.graphgym.register import register_pooling
2 | from torch_geometric.nn import (
3 | global_add_pool,
4 | global_max_pool,
5 | global_mean_pool,
6 | )
7 |
8 | register_pooling('add', global_add_pool)
9 | register_pooling('mean', global_mean_pool)
10 | register_pooling('max', global_max_pool)
11 |
--------------------------------------------------------------------------------
/torch_geometric/graphgym/utils/LICENSE:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pyg-team/pytorch_geometric/6c22779538bba0cc6650110f8edcfcd05da5acb4/torch_geometric/graphgym/utils/LICENSE
--------------------------------------------------------------------------------
/torch_geometric/graphgym/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .agg_runs import agg_runs, agg_batch
2 | from .comp_budget import params_count, match_baseline_cfg
3 | from .device import get_current_gpu_usage, auto_select_device
4 | from .epoch import is_eval_epoch, is_ckpt_epoch
5 | from .io import dict_to_json, dict_list_to_json, dict_to_tb, makedirs_rm_exist
6 | from .tools import dummy_context
7 |
8 | __all__ = [
9 | 'agg_runs',
10 | 'agg_batch',
11 | 'params_count',
12 | 'match_baseline_cfg',
13 | 'get_current_gpu_usage',
14 | 'auto_select_device',
15 | 'is_eval_epoch',
16 | 'is_ckpt_epoch',
17 | 'dict_to_json',
18 | 'dict_list_to_json',
19 | 'dict_to_tb',
20 | 'makedirs_rm_exist',
21 | 'dummy_context',
22 | ]
23 |
24 | classes = __all__
25 |
--------------------------------------------------------------------------------
/torch_geometric/graphgym/utils/epoch.py:
--------------------------------------------------------------------------------
1 | from torch_geometric.graphgym.config import cfg
2 |
3 |
4 | def is_train_eval_epoch(cur_epoch):
5 | """Determines if the model should be evaluated at the training epoch."""
6 | return is_eval_epoch(cur_epoch) or not cfg.train.skip_train_eval
7 |
8 |
9 | def is_eval_epoch(cur_epoch):
10 | """Determines if the model should be evaluated at the current epoch."""
11 | return ((cur_epoch + 1) % cfg.train.eval_period == 0 or cur_epoch == 0
12 | or (cur_epoch + 1) == cfg.optim.max_epoch)
13 |
14 |
15 | def is_ckpt_epoch(cur_epoch):
16 | """Determines if the model should be evaluated at the current epoch."""
17 | return ((cur_epoch + 1) % cfg.train.ckpt_period == 0
18 | or (cur_epoch + 1) == cfg.optim.max_epoch)
19 |
--------------------------------------------------------------------------------
/torch_geometric/graphgym/utils/plot.py:
--------------------------------------------------------------------------------
1 | import os.path as osp
2 |
3 |
4 | def view_emb(emb, dir):
5 | """Visualize a embedding matrix.
6 |
7 | Args:
8 | emb (torch.tensor): Embedding matrix with shape (N, D). D is the
9 | feature dimension.
10 | dir (str): Output directory for the embedding figure.
11 | """
12 | import matplotlib.pyplot as plt
13 | import seaborn as sns
14 | from sklearn.decomposition import PCA
15 |
16 | sns.set_context('poster')
17 |
18 | if emb.shape[1] > 2:
19 | pca = PCA(n_components=2)
20 | emb = pca.fit_transform(emb)
21 | plt.figure(figsize=(10, 10))
22 | plt.scatter(emb[:, 0], emb[:, 1])
23 | plt.savefig(osp.join(dir, 'emb_pca.png'), dpi=100)
24 |
--------------------------------------------------------------------------------
/torch_geometric/graphgym/utils/tools.py:
--------------------------------------------------------------------------------
1 | class dummy_context():
2 | """Default context manager that does nothing."""
3 | def __enter__(self):
4 | return None
5 |
6 | def __exit__(self, exc_type, exc_value, traceback):
7 | return False
8 |
--------------------------------------------------------------------------------
/torch_geometric/home.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path as osp
3 | from typing import Optional
4 |
5 | ENV_PYG_HOME = 'PYG_HOME'
6 | DEFAULT_CACHE_DIR = osp.join('~', '.cache', 'pyg')
7 |
8 | _home_dir: Optional[str] = None
9 |
10 |
11 | def get_home_dir() -> str:
12 | r"""Get the cache directory used for storing all :pyg:`PyG`-related data.
13 |
14 | If :meth:`set_home_dir` is not called, the path is given by the environment
15 | variable :obj:`$PYG_HOME` which defaults to :obj:`"~/.cache/pyg"`.
16 | """
17 | if _home_dir is not None:
18 | return _home_dir
19 |
20 | return osp.expanduser(os.getenv(ENV_PYG_HOME, DEFAULT_CACHE_DIR))
21 |
22 |
23 | def set_home_dir(path: str) -> None:
24 | r"""Set the cache directory used for storing all :pyg:`PyG`-related data.
25 |
26 | Args:
27 | path (str): The path to a local folder.
28 | """
29 | global _home_dir
30 | _home_dir = path
31 |
--------------------------------------------------------------------------------
/torch_geometric/io/__init__.py:
--------------------------------------------------------------------------------
1 | from .txt_array import parse_txt_array, read_txt_array
2 | from .tu import read_tu_data
3 | from .planetoid import read_planetoid_data
4 | from .ply import read_ply
5 | from .obj import read_obj
6 | from .sdf import read_sdf, parse_sdf
7 | from .off import read_off, write_off
8 | from .npz import read_npz, parse_npz
9 |
10 | __all__ = [
11 | 'read_off',
12 | 'write_off',
13 | 'parse_txt_array',
14 | 'read_txt_array',
15 | 'read_tu_data',
16 | 'read_planetoid_data',
17 | 'read_ply',
18 | 'read_obj',
19 | 'read_sdf',
20 | 'parse_sdf',
21 | 'read_npz',
22 | 'parse_npz',
23 | ]
24 |
--------------------------------------------------------------------------------
/torch_geometric/io/ply.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from torch_geometric.data import Data
4 |
5 | try:
6 | import openmesh
7 | except ImportError:
8 | openmesh = None
9 |
10 |
11 | def read_ply(path: str) -> Data:
12 | if openmesh is None:
13 | raise ImportError('`read_ply` requires the `openmesh` package.')
14 |
15 | mesh = openmesh.read_trimesh(path)
16 | pos = torch.from_numpy(mesh.points()).to(torch.float)
17 | face = torch.from_numpy(mesh.face_vertex_indices())
18 | face = face.t().to(torch.long).contiguous()
19 | return Data(pos=pos, face=face)
20 |
--------------------------------------------------------------------------------
/torch_geometric/llm/__init__.py:
--------------------------------------------------------------------------------
1 | from .large_graph_indexer import LargeGraphIndexer
2 | from .rag_loader import RAGQueryLoader
3 | from .utils import * # noqa
4 | from .models import * # noqa
5 |
6 | __all__ = classes = [
7 | 'LargeGraphIndexer',
8 | 'RAGQueryLoader',
9 | ]
10 |
--------------------------------------------------------------------------------
/torch_geometric/llm/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .sentence_transformer import SentenceTransformer
2 | from .vision_transformer import VisionTransformer
3 | from .llm import LLM
4 | from .txt2kg import TXT2KG
5 | from .llm_judge import LLMJudge
6 | from .g_retriever import GRetriever
7 | from .molecule_gpt import MoleculeGPT
8 | from .glem import GLEM
9 | from .protein_mpnn import ProteinMPNN
10 | from .git_mol import GITMol
11 |
12 | __all__ = classes = [
13 | 'SentenceTransformer',
14 | 'VisionTransformer',
15 | 'LLM',
16 | 'LLMJudge',
17 | 'TXT2KG',
18 | 'GRetriever',
19 | 'MoleculeGPT',
20 | 'GLEM',
21 | 'ProteinMPNN',
22 | 'GITMol',
23 | ]
24 |
--------------------------------------------------------------------------------
/torch_geometric/llm/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .backend_utils import * # noqa
2 | from .feature_store import KNNRAGFeatureStore
3 | from .graph_store import NeighborSamplingRAGGraphStore
4 | from .vectorrag import DocumentRetriever
5 |
6 | __all__ = classes = [
7 | 'KNNRAGFeatureStore',
8 | 'NeighborSamplingRAGGraphStore',
9 | 'DocumentRetriever',
10 | ]
11 |
--------------------------------------------------------------------------------
/torch_geometric/logging.py:
--------------------------------------------------------------------------------
1 | import sys
2 | from typing import Any
3 |
4 | _wandb_initialized: bool = False
5 |
6 |
7 | def init_wandb(name: str, **kwargs: Any) -> None:
8 | if '--wandb' not in sys.argv:
9 | return
10 |
11 | from datetime import datetime
12 |
13 | import wandb
14 |
15 | wandb.init(
16 | project=name,
17 | entity='pytorch-geometric',
18 | name=datetime.now().strftime('%Y-%m-%d_%H:%M'),
19 | config=kwargs,
20 | )
21 |
22 | global _wandb_initialized
23 | _wandb_initialized = True
24 |
25 |
26 | def log(**kwargs: Any) -> None:
27 | def _map(value: Any) -> str:
28 | if isinstance(value, int) and not isinstance(value, bool):
29 | return f'{value:03d}'
30 | if isinstance(value, float):
31 | return f'{value:.4f}'
32 | return value
33 |
34 | print(', '.join(f'{key}: {_map(value)}' for key, value in kwargs.items()))
35 |
36 | if _wandb_initialized:
37 | import wandb
38 | wandb.log(kwargs)
39 |
--------------------------------------------------------------------------------
/torch_geometric/metrics/__init__.py:
--------------------------------------------------------------------------------
1 | # flake8: noqa
2 |
3 | from .link_pred import (
4 | LinkPredMetric,
5 | LinkPredMetricCollection,
6 | LinkPredPrecision,
7 | LinkPredRecall,
8 | LinkPredF1,
9 | LinkPredMAP,
10 | LinkPredNDCG,
11 | LinkPredMRR,
12 | LinkPredHitRatio,
13 | LinkPredCoverage,
14 | LinkPredDiversity,
15 | LinkPredPersonalization,
16 | LinkPredAveragePopularity,
17 | )
18 |
19 | link_pred_metrics = [
20 | 'LinkPredMetric',
21 | 'LinkPredMetricCollection',
22 | 'LinkPredPrecision',
23 | 'LinkPredRecall',
24 | 'LinkPredF1',
25 | 'LinkPredMAP',
26 | 'LinkPredNDCG',
27 | 'LinkPredMRR',
28 | 'LinkPredHitRatio',
29 | 'LinkPredCoverage',
30 | 'LinkPredDiversity',
31 | 'LinkPredPersonalization',
32 | 'LinkPredAveragePopularity',
33 | ]
34 |
35 | __all__ = link_pred_metrics
36 |
--------------------------------------------------------------------------------
/torch_geometric/nn/__init__.py:
--------------------------------------------------------------------------------
1 | from .reshape import Reshape
2 | from .sequential import Sequential
3 | from .data_parallel import DataParallel
4 | from .to_hetero_transformer import to_hetero
5 | from .to_hetero_with_bases_transformer import to_hetero_with_bases
6 | from .to_fixed_size_transformer import to_fixed_size
7 | from .encoding import PositionalEncoding, TemporalEncoding
8 | from .summary import summary
9 |
10 | from .aggr import * # noqa
11 | from .attention import * # noqa
12 | from .conv import * # noqa
13 | from .pool import * # noqa
14 | from .glob import * # noqa
15 | from .norm import * # noqa
16 | from .unpool import * # noqa
17 | from .dense import * # noqa
18 | from .kge import * # noqa
19 | from .models import * # noqa
20 | from .functional import * # noqa
21 |
22 | __all__ = [
23 | 'Reshape',
24 | 'Sequential',
25 | 'DataParallel',
26 | 'to_hetero',
27 | 'to_hetero_with_bases',
28 | 'to_fixed_size',
29 | 'PositionalEncoding',
30 | 'TemporalEncoding',
31 | 'summary',
32 | ]
33 |
--------------------------------------------------------------------------------
/torch_geometric/nn/attention/__init__.py:
--------------------------------------------------------------------------------
1 | from .performer import PerformerAttention
2 | from .qformer import QFormer
3 | from .sgformer import SGFormerAttention
4 | from .polynormer import PolynormerAttention
5 |
6 | __all__ = classes = [
7 | 'PerformerAttention',
8 | 'QFormer',
9 | 'SGFormerAttention',
10 | 'PolynormerAttention',
11 | ]
12 |
--------------------------------------------------------------------------------
/torch_geometric/nn/conv/cugraph/__init__.py:
--------------------------------------------------------------------------------
1 | from .base import CuGraphModule
2 | from .sage_conv import CuGraphSAGEConv
3 | from .gat_conv import CuGraphGATConv
4 | from .rgcn_conv import CuGraphRGCNConv
5 |
6 | __all__ = [
7 | 'CuGraphModule',
8 | 'CuGraphSAGEConv',
9 | 'CuGraphGATConv',
10 | 'CuGraphRGCNConv',
11 | ]
12 |
--------------------------------------------------------------------------------
/torch_geometric/nn/conv/utils/__init__.py:
--------------------------------------------------------------------------------
1 | r"""GNN utility package."""
2 |
3 | from .cheatsheet import paper_title, paper_link
4 | from .cheatsheet import supports_sparse_tensor
5 | from .cheatsheet import supports_edge_weights
6 | from .cheatsheet import supports_edge_features
7 | from .cheatsheet import supports_bipartite_graphs
8 | from .cheatsheet import supports_static_graphs
9 | from .cheatsheet import supports_lazy_initialization
10 | from .cheatsheet import processes_heterogeneous_graphs
11 | from .cheatsheet import processes_hypergraphs
12 | from .cheatsheet import processes_point_clouds
13 |
14 | __all__ = [
15 | 'paper_title',
16 | 'paper_link',
17 | 'supports_sparse_tensor',
18 | 'supports_edge_weights',
19 | 'supports_edge_features',
20 | 'supports_bipartite_graphs',
21 | 'supports_static_graphs',
22 | 'supports_lazy_initialization',
23 | 'processes_heterogeneous_graphs',
24 | 'processes_hypergraphs',
25 | 'processes_point_clouds',
26 | ]
27 |
--------------------------------------------------------------------------------
/torch_geometric/nn/dense/__init__.py:
--------------------------------------------------------------------------------
1 | r"""Dense neural network module package.
2 |
3 | This package provides modules applicable for operating on dense tensor
4 | representations.
5 | """
6 |
7 | from .linear import Linear, HeteroLinear, HeteroDictLinear
8 | from .dense_gat_conv import DenseGATConv
9 | from .dense_sage_conv import DenseSAGEConv
10 | from .dense_gcn_conv import DenseGCNConv
11 | from .dense_graph_conv import DenseGraphConv
12 | from .dense_gin_conv import DenseGINConv
13 | from .diff_pool import dense_diff_pool
14 | from .mincut_pool import dense_mincut_pool
15 | from .dmon_pool import DMoNPooling
16 |
17 | __all__ = [
18 | 'Linear',
19 | 'HeteroLinear',
20 | 'HeteroDictLinear',
21 | 'DenseGCNConv',
22 | 'DenseGINConv',
23 | 'DenseGraphConv',
24 | 'DenseSAGEConv',
25 | 'DenseGATConv',
26 | 'dense_diff_pool',
27 | 'dense_mincut_pool',
28 | 'DMoNPooling',
29 | ]
30 |
31 | lin_classes = __all__[:3]
32 | conv_classes = __all__[3:8]
33 | pool_classes = __all__[8:]
34 |
--------------------------------------------------------------------------------
/torch_geometric/nn/functional/__init__.py:
--------------------------------------------------------------------------------
1 | r"""Functional operator package."""
2 |
3 | from .bro import bro
4 | from .gini import gini
5 |
6 | __all__ = classes = [
7 | 'bro',
8 | 'gini',
9 | ]
10 |
--------------------------------------------------------------------------------
/torch_geometric/nn/functional/gini.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def gini(w: torch.Tensor) -> torch.Tensor:
5 | r"""The Gini coefficient from the `"Improving Molecular Graph Neural
6 | Network Explainability with Orthonormalization and Induced Sparsity"
7 | `_ paper.
8 |
9 | Computes a regularization penalty :math:`\in [0, 1]` for each row of a
10 | matrix according to
11 |
12 | .. math::
13 | \mathcal{L}_\textrm{Gini}^i = \sum_j^n \sum_{j'}^n \frac{|w_{ij}
14 | - w_{ij'}|}{2 (n^2 - n)\bar{w_i}}
15 |
16 | and returns an average over all rows.
17 |
18 | Args:
19 | w (torch.Tensor): A two-dimensional tensor.
20 | """
21 | s = 0
22 | for row in w:
23 | t = row.repeat(row.size(0), 1)
24 | u = (t - t.T).abs().sum() / (2 * (row.size(-1)**2 - row.size(-1)) *
25 | row.abs().mean() + torch.finfo().eps)
26 | s += u
27 | s /= w.shape[0]
28 | return s
29 |
--------------------------------------------------------------------------------
/torch_geometric/nn/kge/__init__.py:
--------------------------------------------------------------------------------
1 | r"""Knowledge Graph Embedding (KGE) package."""
2 |
3 | from .base import KGEModel
4 | from .transe import TransE
5 | from .complex import ComplEx
6 | from .distmult import DistMult
7 | from .rotate import RotatE
8 |
9 | __all__ = classes = [
10 | 'KGEModel',
11 | 'TransE',
12 | 'ComplEx',
13 | 'DistMult',
14 | 'RotatE',
15 | ]
16 |
--------------------------------------------------------------------------------
/torch_geometric/nn/kge/loader.py:
--------------------------------------------------------------------------------
1 | from typing import List, Tuple
2 |
3 | import torch
4 | from torch import Tensor
5 |
6 |
7 | class KGTripletLoader(torch.utils.data.DataLoader):
8 | def __init__(self, head_index: Tensor, rel_type: Tensor,
9 | tail_index: Tensor, **kwargs):
10 | self.head_index = head_index
11 | self.rel_type = rel_type
12 | self.tail_index = tail_index
13 |
14 | super().__init__(range(head_index.numel()), collate_fn=self.sample,
15 | **kwargs)
16 |
17 | def sample(self, index: List[int]) -> Tuple[Tensor, Tensor, Tensor]:
18 | index = torch.tensor(index, device=self.head_index.device)
19 |
20 | head_index = self.head_index[index]
21 | rel_type = self.rel_type[index]
22 | tail_index = self.tail_index[index]
23 |
24 | return head_index, rel_type, tail_index
25 |
--------------------------------------------------------------------------------
/torch_geometric/nn/norm/__init__.py:
--------------------------------------------------------------------------------
1 | r"""Normalization package."""
2 |
3 | from .batch_norm import BatchNorm, HeteroBatchNorm
4 | from .instance_norm import InstanceNorm
5 | from .layer_norm import LayerNorm, HeteroLayerNorm
6 | from .graph_norm import GraphNorm
7 | from .graph_size_norm import GraphSizeNorm
8 | from .pair_norm import PairNorm
9 | from .mean_subtraction_norm import MeanSubtractionNorm
10 | from .msg_norm import MessageNorm
11 | from .diff_group_norm import DiffGroupNorm
12 |
13 | __all__ = [
14 | 'BatchNorm',
15 | 'HeteroBatchNorm',
16 | 'InstanceNorm',
17 | 'LayerNorm',
18 | 'HeteroLayerNorm',
19 | 'GraphNorm',
20 | 'GraphSizeNorm',
21 | 'PairNorm',
22 | 'MeanSubtractionNorm',
23 | 'MessageNorm',
24 | 'DiffGroupNorm',
25 | ]
26 |
27 | classes = __all__
28 |
--------------------------------------------------------------------------------
/torch_geometric/nn/pool/connect/__init__.py:
--------------------------------------------------------------------------------
1 | r"""Graph connection package.
2 |
3 | This package provides classes for determining coarsened graph connections in
4 | graph pooling scenarios.
5 | """
6 |
7 | from .base import Connect, ConnectOutput
8 | from .filter_edges import FilterEdges
9 |
10 | __all__ = [
11 | 'Connect',
12 | 'ConnectOutput',
13 | 'FilterEdges',
14 | ]
15 |
--------------------------------------------------------------------------------
/torch_geometric/nn/pool/consecutive.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def consecutive_cluster(src):
5 | unique, inv = torch.unique(src, sorted=True, return_inverse=True)
6 | perm = torch.arange(inv.size(0), dtype=inv.dtype, device=inv.device)
7 | perm = inv.new_empty(unique.size(0)).scatter_(0, inv, perm)
8 | return inv, perm
9 |
--------------------------------------------------------------------------------
/torch_geometric/nn/pool/pool.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | import torch
4 |
5 | from torch_geometric.utils import coalesce, remove_self_loops, scatter
6 |
7 |
8 | def pool_edge(
9 | cluster,
10 | edge_index,
11 | edge_attr: Optional[torch.Tensor] = None,
12 | reduce: Optional[str] = 'sum',
13 | ):
14 | num_nodes = cluster.size(0)
15 | edge_index = cluster[edge_index.view(-1)].view(2, -1)
16 | edge_index, edge_attr = remove_self_loops(edge_index, edge_attr)
17 | if edge_index.numel() > 0:
18 | edge_index, edge_attr = coalesce(edge_index, edge_attr, num_nodes,
19 | reduce=reduce)
20 | return edge_index, edge_attr
21 |
22 |
23 | def pool_batch(perm, batch):
24 | return batch[perm]
25 |
26 |
27 | def pool_pos(cluster, pos):
28 | return scatter(pos, cluster, dim=0, reduce='mean')
29 |
--------------------------------------------------------------------------------
/torch_geometric/nn/pool/select/__init__.py:
--------------------------------------------------------------------------------
1 | r"""Node-selection package.
2 |
3 | This package provides classes for node selection methods in graph pooling
4 | scenarios.
5 | """
6 |
7 | from .base import Select, SelectOutput
8 | from .topk import SelectTopK
9 |
10 | __all__ = [
11 | 'Select',
12 | 'SelectOutput',
13 | 'SelectTopK',
14 | ]
15 |
--------------------------------------------------------------------------------
/torch_geometric/nn/reshape.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import Tensor
3 |
4 |
5 | class Reshape(torch.nn.Module):
6 | def __init__(self, *shape):
7 | super().__init__()
8 | self.shape = shape
9 |
10 | def forward(self, x: Tensor) -> Tensor:
11 | """""" # noqa: D419
12 | x = x.view(*self.shape)
13 | return x
14 |
15 | def __repr__(self) -> str:
16 | shape = ', '.join([str(dim) for dim in self.shape])
17 | return f'{self.__class__.__name__}({shape})'
18 |
--------------------------------------------------------------------------------
/torch_geometric/nn/sequential.jinja:
--------------------------------------------------------------------------------
1 | import typing
2 |
3 | import torch
4 | from torch import Tensor
5 |
6 | import torch_geometric.typing
7 | {% for module in modules %}
8 | from {{module}} import *
9 | {%- endfor %}
10 |
11 |
12 | def forward(
13 | self,
14 | {%- for param in signature.param_dict.values() %}
15 | {{param.name}}: {{param.type_repr}},
16 | {%- endfor %}
17 | ) -> {{signature.return_type_repr}}:
18 |
19 | {%- for child in children %}
20 | {{child.return_names|join(', ')}} = self.{{child.name}}({{child.param_names|join(', ')}})
21 | {%- endfor %}
22 | return {{children[-1].return_names|join(', ')}}
23 |
--------------------------------------------------------------------------------
/torch_geometric/nn/unpool/__init__.py:
--------------------------------------------------------------------------------
1 | r"""Unpooling package."""
2 |
3 | from .knn_interpolate import knn_interpolate
4 |
5 | __all__ = [
6 | 'knn_interpolate',
7 | ]
8 |
9 | classes = __all__
10 |
--------------------------------------------------------------------------------
/torch_geometric/sampler/__init__.py:
--------------------------------------------------------------------------------
1 | r"""Graph sampler package."""
2 |
3 | from .base import (BaseSampler, NodeSamplerInput, EdgeSamplerInput,
4 | SamplerOutput, HeteroSamplerOutput, NegativeSampling,
5 | NumNeighbors)
6 | from .neighbor_sampler import NeighborSampler, BidirectionalNeighborSampler
7 | from .hgt_sampler import HGTSampler
8 |
9 | __all__ = classes = [
10 | 'BaseSampler',
11 | 'NodeSamplerInput',
12 | 'EdgeSamplerInput',
13 | 'SamplerOutput',
14 | 'HeteroSamplerOutput',
15 | 'NumNeighbors',
16 | 'NegativeSampling',
17 | 'NeighborSampler',
18 | 'BidirectionalNeighborSampler',
19 | 'HGTSampler',
20 | ]
21 |
--------------------------------------------------------------------------------
/torch_geometric/seed.py:
--------------------------------------------------------------------------------
1 | import random
2 |
3 | import numpy as np
4 | import torch
5 |
6 |
7 | def seed_everything(seed: int) -> None:
8 | r"""Sets the seed for generating random numbers in :pytorch:`PyTorch`,
9 | :obj:`numpy` and :python:`Python`.
10 |
11 | Args:
12 | seed (int): The desired seed.
13 | """
14 | random.seed(seed)
15 | np.random.seed(seed)
16 | torch.manual_seed(seed)
17 | torch.cuda.manual_seed_all(seed)
18 |
--------------------------------------------------------------------------------
/torch_geometric/transforms/center.py:
--------------------------------------------------------------------------------
1 | from typing import Union
2 |
3 | from torch_geometric.data import Data, HeteroData
4 | from torch_geometric.data.datapipes import functional_transform
5 | from torch_geometric.transforms import BaseTransform
6 |
7 |
8 | @functional_transform('center')
9 | class Center(BaseTransform):
10 | r"""Centers node positions :obj:`data.pos` around the origin
11 | (functional name: :obj:`center`).
12 | """
13 | def forward(
14 | self,
15 | data: Union[Data, HeteroData],
16 | ) -> Union[Data, HeteroData]:
17 | for store in data.node_stores:
18 | if hasattr(store, 'pos'):
19 | store.pos = store.pos - store.pos.mean(dim=-2, keepdim=True)
20 | return data
21 |
--------------------------------------------------------------------------------
/torch_geometric/transforms/normalize_scale.py:
--------------------------------------------------------------------------------
1 | from torch_geometric.data import Data
2 | from torch_geometric.data.datapipes import functional_transform
3 | from torch_geometric.transforms import BaseTransform, Center
4 |
5 |
6 | @functional_transform('normalize_scale')
7 | class NormalizeScale(BaseTransform):
8 | r"""Centers and normalizes node positions to the interval :math:`(-1, 1)`
9 | (functional name: :obj:`normalize_scale`).
10 | """
11 | def __init__(self) -> None:
12 | self.center = Center()
13 |
14 | def forward(self, data: Data) -> Data:
15 | data = self.center(data)
16 |
17 | assert data.pos is not None
18 | scale = (1.0 / data.pos.abs().max()) * 0.999999
19 | data.pos = data.pos * scale
20 |
21 | return data
22 |
--------------------------------------------------------------------------------
/torch_geometric/utils/functions.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import Tensor
3 |
4 |
5 | def cumsum(x: Tensor, dim: int = 0) -> Tensor:
6 | r"""Returns the cumulative sum of elements of :obj:`x`.
7 | In contrast to :meth:`torch.cumsum`, prepends the output with zero.
8 |
9 | Args:
10 | x (torch.Tensor): The input tensor.
11 | dim (int, optional): The dimension to do the operation over.
12 | (default: :obj:`0`)
13 |
14 | Example:
15 | >>> x = torch.tensor([2, 4, 1])
16 | >>> cumsum(x)
17 | tensor([0, 2, 6, 7])
18 |
19 | """
20 | size = x.size()[:dim] + (x.size(dim) + 1, ) + x.size()[dim + 1:]
21 | out = x.new_empty(size)
22 |
23 | out.narrow(dim, 0, 1).zero_()
24 | torch.cumsum(x, dim=dim, out=out.narrow(dim, 1, x.size(dim)))
25 |
26 | return out
27 |
--------------------------------------------------------------------------------
/torch_geometric/utils/mixin.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Iterator, TypeVar
2 |
3 | T = TypeVar('T')
4 |
5 |
6 | class CastMixin:
7 | @classmethod
8 | def cast(cls: T, *args: Any, **kwargs: Any) -> T:
9 | if len(args) == 1 and len(kwargs) == 0:
10 | elem = args[0]
11 | if elem is None:
12 | return None # type: ignore
13 | if isinstance(elem, CastMixin):
14 | return elem # type: ignore
15 | if isinstance(elem, tuple):
16 | return cls(*elem) # type: ignore
17 | if isinstance(elem, dict):
18 | return cls(**elem) # type: ignore
19 | return cls(*args, **kwargs) # type: ignore
20 |
21 | def __iter__(self) -> Iterator:
22 | return iter(self.__dict__.values())
23 |
--------------------------------------------------------------------------------
/torch_geometric/utils/repeat.py:
--------------------------------------------------------------------------------
1 | import itertools
2 | import numbers
3 | from typing import Any
4 |
5 | import torch
6 | from torch import Tensor
7 |
8 |
9 | def repeat(src: Any, length: int) -> Any:
10 | if src is None:
11 | return None
12 |
13 | if isinstance(src, Tensor):
14 | if src.numel() == 1:
15 | return src.repeat(length)
16 |
17 | if src.numel() > length:
18 | return src[:length]
19 |
20 | if src.numel() < length:
21 | last_elem = src[-1].unsqueeze(0)
22 | padding = last_elem.repeat(length - src.numel())
23 | return torch.cat([src, padding])
24 |
25 | return src
26 |
27 | if isinstance(src, numbers.Number):
28 | return list(itertools.repeat(src, length))
29 |
30 | if (len(src) > length):
31 | return src[:length]
32 |
33 | if (len(src) < length):
34 | return src + list(itertools.repeat(src[-1], length - len(src)))
35 |
36 | return src
37 |
--------------------------------------------------------------------------------
/torch_geometric/visualization/__init__.py:
--------------------------------------------------------------------------------
1 | r"""Visualization package."""
2 |
3 | from .graph import visualize_graph, visualize_hetero_graph
4 | from .influence import influence
5 |
6 | __all__ = [
7 | 'visualize_graph',
8 | 'visualize_hetero_graph',
9 | 'influence',
10 | ]
11 |
--------------------------------------------------------------------------------
/torch_geometric/visualization/influence.py:
--------------------------------------------------------------------------------
1 | from typing import Any
2 |
3 | import torch
4 | from torch import Tensor
5 | from torch.autograd import grad
6 |
7 |
8 | def influence(model: torch.nn.Module, src: Tensor, *args: Any) -> Tensor:
9 | x = src.clone().requires_grad_()
10 | out = model(x, *args).sum(dim=-1)
11 |
12 | influences = []
13 | for j in range(src.size(0)):
14 | influence = grad([out[j]], [x], retain_graph=True)[0].abs().sum(dim=-1)
15 | influences.append(influence / influence.sum())
16 |
17 | return torch.stack(influences, dim=0)
18 |
--------------------------------------------------------------------------------
/torch_geometric/warnings.py:
--------------------------------------------------------------------------------
1 | import warnings
2 | from typing import Literal
3 |
4 | import torch_geometric
5 |
6 |
7 | def warn(message: str, stacklevel: int = 5) -> None:
8 | if torch_geometric.is_compiling():
9 | return
10 |
11 | warnings.warn(message, stacklevel=stacklevel)
12 |
13 |
14 | def filterwarnings(
15 | action: Literal['default', 'error', 'ignore', 'always', 'module', 'once'],
16 | message: str,
17 | ) -> None:
18 | if torch_geometric.is_compiling():
19 | return
20 |
21 | warnings.filterwarnings(action, message)
22 |
23 |
24 | class WarningCache(set):
25 | """Cache for warnings."""
26 | def warn(self, message: str, stacklevel: int = 5) -> None:
27 | """Trigger warning message."""
28 | if message not in self:
29 | self.add(message)
30 | warn(message, stacklevel=stacklevel)
31 |
--------------------------------------------------------------------------------