├── .gitignore ├── LICENSE ├── README.md ├── assets └── JittorGeometric_logo.png ├── docs ├── Makefile ├── make.bat ├── readthedocs.yaml ├── requirements.txt └── source │ ├── conf.py │ ├── data │ └── data.rst │ ├── dataloader │ └── dataloader.rst │ ├── datasets │ └── datasets.rst │ ├── get_started │ └── introduction.rst │ ├── index.rst │ ├── install │ └── installation.rst │ ├── io │ └── io.rst │ ├── modules │ ├── nn.rst │ └── nn │ │ ├── aggr.rst │ │ ├── conv.rst │ │ └── dense.rst │ ├── ops │ └── ops.rst │ ├── partition │ └── partition.rst │ ├── transforms │ └── transforms.rst │ └── utils │ └── utils.rst ├── examples ├── appnp_example.py ├── bernnet_example.py ├── cheb_example.py ├── chebnet2_example.py ├── dimenet_example.py ├── dygformer_example.py ├── dyrep_example.py ├── egnn_example.py ├── evennet_example.py ├── gat_example.py ├── gcn2_example.py ├── gcn_example.py ├── gprgnn_example.py ├── graphmixer_example.py ├── jodie_example.py ├── optbasisgnn_example.py ├── sage_example.py ├── schnet_example.py ├── sgc_example.py ├── spherenet_example.py ├── tgn_example.py └── unimol_example.py ├── jittor_geometric ├── __init__.py ├── data │ ├── __init__.py │ ├── batch.py │ ├── conformer.py │ ├── data.py │ ├── dataset.py │ ├── dictionary.py │ ├── download.py │ ├── graphchunk.py │ ├── in_memory_dataset.py │ ├── makedirs.py │ └── temporal.py ├── dataloader │ ├── __init__.py │ ├── cluster_loader.py │ ├── dataloader.py │ ├── general_loader.py │ ├── graphsaint_loader.py │ ├── neighbor_loader.py │ ├── random_node_loader.py │ └── temporal_dataloader.py ├── datasets │ ├── __init__.py │ ├── amazon.py │ ├── geomgcn.py │ ├── hetero.py │ ├── jodie.py │ ├── linkx.py │ ├── master.csv │ ├── molecule_net.py │ ├── ogb.py │ ├── planetoid.py │ ├── qm9.py │ ├── reddit.py │ ├── tgb_seq.py │ └── wikipedia_network.py ├── evaluate │ ├── __init__.py │ └── evaluators.py ├── io │ ├── __init__.py │ ├── npz.py │ ├── ogb.py │ ├── ogb_raw.py │ ├── planetoid.py │ └── txt_array.py ├── nn │ ├── __init__.py │ ├── aggr │ │ ├── __init__.py │ │ ├── base.py │ │ ├── basic.py │ │ └── multi.py │ ├── conv │ │ ├── __init__.py │ │ ├── appnp_conv.py │ │ ├── bernnet_conv.py │ │ ├── cheb_conv.py │ │ ├── chebnet2_conv.py │ │ ├── clustergcn_conv.py │ │ ├── egnn_conv.py │ │ ├── even_conv.py │ │ ├── gat_conv.py │ │ ├── gcn2_conv.py │ │ ├── gcn_conv.py │ │ ├── gpr_conv.py │ │ ├── message_passing.py │ │ ├── message_passiong_nts.py │ │ ├── optbasis_conv.py │ │ ├── sage_conv.py │ │ ├── sg_conv.py │ │ ├── spherenet_conv.py │ │ ├── transformer_conv.py │ │ └── utils │ │ │ ├── __init__.py │ │ │ ├── inspector.py │ │ │ └── typing.py │ ├── dense │ │ ├── __init__.py │ │ ├── merge_predictor.py │ │ └── time_encoder.py │ ├── inits.py │ ├── models │ │ ├── __init__.py │ │ ├── dimenet.py │ │ ├── dimenet_utils.py │ │ ├── dygformer.py │ │ ├── dyrep.py │ │ ├── graphmixer.py │ │ ├── jodie.py │ │ ├── schnet.py │ │ ├── tgn.py │ │ ├── transformers.py │ │ └── unimol.py │ └── pool │ │ ├── __init__.py │ │ └── glob.py ├── ops │ ├── README.md │ ├── __init__.py │ ├── aggregateWithWeight.py │ ├── cootocsc.py │ ├── cootocsr.py │ ├── cpp │ │ ├── aggregate_op.cc │ │ ├── aggregate_op.h │ │ ├── cootocsc_op.cc │ │ ├── cootocsc_op.h │ │ ├── cootocsr_op.cc │ │ ├── cootocsr_op.h │ │ ├── edgesoftmax_op.cc │ │ ├── edgesoftmax_op.h │ │ ├── edgesoftmaxbackward_op.cc │ │ ├── edgesoftmaxbackward_op.h │ │ ├── edgetovertex_op.cc │ │ ├── edgetovertex_op.h │ │ ├── scattertoedge_op.cc │ │ ├── scattertoedge_op.h │ │ ├── spmmcoo_op.cc │ │ ├── spmmcoo_op.h │ │ ├── spmmcsr_op.cc │ │ ├── spmmcsr_op.h │ │ ├── toundirected_op.cc │ │ └── toundirected_op.h │ ├── edgesoftmax.py │ ├── repeat_interleave.py │ ├── saparse_ops.py │ ├── scatterToEdge.py │ ├── scatterToVertex.py │ ├── spmmcoo.py │ ├── spmmcsr.py │ └── toundirected.py ├── partition │ ├── __init__.py │ ├── chunk_manager.py │ └── partition_graph.py ├── tests │ ├── test_aggregate.py │ ├── test_csr.py │ ├── test_edgesoftmax.py │ ├── test_edgetovertex.py │ ├── test_fromtonodes.py │ ├── test_mp_spmm.py │ ├── test_repeat_interleave.py │ ├── test_scatter_to_edge.py │ ├── test_spmmcoo.py │ ├── test_spmmcsr.py │ └── test_undirected.py ├── transforms │ ├── __init__.py │ └── normalize_features.py ├── typing.py └── utils │ ├── __init__.py │ ├── coalesce.py │ ├── degree.py │ ├── get_laplacian.py │ ├── induced_graph.py │ ├── isolated.py │ ├── loop.py │ ├── neighbor_sampler.py │ ├── num_nodes.py │ ├── one_hot.py │ ├── scatter.py │ ├── smiles.py │ ├── sort.py │ ├── sort_edge_index.py │ ├── sparse.py │ └── undirected.py ├── requirements.txt └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | my 2 | .refresh 3 | __pycache__ 4 | .ipynb_checkpoints/ 5 | .vscode/ 6 | __res/ 7 | perf.data 8 | perf.data.old 9 | *.swp 10 | *.ipynb 11 | *.pdf 12 | *.zip 13 | *.tgz 14 | test.py 15 | extern/mkl/mkldnn_lnx*/* 16 | build/ 17 | venv/ 18 | !*.src.md 19 | !README.md 20 | !README.cn.md 21 | !CHANGELOG.md 22 | python/jittor.egg-info 23 | dist/ 24 | !doc/source/* 25 | core 26 | *.log 27 | dhp_test.py 28 | dhp_test_arxiv.py 29 | examples/data/ 30 | -------------------------------------------------------------------------------- /assets/JittorGeometric_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AlgRUC/JittorGeometric/e6fb9a3401d58d9da7b3f20f07a807c7a329e0fb/assets/JittorGeometric_logo.png -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = source 9 | BUILDDIR = build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=build 12 | 13 | %SPHINXBUILD% >NUL 2>NUL 14 | if errorlevel 9009 ( 15 | echo. 16 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 17 | echo.installed, then set the SPHINXBUILD environment variable to point 18 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 19 | echo.may add the Sphinx directory to PATH. 20 | echo. 21 | echo.If you don't have Sphinx installed, grab it from 22 | echo.https://www.sphinx-doc.org/ 23 | exit /b 1 24 | ) 25 | 26 | if "%1" == "" goto help 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /docs/readthedocs.yaml: -------------------------------------------------------------------------------- 1 | version: 2 2 | 3 | build: 4 | os: ubuntu-22.04 5 | tools: 6 | python: "3.9" 7 | 8 | python: 9 | install: 10 | - requirements: docs/requirements.txt 11 | - method: pip 12 | path: . 13 | 14 | spinx: 15 | configuration: docs/source/conf.py 16 | 17 | formats: [] -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | ase==3.24.0 2 | astunparse==1.6.3 3 | autograd==1.7.0 4 | cupy==13.3.0 5 | Flask==3.1.0 6 | huggingface_hub==0.27.1 7 | jittor==1.3.9.14 8 | numpy==1.24.0 9 | pandas==2.2.3 10 | Pillow==11.1.0 11 | PyMetis==2023.1.1 12 | pyparsing==3.2.1 13 | pywebio==1.8.3 14 | recommonmark==0.7.1 15 | schnetpack==2.0.0 16 | scikit_learn==1.6.1 17 | scipy==1.15.1 18 | setuptools==69.5.1 19 | six==1.16.0 20 | sphinx_rtd_theme==3.0.2 21 | sympy==1.13.3 22 | tqdm==4.66.4 23 | -------------------------------------------------------------------------------- /docs/source/conf.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | # Configuration file for the Sphinx documentation builder. 4 | # 5 | # For the full list of built-in configuration values, see the documentation: 6 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 7 | 8 | # -- Project information ----------------------------------------------------- 9 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information 10 | 11 | project = 'Jittor_Geometric' 12 | copyright = '2025, Jittor_Geometric_Team' 13 | author = 'Jittor_Geometric_Team' 14 | release = '1.0.0' 15 | 16 | # -- General configuration --------------------------------------------------- 17 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration 18 | 19 | extensions = [ 20 | 'sphinx.ext.autodoc', 21 | 'sphinx.ext.napoleon', 22 | 'sphinx.ext.viewcode', 23 | 'sphinx_autodoc_typehints', 24 | 'sphinx.ext.autosummary', 25 | ] 26 | 27 | templates_path = ['_templates'] 28 | exclude_patterns = [] 29 | 30 | 31 | 32 | # -- Options for HTML output ------------------------------------------------- 33 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output 34 | 35 | html_theme = 'sphinx_rtd_theme' 36 | html_static_path = ['_static'] 37 | 38 | html_build_dir = os.environ.get('READTHEDOCS_OUTPUT', 'docs/en/build/html') 39 | -------------------------------------------------------------------------------- /docs/source/data/data.rst: -------------------------------------------------------------------------------- 1 | jittor_geometric.data 2 | ========================= 3 | 4 | .. automodule:: jittor_geometric.data 5 | :imported-members: 6 | :members: 7 | :undoc-members: 8 | ``` 9 | 10 | 11 | -------------------------------------------------------------------------------- /docs/source/dataloader/dataloader.rst: -------------------------------------------------------------------------------- 1 | jittor_geometric.dataloader 2 | ========================= 3 | 4 | .. automodule:: jittor_geometric.dataloader 5 | :imported-members: 6 | :members: 7 | :undoc-members: 8 | ``` 9 | 10 | 11 | -------------------------------------------------------------------------------- /docs/source/datasets/datasets.rst: -------------------------------------------------------------------------------- 1 | jittor_geometric.datasets 2 | ========================= 3 | 4 | .. automodule:: jittor_geometric.datasets 5 | :imported-members: 6 | :members: 7 | :undoc-members: 8 | ``` 9 | 10 | 11 | -------------------------------------------------------------------------------- /docs/source/get_started/introduction.rst: -------------------------------------------------------------------------------- 1 | Get Started 2 | =========== 3 | 4 | Welcome to JittorGeometric! This guide will walk you through the basic steps to get started with graph neural networks using this library. 5 | 6 | Quick Start 7 | ----------- 8 | 9 | Let's start by building a simple Graph Neural Network (GNN) model using `jittor_geometric`. 10 | 11 | Step 1: Import Libraries 12 | ------------------------ 13 | 14 | First, import the necessary libraries: 15 | 16 | .. code-block:: python 17 | 18 | import jittor as jt 19 | from jittor import nn 20 | from jittor_geometric.nn import GCNConv 21 | from jittor_geometric.datasets import Planetoid 22 | 23 | Step 2: Load a Dataset 24 | ---------------------- 25 | 26 | We will use the popular `Planetoid` dataset (e.g., Cora) for this example: 27 | 28 | .. code-block:: python 29 | 30 | dataset = Planetoid(root='your_path', name='Cora') 31 | data = dataset[0] # Getting the first graph 32 | 33 | # Prepare data 34 | from jittor_geometric.ops import cootocsr, cootocsc, gcn_norm 35 | 36 | edge_index, edge_weight = data.edge_index, data.edge_attr 37 | edge_index, edge_weight = gcn_norm(edge_index, edge_weight, v_num, improved=False, add_self_loops=True) 38 | with jt.no_grad(): 39 | data.csc = cootocsc(edge_index, edge_weight, v_num) 40 | data.csr = cootocsr(edge_index, edge_weight, v_num) 41 | 42 | Step 3: Define a Simple GCN Model 43 | --------------------------------- 44 | 45 | Now, let's define a basic Graph Convolutional Network (GCN) model: 46 | 47 | .. code-block:: python 48 | 49 | class GCNModel(jt.Module): 50 | def __init__(self, dataset, dropout=0.8): 51 | super(GCNModel, self).__init__() 52 | self.conv1 = GCNConv(in_channels=dataset.num_features, out_channels=256, spmm=args.spmm) 53 | self.conv2 = GCNConv(in_channels=256, out_channels=dataset.num_classes, spmm=args.spmm) 54 | self.dropout = dropout 55 | 56 | def execute(self): 57 | x, csc, csr = data.x, data.csc, data.csr 58 | x = nn.relu(self.conv1(x, csc, csr)) 59 | x = nn.dropout(x, self.dropout, is_train=self.training) 60 | x = self.conv2(x, csc, csr) 61 | return nn.log_softmax(x, dim=1) 62 | 63 | Step 4: Training the Model 64 | -------------------------- 65 | 66 | Let's train the model on the dataset: 67 | 68 | .. code-block:: python 69 | 70 | # Initialize the model 71 | model = GCNModel(dataset) 72 | 73 | # Set optimizer 74 | optimizer = nn.Adam(params=model.parameters(), lr=0.001, weight_decay=5e-4) 75 | 76 | # Training loop 77 | for epoch in range(200): 78 | model.train() 79 | pred = model()[data.train_mask] 80 | label = data.y[data.train_mask] 81 | loss = nn.nll_loss(pred, label) 82 | optimizer.step(loss) 83 | 84 | Step 5: Evaluate the Model 85 | -------------------------- 86 | 87 | After training, evaluate the model's performance: 88 | 89 | .. code-block:: python 90 | 91 | model.eval() 92 | out = model() 93 | pred, _ = jt.argmax(out, dim=1) 94 | y_test = data.y[data.test_mask] 95 | accuracy = pred.equal(y_test).sum().item() / data.test_mask.sum().item() 96 | print(f'Accuracy: {accuracy.item() * 100:.2f}%') 97 | 98 | Congratulations, you have successfully trained and tested a GNN model using `jittor_geometric`! 99 | 100 | Next Steps 101 | ---------- 102 | 103 | - Explore more datasets: `Planetoid`, `Cora`, `Citeseer`, etc. 104 | - Try other graph neural network layers like `SAGEConv`, `GATConv`, etc. 105 | - Check out the documentation for more advanced features. 106 | -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | Welcome to JittorGeometric's Documentation 2 | =========================================== 3 | 4 | Overview 5 | -------- 6 | 7 | JittorGeometric is a library designed for machine learning on graph data based on the Jittor framework. 8 | 9 | .. toctree:: 10 | :maxdepth: 1 11 | :caption: JittorGeometric 12 | 13 | install/installation 14 | 15 | .. toctree:: 16 | :maxdepth: 1 17 | :caption: Get Started 18 | 19 | get_started/introduction 20 | 21 | .. toctree:: 22 | :maxdepth: 2 23 | :caption: API References 24 | 25 | modules/nn 26 | data/data 27 | datasets/datasets 28 | dataloader/dataloader 29 | io/io 30 | partition/partition 31 | transforms/transforms 32 | ops/ops 33 | utils/utils -------------------------------------------------------------------------------- /docs/source/install/installation.rst: -------------------------------------------------------------------------------- 1 | ============ 2 | Installation 3 | ============ 4 | 5 | This document provides the installation instructions for jittor_geometric 6 | 7 | System Requirements 8 | ------------------- 9 | 10 | - **Python**: 3.7 or higher. 11 | - **GCC** (Linux only): 5.4 or higher. 12 | - **CMake**: 3.10 or higher. 13 | - **CUDA** (optional): 11.2 (for GPU support). 14 | - **cuDNN** (optional): Compatible with the installed CUDA version. 15 | 16 | Dependencies 17 | ----------------- 18 | 19 | - ase==3.24.0 20 | - astunparse==1.6.3 21 | - autograd==1.7.0 22 | - cupy==13.3.0 23 | - Flask==3.1.0 24 | - einops 25 | - huggingface_hub==0.27.1 26 | - jittor==1.3.9.14 27 | - numpy==1.24.0 28 | - pandas==2.2.3 29 | - Pillow==11.1.0 30 | - PyMetis==2023.1.1 31 | - pyparsing==3.2.1 32 | - pywebio==1.8.3 33 | - recommonmark==0.7.1 34 | - schnetpack==2.0.0 35 | - scikit_learn==1.6.1 36 | - scipy==1.15.1 37 | - setuptools==69.5.1 38 | - six==1.16.0 39 | - sphinx_rtd_theme==3.0.2 40 | - sympy==1.13.3 41 | - tqdm==4.66.4 42 | 43 | Installation Steps 44 | ------------------ 45 | 46 | 1. Install Jittor:: 47 | 48 | python -m pip install git+https://github.com/Jittor/jittor.git 49 | 50 | 2. Installing other dependencies, such as:: 51 | 52 | pip install astunparse==1.6.3 autograd==1.7.0 cupy==13.3.0 numpy==1.24.0 pandas==2.2.3 Pillow==11.1.0 PyMetis==2023.1.1 six==1.16.0 pyparsing==3.2.1 scipy==1.15.1 setuptools==69.5.1 sympy==1.13.3 tqdm==4.66.4 einops huggingface_hub==0.27.1 53 | 54 | 3. Install the package:: 55 | 56 | git clone https://github.com/AlgRUC/JittorGeometric.git 57 | cd JittorGeometric 58 | pip install . 59 | 60 | 4. Verify the installation 61 | Run the gcn_example.py to check if jittor_geometric is installed correctly 62 | 63 | 64 | Troubleshooting 65 | --------------- 66 | 67 | - Higher versions of cuda may not have been adapted, version 11.2 is recommended. 68 | - On Linux, ensure that GCC 5.4 or higher is installed. 69 | - Ensure that CMake 3.10 or higher is installed and accessible in your environment. 70 | 71 | If you have any questions or would like to contribute, please feel free to contact runlin_lei@ruc.edu.cn. 72 | -------------------------------------------------------------------------------- /docs/source/io/io.rst: -------------------------------------------------------------------------------- 1 | jittor_geometric.io 2 | ========================= 3 | 4 | .. automodule:: jittor_geometric.io 5 | :imported-members: 6 | :members: 7 | :undoc-members: 8 | ``` 9 | 10 | 11 | -------------------------------------------------------------------------------- /docs/source/modules/nn.rst: -------------------------------------------------------------------------------- 1 | Modules 2 | ========== 3 | 4 | Modules for Jittor Geometric. 5 | 6 | Including models, convoluition layers and aggregation layers. 7 | 8 | .. toctree:: 9 | :maxdepth: 1 10 | 11 | nn/aggr 12 | nn/conv 13 | nn/dense 14 | -------------------------------------------------------------------------------- /docs/source/modules/nn/aggr.rst: -------------------------------------------------------------------------------- 1 | jittor_geometric.nn.aggr 2 | ========================= 3 | 4 | Aggregation layers used in Graph Neural Networks. 5 | 6 | .. automodule:: jittor_geometric.nn.aggr 7 | :members: 8 | :undoc-members: 9 | :show-inheritance: 10 | -------------------------------------------------------------------------------- /docs/source/modules/nn/conv.rst: -------------------------------------------------------------------------------- 1 | jittor_geometric.nn.conv 2 | ========================= 3 | 4 | Convolutional layers used in Graph Neural Networks. 5 | 6 | .. automodule:: jittor_geometric.nn.conv 7 | :members: 8 | :undoc-members: 9 | :show-inheritance: 10 | -------------------------------------------------------------------------------- /docs/source/modules/nn/dense.rst: -------------------------------------------------------------------------------- 1 | jittor_geometric.nn.dense 2 | ========================= 3 | 4 | Dense layers used in Dynamic Graph Neural Networks. 5 | 6 | .. automodule:: jittor_geometric.nn.dense 7 | :members: 8 | :undoc-members: 9 | :show-inheritance: 10 | -------------------------------------------------------------------------------- /docs/source/ops/ops.rst: -------------------------------------------------------------------------------- 1 | jittor_geometric.ops 2 | ========================= 3 | 4 | .. automodule:: jittor_geometric.ops 5 | :imported-members: 6 | :members: 7 | :undoc-members: 8 | ``` 9 | 10 | 11 | -------------------------------------------------------------------------------- /docs/source/partition/partition.rst: -------------------------------------------------------------------------------- 1 | jittor_geometric.partition 2 | ========================= 3 | 4 | .. automodule:: jittor_geometric.partition 5 | :imported-members: 6 | :members: 7 | :undoc-members: 8 | ``` 9 | 10 | 11 | -------------------------------------------------------------------------------- /docs/source/transforms/transforms.rst: -------------------------------------------------------------------------------- 1 | jittor_geometric.transforms 2 | ========================= 3 | 4 | .. automodule:: jittor_geometric.transforms 5 | :imported-members: 6 | :members: 7 | :undoc-members: 8 | ``` 9 | 10 | 11 | -------------------------------------------------------------------------------- /docs/source/utils/utils.rst: -------------------------------------------------------------------------------- 1 | jittor_geometric.utils 2 | ========================= 3 | 4 | .. automodule:: jittor_geometric.utils 5 | :imported-members: 6 | :members: 7 | :undoc-members: 8 | ``` 9 | 10 | 11 | -------------------------------------------------------------------------------- /examples/cheb_example.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import argparse 3 | 4 | import jittor as jt 5 | from jittor import nn 6 | from jittor_geometric.datasets import Planetoid 7 | import jittor_geometric.transforms as T 8 | from jittor_geometric.nn import GCNConv, ChebConv, SGConv, GCN2Conv 9 | 10 | jt.flags.use_cuda = 0 11 | 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument('--use_gdc', action='store_true', 14 | help='Use GDC preprocessing.') 15 | args = parser.parse_args() 16 | 17 | dataset = 'Cora' 18 | path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', dataset) 19 | dataset = Planetoid(path, dataset, transform=T.NormalizeFeatures()) 20 | data = dataset[0] 21 | 22 | if args.use_gdc: 23 | gdc = T.GDC(self_loop_weight=1, normalization_in='sym', 24 | normalization_out='col', 25 | diffusion_kwargs=dict(method='ppr', alpha=0.05), 26 | sparsification_kwargs=dict(method='topk', k=128, 27 | dim=0), exact=True) 28 | data = gdc(data) 29 | 30 | 31 | class Net(nn.Module): 32 | def __init__(self): 33 | super(Net, self).__init__() 34 | self.conv1 = ChebConv(data.num_features, 16, K=2) 35 | self.conv2 = ChebConv(16, data.num_features, K=2) 36 | 37 | def execute(self): 38 | x, edge_index, edge_weight = data.x, data.edge_index, data.edge_attr 39 | x = nn.relu(self.conv1(x, edge_index, edge_weight)) 40 | x = nn.dropout(x) 41 | x = self.conv2(x, edge_index, edge_weight) 42 | return nn.log_softmax(x, dim=1) 43 | 44 | 45 | model, data = Net(), data 46 | optimizer = nn.Adam([ 47 | dict(params=model.conv1.parameters(), weight_decay=5e-4), 48 | dict(params=model.conv2.parameters(), weight_decay=0) 49 | ], lr=0.01) # Only perform weight-decay on first convolution. 50 | 51 | 52 | def train(): 53 | model.train() 54 | pred = model()[data.train_mask] 55 | label = data.y[data.train_mask] 56 | loss = nn.nll_loss(pred, label) 57 | optimizer.step(loss) 58 | 59 | 60 | def test(): 61 | model.eval() 62 | logits, accs = model(), [] 63 | for _, mask in data('train_mask', 'val_mask', 'test_mask'): 64 | y_ = data.y[mask] 65 | mask = mask 66 | tmp = [] 67 | for i in range(mask.shape[0]): 68 | if mask[i] == True: 69 | tmp.append(logits[i]) 70 | logits_ = jt.stack(tmp) 71 | pred, _ = jt.argmax(logits_, dim=1) 72 | acc = pred.equal(y_).sum().item() / mask.sum().item() 73 | accs.append(acc) 74 | return accs 75 | 76 | 77 | # train() 78 | best_val_acc = test_acc = 0 79 | for epoch in range(1, 201): 80 | train() 81 | train_acc, val_acc, tmp_test_acc = test() 82 | if val_acc > best_val_acc: 83 | best_val_acc = val_acc 84 | test_acc = tmp_test_acc 85 | log = 'Epoch: {:03d}, Train: {:.4f}, Val: {:.4f}, Test: {:.4f}' 86 | print(log.format(epoch, train_acc, best_val_acc, test_acc)) 87 | -------------------------------------------------------------------------------- /examples/evennet_example.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Author: ivam 3 | Date: 2024-12-13 4 | Description: 5 | ''' 6 | import os.path as osp 7 | import argparse 8 | 9 | import jittor as jt 10 | from jittor import nn 11 | import sys,os 12 | root = osp.dirname(osp.dirname(osp.abspath(__file__))) 13 | sys.path.append(root) 14 | from jittor_geometric.datasets import Planetoid, Amazon, WikipediaNetwork, OGBNodePropPredDataset, HeteroDataset, Reddit 15 | import jittor_geometric.transforms as T 16 | from jittor_geometric.nn import EvenNet 17 | import time 18 | from jittor_geometric.ops import cootocsr,cootocsc 19 | from jittor_geometric.nn.conv.gcn_conv import gcn_norm 20 | 21 | jt.flags.use_cuda = 1 22 | parser = argparse.ArgumentParser() 23 | parser.add_argument('--use_gdc', action='store_true', 24 | help='Use GDC preprocessing.') 25 | parser.add_argument('--dataset', default="cora", help='graph dataset') 26 | parser.add_argument('--alpha', type=float, default=0.2, help='alpha for PPR') 27 | parser.add_argument('--K', type=int, default=10, help='number of coe') 28 | parser.add_argument('--spmm', action='store_true', help='whether using spmm') 29 | args = parser.parse_args() 30 | dataset=args.dataset 31 | path = osp.join(osp.dirname(osp.realpath(__file__)), '../data') 32 | 33 | if dataset in ['computers', 'photo']: 34 | dataset = Amazon(path, dataset, transform=T.NormalizeFeatures()) 35 | elif dataset in ['cora', 'citeseer', 'pubmed']: 36 | dataset = Planetoid(path, dataset, transform=T.NormalizeFeatures()) 37 | elif dataset in ['chameleon', 'squirrel']: 38 | dataset = WikipediaNetwork(path, dataset, geom_gcn_preprocess=False) 39 | elif dataset in ['ogbn-arxiv','ogbn-products','ogbn-papers100M']: 40 | dataset = OGBNodePropPredDataset(name=dataset, root=path) 41 | elif dataset in ['roman_empire', 'amazon_ratings', 'minesweeper', 'questions', 'tolokers']: 42 | dataset = HeteroDataset(path, dataset) 43 | elif dataset in ['reddit']: 44 | dataset = Reddit(os.path.join(path, 'Reddit')) 45 | 46 | data = dataset[0] 47 | total_forward_time = 0.0 48 | total_backward_time = 0.0 49 | v_num = data.x.shape[0] 50 | edge_index, edge_weight = data.edge_index, data.edge_attr 51 | edge_index, edge_weight = gcn_norm( 52 | edge_index, edge_weight,v_num, 53 | improved=False, add_self_loops=False) 54 | with jt.no_grad(): 55 | data.csc = cootocsc(edge_index, edge_weight, v_num) 56 | data.csr = cootocsr(edge_index, edge_weight, v_num) 57 | 58 | 59 | class Net(nn.Module): 60 | def __init__(self, dataset, dropout=0.5): 61 | super(Net, self).__init__() 62 | hidden = 64 63 | self.lin1 = nn.Linear(dataset.num_features, hidden) 64 | self.lin2 = nn.Linear(hidden, dataset.num_classes) 65 | 66 | self.prop = EvenNet(args.K, args.alpha) 67 | self.dropout = dropout 68 | 69 | def execute(self): 70 | x, csc, csr = data.x, data.csc, data.csr 71 | x = nn.dropout(x, self.dropout) 72 | x = nn.relu(self.lin1(x)) 73 | x = nn.dropout(x, self.dropout) 74 | x = self.lin2(x) 75 | x = self.prop(x, csc, csr) 76 | 77 | return nn.log_softmax(x, dim=1) 78 | 79 | 80 | model, data = Net(dataset), data 81 | optimizer = nn.Adam(params=model.parameters(), lr=0.01, weight_decay=5e-4) 82 | 83 | def train(): 84 | global total_forward_time, total_backward_time 85 | model.train() 86 | pred = model()[data.train_mask] 87 | label = data.y[data.train_mask] 88 | loss = nn.nll_loss(pred, label) 89 | optimizer.step(loss) 90 | 91 | def test(): 92 | model.eval() 93 | logits, accs = model(), [] 94 | for _, mask in data('train_mask', 'val_mask', 'test_mask'): 95 | y_ = data.y[mask] 96 | logits_=logits[mask] 97 | pred, _ = jt.argmax(logits_, dim=1) 98 | acc = pred.equal(y_).sum().item() / mask.sum().item() 99 | accs.append(acc) 100 | return accs 101 | 102 | 103 | train() 104 | best_val_acc = test_acc = 0 105 | start = time.time() 106 | for epoch in range(1, 201): 107 | train() 108 | train_acc, val_acc, tmp_test_acc = test() 109 | if val_acc > best_val_acc: 110 | best_val_acc = val_acc 111 | test_acc = tmp_test_acc 112 | log = 'Epoch: {:03d}, Train: {:.4f}, Val: {:.4f}, Test: {:.4f}' 113 | print(log.format(epoch, train_acc, best_val_acc, test_acc)) 114 | 115 | jt.sync_all() 116 | end = time.time() 117 | print("Training_time"+str(end-start)) -------------------------------------------------------------------------------- /examples/gcn2_example.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import jittor as jt 3 | from jittor import nn 4 | import sys 5 | root = osp.dirname(osp.dirname(osp.abspath(__file__))) 6 | sys.path.append(root) 7 | from jittor_geometric.datasets import Planetoid 8 | import jittor_geometric.transforms as T 9 | 10 | from jittor_geometric.nn import GCN2Conv 11 | from jittor_geometric.ops import cootocsr,cootocsc 12 | from jittor_geometric.nn.conv.gcn_conv import gcn_norm 13 | from math import log 14 | import argparse 15 | 16 | 17 | jt.flags.use_cuda = 1 18 | 19 | parser = argparse.ArgumentParser() 20 | parser.add_argument('--spmm', action='store_true', help='whether using spmm') 21 | args = parser.parse_args() 22 | 23 | dataset_name = 'cora' 24 | path = osp.join(osp.dirname(osp.realpath(__file__)), '../data') 25 | 26 | if dataset_name in ['cora', 'citeseer', 'pubmed']: 27 | dataset = Planetoid(path, dataset_name, transform=T.NormalizeFeatures()) 28 | else: 29 | # See more dataset examples in ./dataset_example.py 30 | pass 31 | 32 | data = dataset[0] 33 | v_num = data.x.shape[0] 34 | edge_index, edge_weight = data.edge_index, data.edge_attr 35 | edge_index, edge_weight = gcn_norm( 36 | edge_index, edge_weight,v_num, 37 | improved=False, add_self_loops=True) 38 | with jt.no_grad(): 39 | data.csc = cootocsc(edge_index, edge_weight, v_num) 40 | data.csr = cootocsr(edge_index, edge_weight, v_num) 41 | 42 | 43 | class Net(nn.Module): 44 | def __init__(self, dataset, hidden_channels, num_layers=64, alpha=0.1, lamda=0.5, dropout=0.6): 45 | super(Net, self).__init__() 46 | 47 | self.lins = nn.ModuleList() 48 | self.lins.append(nn.Linear(dataset.num_features, hidden_channels)) 49 | self.lins.append(nn.Linear(hidden_channels, dataset.num_classes)) 50 | 51 | self.convs = nn.ModuleList() 52 | for layer in range(num_layers): 53 | self.convs.append( 54 | GCN2Conv(hidden_channels, hidden_channels, spmm=args.spmm)) 55 | 56 | self.dropout = dropout 57 | self.alpha = alpha 58 | self.lamda = lamda 59 | 60 | def execute(self): 61 | x, csc, csr = data.x, data.csc, data.csr 62 | _hidden = [] 63 | x = nn.relu(self.lins[0](x)) 64 | _hidden.append(x) 65 | 66 | for i, conv in enumerate(self.convs): 67 | x = nn.dropout(x, self.dropout, is_train=self.training) 68 | alpha = self.alpha 69 | beta = log(self.lamda / (i + 1) + 1) 70 | x = conv(x, _hidden[0], csc, csr, alpha, beta) 71 | x = nn.relu(x) 72 | 73 | x = nn.dropout(x, self.dropout, is_train=self.training) 74 | x = self.lins[1](x) 75 | 76 | return nn.log_softmax(x, dim=-1) 77 | 78 | 79 | model = Net(dataset, hidden_channels=64, num_layers=64, alpha=0.1, lamda=0.5, dropout=0.6) 80 | optimizer = nn.Adam([ 81 | dict(params=model.convs.parameters(), weight_decay=0.01), 82 | dict(params=model.lins.parameters(), weight_decay=5e-4) 83 | ], lr=0.01) 84 | 85 | 86 | print(model) 87 | 88 | 89 | def train(): 90 | model.train() 91 | out = model()[data.train_mask] 92 | label = data.y[data.train_mask] 93 | loss = nn.nll_loss(out, label) 94 | optimizer.step(loss) 95 | return float(loss) 96 | 97 | 98 | def test(): 99 | model.eval() 100 | logits, accs = model(), [] 101 | for _, mask in data('train_mask', 'val_mask', 'test_mask'): 102 | y_ = data.y[mask] 103 | logits_=logits[mask] 104 | pred, _ = jt.argmax(logits_, dim=1) 105 | acc = pred.equal(y_).sum().item() / mask.sum().item() 106 | accs.append(acc) 107 | return accs 108 | 109 | 110 | best_val_acc = test_acc = 0 111 | for epoch in range(1, 1001): 112 | loss = train() 113 | train_acc, val_acc, tmp_test_acc = test() 114 | if val_acc > best_val_acc: 115 | best_val_acc = val_acc 116 | test_acc = tmp_test_acc 117 | print(f'Epoch: {epoch:04d}, Loss: {loss:.4f} Train: {train_acc:.4f}, ' 118 | f'Val: {val_acc:.4f}, Test: {tmp_test_acc:.4f}, ' 119 | f'Final Test: {test_acc:.4f}') 120 | -------------------------------------------------------------------------------- /examples/sage_example.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import argparse 3 | 4 | import jittor as jt 5 | from jittor import nn 6 | import sys,os 7 | root = osp.dirname(osp.dirname(osp.abspath(__file__))) 8 | sys.path.append(root) 9 | from jittor_geometric.datasets import Planetoid, Amazon, WikipediaNetwork, OGBNodePropPredDataset, HeteroDataset, Reddit 10 | import jittor_geometric.transforms as T 11 | from jittor_geometric.nn import GCNConv, SAGEConv 12 | import time 13 | from jittor_geometric.ops import cootocsr,cootocsc 14 | from jittor_geometric.nn.conv.sage_conv import sage_norm 15 | 16 | jt.flags.use_cuda = 1 17 | jt.misc.set_global_seed(42) 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument('--dataset', help='graph dataset') 20 | parser.add_argument('--spmm', action='store_true', help='whether using spmm') 21 | args = parser.parse_args() 22 | dataset=args.dataset 23 | path = osp.join(osp.dirname(osp.realpath(__file__)), '../data') 24 | 25 | if dataset in ['computers', 'photo']: 26 | dataset = Amazon(path, dataset, transform=T.NormalizeFeatures()) 27 | elif dataset in ['cora', 'citeseer', 'pubmed']: 28 | dataset = Planetoid(path, dataset, transform=T.NormalizeFeatures()) 29 | elif dataset in ['chameleon', 'squirrel']: 30 | dataset = WikipediaNetwork(path, dataset, geom_gcn_preprocess=False) 31 | elif dataset in ['ogbn-arxiv','ogbn-products','ogbn-papers100M']: 32 | dataset = OGBNodePropPredDataset(name=dataset, root=path) 33 | elif dataset in ['roman_empire', 'amazon_ratings', 'minesweeper', 'questions', 'tolokers']: 34 | dataset = HeteroDataset(path, dataset) 35 | elif dataset in ['reddit']: 36 | dataset = Reddit(os.path.join(path, 'Reddit')) 37 | 38 | data = dataset[0] 39 | total_forward_time = 0.0 40 | total_backward_time = 0.0 41 | 42 | v_num = data.x.shape[0] 43 | edge_index, edge_weight = data.edge_index, data.edge_attr 44 | edge_index, edge_weight = sage_norm( 45 | edge_index, edge_weight,v_num, 46 | improved=False, add_self_loops=True) 47 | with jt.no_grad(): 48 | data.csc = cootocsc(edge_index, edge_weight, v_num) 49 | data.csr = cootocsr(edge_index, edge_weight, v_num) 50 | 51 | 52 | class Net(nn.Module): 53 | def __init__(self, dataset, dropout=0.8): 54 | super(Net, self).__init__() 55 | self.conv1 = SAGEConv(in_channels=dataset.num_features, out_channels=256, cached = True, root_weight = False, spmm=args.spmm) 56 | self.conv2 = SAGEConv(in_channels=256, out_channels=dataset.num_classes, cached = True, root_weight = False, spmm=args.spmm) 57 | self.dropout = dropout 58 | 59 | def execute(self): 60 | x, edge_index = data.x, data.edge_index 61 | x = nn.relu(self.conv1(x, edge_index)) 62 | x = nn.dropout(x, self.dropout, is_train=self.training) 63 | x = self.conv2(x, edge_index) 64 | return nn.log_softmax(x, dim=1) 65 | 66 | 67 | 68 | model, data =Net(dataset), data 69 | optimizer = nn.Adam(params=model.parameters(), lr=0.001, weight_decay=5e-4) 70 | 71 | def train(): 72 | global total_forward_time, total_backward_time 73 | model.train() 74 | pred = model()[data.train_mask] 75 | label = data.y[data.train_mask] 76 | loss = nn.nll_loss(pred, label) 77 | optimizer.step(loss) 78 | 79 | def test(): 80 | model.eval() 81 | logits, accs = model(), [] 82 | for _, mask in data('train_mask', 'val_mask', 'test_mask'): 83 | y_ = data.y[mask] 84 | logits_=logits[mask] 85 | pred, _ = jt.argmax(logits_, dim=1) 86 | acc = pred.equal(y_).sum().item() / mask.sum().item() 87 | accs.append(acc) 88 | return accs 89 | 90 | 91 | best_val_acc = test_acc = 0 92 | start = time.time() 93 | for epoch in range(1, 201): 94 | train() 95 | train_acc, val_acc, tmp_test_acc = test() 96 | if val_acc > best_val_acc: 97 | best_val_acc = val_acc 98 | test_acc = tmp_test_acc 99 | log = 'Epoch: {:03d}, Train: {:.4f}, Val: {:.4f}, Test: {:.4f}' 100 | print(log.format(epoch, train_acc, best_val_acc, test_acc)) 101 | 102 | jt.sync_all() 103 | end = time.time() 104 | print("Training_time"+str(end-start)) -------------------------------------------------------------------------------- /examples/schnet_example.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os.path as osp 3 | 4 | from tqdm import tqdm 5 | import jittor as jt 6 | from jittor_geometric.datasets import QM9 7 | import jittor_geometric.transforms as T 8 | from jittor_geometric.dataloader import DataLoader 9 | from jittor import nn 10 | from jittor_geometric.nn import SchNet 11 | 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument('--cutoff', type=float, default=10.0, 14 | help='Cutoff distance for interatomic interactions') 15 | args = parser.parse_args() 16 | 17 | path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'QM9') 18 | # dataset = QM9(path) 19 | qm9_dataset = QM9(path, transform=T.NormalizeFeatures()) 20 | print(len(qm9_dataset)) 21 | # random split train/val/test = 8/1/1 22 | split_dict = qm9_dataset.get_idx_split() 23 | 24 | # # dataloader 25 | train_loader = DataLoader(qm9_dataset[split_dict["train"]], batch_size=8, shuffle=True) 26 | valid_loader = DataLoader(qm9_dataset[split_dict["valid"]], batch_size=8, shuffle=False) 27 | test_loader = DataLoader(qm9_dataset[split_dict["test"]], batch_size=8, shuffle=False) 28 | 29 | def train(model, loader, optimizer, target): 30 | model.train() 31 | total_loss = 0 32 | for data in tqdm(loader, desc='Training'): 33 | optimizer.zero_grad() 34 | pred = model(data.z, data.pos, data.batch) 35 | loss = nn.MSELoss()(pred.view(-1), data.y[:, target]) 36 | optimizer.step(loss) 37 | total_loss += loss.item() * data.num_graphs 38 | return total_loss / len(loader.dataset) 39 | def evaluate(model, loader, target): 40 | model.eval() 41 | maes = [] 42 | for data in loader: 43 | with jt.no_grad(): 44 | pred = model(data.z, data.pos, data.batch) 45 | mae = (pred.view(-1).numpy() - data.y[:, target].numpy()).abs() 46 | maes.append(mae) 47 | mae = jt.cat(maes, dim=0) 48 | return mae.mean() 49 | # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 50 | jt.flags.use_cuda = 1 51 | 52 | for target in range(12): 53 | # model, datasets = SchNet.from_qm9_pretrained(path, qm9_dataset, target) 54 | # train_dataset, val_dataset, test_dataset = datasets 55 | 56 | model = SchNet(hidden_channels=128, num_filters=128, num_interactions=6, 57 | num_gaussians=50, cutoff=args.cutoff) 58 | optimizer = jt.optim.Adam(model.parameters(), lr=0.0001) 59 | 60 | # 训练循环 61 | best_val_mae = float('inf') 62 | for epoch in range(30): # 训练30个epoch 63 | loss = train(model, train_loader, optimizer, target) 64 | val_mae = evaluate(model, valid_loader, target) 65 | 66 | if val_mae < best_val_mae: 67 | best_val_mae = val_mae 68 | best_model = model 69 | 70 | print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Val MAE: {val_mae:.4f}') 71 | 72 | # 测试过程保持不变 73 | # loader = DataLoader(test_dataset, batch_size=256) 74 | 75 | maes = [] 76 | for data in tqdm(test_loader): 77 | # data = data.to(device) 78 | with jt.no_grad(): 79 | pred = model(data.z, data.pos, data.batch) 80 | mae = (pred.view(-1) - data.y[:, target]).abs() 81 | maes.append(mae) 82 | 83 | mae = jt.cat(maes, dim=0) 84 | 85 | # Report meV instead of eV. 86 | mae = 1000 * mae if target in [2, 3, 4, 6, 7, 8, 9, 10] else mae 87 | 88 | print(f'Target: {target:02d}, MAE: {mae.mean():.5f} ± {mae.std():.5f}') 89 | -------------------------------------------------------------------------------- /examples/sgc_example.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import argparse 3 | 4 | import jittor as jt 5 | from jittor import nn 6 | from jittor_geometric.datasets import Planetoid 7 | import jittor_geometric.transforms as T 8 | from jittor_geometric.nn import SGConv 9 | from jittor_geometric.ops import cootocsr,cootocsc 10 | from jittor_geometric.nn.conv.gcn_conv import gcn_norm 11 | 12 | jt.flags.use_cuda = 1 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument('--dataset', help='graph dataset') 15 | parser.add_argument('--spmm', action='store_true', help='whether using spmm') 16 | args = parser.parse_args() 17 | dataset=args.dataset 18 | path = osp.join(osp.dirname(osp.realpath(__file__)), '../data') 19 | 20 | 21 | dataset = Planetoid(path, dataset, transform=T.NormalizeFeatures()) 22 | 23 | 24 | data = dataset[0] 25 | total_forward_time = 0.0 26 | total_backward_time = 0.0 27 | v_num = data.x.shape[0] 28 | edge_index, edge_weight = data.edge_index, data.edge_attr 29 | edge_index, edge_weight = gcn_norm( 30 | edge_index, edge_weight,v_num, 31 | improved=False, add_self_loops=True) 32 | with jt.no_grad(): 33 | data.csc = cootocsc(edge_index, edge_weight, v_num) 34 | data.csr = cootocsr(edge_index, edge_weight, v_num) 35 | 36 | class Net(nn.Module): 37 | def __init__(self, dataset, dropout=0.8): 38 | super(Net, self).__init__() 39 | self.conv1 = SGConv(in_channels=dataset.num_features, out_channels=64, K=2, spmm=args.spmm) 40 | self.conv2 = SGConv(in_channels=64, out_channels=dataset.num_classes, K=2, spmm=args.spmm) 41 | self.dropout = dropout 42 | 43 | def execute(self): 44 | x, csc, csr = data.x, data.csc, data.csr 45 | x = nn.relu(self.conv1(x, csc, csr)) 46 | x = nn.dropout(x, self.dropout, is_train=self.training) 47 | x = self.conv2(x, csc, csr) 48 | return nn.log_softmax(x, dim=1) 49 | 50 | 51 | 52 | model, data = Net(dataset), data 53 | optimizer = nn.Adam(model.parameters(), lr=0.01, weight_decay=0.0005) 54 | 55 | 56 | def train(): 57 | model.train() 58 | pred = model()[data.train_mask] 59 | label = data.y[data.train_mask] 60 | loss = nn.nll_loss(pred, label) 61 | optimizer.step(loss) 62 | 63 | 64 | def test(): 65 | model.eval() 66 | logits, accs = model(), [] 67 | for _, mask in data('train_mask', 'val_mask', 'test_mask'): 68 | y_ = data.y[mask] 69 | mask = mask 70 | tmp = [] 71 | for i in range(mask.shape[0]): 72 | if mask[i] == True: 73 | tmp.append(logits[i]) 74 | logits_ = jt.stack(tmp) 75 | pred, _ = jt.argmax(logits_, dim=1) 76 | acc = pred.equal(y_).sum().item() / mask.sum().item() 77 | accs.append(acc) 78 | return accs 79 | 80 | 81 | best_val_acc = test_acc = 0 82 | for epoch in range(1, 201): 83 | train() 84 | train_acc, val_acc, tmp_test_acc = test() 85 | if val_acc > best_val_acc: 86 | best_val_acc = val_acc 87 | test_acc = tmp_test_acc 88 | log = 'Epoch: {:03d}, Train: {:.4f}, Val: {:.4f}, Test: {:.4f}' 89 | print(log.format(epoch, train_acc, best_val_acc, test_acc)) 90 | -------------------------------------------------------------------------------- /examples/spherenet_example.py: -------------------------------------------------------------------------------- 1 | import jittor as jt 2 | import os.path as osp 3 | import sys,os 4 | root = osp.dirname(osp.dirname(osp.abspath(__file__))) 5 | sys.path.append(root) 6 | from jittor import nn 7 | from jittor_geometric.nn import EGNNConv, global_add_pool 8 | from jittor_geometric.nn.conv.spherenet_conv import SphereNet 9 | from jittor_geometric.typing import Var 10 | from jittor_geometric.datasets import QM9 11 | import jittor_geometric.transforms as T 12 | from jittor_geometric.jitgeo_loader import DataLoader 13 | import jittor_geometric.jitgeo_loader 14 | from tqdm import tqdm 15 | import numpy as np 16 | 17 | # sample synthetic data (e.g., random graph) 18 | def generate_data(num_nodes, num_edges): 19 | x = jt.randn((num_nodes, 6)) # 3 coordinates + 3 features 20 | edge_index = jt.randint(0, num_nodes, (2, num_edges)) # Random edge indices 21 | edge_attr = jt.randn((num_edges, 3)) # Random edge attributes 22 | return x, edge_index, edge_attr 23 | 24 | 25 | # Define MAE loss function 26 | def mae_loss(pred: Var, target: Var) -> Var: 27 | return jt.abs(pred - target).mean() 28 | 29 | 30 | # Run training 31 | def train(model, loader, optimizer): 32 | model.train() 33 | loss_accum = 0 34 | 35 | # batch_data.z, batch_data.pos, batch_data.pos 36 | for step, batch_data in enumerate(tqdm(loader, desc="Iteration")): 37 | pred = model(batch_data) 38 | loss = mae_loss(pred, batch_data.y) 39 | optimizer.step(loss) 40 | loss_accum += loss 41 | 42 | return float(loss_accum / (step + 1)) 43 | 44 | 45 | def eval(model, loader): 46 | model.eval() 47 | y_true = [] 48 | y_pred = [] 49 | with jt.no_grad(): 50 | # batch_data.z, batch_data.pos, batch_data.pos 51 | for step, batch_data in enumerate(tqdm(loader, desc="Iteration")): 52 | pred = model(batch_data) 53 | y_true.append(batch_data.y.numpy()) 54 | y_pred.append(pred.numpy()) 55 | 56 | y_true = jt.cat(y_true, dim = 0) 57 | y_pred = jt.cat(y_pred, dim = 0) 58 | 59 | return float(mae_loss(y_pred, y_true)) 60 | 61 | def main(): 62 | # data 63 | dataset_name = 'qm9' 64 | path = osp.join(osp.dirname(osp.realpath(__file__)), '../data/QM9') 65 | qm9_dataset = QM9(path, transform=T.NormalizeFeatures()) 66 | # random split train/val/test = 8/1/1 67 | split_dict = qm9_dataset.get_idx_split() 68 | 69 | # dataloader 70 | train_loader = DataLoader(qm9_dataset[split_dict["train"]], batch_size=8, shuffle=True) 71 | valid_loader = DataLoader(qm9_dataset[split_dict["valid"]], batch_size=8, shuffle=False) 72 | test_loader = DataLoader(qm9_dataset[split_dict["test"]], batch_size=8, shuffle=False) 73 | 74 | # model 75 | model = SphereNet(energy_and_force=False, cutoff=5.0, num_layers=4, 76 | hidden_channels=128, out_channels=1, int_emb_size=64, 77 | basis_emb_size_dist=8, basis_emb_size_angle=8, basis_emb_size_torsion=8, out_emb_channels=256, 78 | num_spherical=7, num_radial=6, envelope_exponent=5, 79 | num_before_skip=1, num_after_skip=2, num_output_layers=3, 80 | output_init='GlorotOrthogonal', use_node_features=True) 81 | 82 | optimizer = jt.optim.Adam(model.parameters(), lr=1e-3) 83 | 84 | best_valid_mae = 1000 85 | 86 | for epoch in range(1, 3): 87 | print("=====Epoch {}".format(epoch)) 88 | print('Training...') 89 | train_mae = train(model, train_loader, optimizer) 90 | 91 | print('Evaluating...') 92 | valid_mae = eval(model, valid_loader) 93 | 94 | print('Testing...') 95 | test_mae = eval(model, test_loader) 96 | 97 | print({'Train': train_mae, 'Validation': valid_mae, 'Test': test_mae}) 98 | 99 | if valid_mae < best_valid_mae: 100 | best_valid_mae = valid_mae 101 | print(f'Best validation MAE so far: {best_valid_mae}') 102 | 103 | 104 | if __name__ == "__main__": 105 | main() 106 | -------------------------------------------------------------------------------- /jittor_geometric/__init__.py: -------------------------------------------------------------------------------- 1 | from types import ModuleType 2 | from importlib import import_module 3 | 4 | 5 | __version__ = '0.0.1' 6 | 7 | __all__ = [ 8 | 'jittor_geometric', 9 | '__version__', 10 | ] 11 | -------------------------------------------------------------------------------- /jittor_geometric/data/__init__.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Author: lusz 3 | Date: 2025-01-14 00:47:42 4 | Description: 5 | ''' 6 | from .data import Data 7 | from .dataset import Dataset 8 | from .in_memory_dataset import InMemoryDataset 9 | from .download import download_url, decide_download, extract_zip, extract_gz 10 | from .data import CSC,CSR 11 | from .temporal import TemporalData 12 | from .batch import Batch 13 | from .dictionary import Dictionary 14 | from .conformer import ConformerGen 15 | from .graphchunk import GraphChunk 16 | __all__ = [ 17 | 'Data', 18 | 'Dataset', 19 | 'InMemoryDataset', 20 | 'download_url', 21 | 'decide_download', 22 | 'extract_zip', 23 | 'extract_gz', 24 | 'CSC', 25 | 'CSR', 26 | 'GraphChunk', 27 | 'TemporalData', 28 | 'Batch', 29 | 'Dictionary', 30 | 'ConformerGen', 31 | ] 32 | -------------------------------------------------------------------------------- /jittor_geometric/data/download.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import ssl 4 | import os 5 | import os.path as osp 6 | import urllib.request as ur 7 | import zipfile 8 | from six.moves import urllib 9 | import gzip 10 | from .makedirs import makedirs 11 | 12 | 13 | GBFACTOR = float(1 << 30) 14 | 15 | 16 | def download_url(url, folder, log=True): 17 | r"""Downloads the content of an URL to a specific folder. 18 | 19 | Args: 20 | url (string): The url. 21 | folder (string): The folder. 22 | log (bool, optional): If :obj:`False`, will not print anything to the 23 | console. (default: :obj:`True`) 24 | """ 25 | 26 | filename = url.rpartition('/')[2].split('?')[0] 27 | path = osp.join(folder, filename) 28 | 29 | if osp.exists(path): # pragma: no cover 30 | if log: 31 | print('Using exist file', filename) 32 | return path 33 | 34 | if log: 35 | print('Downloading', url) 36 | 37 | makedirs(folder) 38 | 39 | context = ssl._create_unverified_context() 40 | data = urllib.request.urlopen(url, context=context) 41 | 42 | with open(path, 'wb') as f: 43 | f.write(data.read()) 44 | 45 | return path 46 | 47 | 48 | def decide_download(url): 49 | d = ur.urlopen(url) 50 | size = int(d.info()["Content-Length"])/GBFACTOR 51 | 52 | ### confirm if larger than 1GB 53 | if size > 1: 54 | return input("This will download %.2fGB. Will you proceed? (y/N)\n" % (size)).lower() == "y" 55 | else: 56 | return True 57 | 58 | 59 | def maybe_log(path, log=True): 60 | if log: 61 | print('Extracting', path) 62 | 63 | 64 | def extract_zip(path, folder, log=True): 65 | r"""Extracts a zip archive to a specific folder. 66 | Args: 67 | path (string): The path to the tar archive. 68 | folder (string): The folder. 69 | log (bool, optional): If :obj:`False`, will not print anything to the 70 | console. (default: :obj:`True`) 71 | """ 72 | maybe_log(path, log) 73 | with zipfile.ZipFile(path, 'r') as f: 74 | f.extractall(folder) 75 | 76 | 77 | def extract_gz(path: str, folder: str, log: bool = True) -> None: 78 | r"""Extracts a gz archive to a specific folder. 79 | 80 | Args: 81 | path (str): The path to the tar archive. 82 | folder (str): The folder. 83 | log (bool, optional): If :obj:`False`, will not print anything to the 84 | console. (default: :obj:`True`) 85 | """ 86 | maybe_log(path, log) 87 | path = osp.abspath(path) 88 | with gzip.open(path, 'r') as r: 89 | with open(osp.join(folder, '.'.join(path.split('.')[:-1])), 'wb') as w: 90 | w.write(r.read()) -------------------------------------------------------------------------------- /jittor_geometric/data/graphchunk.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | from jittor_geometric.data import CSR 3 | import pickle 4 | 5 | class GraphChunk: 6 | def __init__(self, 7 | chunks: int, 8 | chunk_id: int, 9 | v_num: int, 10 | global_v_num: int, 11 | local_mask: dict = None, 12 | local_feature: object = None, 13 | local_label: object = None): 14 | 15 | self.chunks = chunks 16 | self.chunk_id = chunk_id 17 | self.v_num = v_num 18 | self.global_v_num = global_v_num 19 | self.CSR = None 20 | self.local_mask = local_mask 21 | self.local_feature = local_feature 22 | self.local_label = local_label 23 | 24 | def set_csr(self, column_indices, row_offset, edge_weight=None): 25 | """ 26 | Set the CSR (Compressed Sparse Row) representation of the graph. 27 | :param column_indices: Column indices of the non-zero elements. 28 | :param row_offset: Row offsets for the CSR format. 29 | :param edge_weight: Optional edge weights. 30 | """ 31 | self.CSR = CSR(column_indices, row_offset, edge_weight) 32 | 33 | def save(self, file_path: str): 34 | """ 35 | Save the GraphChunk instance as a binary file. 36 | :param file_path: Path to the file where the instance will be saved. 37 | """ 38 | with open(file_path, 'wb') as f: 39 | pickle.dump(self, f) 40 | 41 | @staticmethod 42 | def load(file_path: str): 43 | """ 44 | Load a GraphChunk instance from a binary file. 45 | :param file_path: Path to the file from which the instance will be loaded. 46 | :return: Loaded GraphChunk instance. 47 | """ 48 | with open(file_path, 'rb') as f: 49 | return pickle.load(f) 50 | -------------------------------------------------------------------------------- /jittor_geometric/data/makedirs.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import errno 4 | 5 | 6 | def makedirs(path): 7 | try: 8 | os.makedirs(osp.expanduser(osp.normpath(path))) 9 | except OSError as e: 10 | if e.errno != errno.EEXIST and osp.isdir(path): 11 | raise e 12 | -------------------------------------------------------------------------------- /jittor_geometric/dataloader/__init__.py: -------------------------------------------------------------------------------- 1 | from .random_node_loader import RandomNodeLoader 2 | from .general_loader import GeneralLoader 3 | from .neighbor_loader import NeighborLoader 4 | from .cluster_loader import ClusterLoader 5 | from .dataloader import DataLoader 6 | 7 | __all__ = [ 8 | 'GeneralLoader', 9 | 'RandomNodeLoader', 10 | 'NeighborLoader', 11 | 'ClusterLoader', 12 | 'DataLoader', 13 | ] -------------------------------------------------------------------------------- /jittor_geometric/dataloader/dataloader.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Mapping 2 | from typing import Any, List, Optional, Sequence, Union, Callable 3 | import jittor 4 | from jittor.dataset.utils import collate_batch 5 | from jittor_geometric.data import Batch, Dataset 6 | from jittor_geometric.data.data import Data 7 | import numpy as np 8 | 9 | 10 | class Collater: 11 | def __init__( 12 | self, 13 | dataset: Union[Dataset, Sequence[Data]], 14 | follow_batch: Optional[List[str]] = None, 15 | exclude_keys: Optional[List[str]] = None, 16 | ): 17 | self.dataset = dataset 18 | self.follow_batch = follow_batch 19 | self.exclude_keys = exclude_keys 20 | 21 | def __call__(self, batch: List[Any]) -> Any: 22 | elem = batch[0] 23 | if isinstance(elem, Data): 24 | return Batch.from_data_list( 25 | batch, 26 | follow_batch=self.follow_batch, 27 | exclude_keys=self.exclude_keys, 28 | ) 29 | elif isinstance(elem, jittor.Var): 30 | return collate_batch(batch) 31 | elif isinstance(elem, float): 32 | return collate_batch(batch) 33 | elif isinstance(elem, int): 34 | return collate_batch(batch) 35 | elif isinstance(elem, str): 36 | return batch 37 | elif isinstance(elem, Mapping): 38 | return collate_batch(batch) 39 | elif isinstance(elem, tuple) and hasattr(elem, '_fields'): 40 | return collate_batch(batch) 41 | elif isinstance(elem, Sequence) and not isinstance(elem, str): 42 | return collate_batch(batch) 43 | 44 | raise TypeError(f"DataLoader found invalid type: '{type(elem)}'") 45 | 46 | 47 | class DataLoader: 48 | r"""A data loader which merges data objects from a 49 | :class:`jittor_geometric.data.Dataset` to a mini-batch. 50 | Data objects can be either of type :class:`~jittor_geometric.data.Data` or 51 | :class:`~jittor_geometric.data.HeteroData`. 52 | 53 | Args: 54 | dataset (Dataset): The dataset from which to load the data. 55 | batch_size (int, optional): How many samples per batch to load. 56 | (default: :obj:`1`) 57 | shuffle (bool, optional): If set to :obj:`True`, the data will be 58 | reshuffled at every epoch. (default: :obj:`False`) 59 | follow_batch (List[str], optional): Creates assignment batch 60 | vectors for each key in the list. (default: :obj:`None`) 61 | exclude_keys (List[str], optional): Will exclude each key in the 62 | list. (default: :obj:`None`) 63 | **kwargs (optional): Additional arguments. 64 | """ 65 | def __init__( 66 | self, 67 | dataset: Union[Dataset, Sequence[Data]], 68 | batch_size: int = 1, 69 | shuffle: bool = False, 70 | follow_batch: Optional[List[str]] = None, 71 | exclude_keys: Optional[List[str]] = None, 72 | collate_fn: Optional[Callable] = None, 73 | **kwargs, 74 | ): 75 | self.dataset = dataset 76 | self.batch_size = batch_size 77 | self.shuffle = shuffle 78 | self.follow_batch = follow_batch 79 | self.exclude_keys = exclude_keys 80 | self.collate_fn = collate_fn if collate_fn is not None else Collater(dataset, follow_batch, exclude_keys) 81 | 82 | # Initialize indices 83 | self.indices = np.arange(len(dataset)) 84 | if self.shuffle: 85 | np.random.shuffle(self.indices) 86 | 87 | def __iter__(self): 88 | # Reset indices if shuffle is enabled 89 | if self.shuffle: 90 | np.random.shuffle(self.indices) 91 | 92 | # Yield batches 93 | for start_idx in range(0, len(self.indices), self.batch_size): 94 | end_idx = min(start_idx + self.batch_size, len(self.indices)) 95 | batch_indices = self.indices[start_idx:end_idx].tolist() 96 | batch = [self.dataset[i] for i in batch_indices] 97 | yield self.collate_fn(batch) 98 | 99 | def __len__(self): 100 | return (len(self.dataset) + self.batch_size - 1) // self.batch_size -------------------------------------------------------------------------------- /jittor_geometric/dataloader/general_loader.py: -------------------------------------------------------------------------------- 1 | import jittor as jt 2 | from jittor_geometric.data import InMemoryDataset 3 | from typing import Optional 4 | 5 | #class GeneralIterator 6 | 7 | class GeneralLoader: 8 | dataset: InMemoryDataset 9 | itercnt: int 10 | itermax: int 11 | 12 | def __init__(self, 13 | dataset: InMemoryDataset, 14 | shuffle: Optional[bool] = None): 15 | self.itercnt = 0 16 | self.itercnt = 0 17 | self.dataset = dataset 18 | 19 | def __reset__(self): 20 | self.itercnt = 0 21 | 22 | def __iter__(self): 23 | return self 24 | 25 | def __next__(self): 26 | if self.itercnt == 0: 27 | self.itercnt += 1 28 | return self.dataset[0] 29 | else: 30 | self.__reset__() 31 | raise StopIteration 32 | 33 | 34 | # def __collate__(self): 35 | # raise NotImplementedError -------------------------------------------------------------------------------- /jittor_geometric/dataloader/graphsaint_loader.py: -------------------------------------------------------------------------------- 1 | import jittor as jt 2 | from jittor import sparse 3 | import copy 4 | from .general_loader import GeneralLoader 5 | from ..utils import induced_graph 6 | from jittor_geometric.partition.chunk_manager import ChunkManager 7 | import numpy as np 8 | 9 | 10 | class GraphSAINTLoader(GeneralLoader): 11 | 12 | r''' 13 | NOT YET FINISHED! 14 | ''' 15 | 16 | def __init__(self, dataset, num_parts: int, mini_splits: int, fixed: bool = False): 17 | self.data = copy.copy(dataset[0]) 18 | self.N = self.data.num_nodes 19 | self.E = self.data.num_edges 20 | self.edge_index = copy.copy(self.data.edge_index) 21 | self.data.edge_index = None 22 | 23 | self.num_parts = num_parts 24 | self.mini_splits = mini_splits 25 | self.fixed = fixed 26 | 27 | self.itermax = num_parts 28 | 29 | self.parts = self.partition(self.edge_index, self.N, self.mini_splits) 30 | 31 | self.itercnt = 0 32 | self.n_splits = self.get_node_indices() 33 | for i in self.n_splits: 34 | print(i) 35 | 36 | def partition(self, edge_index, num_nodes, mini_splits): 37 | chunk_manager = ChunkManager(output_dir=None) 38 | partition = chunk_manager.metis_partition(edge_index, num_nodes, mini_splits) 39 | partition = jt.Var(partition) 40 | partition = partition.sort() 41 | 42 | parts = [] 43 | part_begin = 0 44 | part_end = 1 45 | for i in range(self.mini_splits): 46 | part_end = part_begin + 1 47 | while part_end <= num_nodes - 1 and partition[0][part_end] == partition[0][part_end - 1]: 48 | part_end += 1 49 | if part_begin >= part_end: 50 | parts.append(jt.zeros(0, dtype='int')) 51 | elif part_end >= num_nodes: 52 | parts.append(partition[1][part_begin:]) 53 | else: 54 | parts.append(partition[1][part_begin: part_end]) 55 | part_begin = part_end 56 | return parts 57 | 58 | def get_node_indices(self): 59 | n_id = np.random.permutation(self.mini_splits) % self.num_parts 60 | n_ids = [jt.nonzero((n_id == i)).view(-1) for i in range(self.num_parts)] 61 | return n_ids 62 | 63 | def __reset__(self): 64 | self.itercnt = 0 65 | if self.fixed == False: 66 | self.n_ids = self.get_node_indices() 67 | 68 | def __iter__(self): 69 | return self 70 | 71 | def __next__(self): 72 | if self.itercnt < self.itermax: 73 | 74 | node_id = jt.zeros(0, dtype='int') 75 | for i in self.n_splits[self.itercnt]: 76 | node_id = jt.concat([node_id, self.parts[i]]) 77 | 78 | node_map, edge_id = induced_graph(self.edge_index, node_id, self.N) 79 | 80 | # node_mask = jt.zeros(self.N, dtype='bool') 81 | # node_mask[node_id] = True 82 | # edge_mask = node_mask[self.edge_index[0]] & node_mask[self.edge_index[1]] 83 | # edge_id = jt.nonzero(edge_mask).view(-1) 84 | 85 | data = self.data.__class__() 86 | 87 | # node_map = jt.zeros(self.N, dtype='int32') 88 | # node_map[node_id] = jt.arange(0, node_id.size(0)) 89 | data.edge_index = node_map[self.edge_index[:, edge_id]] 90 | 91 | for key, item in self.data: 92 | if key in ['num_nodes']: 93 | data[key] = node_id.size(0) 94 | elif isinstance(item, jt.Var) and item.size(0) == self.N: 95 | data[key] = item[node_id] 96 | elif isinstance(item, jt.Var) and item.size(0) == self.E: 97 | data[key] = item[edge_id] 98 | else: 99 | data[key] = item 100 | 101 | self.itercnt += 1 102 | return data 103 | 104 | else: 105 | self.__reset__() 106 | raise StopIteration -------------------------------------------------------------------------------- /jittor_geometric/dataloader/random_node_loader.py: -------------------------------------------------------------------------------- 1 | import jittor as jt 2 | from jittor import sparse 3 | import copy 4 | from .general_loader import GeneralLoader 5 | from ..utils import induced_graph 6 | 7 | class RandomNodeLoader(GeneralLoader): 8 | 9 | r''' 10 | The graph dataset loader, randomly split all of the nodes into 'num_parts' mini-batches. 11 | The dataset loader yields the induced graph of the selected nodes iteratively. 12 | 13 | Args: 14 | dataset (InMemoryDataset): The original graph dataset. 15 | num_parts (int): Number of expected mini-batches. 16 | fixed (bool, optional): If set to 'True', the dataset loader will yield identical mini-batches every round. 17 | ''' 18 | 19 | 20 | def __init__(self, dataset, num_parts: int, fixed: bool = True): 21 | self.data = copy.copy(dataset[0]) 22 | self.N = self.data.num_nodes 23 | self.E = self.data.num_edges 24 | self.edge_index = copy.copy(self.data.edge_index) 25 | self.data.edge_index = None 26 | 27 | self.num_parts = num_parts 28 | self.fixed = fixed 29 | 30 | self.itermax = num_parts 31 | 32 | self.itercnt = 0 33 | self.n_ids = self.get_node_indices() 34 | 35 | def get_node_indices(self): 36 | n_id = jt.randint(0, self.num_parts, (self.N, ), dtype="int32") 37 | n_ids = [jt.nonzero((n_id == i)).view(-1) for i in range(self.num_parts)] 38 | return n_ids 39 | 40 | def __reset__(self): 41 | self.itercnt = 0 42 | if self.fixed == False: 43 | self.n_ids = self.get_node_indices() 44 | 45 | def __iter__(self): 46 | return self 47 | 48 | def __next__(self): 49 | if self.itercnt < self.itermax: 50 | 51 | node_id = None 52 | 53 | while True: 54 | node_id = self.n_ids[self.itercnt] 55 | if node_id.size(0) != 0: 56 | break 57 | else: 58 | self.itercnt += 1 59 | if self.itercnt >= self.itermax: 60 | self.__reset__() 61 | raise StopIteration 62 | 63 | node_map, edge_id = induced_graph(self.edge_index, node_id, self.N) 64 | 65 | # node_mask = jt.zeros(self.N, dtype='bool') 66 | # node_mask[node_id] = True 67 | # edge_mask = node_mask[self.edge_index[0]] & node_mask[self.edge_index[1]] 68 | # edge_id = jt.nonzero(edge_mask).view(-1) 69 | 70 | data = self.data.__class__() 71 | 72 | # node_map = jt.zeros(self.N, dtype='int32') 73 | # node_map[node_id] = jt.arange(0, node_id.size(0)) 74 | data.edge_index = node_map[self.edge_index[:, edge_id]] 75 | 76 | for key, item in self.data: 77 | if key in ['num_nodes']: 78 | data[key] = node_id.size(0) 79 | elif isinstance(item, jt.Var) and item.size(0) == self.N: 80 | data[key] = item[node_id] 81 | elif isinstance(item, jt.Var) and item.size(0) == self.E: 82 | data[key] = item[edge_id] 83 | else: 84 | data[key] = item 85 | 86 | self.itercnt += 1 87 | return data 88 | 89 | else: 90 | self.__reset__() 91 | raise StopIteration -------------------------------------------------------------------------------- /jittor_geometric/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .planetoid import Planetoid 2 | from .amazon import Amazon 3 | from .wikipedia_network import WikipediaNetwork 4 | from .geomgcn import GeomGCN 5 | from .ogb import OGBNodePropPredDataset 6 | from .jodie import JODIEDataset, TemporalDataLoader 7 | from .linkx import LINKXDataset 8 | from .hetero import HeteroDataset 9 | from .reddit import Reddit 10 | from .qm9 import QM9 11 | from .molecule_net import MoleculeNet 12 | 13 | __all__ = [ 14 | 'Planetoid', 15 | 'Amazon', 16 | 'WikipediaNetwork', 17 | 'GeomGCN', 18 | 'LINKXDataset', 19 | 'OGBNodePropPredDataset', 20 | 'HeteroDataset', 21 | 'JODIEDataset', 22 | 'Reddit', 23 | 'TemporalDataLoader', 24 | 'QM9', 25 | 'MoleculeNet', 26 | ] 27 | 28 | classes = __all__ 29 | -------------------------------------------------------------------------------- /jittor_geometric/datasets/amazon.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | from typing import Callable, Optional 3 | 4 | import jittor as jt 5 | from jittor_geometric.io import read_npz 6 | from jittor_geometric.data import InMemoryDataset, download_url 7 | 8 | 9 | class Amazon(InMemoryDataset): 10 | r"""The Amazon Computers and Amazon Photo datasets from the paper 11 | "Pitfalls of Graph Neural Network Evaluation" 12 | `_. 13 | 14 | This class represents the Amazon dataset used in the paper "Pitfalls of Graph Neural Network Evaluation". In this dataset, nodes represent products, and edges indicate that two products are frequently bought together. The dataset provides product reviews represented as bag-of-words node features, and the task is to classify products into their respective categories. 15 | 16 | Dataset Details: 17 | 18 | - **Amazon Computers**: This dataset contains products related to computers, where the task is to classify the products based on the reviews and co-purchase information. 19 | - **Amazon Photo**: This dataset contains products related to photography, with a similar task of classifying products based on reviews and co-purchase data. 20 | 21 | Args: 22 | root (str): Root directory where the dataset should be saved. 23 | name (str): The name of the dataset, either :obj:`"Computers"` or :obj:`"Photo"`. 24 | transform (callable, optional): A function/transform that takes in a :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed on each access. (default: :obj:`None`) 25 | pre_transform (callable, optional): A function/transform that takes in a :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before being saved to disk. (default: :obj:`None`) 26 | 27 | Example: 28 | >>> dataset = Amazon(root='/path/to/dataset', name='Computers') 29 | >>> dataset.data 30 | >>> dataset[0] # Accessing the first data point 31 | """ 32 | 33 | url = 'https://github.com/shchur/gnn-benchmark/raw/master/data/npz/' 34 | 35 | def __init__( 36 | self, 37 | root: str, 38 | name: str, 39 | transform: Optional[Callable] = None, 40 | pre_transform: Optional[Callable] = None, 41 | ) -> None: 42 | self.name = name.lower() 43 | assert self.name in ['computers', 'photo'] 44 | super(Amazon, self).__init__(root, transform, pre_transform) 45 | self.data, self.slices = jt.load(self.processed_paths[0]) 46 | 47 | @property 48 | def raw_dir(self) -> str: 49 | return osp.join(self.root, self.name, 'raw') 50 | 51 | @property 52 | def processed_dir(self) -> str: 53 | return osp.join(self.root, self.name, 'processed') 54 | 55 | @property 56 | def raw_file_names(self) -> str: 57 | return f'amazon_electronics_{self.name}.npz' 58 | 59 | @property 60 | def processed_file_names(self) -> str: 61 | return 'data.pkl' 62 | 63 | def download(self) -> None: 64 | download_url(self.url + self.raw_file_names, self.raw_dir) 65 | 66 | def process(self) -> None: 67 | data = read_npz(self.raw_paths[0], to_undirected=True) 68 | data = data if self.pre_transform is None else self.pre_transform(data) 69 | jt.save(self.collate([data]), self.processed_paths[0]) 70 | 71 | def __repr__(self) -> str: 72 | return '{}()'.format(self.name) -------------------------------------------------------------------------------- /jittor_geometric/datasets/master.csv: -------------------------------------------------------------------------------- 1 | ,ogbn-proteins,ogbn-products,ogbn-arxiv,ogbn-mag,ogbn-papers100M 2 | num tasks,112,1,1,1,1 3 | num classes,2,47,40,349,172 4 | eval metric,rocauc,acc,acc,acc,acc 5 | task type,binary classification,multiclass classification,multiclass classification,multiclass classification,multiclass classification 6 | download_name,proteins,products,arxiv,mag,papers100M-bin 7 | version,1,1,1,2,1 8 | url,http://snap.stanford.edu/ogb/data/nodeproppred/proteins.zip,http://snap.stanford.edu/ogb/data/nodeproppred/products.zip,http://snap.stanford.edu/ogb/data/nodeproppred/arxiv.zip,http://snap.stanford.edu/ogb/data/nodeproppred/mag.zip,http://snap.stanford.edu/ogb/data/nodeproppred/papers100M-bin.zip 9 | add_inverse_edge,True,True,False,False,False 10 | has_node_attr,False,True,True,True,True 11 | has_edge_attr,True,False,False,False,False 12 | split,species,sales_ranking,time,time,time 13 | additional node files,node_species,None,node_year,node_year,node_year 14 | additional edge files,None,None,None,edge_reltype,None 15 | is hetero,False,False,False,True,False 16 | binary,False,False,False,False,True 17 | -------------------------------------------------------------------------------- /jittor_geometric/datasets/reddit.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | from typing import Callable, Optional,List 4 | from jittor_geometric.data import ( 5 | Data, 6 | InMemoryDataset, 7 | download_url, 8 | extract_zip, 9 | ) 10 | import pandas as pd 11 | from jittor_geometric.data import InMemoryDataset, download_url 12 | import jittor as jt 13 | import numpy as np 14 | from jittor_geometric.utils import coalesce 15 | 16 | 17 | class Reddit(InMemoryDataset): 18 | r"""The Reddit dataset from the `"Inductive Representation Learning on 19 | Large Graphs" `_ paper, containing 20 | Reddit posts belonging to different communities. 21 | 22 | This dataset is designed for large-scale graph representation learning. Nodes in the graph represent Reddit posts, and edges represent interactions (e.g., comments) between posts in the same community. The task is to classify posts into one of the 41 communities based on their content and connectivity. 23 | 24 | **Dataset Statistics:** 25 | 26 | - **Number of Nodes**: 232,965 27 | - **Number of Edges**: 114,615,892 28 | - **Number of Features**: 602 29 | - **Number of Classes**: 41 30 | 31 | The dataset is pre-split into training, validation, and test sets using node type masks. 32 | 33 | Args: 34 | root (str): Root directory where the dataset should be saved. 35 | transform (callable, optional): A function/transform that takes in a 36 | :obj:`torch_geometric.data.Data` object and returns a transformed 37 | version. The data object will be transformed before every access. 38 | (default: :obj:`None`) 39 | pre_transform (callable, optional): A function/transform that takes in 40 | an :obj:`torch_geometric.data.Data` object and returns a 41 | transformed version. The data object will be transformed before 42 | being saved to disk. (default: :obj:`None`) 43 | force_reload (bool, optional): Whether to re-process the dataset. 44 | (default: :obj:`False`) 45 | 46 | Example: 47 | >>> dataset = Reddit(root='/path/to/reddit') 48 | >>> data = dataset[0] # Access the first graph object 49 | """ 50 | 51 | url = 'https://data.dgl.ai/dataset/reddit.zip' 52 | 53 | def __init__( 54 | self, 55 | root: str, 56 | transform: Optional[Callable] = None, 57 | pre_transform: Optional[Callable] = None 58 | ) -> None: 59 | super().__init__(root, transform, pre_transform) 60 | self.data,self.slices= jt.load(self.processed_paths[0]) 61 | 62 | @property 63 | def raw_file_names(self) -> List[str]: 64 | return ['reddit_data.npz', 'reddit_graph.npz'] 65 | 66 | @property 67 | def processed_file_names(self) -> str: 68 | # return 'data.pt' 69 | return osp.join('geometric_data_processed.pkl') 70 | 71 | def download(self) -> None: 72 | path = download_url(self.url, self.raw_dir) 73 | extract_zip(path, self.raw_dir) 74 | os.unlink(path) 75 | 76 | def process(self) -> None: 77 | import scipy.sparse as sp 78 | 79 | data = np.load(osp.join(self.raw_dir, 'reddit_data.npz')) 80 | x = jt.array(data['feature']).to(jt.float32) 81 | y =jt.array(data['label']).to(jt.int32) 82 | split = jt.array(data['node_types']) 83 | 84 | adj = sp.load_npz(osp.join(self.raw_dir, 'reddit_graph.npz')) 85 | row = jt.array(adj.row).to(jt.int32) 86 | col = jt.array(adj.col).to(jt.int32) 87 | row = jt.unsqueeze(row, dim=1) 88 | col = jt.unsqueeze(col, dim=1) 89 | 90 | arr=[] 91 | arr.append(row) 92 | arr.append(col) 93 | arr2 = jt.concat(arr, dim=1).transpose() 94 | edge_index,_ = coalesce(arr2, num_nodes=x.size(0)) 95 | data = Data(x=x, edge_index=edge_index, y=y) 96 | data.train_mask = split == 1 97 | data.val_mask = split == 2 98 | data.test_mask = split == 3 99 | 100 | data = data if self.pre_transform is None else self.pre_transform(data) 101 | 102 | jt.save(self.collate([data]), self.processed_paths[0]) 103 | -------------------------------------------------------------------------------- /jittor_geometric/datasets/wikipedia_network.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import numpy as np 3 | import jittor as jt 4 | from jittor_geometric.data import Data, InMemoryDataset, download_url 5 | from typing import Callable, Optional 6 | 7 | 8 | class WikipediaNetwork(InMemoryDataset): 9 | r"""Heterophilic dataset from the paper 'A critical look at the evaluation of GNNs under 10 | heterophily: Are we really making progress?' 11 | . 12 | 13 | This class represents a collection of heterophilic graph datasets used to evaluate the performance of Graph Neural Networks (GNNs) in heterophilic settings. These datasets consist of graphs where nodes are connected based on certain relationships, and the task is to classify the nodes based on their features or labels. The datasets in this collection come from different domains, and each dataset has a unique structure and task. 14 | 15 | Dataset Details: 16 | 17 | - **Chameleon** 18 | - **Squirrel** 19 | - **Chameleon-Filtered** 20 | - **Squirrel-Filtered** 21 | 22 | Args: 23 | root (str): Root directory where the dataset should be saved. 24 | name (str): The name of the dataset to load. Options include: 25 | - `"chameleon"` 26 | - `"squirrel"` 27 | - `"chameleon_filtered"` 28 | - `"squirrel_filtered"` 29 | transform (callable, optional): A function/transform that takes in a :obj:`Data` object 30 | and returns a transformed version. The data object will be transformed on every access. 31 | (default: :obj:`None`) 32 | pre_transform (callable, optional): A function/transform that takes in a :obj:`Data` object 33 | and returns a transformed version. The data object will be transformed before being saved to disk. 34 | (default: :obj:`None`) 35 | 36 | Example: 37 | >>> dataset = Wikipedia(root='/path/to/dataset', name='chameleon') 38 | >>> dataset.data 39 | >>> dataset[0] # Accessing the first data point 40 | """ 41 | 42 | url = ('https://github.com/yandex-research/heterophilous-graphs/raw/' 43 | 'main/data') 44 | 45 | def __init__(self, root: str, name: str, 46 | transform: Optional[Callable] = None, 47 | pre_transform: Optional[Callable] = None): 48 | self.root = root 49 | self.name = name.lower() 50 | super().__init__(root, transform, pre_transform) 51 | self.data, self.slices = jt.load(self.processed_paths[0]) # Jittor's loading method 52 | 53 | @property 54 | def raw_dir(self) -> str: 55 | return osp.join(self.root, self.name, 'raw') 56 | 57 | @property 58 | def processed_dir(self) -> str: 59 | return osp.join(self.root, self.name, 'processed') 60 | 61 | @property 62 | def raw_file_names(self) -> str: 63 | return f'{self.name}.npz' 64 | 65 | @property 66 | def processed_file_names(self) -> str: 67 | return 'data.pkl' 68 | 69 | def download(self): 70 | download_url(f'{self.url}/{self.raw_file_names}', self.raw_dir) 71 | 72 | def process(self, undirected=True): 73 | data = np.load(self.raw_paths[0]) 74 | 75 | x = jt.array(data['node_features']) 76 | y = jt.array(data['node_labels']) 77 | edge_index = jt.array(data['edges']).transpose() 78 | 79 | if undirected: 80 | reverse_edges = edge_index.flip(0) 81 | edge_index = jt.contrib.concat([edge_index, reverse_edges], dim=1) 82 | edge_index = jt.unique(edge_index, dim=1) 83 | 84 | train_mask = jt.array(data['train_masks']).bool() 85 | val_mask = jt.array(data['val_masks']).bool() 86 | test_mask = jt.array(data['test_masks']).bool() 87 | 88 | data = Data(x=x, edge_index=edge_index, y=y, train_mask=train_mask, 89 | val_mask=val_mask, test_mask=test_mask) 90 | 91 | if self.pre_transform is not None: 92 | data = self.pre_transform(data) 93 | 94 | jt.save(self.collate([data]), self.processed_paths[0]) -------------------------------------------------------------------------------- /jittor_geometric/evaluate/__init__.py: -------------------------------------------------------------------------------- 1 | from .evaluators import MRR_Evaluator 2 | 3 | __all__ = [ 4 | 'MRR_Evaluator', 5 | ] 6 | 7 | classes = __all__ 8 | -------------------------------------------------------------------------------- /jittor_geometric/evaluate/evaluators.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import jittor as jt 3 | 4 | class MRR_Evaluator(): 5 | def __init__(self) -> None: 6 | pass 7 | 8 | def eval(self, y_pred_pos, y_pred_neg): 9 | if jt is not None and isinstance(y_pred_pos, jt.Var): 10 | y_pred_pos = y_pred_pos.detach().cpu().numpy() 11 | if jt is not None and isinstance(y_pred_neg, jt.Var): 12 | y_pred_neg = y_pred_neg.detach().cpu().numpy() 13 | if not isinstance(y_pred_pos, np.ndarray) or not isinstance(y_pred_neg, np.ndarray): 14 | raise RuntimeError( 15 | "Arguments to Evaluator need to be either numpy ndarray or jittor Var!" 16 | ) 17 | batch_size = y_pred_pos.shape[0] 18 | y_pred_pos = y_pred_pos.reshape(-1, 1) 19 | y_pred_neg = y_pred_neg.reshape(batch_size,-1) 20 | optimistic_rank = (y_pred_neg > y_pred_pos).sum(axis=1) 21 | pessimistic_rank = (y_pred_neg >= y_pred_pos).sum(axis=1) 22 | ranking_list = 0.5 * (optimistic_rank + pessimistic_rank) + 1 23 | mrr_list = 1./ranking_list.astype(np.float32) 24 | return mrr_list 25 | -------------------------------------------------------------------------------- /jittor_geometric/io/__init__.py: -------------------------------------------------------------------------------- 1 | from .txt_array import parse_txt_array, read_txt_array 2 | from .planetoid import read_planetoid_data 3 | from .npz import read_npz 4 | from .ogb import read_graph, read_heterograph 5 | from .ogb_raw import read_node_label_hetero, read_nodesplitidx_split_hetero 6 | 7 | __all__ = [ 8 | 'parse_txt_array', 9 | 'read_txt_array', 10 | 'read_planetoid_data', 11 | 'read_npz', 12 | 'read_graph', 13 | 'read_heterograph', 14 | 'read_node_label_hetero', 15 | 'read_nodesplitidx_split_hetero', 16 | ] 17 | -------------------------------------------------------------------------------- /jittor_geometric/io/npz.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict 2 | 3 | import numpy as np 4 | import scipy.sparse as sp 5 | 6 | from jittor import Var 7 | import jittor as jt 8 | from jittor_geometric.data import Data 9 | from jittor_geometric.utils import remove_self_loops 10 | from jittor_geometric.utils import to_undirected as to_undirected_fn 11 | 12 | 13 | def read_npz(path: str, to_undirected: bool = True) -> Data: 14 | with np.load(path, allow_pickle=True) as f: 15 | return parse_npz(f, to_undirected=to_undirected) 16 | 17 | 18 | def parse_npz(f: Dict[str, Any], to_undirected: bool = True) -> Data: 19 | x = sp.csr_matrix((f['attr_data'], f['attr_indices'], f['attr_indptr']), 20 | f['attr_shape']).todense() 21 | x = np.array(x) 22 | x = jt.array(x).float32() 23 | x[x > 0] = 1 24 | 25 | adj = sp.csr_matrix((f['adj_data'], f['adj_indices'], f['adj_indptr']), 26 | f['adj_shape']).tocoo() 27 | row = jt.array(adj.row).int32() 28 | col = jt.array(adj.col).int32() 29 | edge_index = jt.stack([row, col], dim=0) 30 | edge_index, _ = remove_self_loops(edge_index) 31 | 32 | if to_undirected: 33 | edge_index, _ = to_undirected_fn(edge_index, num_nodes=x.shape[0]) 34 | 35 | y = jt.array(f['labels']).int32() 36 | 37 | return Data(x=x, edge_index=edge_index, y=y) -------------------------------------------------------------------------------- /jittor_geometric/io/ogb.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | from jittor_geometric.data import Data 3 | import os.path as osp 4 | import numpy as np 5 | from jittor_geometric.io.ogb_raw import read_csv_graph_raw, read_csv_heterograph_raw, read_binary_graph_raw, read_binary_heterograph_raw 6 | from tqdm.auto import tqdm 7 | import jittor as jt 8 | 9 | 10 | def read_graph(raw_dir, add_inverse_edge=False, additional_node_files=[], additional_edge_files=[], binary=False): 11 | if binary: 12 | # npz 13 | graph_list = read_binary_graph_raw(raw_dir, add_inverse_edge) 14 | else: 15 | # csv 16 | graph_list = read_csv_graph_raw(raw_dir, add_inverse_edge, additional_node_files=additional_node_files, additional_edge_files=additional_edge_files) 17 | 18 | jittor_graph_list = [] 19 | 20 | print('Converting graphs into Jittor objects...') 21 | 22 | for graph in tqdm(graph_list): 23 | g = Data() 24 | g.num_nodes = graph['num_nodes'] 25 | g.edge_index = jt.array(graph['edge_index']).int() 26 | 27 | del graph['num_nodes'] 28 | del graph['edge_index'] 29 | 30 | if graph['edge_feat'] is not None: 31 | g.edge_attr = jt.array(graph['edge_feat']) 32 | del graph['edge_feat'] 33 | 34 | if graph['node_feat'] is not None: 35 | g.x = jt.array(graph['node_feat']) 36 | del graph['node_feat'] 37 | 38 | for key in additional_node_files: 39 | g[key] = jt.array(graph[key]) 40 | del graph[key] 41 | 42 | for key in additional_edge_files: 43 | g[key] = jt.array(graph[key]) 44 | del graph[key] 45 | 46 | jittor_graph_list.append(g) 47 | 48 | return jittor_graph_list 49 | 50 | def read_heterograph(raw_dir, add_inverse_edge=False, additional_node_files=[], additional_edge_files=[], binary=False): 51 | if binary: 52 | # npz 53 | graph_list = read_binary_heterograph_raw(raw_dir, add_inverse_edge) 54 | else: 55 | # csv 56 | graph_list = read_csv_heterograph_raw(raw_dir, add_inverse_edge, additional_node_files=additional_node_files, additional_edge_files=additional_edge_files) 57 | 58 | jittor_graph_list = [] 59 | 60 | print('Converting graphs into Jittor objects...') 61 | 62 | for graph in tqdm(graph_list): 63 | g = Data() 64 | 65 | g.__num_nodes__ = graph['num_nodes_dict'] 66 | g.num_nodes_dict = graph['num_nodes_dict'] 67 | 68 | # add edge connectivity 69 | g.edge_index_dict = {} 70 | for triplet, edge_index in graph['edge_index_dict'].items(): 71 | g.edge_index_dict[triplet] = jt.array(edge_index).int() 72 | 73 | del graph['edge_index_dict'] 74 | 75 | if graph['edge_feat_dict'] is not None: 76 | g.edge_attr_dict = {} 77 | for triplet in graph['edge_feat_dict'].keys(): 78 | g.edge_attr_dict[triplet] = jt.array(graph['edge_feat_dict'][triplet]) 79 | 80 | del graph['edge_feat_dict'] 81 | 82 | if graph['node_feat_dict'] is not None: 83 | g.x_dict = {} 84 | for nodetype in graph['node_feat_dict'].keys(): 85 | g.x_dict[nodetype] = jt.array(graph['node_feat_dict'][nodetype]) 86 | 87 | del graph['node_feat_dict'] 88 | 89 | for key in additional_node_files: 90 | g[key] = {} 91 | for nodetype in graph[key].keys(): 92 | g[key][nodetype] = jt.array(graph[key][nodetype]) 93 | 94 | del graph[key] 95 | 96 | for key in additional_edge_files: 97 | g[key] = {} 98 | for triplet in graph[key].keys(): 99 | g[key][triplet] = jt.array(graph[key][triplet]) 100 | 101 | del graph[key] 102 | 103 | jittor_graph_list.append(g) 104 | 105 | return jittor_graph_list 106 | 107 | if __name__ == '__main__': 108 | pass -------------------------------------------------------------------------------- /jittor_geometric/io/txt_array.py: -------------------------------------------------------------------------------- 1 | import jittor as jt 2 | 3 | 4 | def parse_txt_array(src, sep=None, start=0, end=None, dtype=None): 5 | src = [[float(x) for x in line.split(sep)[start:end]] for line in src] 6 | src = jt.array(src, dtype=dtype).squeeze(1) 7 | return src 8 | 9 | 10 | def read_txt_array(path, sep=None, start=0, end=None, dtype=None): 11 | with open(path, 'r') as f: 12 | src = f.read().split('\n')[:-1] 13 | return parse_txt_array(src, sep, start, end, dtype) 14 | -------------------------------------------------------------------------------- /jittor_geometric/nn/__init__.py: -------------------------------------------------------------------------------- 1 | from .conv import * # noqa 2 | from .models import * # noqa 3 | from .aggr import * 4 | from .dense import * 5 | from .pool import * 6 | 7 | __all__ = [ 8 | 'Sequential', 9 | 'MetaLayer', 10 | 'DataParallel', 11 | 'Reshape', 12 | ] 13 | -------------------------------------------------------------------------------- /jittor_geometric/nn/aggr/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import Aggregation 2 | from .multi import MultiAggregation 3 | from .basic import ( 4 | MeanAggregation, 5 | SumAggregation, 6 | MaxAggregation, 7 | MinAggregation 8 | ) 9 | __all__ = classes = [ 10 | 'Aggregation', 11 | 'MultiAggregation', 12 | 'SumAggregation', 13 | 'MeanAggregation', 14 | 'MaxAggregation', 15 | 'MinAggregation' 16 | ] -------------------------------------------------------------------------------- /jittor_geometric/nn/aggr/basic.py: -------------------------------------------------------------------------------- 1 | import jittor as jt 2 | from jittor import nn 3 | from jittor_geometric.nn.aggr import Aggregation 4 | from typing import Optional 5 | 6 | 7 | class SumAggregation(Aggregation): 8 | r"""An aggregation operator that sums up features across a set of elements. 9 | 10 | .. math:: 11 | \mathrm{sum}(\mathcal{X}) = \sum_{\mathbf{x}_i \in \mathcal{X}} 12 | \mathbf{x}_i. 13 | """ 14 | def execute(self, x: jt.Var, index: Optional[jt.Var] = None, 15 | ptr: Optional[jt.Var] = None, dim_size: Optional[int] = None, 16 | dim: int = -2) -> jt.Var: 17 | return self.reduce(x, index, ptr, dim_size, dim, reduce='sum') 18 | 19 | 20 | class MeanAggregation(Aggregation): 21 | r"""An aggregation operator that averages features across a set of elements. 22 | 23 | .. math:: 24 | \mathrm{mean}(\mathcal{X}) = \frac{1}{|\mathcal{X}|} 25 | \sum_{\mathbf{x}_i \in \mathcal{X}} \mathbf{x}_i. 26 | """ 27 | def execute(self, x: jt.Var, index: Optional[jt.Var] = None, 28 | ptr: Optional[jt.Var] = None, dim_size: Optional[int] = None, 29 | dim: int = -2) -> jt.Var: 30 | return self.reduce(x, index, ptr, dim_size, dim, reduce='mean') 31 | 32 | 33 | class MaxAggregation(Aggregation): 34 | r"""An aggregation operator that takes the feature-wise maximum across a set of elements. 35 | 36 | .. math:: 37 | \mathrm{max}(\mathcal{X}) = \max_{\mathbf{x}_i \in \mathcal{X}} 38 | \mathbf{x}_i. 39 | """ 40 | def execute(self, x: jt.Var, index: Optional[jt.Var] = None, 41 | ptr: Optional[jt.Var] = None, dim_size: Optional[int] = None, 42 | dim: int = -2) -> jt.Var: 43 | return self.reduce(x, index, ptr, dim_size, dim, reduce='max') 44 | 45 | 46 | class MinAggregation(Aggregation): 47 | r"""An aggregation operator that takes the feature-wise minimum across a set of elements. 48 | 49 | .. math:: 50 | \mathrm{min}(\mathcal{X}) = \min_{\mathbf{x}_i \in \mathcal{X}} 51 | \mathbf{x}_i. 52 | """ 53 | def execute(self, x: jt.Var, index: Optional[jt.Var] = None, 54 | ptr: Optional[jt.Var] = None, dim_size: Optional[int] = None, 55 | dim: int = -2) -> jt.Var: 56 | return self.reduce(x, index, ptr, dim_size, dim, reduce='min') 57 | 58 | -------------------------------------------------------------------------------- /jittor_geometric/nn/conv/__init__.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Description: 3 | Author: lusz 4 | Date: 2025-01-10 13:52:59 5 | ''' 6 | from .message_passing import MessagePassing 7 | from .gcn_conv import GCNConv 8 | from .cheb_conv import ChebConv 9 | from .sg_conv import SGConv 10 | from .gcn2_conv import GCN2Conv 11 | from .message_passiong_nts import MessagePassingNts 12 | from .gat_conv import GATConv 13 | from .egnn_conv import EGNNConv 14 | from .appnp_conv import APPNP 15 | from .gpr_conv import GPRGNN 16 | from .even_conv import EvenNet 17 | from .bernnet_conv import BernNet 18 | from .chebnet2_conv import ChebNetII 19 | from .transformer_conv import TransformerConv 20 | from .optbasis_conv import OptBasisConv 21 | from .clustergcn_conv import ClusterGCNConv 22 | from .sage_conv import SAGEConv 23 | 24 | __all__ = [ 25 | 'MessagePassing', 26 | 'GCNConv', 27 | 'ChebConv', 28 | 'SGConv', 29 | 'GCN2Conv', 30 | 'MessagePassingNts', 31 | 'GATConv', 32 | 'EGNNConv', 33 | 'APPNP', 34 | 'GPRGNN', 35 | 'EvenNet', 36 | 'BernNet', 37 | 'ChebNetII', 38 | 'TransformerConv', 39 | 'OptBasisConv', 40 | 'ClusterGCNConv', 41 | 'SAGEConv' 42 | ] 43 | 44 | classes = __all__ 45 | -------------------------------------------------------------------------------- /jittor_geometric/nn/conv/appnp_conv.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Description: 3 | Author: ivam 4 | Date: 2024-12-13 5 | ''' 6 | from typing import Optional, Tuple 7 | from jittor_geometric.typing import Adj, OptVar 8 | import jittor as jt 9 | from jittor import Var,nn,Module 10 | from jittor_geometric.utils import add_remaining_self_loops 11 | from jittor_geometric.utils.num_nodes import maybe_num_nodes 12 | 13 | from ..inits import glorot, zeros 14 | from jittor_geometric.data import CSC, CSR 15 | from jittor_geometric.ops import SpmmCsr, aggregateWithWeight 16 | 17 | class APPNP(Module): 18 | r"""The graph propagation operator from the `"Predict then Propagate: 19 | Graph Neural Networks meet Personalized PageRank" 20 | `_ paper 21 | """ 22 | #_cached_edge_index: Optional[Tuple[Var, Var]] 23 | #_cached_csc: Optional[CSC] 24 | def __init__(self, K: int, alpha: float, spmm:bool=True, **kwargs): 25 | kwargs.setdefault('aggr', 'add') 26 | super(APPNP, self).__init__(**kwargs) 27 | self.K = K 28 | self.alpha = alpha 29 | #self._cached_edge_index = None 30 | #self._cached_adj_t = None 31 | 32 | self.spmm = spmm 33 | self.reset_parameters() 34 | 35 | def reset_parameters(self): 36 | pass 37 | #glorot(self.weight) 38 | #zeros(self.bias) 39 | #self._cached_adj_t = None 40 | #self._cached_csc=None 41 | 42 | def execute(self, x: Var, csc: OptVar, csr: OptVar) -> Var: 43 | h = x 44 | for k in range(self.K): 45 | if self.spmm and jt.flags.use_cuda==1: 46 | x = self.propagate_spmm(x=x, csr=csr) 47 | else: 48 | x = self.propagate_msg(x=x, csc=csc, csr=csr) 49 | x = x * (1 - self.alpha) 50 | x = x + self.alpha * h 51 | 52 | return x 53 | 54 | # propagate by message passing 55 | def propagate_msg(self,x, csc: CSC, csr:CSR): 56 | out = aggregateWithWeight(x,csc,csr) 57 | return out 58 | 59 | # propagate by spmm 60 | def propagate_spmm(self, x, csr:CSR): 61 | out = SpmmCsr(x,csr) 62 | return out 63 | 64 | def __repr__(self): 65 | return '{}({}, {})'.format(self.__class__.__name__, self.in_channels, 66 | self.out_channels) 67 | -------------------------------------------------------------------------------- /jittor_geometric/nn/conv/bernnet_conv.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Description: 3 | Author: ivam 4 | Date: 2024-12-13 5 | ''' 6 | from typing import Optional, Tuple 7 | from jittor_geometric.typing import Adj, OptVar 8 | import jittor as jt 9 | import numpy as np 10 | from jittor import Var,nn,Module 11 | from jittor_geometric.utils import add_remaining_self_loops 12 | from jittor_geometric.utils.num_nodes import maybe_num_nodes 13 | from scipy.special import comb 14 | from ..inits import glorot, zeros, ones 15 | from jittor_geometric.data import CSC, CSR 16 | from jittor_geometric.ops import SpmmCsr, aggregateWithWeight 17 | 18 | class BernNet(Module): 19 | r"""The graph propagation operator from the `"BernNet: Learning Arbitrary 20 | Graph Spectral Filters via Bernstein Approximation" 21 | `_ paper 22 | 23 | Mathematical Formulation: 24 | .. math:: 25 | \mathbf{Z} = \sum_{k=0}^{K} \alpha_k \mathrm{Bern}_{k}(\tilde{L}) \mathbf{X}. 26 | where: 27 | :math:`\mathbf{X}` is the input node feature matrix. 28 | :math:`\mathbf{Z}` is the output node feature matrix. 29 | :math:`\mathrm{Bern}_{k}` is the Bernstein polynomial of order :math:`k`. 30 | :math:`\tilde{\mathbf{L}}` is the normalized Laplacian matrix of the graph, translated to the interval :math:`[-1,1]`. 31 | :math:`\alpha_k` is the parameter for the :math:`k`-th order Bernstein polynomial. 32 | 33 | Args: 34 | K (int): Order of polynomial, or maximum number of hops considered for message passing. 35 | spmm (bool, optional): If set to `True`, uses sparse matrix multiplication (SPMM) for propagation. Default is `True`. 36 | **kwargs (optional): Additional arguments for the `MessagePassing` class. 37 | """ 38 | 39 | def __init__(self, K: int, spmm:bool=True, **kwargs): 40 | kwargs.setdefault('aggr', 'add') 41 | super(BernNet, self).__init__(**kwargs) 42 | self.K = K 43 | self.spmm = spmm 44 | self.temp= jt.random((self.K + 1,)) 45 | 46 | self.reset_parameters() 47 | 48 | def reset_parameters(self): 49 | ones(self.temp) 50 | 51 | def execute(self, x: Var, csc1: OptVar, csr1: OptVar, csc2: OptVar, csr2: OptVar) -> Var: 52 | TEMP=nn.relu(self.temp) 53 | 54 | tmp=[] 55 | tmp.append(x) 56 | for i in range(self.K): 57 | if self.spmm and jt.flags.use_cuda==1: 58 | x = self.propagate_spmm(x=x, csr=csr2) 59 | else: 60 | x = self.propagate_msg(x=x, csc=csc2, csr=csr2) 61 | tmp.append(x) 62 | 63 | out=(comb(self.K,0)/(2**self.K))*TEMP[0]*tmp[self.K] 64 | 65 | for i in range(self.K): 66 | x=tmp[self.K-i-1] 67 | if self.spmm and jt.flags.use_cuda==1: 68 | x = self.propagate_spmm(x=x, csr=csr1) 69 | else: 70 | x = self.propagate_msg(x=x, csc=csc1, csr=csr1) 71 | for j in range(i): 72 | if self.spmm and jt.flags.use_cuda==1: 73 | x = self.propagate_spmm(x=x, csr=csr1) 74 | else: 75 | x = self.propagate_msg(x=x, csc=csc1, csr=csr1) 76 | out=out+(comb(self.K,i+1)/(2**self.K))*TEMP[i+1]*x 77 | return out 78 | 79 | 80 | # propagate by message passing 81 | def propagate_msg(self,x, csc: CSC, csr:CSR): 82 | out = aggregateWithWeight(x,csc,csr) 83 | return out 84 | 85 | # propagate by spmm 86 | def propagate_spmm(self, x, csr:CSR): 87 | out = SpmmCsr(x,csr) 88 | return out 89 | 90 | def __repr__(self): 91 | return '{}({}, {})'.format(self.__class__.__name__, self.in_channels, 92 | self.out_channels) 93 | -------------------------------------------------------------------------------- /jittor_geometric/nn/conv/chebnet2_conv.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Description: 3 | Author: ivam 4 | Date: 2024-12-13 5 | ''' 6 | from typing import Optional, Tuple 7 | from jittor_geometric.typing import Adj, OptVar 8 | import jittor as jt 9 | import math 10 | import numpy as np 11 | from jittor import Var,nn,Module 12 | from jittor_geometric.utils import add_remaining_self_loops 13 | from jittor_geometric.utils.num_nodes import maybe_num_nodes 14 | from scipy.special import comb 15 | from ..inits import glorot, zeros, ones 16 | from jittor_geometric.data import CSC, CSR 17 | from jittor_geometric.ops import SpmmCsr, aggregateWithWeight 18 | 19 | def cheby(i,x): 20 | if i==0: 21 | return 1 22 | elif i==1: 23 | return x 24 | else: 25 | T0=1 26 | T1=x 27 | for ii in range(2,i+1): 28 | T2=2*x*T1-T0 29 | T0,T1=T1,T2 30 | return T2 31 | 32 | class ChebNetII(Module): 33 | r"""The graph propagation operator from the `"Convolutional Neural Networks 34 | on Graphs with Chebyshev Approximation, Revisited" 35 | `_ paper 36 | 37 | 38 | Mathematical Formulation: 39 | .. math:: 40 | \mathbf{Z} = \sum_{k=0}^{K} \alpha_k \mathrm{cheb}_{k}(\tilde{\mathbf{L}}) \mathbf{X}. 41 | where: 42 | :math:`\mathbf{X}` is the input node feature matrix. 43 | :math:`\mathbf{Z}` is the output node feature matrix. 44 | :math:`\mathrm{cheb}_{k}` is the Chebyshev polynomial of order :math:`k`. 45 | :math:`\alpha_k` is the parameter for the :math:`k`-th order Chebyshev polynomial, they are further derived via learnable values on the Chebyshev nodes. 46 | :math:`\tilde{L}` is the normalized Laplacian matrix of the graph, translated to the interval :math:`[-1,1]`. 47 | 48 | Args: 49 | K (int): Order of polynomial, or maximum number of hops considered for message passing. 50 | spmm (bool, optional): If set to `True`, uses sparse matrix multiplication (SPMM) for propagation. Default is `True`. 51 | **kwargs (optional): Additional arguments for the `MessagePassing` class. 52 | """ 53 | 54 | def __init__(self, K: int, spmm:bool=True, **kwargs): 55 | kwargs.setdefault('aggr', 'add') 56 | super(ChebNetII, self).__init__(**kwargs) 57 | self.K = K 58 | self.spmm = spmm 59 | self.temp= jt.random((self.K + 1,)) 60 | self.reset_parameters() 61 | 62 | def reset_parameters(self): 63 | ones(self.temp) 64 | 65 | def execute(self, x: Var, csc: OptVar, csr: OptVar) -> Var: 66 | coe_tmp = nn.relu(self.temp) 67 | coe = coe_tmp.clone() 68 | 69 | for i in range(self.K+1): 70 | coe[i] = coe_tmp[0]*cheby(i,math.cos((self.K+0.5)*math.pi/(self.K+1))) 71 | for j in range(1,self.K+1): 72 | x_j = math.cos((self.K-j+0.5)*math.pi/(self.K+1)) 73 | coe[i] = coe[i]+coe_tmp[j]*cheby(i,x_j) 74 | coe[i] = 2*coe[i]/(self.K+1) 75 | 76 | Tx_0=x 77 | if self.spmm and jt.flags.use_cuda==1: 78 | Tx_1 = self.propagate_spmm(x=x, csr=csr) 79 | else: 80 | Tx_1 = self.propagate_msg(x=x, csc=csc, csr=csr) 81 | out=coe[0]/2*Tx_0+coe[1]*Tx_1 82 | for i in range(2,self.K+1): 83 | if self.spmm and jt.flags.use_cuda==1: 84 | Tx_2 = self.propagate_spmm(x=Tx_1, csr=csr) 85 | else: 86 | Tx_2 = self.propagate_msg(x=Tx_1, csc=csc, csr=csr) 87 | 88 | Tx_2 = 2*Tx_2-Tx_0 89 | out = out+coe[i]*Tx_2 90 | Tx_0,Tx_1 = Tx_1, Tx_2 91 | 92 | return out 93 | 94 | 95 | # propagate by message passing 96 | def propagate_msg(self,x, csc: CSC, csr:CSR): 97 | out = aggregateWithWeight(x,csc,csr) 98 | return out 99 | 100 | # propagate by spmm 101 | def propagate_spmm(self, x, csr:CSR): 102 | out = SpmmCsr(x,csr) 103 | return out 104 | 105 | def __repr__(self): 106 | return '{}({}, {})'.format(self.__class__.__name__, self.in_channels, 107 | self.out_channels) 108 | -------------------------------------------------------------------------------- /jittor_geometric/nn/conv/even_conv.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Description: 3 | Author: ivam 4 | Date: 2024-12-13 5 | ''' 6 | from typing import Optional, Tuple 7 | from jittor_geometric.typing import Adj, OptVar 8 | import jittor as jt 9 | import numpy as np 10 | from jittor import Var,nn,Module 11 | from jittor_geometric.utils import add_remaining_self_loops 12 | from jittor_geometric.utils.num_nodes import maybe_num_nodes 13 | 14 | from ..inits import glorot, zeros 15 | from jittor_geometric.data import CSC, CSR 16 | from jittor_geometric.ops import SpmmCsr, aggregateWithWeight 17 | 18 | class EvenNet(Module): 19 | r"""EvenNet: Ignoring Odd-Hop Neighbors Improves 20 | Robustness of Graph Neural Networks 21 | `_ paper. 22 | 23 | This class implements the EvenNet architecture, which improves the robustness of graph neural networks by focusing on even-hop neighbors while ignoring odd-hop neighbors. 24 | 25 | Args: 26 | K (int): Maximum number of hops considered for message passing. 27 | alpha (float): Parameter controlling the weighting of different hops. 28 | spmm (bool, optional): If set to `True`, uses sparse matrix multiplication (SPMM) for propagation. Default is `True`. 29 | **kwargs (optional): Additional arguments for the base `Module`. 30 | """ 31 | 32 | #_cached_edge_index: Optional[Tuple[Var, Var]] 33 | #_cached_csc: Optional[CSC] 34 | def __init__(self, K: int, alpha: float, spmm:bool=True, **kwargs): 35 | kwargs.setdefault('aggr', 'add') 36 | super(EvenNet, self).__init__(**kwargs) 37 | self.K = K 38 | self.Init = Init 39 | self.alpha = alpha 40 | 41 | TEMP = alpha*(1-alpha)**np.arange(K+1) 42 | TEMP[-1] = (1-alpha)**K 43 | 44 | TEMP_jt = jt.array(TEMP) 45 | self.temp = nn.Parameter(jt.Var(TEMP_jt)) 46 | 47 | self.spmm = spmm 48 | self.reset_parameters() 49 | 50 | def reset_parameters(self): 51 | self.temp = self.alpha*(1-self.alpha)**np.arange(self.K+1) 52 | self.temp[-1] = (1-self.alpha)**self.K 53 | 54 | 55 | def execute(self, x: Var, csc: OptVar, csr: OptVar) -> Var: 56 | out = x * (self.temp[0]) 57 | for k in range(self.K): 58 | if self.spmm and jt.flags.use_cuda==1: 59 | x = self.propagate_spmm(x=x, csr=csr) 60 | else: 61 | x = self.propagate_msg(x=x, csc=csc, csr=csr) 62 | if k // 2 == 1: 63 | out = out + self.temp[k+1] * x 64 | return out 65 | 66 | # propagate by message passing 67 | def propagate_msg(self,x, csc: CSC, csr:CSR): 68 | out = aggregateWithWeight(x,csc,csr) 69 | return out 70 | 71 | # propagate by spmm 72 | def propagate_spmm(self, x, csr:CSR): 73 | out = SpmmCsr(x,csr) 74 | return out 75 | 76 | def __repr__(self): 77 | return '{}({}, {})'.format(self.__class__.__name__, self.in_channels, 78 | self.out_channels) 79 | -------------------------------------------------------------------------------- /jittor_geometric/nn/conv/gat_conv.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Description: 3 | Author: lusz 4 | Date: 2024-06-26 10:57:06 5 | ''' 6 | from typing import Optional, Tuple 7 | from jittor_geometric.typing import Adj, OptVar 8 | 9 | import jittor as jt 10 | from jittor import Var 11 | from jittor_geometric.nn.conv import MessagePassingNts 12 | from jittor_geometric.utils import add_remaining_self_loops 13 | from jittor_geometric.utils.num_nodes import maybe_num_nodes 14 | 15 | from ..inits import glorot, zeros 16 | from jittor_geometric.data import CSC,CSR 17 | from jittor_geometric.ops import ScatterToEdge,EdgeSoftmax,aggregateWithWeight,ScatterToVertex 18 | 19 | 20 | class GATConv(MessagePassingNts): 21 | r"""The graph convolutional operator from the `"Graph Attention Networks" 22 | 2018 ICLR _ paper 23 | """ 24 | 25 | _cached_edge_index: Optional[Tuple[Var, Var]] 26 | _cached_csc: Optional[CSC] 27 | 28 | def __init__(self, in_channels: int, out_channels: int,e_num: int, 29 | improved: bool = False, cached: bool = False, 30 | add_self_loops: bool = True, normalize: bool = True, 31 | bias: bool = True, **kwargs): 32 | 33 | kwargs.setdefault('aggr', 'add') 34 | super(GATConv, self).__init__(**kwargs) 35 | 36 | self.in_channels = in_channels 37 | self.out_channels = out_channels 38 | self.improved = improved 39 | self.cached = cached 40 | self.add_self_loops = add_self_loops 41 | self.normalize = normalize 42 | 43 | self._cached_edge_index = None 44 | self._cached_adj_t = None 45 | 46 | self.weight = jt.random((in_channels, out_channels)) 47 | self.edge_weight=jt.random((2*out_channels,1)) 48 | self.reset_parameters() 49 | 50 | def reset_parameters(self): 51 | glorot(self.weight) 52 | glorot(self.edge_weight) 53 | self._cached_adj_t = None 54 | self._cached_csc=None 55 | 56 | def execute(self, x: Var, csc: CSC) -> Var: 57 | """""" 58 | out=self.vertex_forward(x) 59 | out = self.propagate(x=out,csc=csc) 60 | return out 61 | 62 | def propagate(self,x,csc): 63 | e_msg=self.scatter_to_edge(x,csc) 64 | out = self.edge_forward(e_msg,csc) 65 | out=self.scatter_to_vertex(out,csc) 66 | return out 67 | 68 | def scatter_to_edge(self,x,csc)->Var: 69 | out1=ScatterToEdge(x,csc,"src") 70 | out2=ScatterToEdge(x,csc,"dst") 71 | out =jt.contrib.concat([out1,out2],dim=1) 72 | return out 73 | 74 | def edge_forward(self,x,csc)->Var: 75 | out = x @ self.edge_weight 76 | m=jt.nn.leaky_relu(out,scale=0.2) 77 | a=EdgeSoftmax(m,csc) 78 | half_dim=int(jt.size(x,1)/2) 79 | e_msg=x[:,0:half_dim] 80 | return e_msg * a 81 | 82 | def scatter_to_vertex(self,edge,csc)->Var: 83 | out=ScatterToVertex(edge,csc,'src') 84 | return out 85 | 86 | def vertex_forward(self,x:Var)->Var: 87 | out = x @ self.weight 88 | out=jt.nn.relu(out) 89 | return out 90 | 91 | 92 | def __repr__(self): 93 | return '{}({}, {})'.format(self.__class__.__name__, self.in_channels, 94 | self.out_channels) 95 | -------------------------------------------------------------------------------- /jittor_geometric/nn/conv/gcn2_conv.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple 2 | from jittor_geometric.typing import Adj, OptVar 3 | 4 | from math import log 5 | 6 | import jittor as jt 7 | from jittor import Var 8 | from jittor_geometric.nn.conv import MessagePassing 9 | from jittor_geometric.data import CSC, CSR 10 | from jittor_geometric.ops import SpmmCsr, aggregateWithWeight 11 | 12 | from ..inits import glorot 13 | 14 | 15 | class GCN2Conv(MessagePassing): 16 | r"""The graph convolutional operator with initial residual connections and 17 | identity mapping (GCNII) from the `"Simple and Deep Graph Convolutional 18 | Networks" `_ paper. 19 | 20 | This class implements the GCNII layer, which combines initial residual connections and identity mapping 21 | to enable deeper graph convolutional networks without oversmoothing. The layer supports both message-passing 22 | and sparse matrix multiplication (SPMM) for efficient propagation. 23 | 24 | Mathematical Formulation: 25 | .. math:: 26 | \mathbf{H}^{(l)} = (1 - \beta) \big( (1 - \alpha) \mathbf{H}^{(l-1)} + \alpha \mathbf{H}^{(0)} \big) + 27 | \beta \big( \mathbf{\Theta}_1 \mathbf{H}^{(l-1)} + \mathbf{\Theta}_2 \mathbf{H}^{(0)} \big) 28 | 29 | where: 30 | - :math:`\mathbf{H}^{(l)}` is the node feature matrix at layer :math:`l`. 31 | - :math:`\mathbf{H}^{(0)}` is the initial node feature matrix. 32 | - :math:`\mathbf{\Theta}_1` and :math:`\mathbf{\Theta}_2` are learnable weight matrices. 33 | - :math:`\alpha` controls the strength of the initial residual connection. 34 | - :math:`\beta` balances feature aggregation and transformation. 35 | 36 | Args: 37 | in_channels (int): Number of input features per node. 38 | out_channels (int): Number of output features per node. 39 | cached (bool, optional): If set to `True`, caches the normalized edge indices. Default is `False`. 40 | add_self_loops (bool, optional): If set to `True`, adds self-loops to the input graph. Default is `True`. 41 | spmm (bool, optional): If set to `True`, uses sparse matrix multiplication (SPMM) for propagation. Default is `False`. 42 | **kwargs (optional): Additional arguments for the base `MessagePassing` class. 43 | """ 44 | 45 | def __init__(self, in_channels: int, out_channels: int, cached: bool = False, add_self_loops: bool = True, 46 | spmm: bool=False, **kwargs): 47 | kwargs.setdefault('aggr', 'add') 48 | super(GCN2Conv, self).__init__(**kwargs) 49 | 50 | self.cached = cached 51 | self._cached_edge_index = None 52 | 53 | self.weight1 = jt.random((in_channels, out_channels)) 54 | self.weight2 = jt.random((in_channels, out_channels)) 55 | 56 | self.spmm = spmm 57 | self.reset_parameters() 58 | 59 | def reset_parameters(self): 60 | glorot(self.weight1) 61 | glorot(self.weight2) 62 | 63 | 64 | def execute(self, x: Var, x_0: Var, csc: OptVar, csr: OptVar, alpha: float, beta: float) -> Var: 65 | 66 | support = (1-beta) * (1-alpha) * x + beta * jt.matmul(x, self.weight1) 67 | initial = (1-beta) * (alpha) * x_0 + beta * jt.matmul(x_0, self.weight2) 68 | if self.spmm and jt.flags.use_cuda==1: 69 | out = self.propagate_spmm(x=support, csr=csr) + initial 70 | else: 71 | out = self.propagate_msg(x=support, csc=csc, csr=csr) + initial 72 | return out 73 | 74 | def propagate_msg(self, x, csc: CSC, csr: CSR): 75 | out = aggregateWithWeight(x, csc, csr) 76 | return out 77 | 78 | def propagate_spmm(self, x, csr: CSR): 79 | out = SpmmCsr(x, csr) 80 | return out 81 | 82 | def __repr__(self): 83 | return '{}({}, alpha={}, beta={})'.format(self.__class__.__name__, 84 | self.channels, self.alpha, 85 | self.beta) -------------------------------------------------------------------------------- /jittor_geometric/nn/conv/gpr_conv.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Description: 3 | Author: ivam 4 | Date: 2024-12-13 5 | ''' 6 | from typing import Optional, Tuple 7 | from jittor_geometric.typing import Adj, OptVar 8 | import jittor as jt 9 | import numpy as np 10 | from jittor import Var,nn,Module 11 | from jittor_geometric.utils import add_remaining_self_loops 12 | from jittor_geometric.utils.num_nodes import maybe_num_nodes 13 | 14 | from ..inits import glorot, zeros 15 | from jittor_geometric.data import CSC, CSR 16 | from jittor_geometric.ops import SpmmCsr, aggregateWithWeight 17 | 18 | class GPRGNN(Module): 19 | r"""The graph propagation operator from the `"Adaptive Universal 20 | Generalized PageRank Graph Neural Network" 21 | `_ paper 22 | 23 | Mathematical Formulation: 24 | .. math:: 25 | \mathbf{Z} = \sum_{k=0}^{K} \alpha_k \mathbf{P}^{k} \mathbf{X}. 26 | where: 27 | :math:`\mathbf{X}` is the input node feature matrix. 28 | :math:`\mathbf{Z}` is the output node feature matrix. 29 | :math:`\mathbf{P}` is the normalized adjacency matrix of the graph. 30 | :math:`\alpha_k` is the parameter for the :math:`k`-th order polynomial. 31 | 32 | Args: 33 | K (int): Order of polynomial, or maximum number of hops considered for message passing. 34 | alpha (float): Parameter controlling the weighting of different hops. 35 | Init (str): Initialization method for the propagation weights. Possible values are 'SGC', 'PPR', 'NPPR', 'Random', 'WS'. 36 | spmm (bool, optional): If set to `True`, uses sparse matrix multiplication (SPMM) for propagation. Default is `True`. 37 | """ 38 | 39 | def __init__(self, K: int, alpha: float, Init: str, spmm:bool=True, **kwargs): 40 | kwargs.setdefault('aggr', 'add') 41 | super(GPRGNN, self).__init__(**kwargs) 42 | self.K = K 43 | 44 | assert Init in ['SGC', 'PPR', 'NPPR', 'Random', 'WS'] 45 | if Init == 'SGC': 46 | # SGC-like 47 | TEMP = 0.0*np.ones(K+1) 48 | TEMP[-1] = 1.0 49 | elif Init == 'PPR': 50 | # PPR-like 51 | TEMP = alpha*(1-alpha)**np.arange(K+1) 52 | TEMP[-1] = (1-alpha)**K 53 | elif Init == 'NPPR': 54 | # Negative PPR 55 | TEMP = (alpha)**np.arange(K+1) 56 | TEMP = TEMP/np.sum(np.abs(TEMP)) 57 | elif Init == 'Random': 58 | # Random 59 | bound = np.sqrt(3/(K+1)) 60 | TEMP = np.random.uniform(-bound, bound, K+1) 61 | TEMP = TEMP/np.sum(np.abs(TEMP)) 62 | elif Init == 'WS': 63 | # Specify Gamma 64 | TEMP = Gamma 65 | 66 | TEMP_jt = jt.array(TEMP) 67 | self.temp = nn.Parameter(TEMP_jt) 68 | 69 | self.spmm = spmm 70 | self.reset_parameters() 71 | 72 | def reset_parameters(self): 73 | pass 74 | 75 | def execute(self, x: Var, csc: OptVar, csr: OptVar) -> Var: 76 | out = x*(self.temp[0]) 77 | for k in range(self.K): 78 | if self.spmm and jt.flags.use_cuda==1: 79 | x = self.propagate_spmm(x=x, csr=csr) 80 | else: 81 | x = self.propagate_msg(x=x, csc=csc, csr=csr) 82 | out = out + self.temp[k+1]*x 83 | return out 84 | 85 | # propagate by message passing 86 | def propagate_msg(self,x, csc: CSC, csr:CSR): 87 | out = aggregateWithWeight(x,csc,csr) 88 | return out 89 | 90 | # propagate by spmm 91 | def propagate_spmm(self, x, csr:CSR): 92 | out = SpmmCsr(x,csr) 93 | return out 94 | 95 | def __repr__(self): 96 | return '{}({}, {})'.format(self.__class__.__name__, self.in_channels, 97 | self.out_channels) 98 | -------------------------------------------------------------------------------- /jittor_geometric/nn/conv/message_passiong_nts.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import sys 4 | from typing import List, Optional, Set 5 | from inspect import Parameter 6 | 7 | import jittor as jt 8 | from jittor import nn, Module 9 | from jittor import Var 10 | from jittor_geometric.typing import Adj, Size 11 | 12 | from .utils.inspector import Inspector 13 | from jittor_geometric.data import CSC,CSR 14 | from jittor_geometric.ops import aggregateWithWeight 15 | 16 | class MessagePassingNts(Module): 17 | special_args: Set[str] = { 18 | 'edge_index', 'adj_t', 'edge_index_i', 'edge_index_j', 'size', 19 | 'size_i', 'size_j', 'ptr', 'index', 'dim_size' 20 | } 21 | 22 | def __init__(self, aggr: Optional[str] = "add", 23 | flow: str = "source_to_target", node_dim: int = -2): 24 | 25 | super(MessagePassingNts, self).__init__() 26 | 27 | # graph operations 28 | def propagate(self,x): 29 | return 30 | 31 | def aggregate_with_weight(self,x,csc,csr)->Var: 32 | """ 33 | Used for GCN demo ,combine 'scatter_to_edge' with 'scatter_to_vertex' 34 | """ 35 | output=aggregateWithWeight(x,csc,csr) 36 | return output 37 | 38 | def scatter_to_edge(self,x)->Var: 39 | """ 40 | ScatterToEdge is an edge message generating operation t 41 | hat scatters the source and destination representations 42 | to edges for the EdgeForward computation 43 | """ 44 | return 45 | 46 | def edge_forward(self,x:Var)->Var: 47 | """ 48 | EdgeForward is a parameterized function defined on each 49 | edge to generate an output message by combining the edge 50 | representation with the representations of source and destination. 51 | """ 52 | return 53 | 54 | def scatter_to_vertex(self,x,csc)->Var: 55 | """ 56 | Scatter_to_vertex takes incoming edge-associated Vars as input 57 | and outputs a new vertex representation for next layer's computation 58 | """ 59 | return 60 | 61 | def vertex_forward(self,x:Var)->Var: 62 | """ 63 | VertexForward is a parameterized function defined on each vertex 64 | to generate new vertex representation by applying zero or several 65 | NN models on aggregated neighborhood representations. 66 | """ 67 | return 68 | -------------------------------------------------------------------------------- /jittor_geometric/nn/conv/optbasis_conv.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Description: 3 | Author: Yuhe Guo 4 | Date: 2024-12-30 5 | ''' 6 | from typing import Optional, Tuple 7 | from jittor_geometric.typing import Adj, OptVar 8 | import jittor as jt 9 | import numpy as np 10 | from jittor import Var,nn,Module 11 | from jittor_geometric.utils import add_remaining_self_loops 12 | from jittor_geometric.utils.num_nodes import maybe_num_nodes 13 | 14 | from ..inits import glorot, zeros 15 | from jittor_geometric.data import CSC, CSR 16 | from jittor_geometric.ops import SpmmCsr, aggregateWithWeight 17 | 18 | class OptBasisConv(Module): 19 | r"""Graph Neural Networks with Learnable and Optimal Polynomial Bases 20 | `_ paper. 21 | 22 | This class implements the OptBasisConv architecture, which implicitly utilize the optimal polynomial bases on each channel via three term recurrence propagation. 23 | Check Algorithm 4 and Algorithm 5 in the paper for more details. 24 | 25 | Mathematical Formulation: 26 | Please refer to Algorithm 2, 4 and 5 in paper for more details. 27 | 28 | Args: 29 | K (int): Order of polynomial bases. 30 | spmm (bool, optional): If set to `True`, uses sparse matrix multiplication (SPMM) for propagation. Default is `True`. 31 | n_channels (int): Number of signal channels to be filtered. 32 | **kwargs (optional): Additional arguments for the base `Module`. 33 | """ 34 | def __init__(self, K: int, n_channels:int, spmm:bool=True, **kwargs): 35 | kwargs.setdefault('aggr', 'add') 36 | super(OptBasisConv, self).__init__(**kwargs) 37 | 38 | self.K = K 39 | self.spmm = spmm 40 | self.n_channels = n_channels 41 | 42 | self.reset_parameters() 43 | 44 | def reset_parameters(self): 45 | t = jt.zeros(self.K+1) 46 | t[0] = 1 47 | t = t.repeat(self.n_channels, 1) 48 | self.alpha_params = jt.Var(t) 49 | 50 | def three_term_prop(self, csr, last_h, second_last_h): 51 | rst = self.propagate_spmm(x=last_h, csr=csr) 52 | _t = jt.linalg.einsum('nh,nh->h',rst,last_h) 53 | rst = rst - jt.linalg.einsum('h,nh->nh', _t, last_h) 54 | _t = jt.linalg.einsum('nh,nh->h',rst,second_last_h) 55 | rst = rst - jt.linalg.einsum('h,nh->nh', _t, second_last_h) 56 | rst = rst / jt.clamp((jt.norm(rst,dim=0)),1e-8) 57 | return rst 58 | 59 | def execute(self, x, csr): 60 | blank_noise = jt.randn_like(x)*1e-5 61 | x = x + blank_noise 62 | h0 = x / jt.clamp((jt.norm(x,dim=0)), 1e-8) 63 | rst = jt.zeros_like(h0) 64 | rst = rst + self.alpha_params[:,0] * h0 65 | 66 | last_h = h0 67 | second_last_h = jt.zeros_like(h0) 68 | 69 | for i in range(1, self.K+1): 70 | h_i = self.three_term_prop(csr, last_h, second_last_h) 71 | rst = rst + self.alpha_params[:,i] * h_i 72 | second_last_h = last_h 73 | last_h = h_i 74 | 75 | return rst 76 | 77 | # propagate by message passing 78 | def propagate_msg(self,x, csc: CSC, csr:CSR): 79 | out = aggregateWithWeight(x,csc,csr) 80 | return out 81 | 82 | # propagate by spmm 83 | def propagate_spmm(self, x, csr:CSR): 84 | out = SpmmCsr(x,csr) 85 | return out 86 | 87 | def __repr__(self): 88 | return '{}({}, {})'.format(self.__class__.__name__, self.in_channels, 89 | self.out_channels) 90 | -------------------------------------------------------------------------------- /jittor_geometric/nn/conv/sg_conv.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | from jittor_geometric.typing import Adj, OptVar 3 | 4 | import jittor as jt 5 | from jittor import Var 6 | from jittor.nn import Linear 7 | from jittor_geometric.nn.conv import MessagePassing 8 | 9 | from ..inits import glorot, zeros 10 | from jittor_geometric.data import CSC, CSR 11 | from jittor_geometric.ops import SpmmCsr, aggregateWithWeight 12 | 13 | 14 | class SGConv(MessagePassing): 15 | r"""The simple graph convolutional operator from the `"Simplifying Graph 16 | Convolutional Networks" `_ paper. 17 | 18 | This class implements the Simplified Graph Convolution (SGC) layer, which removes nonlinearities and collapses weight 19 | matrices across layers to achieve a simplified and computationally efficient graph convolution. 20 | 21 | Mathematical Formulation: 22 | .. math:: 23 | \mathbf{X}^{\prime} = {\left(\mathbf{\hat{D}}^{-1/2} \mathbf{\hat{A}} 24 | \mathbf{\hat{D}}^{-1/2} \right)}^K \mathbf{X} \mathbf{\Theta}, 25 | 26 | where: 27 | - :math:`\mathbf{\hat{A}} = \mathbf{A} + \mathbf{I}` denotes the adjacency matrix with added self-loops. 28 | - :math:`\hat{D}_{ii} = \sum_{j=0} \hat{A}_{ij}` is its diagonal degree matrix. 29 | - :math:`K` controls the number of propagation steps. 30 | - The adjacency matrix can include other values than :obj:`1`, representing edge weights via the optional `edge_weight` variable. 31 | 32 | Args: 33 | in_channels (int): Number of input features per node. 34 | out_channels (int): Number of output features per node. 35 | K (int, optional): Number of propagation steps. Default is `1`. 36 | bias (bool, optional): Whether to include a learnable bias term. Default is `True`. 37 | **kwargs (optional): Additional arguments for the `MessagePassing` class. 38 | """ 39 | 40 | _cached_x: Optional[Var] 41 | 42 | def __init__(self, in_channels: int, out_channels: int, K: int = 1, bias: bool = True, spmm:bool=True, **kwargs): 43 | kwargs.setdefault('aggr', 'add') 44 | super(SGConv, self).__init__(**kwargs) 45 | 46 | self.in_channels = in_channels 47 | self.out_channels = out_channels 48 | self.K = K 49 | self.lin = Linear(in_channels, out_channels, bias=bias) 50 | self.spmm=spmm 51 | self.reset_parameters() 52 | 53 | 54 | def reset_parameters(self): 55 | glorot(self.lin.parameters()[0]) 56 | zeros(self.lin.parameters()[1]) 57 | self._cached_adj_t = None 58 | self._cached_csc=None 59 | 60 | 61 | def execute(self, x: Var, csc: OptVar, csr: OptVar) -> Var: 62 | """Perform forward propagation.""" 63 | 64 | for k in range(self.K): 65 | if self.spmm and jt.flags.use_cuda == 1: 66 | x = self.propagate_spmm(x=x, csr=csr) 67 | else: 68 | x = self.propagate_msg(x=x, csc=csc, csr=csr) 69 | return self.lin(x) 70 | 71 | # propagate by message passing 72 | def propagate_msg(self,x, csc: CSC, csr:CSR): 73 | out = aggregateWithWeight(x,csc,csr) 74 | return out 75 | 76 | # propagate by spmm 77 | def propagate_spmm(self, x, csr:CSR): 78 | out = SpmmCsr(x,csr) 79 | return out 80 | 81 | def __repr__(self): 82 | return '{}({}, {}, K={})'.format(self.__class__.__name__, 83 | self.in_channels, self.out_channels, 84 | self.K) 85 | -------------------------------------------------------------------------------- /jittor_geometric/nn/conv/utils/__init__.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Description: 3 | Author: lusz 4 | Date: 2025-01-08 16:35:05 5 | ''' 6 | from .inspector import Inspector 7 | from .typing import * -------------------------------------------------------------------------------- /jittor_geometric/nn/conv/utils/inspector.py: -------------------------------------------------------------------------------- 1 | import re 2 | import inspect 3 | from collections import OrderedDict 4 | from typing import Dict, List, Any, Optional, Callable, Set 5 | 6 | from .typing import parse_types 7 | 8 | 9 | class Inspector(object): 10 | def __init__(self, base_class: Any): 11 | self.base_class: Any = base_class 12 | self.params: Dict[str, Dict[str, Any]] = {} 13 | 14 | # func (Callable): 需要检查的函数。 pop_first (bool): 是否移除第一个参数(例如 `self`)。默认为 False 15 | def inspect(self, func: Callable, 16 | pop_first: bool = False) -> Dict[str, Any]: 17 | params = inspect.signature(func).parameters 18 | params = OrderedDict(params) 19 | if pop_first: 20 | params.popitem(last=False) 21 | self.params[func.__name__] = params 22 | 23 | def keys(self, func_names: Optional[List[str]] = None) -> Set[str]: 24 | keys = [] 25 | for func in func_names or list(self.params.keys()): 26 | keys += self.params[func].keys() 27 | # print(keys) ['x_j', 'edge_weight', 'index', 'ptr', 'dim_size'] 28 | return set(keys) 29 | 30 | def __implements__(self, cls, func_name: str) -> bool: 31 | if cls.__name__ == 'MessagePassing': 32 | return False 33 | if func_name in cls.__dict__.keys(): 34 | return True 35 | return any(self.__implements__(c, func_name) for c in cls.__bases__) 36 | 37 | def implements(self, func_name: str) -> bool: 38 | return self.__implements__(self.base_class.__class__, func_name) 39 | 40 | def types(self, func_names: Optional[List[str]] = None) -> Dict[str, str]: 41 | out: Dict[str, str] = {} 42 | for func_name in func_names or list(self.params.keys()): 43 | func = getattr(self.base_class, func_name) 44 | arg_types = parse_types(func)[0][0] 45 | for key in self.params[func_name].keys(): 46 | if key in out and out[key] != arg_types[key]: 47 | raise ValueError( 48 | (f'Found inconsistent types for argument {key}. ' 49 | f'Expected type {out[key]} but found type ' 50 | f'{arg_types[key]}.')) 51 | out[key] = arg_types[key] 52 | return out 53 | 54 | def distribute(self, func_name, kwargs: Dict[str, Any]): 55 | # 确保所有必需的参数都被提供,并且在缺失时使用默认值 56 | out = {} 57 | for key, param in self.params[func_name].items(): 58 | data = kwargs.get(key, inspect.Parameter.empty) 59 | if data is inspect.Parameter.empty: 60 | if param.default is inspect.Parameter.empty: 61 | raise TypeError(f'Required parameter {key} is empty.') 62 | data = param.default 63 | out[key] = data 64 | return out 65 | 66 | 67 | def func_header_repr(func: Callable, keep_annotation: bool = True) -> str: 68 | source = inspect.getsource(func) 69 | signature = inspect.signature(func) 70 | 71 | if keep_annotation: 72 | return ''.join(re.split(r'(\).*?:.*?\n)', source, 73 | maxsplit=1)[:2]).strip() 74 | 75 | params_repr = ['self'] 76 | for param in signature.parameters.values(): 77 | params_repr.append(param.name) 78 | if param.default is not inspect.Parameter.empty: 79 | params_repr[-1] += f'={param.default}' 80 | 81 | return f'def {func.__name__}({", ".join(params_repr)}):' 82 | 83 | 84 | def func_body_repr(func: Callable, keep_annotation: bool = True) -> str: 85 | source = inspect.getsource(func) 86 | body_repr = re.split(r'\).*?:.*?\n', source, maxsplit=1)[1] 87 | if not keep_annotation: 88 | body_repr = re.sub(r'\s*# type:.*\n', '', body_repr) 89 | return body_repr 90 | -------------------------------------------------------------------------------- /jittor_geometric/nn/conv/utils/typing.py: -------------------------------------------------------------------------------- 1 | import re 2 | import inspect 3 | import pyparsing as pp 4 | from itertools import product 5 | from collections import OrderedDict 6 | from typing import Callable, Tuple, Dict, List 7 | 8 | 9 | def split_types_repr(types_repr: str) -> List[str]: 10 | out = [] 11 | i = depth = 0 12 | for j, char in enumerate(types_repr): 13 | if char == '[': 14 | depth += 1 15 | elif char == ']': 16 | depth -= 1 17 | elif char == ',' and depth == 0: 18 | out.append(types_repr[i:j].strip()) 19 | i = j + 1 20 | out.append(types_repr[i:].strip()) 21 | return out 22 | 23 | 24 | def sanitize(type_repr: str): 25 | type_repr = re.sub(r'', r'\1', type_repr) 26 | type_repr = type_repr.replace('typing.', '') 27 | type_repr = type_repr.replace('torch_sparse.tensor.', '') 28 | type_repr = type_repr.replace('Adj', 'Union[Tensor, SparseTensor]') 29 | 30 | # Replace `Union[..., NoneType]` by `Optional[...]`. 31 | sexp = pp.nestedExpr(opener='[', closer=']') 32 | tree = sexp.parseString(f'[{type_repr.replace(",", " ")}]').asList()[0] 33 | 34 | def union_to_optional_(tree): 35 | for i in range(len(tree)): 36 | e, n = tree[i], tree[i + 1] if i + 1 < len(tree) else [] 37 | if e == 'Union' and n[-1] == 'NoneType': 38 | tree[i] = 'Optional' 39 | tree[i + 1] = tree[i + 1][:-1] 40 | elif e == 'Union' and 'NoneType' in n: 41 | idx = n.index('NoneType') 42 | n[idx] = [n[idx - 1]] 43 | n[idx - 1] = 'Optional' 44 | elif isinstance(e, list): 45 | tree[i] = union_to_optional_(e) 46 | return tree 47 | 48 | tree = union_to_optional_(tree) 49 | type_repr = re.sub(r'\'|\"', '', str(tree)[1:-1]).replace(', [', '[') 50 | 51 | return type_repr 52 | 53 | 54 | def param_type_repr(param) -> str: 55 | if param.annotation is inspect.Parameter.empty: 56 | return 'jittor.Var' 57 | return sanitize(re.split(r':|='.strip(), str(param))[1]) 58 | 59 | 60 | def return_type_repr(signature) -> str: 61 | return_type = signature.return_annotation 62 | if return_type is inspect.Parameter.empty: 63 | return 'jittor.Var' 64 | elif str(return_type)[:6] != ' List[Tuple[Dict[str, str], str]]: 73 | source = inspect.getsource(func) 74 | signature = inspect.signature(func) 75 | 76 | # Parse `# type: (...) -> ...` annotation. Note that it is allowed to pass 77 | # multiple `# type:` annotations in `forward()`. 78 | iterator = re.finditer(r'#\s*type:\s*\((.*)\)\s*->\s*(.*)\s*\n', source) 79 | matches = list(iterator) 80 | 81 | if len(matches) > 0: 82 | out = [] 83 | args = list(signature.parameters.keys()) 84 | for match in matches: 85 | arg_types_repr, return_type = match.groups() 86 | arg_types = split_types_repr(arg_types_repr) 87 | arg_types = OrderedDict((k, v) for k, v in zip(args, arg_types)) 88 | return_type = return_type.split('#')[0].strip() 89 | out.append((arg_types, return_type)) 90 | return out 91 | 92 | # Alternatively, parse annotations using the inspected signature. 93 | else: 94 | ps = signature.parameters 95 | arg_types = OrderedDict((k, param_type_repr(v)) for k, v in ps.items()) 96 | return [(arg_types, return_type_repr(signature))] 97 | 98 | 99 | def resolve_types(arg_types: Dict[str, str], 100 | return_type_repr: str) -> List[Tuple[List[str], str]]: 101 | out = [] 102 | for type_repr in arg_types.values(): 103 | if type_repr[:5] == 'Union': 104 | out.append(split_types_repr(type_repr[6:-1])) 105 | else: 106 | out.append([type_repr]) 107 | return [(x, return_type_repr) for x in product(*out)] 108 | -------------------------------------------------------------------------------- /jittor_geometric/nn/dense/__init__.py: -------------------------------------------------------------------------------- 1 | from .merge_predictor import MergeLayer 2 | from .time_encoder import TimeEncoder 3 | 4 | __all__ = [ 5 | 'MergeLayer', 6 | 'TimeEncoder', 7 | ] 8 | 9 | classes = __all__ 10 | -------------------------------------------------------------------------------- /jittor_geometric/nn/dense/merge_predictor.py: -------------------------------------------------------------------------------- 1 | import jittor as jt 2 | import jittor.nn as nn 3 | import numpy as np 4 | 5 | class MergeLayer(nn.Module): 6 | 7 | def __init__(self, input_dim1: int, input_dim2: int, hidden_dim: int, output_dim: int): 8 | """ 9 | Merge Layer to merge two inputs via: input_dim1 + input_dim2 -> hidden_dim -> output_dim. 10 | :param input_dim1: int, dimension of first input 11 | :param input_dim2: int, dimension of the second input 12 | :param hidden_dim: int, hidden dimension 13 | :param output_dim: int, dimension of the output 14 | """ 15 | super().__init__() 16 | self.fc1 = nn.Linear(input_dim1 + input_dim2, hidden_dim) 17 | self.fc2 = nn.Linear(hidden_dim, output_dim) 18 | self.act = nn.ReLU() 19 | 20 | def execute(self, input_1: jt.Var, input_2: jt.Var): 21 | """ 22 | merge and project the inputs 23 | :param input_1: Var, shape (*, input_dim1) 24 | :param input_2: Var, shape (*, input_dim2) 25 | :return: 26 | """ 27 | # Var, shape (*, input_dim1 + input_dim2) 28 | x = jt.cat([input_1, input_2], dim=1) 29 | # Var, shape (*, output_dim) 30 | h = self.fc2(self.act(self.fc1(x))) 31 | return h 32 | 33 | -------------------------------------------------------------------------------- /jittor_geometric/nn/dense/time_encoder.py: -------------------------------------------------------------------------------- 1 | import jittor as jt 2 | import jittor.nn as nn 3 | import numpy as np 4 | 5 | class TimeEncoder(nn.Module): 6 | 7 | def __init__(self, time_dim: int, parameter_requires_grad: bool = True): 8 | """ 9 | Time encoder. 10 | :param time_dim: int, dimension of time encodings 11 | :param parameter_requires_grad: boolean, whether the parameter in TimeEncoder needs gradient 12 | """ 13 | super(TimeEncoder, self).__init__() 14 | 15 | self.time_dim = time_dim 16 | # trainable parameters for time encoding 17 | self.w = nn.Linear(1, time_dim) 18 | self.w.weight = nn.Parameter((jt.Var(1 / 10 ** np.linspace(0, 9, time_dim, dtype=np.float32))).reshape(time_dim, -1)) 19 | self.w.bias = nn.Parameter(jt.zeros(time_dim)) 20 | 21 | if not parameter_requires_grad: 22 | self.w.weight.requires_grad = False 23 | self.w.bias.requires_grad = False 24 | 25 | def execute(self, timestamps: jt.Var): 26 | """ 27 | compute time encodings of time in timestamps 28 | :param timestamps: Var, shape (batch_size, seq_len) 29 | :return: 30 | """ 31 | # Var, shape (batch_size, seq_len, 1) 32 | timestamps = timestamps.unsqueeze(dim=2) 33 | 34 | # Var, shape (batch_size, seq_len, time_dim) 35 | output = jt.cos(self.w(timestamps)) 36 | 37 | return output -------------------------------------------------------------------------------- /jittor_geometric/nn/inits.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import jittor as jt 4 | from jittor import init 5 | 6 | 7 | def uniform(size, var): 8 | if var is not None: 9 | bound = 1.0 / math.sqrt(size) 10 | init.uniform_(var, -bound, bound) 11 | 12 | 13 | def kaiming_uniform(var, fan, a): 14 | if var is not None: 15 | bound = math.sqrt(6 / ((1 + a**2) * fan)) 16 | init.uniform_(var, -bound, bound) 17 | 18 | 19 | def glorot(var): 20 | if var is not None: 21 | stdv = math.sqrt(6.0 / (var.size(-2) + var.size(-1))) 22 | init.uniform_(var, -stdv, stdv) 23 | 24 | def glorot_orthogonal(var, scale): 25 | if var is not None: 26 | # 步骤1:正交初始化 27 | # Jittor目前没有直接的正交初始化,需要自己实现 28 | rows = var.size(-2) 29 | cols = var.size(-1) 30 | 31 | # 创建随机矩阵 32 | flattened = var.view(rows, cols) 33 | if rows < cols: 34 | flattened = flattened.transpose((1, 0)) 35 | 36 | # QR分解实现正交初始化 37 | q, r = jt.linalg.qr(jt.randn((max(rows, cols), max(rows, cols)))) 38 | # 处理符号以确保结果确定性 39 | d = jt.diag(r) 40 | ph = jt.nn.sign(d) 41 | q *= ph 42 | 43 | if rows < cols: 44 | q = q.transpose((1, 0)) 45 | 46 | var = q[:rows, :cols] 47 | 48 | scale /= ((var.size(-2) + var.size(-1)) * var.var()) 49 | 50 | var *= math.sqrt(scale) 51 | 52 | return var 53 | 54 | def xavier_normal(var): 55 | if var is not None: 56 | stdv = math.sqrt(2.0 / (var.size(-2) + var.size(-1))) 57 | init.gauss_(var, mean=0.0, std=stdv) 58 | 59 | def zeros(var): 60 | if var is not None: 61 | init.constant_(var, 0) 62 | 63 | 64 | def ones(var): 65 | if var is not None: 66 | init.constant_(var, 1) 67 | 68 | 69 | def normal(var, mean, std): 70 | if var is not None: 71 | var.assign(jt.normal(mean, std, size=var.size)) 72 | 73 | 74 | def reset(nn): 75 | def _reset(item): 76 | if hasattr(item, 'reset_parameters'): 77 | item.reset_parameters() 78 | 79 | if nn is not None: 80 | if hasattr(nn, 'children') and len(list(nn.children())) > 0: 81 | for item in nn.children(): 82 | _reset(item) 83 | else: 84 | _reset(nn) 85 | -------------------------------------------------------------------------------- /jittor_geometric/nn/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .tgn import TGNMemory 2 | from .dyrep import DyRepMemory 3 | from .jodie import JODIEEmbedding, compute_src_dst_node_time_shifts 4 | from .graphmixer import GraphMixer 5 | from .dygformer import DyGFormer 6 | from .schnet import SchNet 7 | from .unimol import UniMolModel 8 | from .dimenet import DimeNet 9 | 10 | __all__ = [ 11 | 'TGNMemory', 12 | 'DyRepMemory', 13 | 'JODIEEmbedding', 14 | 'GraphMixer', 15 | 'DyGFormer', 16 | 'SchNet', 17 | 'DimeNet', 18 | 'compute_src_dst_node_time_shifts', 19 | 'UniMolModel', 20 | ] 21 | 22 | classes = __all__ 23 | -------------------------------------------------------------------------------- /jittor_geometric/nn/pool/__init__.py: -------------------------------------------------------------------------------- 1 | from .glob import global_add_pool 2 | 3 | __all__ = [ 4 | 'global_add_pool', 5 | ] -------------------------------------------------------------------------------- /jittor_geometric/nn/pool/glob.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | import jittor as jt 3 | from jittor_geometric.utils import scatter 4 | 5 | def global_add_pool(x: jt.Var, batch: Optional[jt.Var], 6 | size: Optional[int] = None) -> jt.Var: 7 | r"""Returns batch-wise graph-level-outputs by adding node features 8 | across the node dimension. 9 | 10 | For a single graph :math:`\mathcal{G}_i`, its output is computed by 11 | 12 | .. math:: 13 | \mathbf{r}_i = \sum_{n=1}^{N_i} \mathbf{x}_n. 14 | 15 | Functional method of the 16 | :class:`~jittor_geometric.nn.aggr.SumAggregation` module. 17 | 18 | Args: 19 | x (jt.Var): Node feature matrix 20 | :math:`\mathbf{X} \in \mathbb{R}^{(N_1 + \ldots + N_B) \times F}`. 21 | batch (jt.Var, optional): The batch vector 22 | :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns 23 | each node to a specific example. 24 | size (int, optional): The number of examples :math:`B`. 25 | Automatically calculated if not given. (default: :obj:`None`) 26 | """ 27 | dim = -1 if x.ndim == 1 else -2 28 | 29 | if batch is None: 30 | return x.sum(dim=dim, keepdims=x.ndim <= 2) 31 | return scatter(x, batch, dim=dim, dim_size=size, reduce='sum') -------------------------------------------------------------------------------- /jittor_geometric/ops/__init__.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Author: lusz 3 | Date: 2024-06-21 11:01:25 4 | Description: 5 | ''' 6 | from .aggregateWithWeight import aggregateWithWeight 7 | from .cootocsc import cootocsc 8 | from .cootocsr import cootocsr 9 | from .toundirected import toUndirected 10 | from .scatterToEdge import ScatterToEdge 11 | from .edgesoftmax import EdgeSoftmax 12 | from .scatterToVertex import ScatterToVertex 13 | from .spmmcsr import SpmmCsr 14 | from .spmmcoo import SpmmCoo 15 | from .saparse_ops import from_nodes,to_nodes -------------------------------------------------------------------------------- /jittor_geometric/ops/aggregateWithWeight.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Description: 3 | Author: lusz 4 | Date: 2024-06-21 14:50:39 5 | ''' 6 | import jittor as jt 7 | import os 8 | import sys 9 | from jittor import nn,Var 10 | from jittor import Function 11 | sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) 12 | from jittor_geometric.data import CSC, CSR 13 | module_path = os.path.dirname(__file__) 14 | src = os.path.join(module_path, "cpp/aggregate_op.cc") 15 | header = os.path.join(module_path, "cpp/aggregate_op.h") 16 | 17 | aggregate_op = jt.compile_custom_ops((src, header)) 18 | # Run the test 19 | class AggregateFunc(Function): 20 | def execute(self,x,csc,csr,edge_weight): 21 | self.csc=csc 22 | self.csr=csr 23 | self.weight=edge_weight 24 | if isinstance(edge_weight, Var)==False: 25 | edge_weight=csc.edge_weight 26 | indices=csc.row_indices 27 | offset=csc.column_offset 28 | output=jt.zeros_like(x) 29 | aggregate_op.aggregate(output,x,indices,offset,edge_weight,True).fetch_sync() 30 | return output 31 | 32 | def grad(self, grad_output): 33 | if isinstance(self.weight, Var)==False: 34 | edge_weight=self.csr.edge_weight 35 | else: 36 | edge_weight=self.weight 37 | indices=self.csr.column_indices 38 | offset=self.csr.row_offset 39 | output_grad=jt.zeros_like(grad_output) 40 | aggregate_op.aggregate(output_grad,grad_output,indices,offset,edge_weight,False).fetch_sync() 41 | return output_grad,None,None 42 | 43 | ''' 44 | description: This function performs aggregation on the vertex embedding matrix using CSC (Compressed Sparse Column) 45 | and CSR (Compressed Sparse Row) representations of the graph 46 | param {*} x The vertex embedding matrix of shape (v_num, dim), where `v_num` is the number of vertices and `dim` is the dimension of the embeddings. 47 | param {*} csc 48 | param {*} csr 49 | return {*} 50 | author: xuchaoxin 51 | ''' 52 | def aggregateWithWeight(x,csc,csr,weight=None): 53 | out = AggregateFunc.apply(x,csc,csr,weight) 54 | return out -------------------------------------------------------------------------------- /jittor_geometric/ops/cootocsc.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Description: Convert COO to CSC 3 | Author: lusz 4 | Date: 2024-06-21 20:20:48 5 | ''' 6 | 7 | import jittor as jt 8 | import os 9 | import sys 10 | from jittor import nn 11 | from jittor import Function 12 | sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) 13 | from jittor_geometric.data import CSC 14 | module_path = os.path.dirname(__file__) 15 | src = os.path.join(module_path, "cpp/cootocsc_op.cc") 16 | header = os.path.join(module_path, "cpp/cootocsc_op.h") 17 | 18 | cootocsc_op = jt.compile_custom_ops((src, header)) 19 | 20 | 21 | ''' 22 | description: Converts a graph from COO (Coordinate) format to CSC (Compressed Sparse Row) format. 23 | param {*} edge_index(Var): The indices of the edges in the COO format. It is expected to be a 2D Var where each column represents an edge, with the first row containing source nodes and the second row containing destination nodes. 24 | param {*} edge_weight(Var): The weights of the edges in the COO format. It is a 1D Var where each element represents the weight of the corresponding edge. 25 | param {*} v_num(int): The number of vertices in the graph. 26 | return {*}: Returns a CSC representation of the graph, which includes column indices, row offsets, and edge weights. 27 | author: lusz 28 | ''' 29 | def cootocsc(edge_index,edge_weight,v_num): 30 | e_num=jt.size(edge_weight,0) 31 | csc_edge_weight=jt.zeros(e_num) 32 | row_indices = jt.zeros((e_num,), dtype='int32') 33 | column_offset = jt.zeros((v_num+1,), dtype='int32') 34 | cootocsc_op.cootocsc(edge_index, edge_weight, row_indices, column_offset, csc_edge_weight, v_num).fetch_sync() 35 | csc=CSC(row_indices,column_offset,csc_edge_weight) 36 | return csc -------------------------------------------------------------------------------- /jittor_geometric/ops/cootocsr.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Description: Convert COO to CSR 3 | Author: lusz 4 | Date: 2024-06-21 19:40:07 5 | ''' 6 | import jittor as jt 7 | import os 8 | import sys 9 | from jittor import nn 10 | from jittor import Function 11 | sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) 12 | from jittor_geometric.data import CSR 13 | module_path = os.path.dirname(__file__) 14 | src = os.path.join(module_path, "cpp/cootocsr_op.cc") 15 | header = os.path.join(module_path, "cpp/cootocsr_op.h") 16 | 17 | cootocsr_op = jt.compile_custom_ops((src, header)) 18 | 19 | 20 | ''' 21 | description: Converts a graph from COO (Coordinate) format to CSR (Compressed Sparse Row) format. 22 | param {*} edge_index(Var): The indices of the edges in the COO format. It is expected to be a 2D Var where each column represents an edge, with the first row containing source nodes and the second row containing destination nodes. 23 | param {*} edge_weight(Var): The weights of the edges in the COO format. It is a 1D Var where each element represents the weight of the corresponding edge. 24 | param {*} v_num(int): The number of vertices in the graph. 25 | return {*}: Returns a CSR representation of the graph, which includes column indices, row offsets, and edge weights. 26 | author: lusz 27 | ''' 28 | def cootocsr(edge_index,edge_weight,v_num): 29 | e_num=jt.size(edge_weight,0) 30 | csr_edge_weight=jt.zeros(e_num) 31 | column_indices = jt.zeros((e_num,), dtype='int32') 32 | row_offset = jt.zeros((v_num+1,), dtype='int32') 33 | cootocsr_op.cootocsr(edge_index, edge_weight, column_indices, row_offset, csr_edge_weight, v_num).fetch_sync() 34 | csr=CSR(column_indices,row_offset,csr_edge_weight) 35 | return csr -------------------------------------------------------------------------------- /jittor_geometric/ops/cpp/aggregate_op.h: -------------------------------------------------------------------------------- 1 | /* 2 | * @Description: 3 | * @Author: lusz 4 | * @Date: 2024-06-21 14:14:12 5 | */ 6 | #pragma once 7 | #include "op.h" 8 | #include 9 | #include 10 | #include 11 | namespace jittor { 12 | 13 | struct AggregateOp : Op { 14 | Var* x; 15 | Var* outputVar; 16 | Var* indices; 17 | Var* offset; 18 | Var* weight; 19 | bool forward; 20 | Var* output; 21 | AggregateOp(Var* outputVar, Var* x_,Var* indices_,Var* offset_,Var* weight_,bool forward_); 22 | const char* name() const override { return "aggregate"; } 23 | DECLARE_jit_run; 24 | }; 25 | 26 | } // jittor -------------------------------------------------------------------------------- /jittor_geometric/ops/cpp/cootocsc_op.cc: -------------------------------------------------------------------------------- 1 | /* 2 | * @Description: 3 | * @Author: lusz 4 | * @Date: 2024-06-21 20:20:17 5 | */ 6 | #include "var.h" 7 | #include "cootocsc_op.h" 8 | 9 | namespace jittor { 10 | #ifndef JIT 11 | CootocscOp::CootocscOp(Var* edge_index_, Var* coo_edge_weight_, Var* row_indices_, Var* column_offset_, Var* csc_edge_weight_, int v_num_) : 12 | edge_index(edge_index_), coo_edge_weight(coo_edge_weight_), row_indices(row_indices_), column_offset(column_offset_), csc_edge_weight(csc_edge_weight_),v_num(v_num_) { 13 | flags.set(NodeFlags::_cpu, 1); 14 | output = create_output(nullptr, coo_edge_weight->dtype()); 15 | } 16 | 17 | void CootocscOp::jit_prepare(JK& jk) { 18 | add_jit_define(jk, "T", coo_edge_weight->dtype()); 19 | add_jit_define(jk, "Tint", edge_index->dtype()); 20 | } 21 | 22 | #else // JIT 23 | void CootocscOp::jit_run() { 24 | Tint max_threads = std::thread::hardware_concurrency(); 25 | auto* __restrict__ e_x = edge_index->ptr(); 26 | auto* __restrict__ e_w = coo_edge_weight->ptr(); 27 | auto* __restrict__ e_wr = csc_edge_weight->ptr(); 28 | auto* __restrict__ r_i = row_indices->ptr(); 29 | auto* __restrict__ col_off = column_offset->ptr(); 30 | 31 | Tint edge_size = edge_index->shape[1]; 32 | // #pragma omp parallel for num_threads(max_threads) schedule(guided) 33 | for (Tint i = 0; i < edge_size; i++) { 34 | __sync_fetch_and_add(&col_off[e_x[i + edge_size] + 1], 1); 35 | } 36 | 37 | for (Tint i = 0; i < v_num; ++i) { 38 | col_off[i + 1] += col_off[i]; 39 | } 40 | 41 | Tint* vertex_index = (Tint*) calloc(v_num, sizeof(Tint)); 42 | // #pragma omp parallel for num_threads(max_threads) schedule(guided) 43 | for (Tint i = 0; i < edge_size; i++) { 44 | Tint src = e_x[i]; 45 | Tint dst = e_x[i + edge_size]; 46 | Tint index = __sync_fetch_and_add((Tint *)&vertex_index[dst], 1); 47 | __sync_fetch_and_add((Tint *)&index, col_off[dst]); 48 | // index += col_off[dst]; 49 | r_i[index] = src; 50 | e_wr[index] = e_w[i]; 51 | } 52 | std::free(vertex_index); 53 | } 54 | #endif // JIT 55 | 56 | } // jittor -------------------------------------------------------------------------------- /jittor_geometric/ops/cpp/cootocsc_op.h: -------------------------------------------------------------------------------- 1 | /* 2 | * @Description: 3 | * @Author: lusz 4 | * @Date: 2024-06-21 20:20:26 5 | */ 6 | #pragma once 7 | #include "op.h" 8 | #include 9 | #include 10 | #include 11 | namespace jittor { 12 | 13 | struct CootocscOp : Op { 14 | Var* row_indices; 15 | Var* column_offset; 16 | Var* csc_edge_weight; // CSC 17 | 18 | Var* edge_index; 19 | Var* coo_edge_weight; // COO 20 | 21 | Var* output; 22 | int v_num; 23 | 24 | CootocscOp(Var* edge_index_, Var* coo_edge_weight_, Var* row_indices_, Var* column_offset_, Var* csc_edge_weight_, int v_num_); 25 | const char* name() const override { return "cootocsc"; } 26 | DECLARE_jit_run; 27 | }; 28 | 29 | } // jittor -------------------------------------------------------------------------------- /jittor_geometric/ops/cpp/cootocsr_op.cc: -------------------------------------------------------------------------------- 1 | /* 2 | * @Description: 3 | * @Author: lusz 4 | * @Date: 2024-06-21 20:20:26 5 | */ 6 | #include "var.h" 7 | #include "cootocsr_op.h" 8 | 9 | 10 | namespace jittor { 11 | #ifndef JIT 12 | CootocsrOp::CootocsrOp(Var* edge_index_,Var* coo_edge_weight_,Var* column_indices_,Var* row_offset_,Var* csr_edge_weight_,int v_num_) : 13 | edge_index(edge_index_), coo_edge_weight(coo_edge_weight_),column_indices(column_indices_), row_offset(row_offset_),csr_edge_weight(csr_edge_weight_),v_num(v_num_){ 14 | flags.set(NodeFlags::_cpu, 1); 15 | output = create_output(nullptr,coo_edge_weight->dtype()); 16 | } 17 | 18 | void CootocsrOp::jit_prepare(JK& jk) { 19 | add_jit_define(jk, "T", coo_edge_weight->dtype()); 20 | add_jit_define(jk, "Tint", edge_index->dtype()); 21 | } 22 | 23 | #else // JIT 24 | void CootocsrOp::jit_run() { 25 | Tint max_threads = std::thread::hardware_concurrency(); 26 | auto* __restrict__ e_x = edge_index->ptr(); 27 | auto* __restrict__ e_w = coo_edge_weight->ptr(); 28 | auto* __restrict__ e_wr = csr_edge_weight->ptr(); 29 | auto* __restrict__ col_indices = column_indices->ptr(); 30 | auto* __restrict__ row_off = row_offset->ptr(); 31 | 32 | Tint edge_size = edge_index->shape[1]; 33 | // Initialize row_offset 34 | // #pragma omp parallel for num_threads(max_threads) schedule(guided) 35 | for (int i = 0; i < edge_size; i++) { 36 | __sync_fetch_and_add(&row_off[e_x[i] + 1], 1); 37 | } 38 | 39 | for (int i = 0; i < v_num; i++) { 40 | row_off[i + 1] += row_off[i]; 41 | } 42 | 43 | Tint* vertex_index = (Tint*) calloc(v_num, sizeof(Tint)); 44 | // #pragma omp parallel for num_threads(max_threads) schedule(guided) 45 | for (int i = 0; i < edge_size; i++) { 46 | Tint src = e_x[i]; 47 | Tint dst = e_x[i + edge_size]; 48 | Tint index = __sync_fetch_and_add((Tint *)&vertex_index[src], 1); 49 | __sync_fetch_and_add((Tint *)&index, row_off[src]); 50 | // index += row_off[src]; 51 | col_indices[index] = dst; 52 | e_wr[index] = e_w[i]; 53 | } 54 | std::free(vertex_index); // free不在jittor命名空间里 55 | 56 | 57 | } 58 | #endif // JIT 59 | 60 | } // jittor -------------------------------------------------------------------------------- /jittor_geometric/ops/cpp/cootocsr_op.h: -------------------------------------------------------------------------------- 1 | /* 2 | * @Author: lusz 3 | * @Date: 2024-06-20 21:40:53 4 | * @Description: 5 | */ 6 | #pragma once 7 | #include "op.h" 8 | #include 9 | #include 10 | #include 11 | namespace jittor { 12 | 13 | struct CootocsrOp : Op { 14 | Var* column_indices; 15 | Var* row_offset; 16 | Var* csr_edge_weight; // CSR 17 | 18 | Var* edge_index; 19 | Var* coo_edge_weight;// COO 20 | 21 | Var* output; 22 | int v_num; 23 | 24 | CootocsrOp(Var* edge_index_,Var* coo_edge_weight_,Var* column_indices_,Var* row_offset_,Var* csr_edge_weight_,int v_num_); 25 | const char* name() const override { return "cootocsr"; } 26 | DECLARE_jit_run; 27 | }; 28 | } // jittor -------------------------------------------------------------------------------- /jittor_geometric/ops/cpp/edgesoftmax_op.h: -------------------------------------------------------------------------------- 1 | /* 2 | * @Description: 3 | * @Author: lusz 4 | * @Date: 2024-07-03 13:50:18 5 | */ 6 | 7 | #pragma once 8 | #include "op.h" 9 | #include 10 | #include 11 | #include 12 | #include 13 | namespace jittor { 14 | 15 | struct EdgesoftmaxOp : Op { 16 | Var* x; 17 | Var* outputVar; 18 | Var* indices; 19 | Var* offset; 20 | Var* edge_weight; 21 | Var* output; 22 | EdgesoftmaxOp(Var* outputVar_, Var* x_, Var* indices_,Var* offset_); 23 | const char* name() const override { return "edgesoftmax"; } 24 | DECLARE_jit_run; 25 | }; 26 | 27 | } // jittor -------------------------------------------------------------------------------- /jittor_geometric/ops/cpp/edgesoftmaxbackward_op.h: -------------------------------------------------------------------------------- 1 | /* 2 | * @Description: 3 | * @Author: lusz 4 | * @Date: 2024-07-04 16:16:31 5 | */ 6 | 7 | 8 | #pragma once 9 | #include "op.h" 10 | #include 11 | #include 12 | #include 13 | #include 14 | namespace jittor { 15 | 16 | struct EdgesoftmaxbackwardOp : Op { 17 | Var* x; 18 | Var* outputVar; 19 | Var* y; 20 | Var* indices; 21 | Var* offset; 22 | Var* edge_weight; 23 | Var* output; 24 | EdgesoftmaxbackwardOp(Var* outputVar_, Var* x_,Var* y_, Var* indices_,Var* offset_); 25 | const char* name() const override { return "edgesoftmaxbackward"; } 26 | DECLARE_jit_run; 27 | }; 28 | 29 | } // jittor -------------------------------------------------------------------------------- /jittor_geometric/ops/cpp/edgetovertex_op.cc: -------------------------------------------------------------------------------- 1 | /* 2 | * @Description: 3 | * @Author: lusz 4 | * @Date: 2024-06-21 14:14:03 5 | */ 6 | #include "var.h" 7 | #include "edgetovertex_op.h" 8 | 9 | 10 | namespace jittor { 11 | #ifndef JIT 12 | 13 | EdgetovertexOp::EdgetovertexOp(Var* outputVar_, Var* x_, Var* indices_,Var* offset_, int flag_) : 14 | outputVar(outputVar_), x(x_), indices(indices_),offset(offset_),flag(flag_) { 15 | flags.set(NodeFlags::_cpu, 1); 16 | flags.set(NodeFlags::_cuda, 1); 17 | output = create_output(nullptr, x->dtype()); 18 | } 19 | 20 | void EdgetovertexOp::jit_prepare(JK& jk) { 21 | add_jit_define(jk, "T", x->dtype()); 22 | add_jit_define(jk, "Tint", indices->dtype()); 23 | } 24 | 25 | #else // JIT 26 | #ifdef JIT_cpu 27 | void EdgetovertexOp::jit_run() { 28 | auto* __restrict__ out_ptr = outputVar->ptr(); 29 | auto* __restrict__ x_ptr = x->ptr(); 30 | auto* __restrict__ i_ptr=indices->ptr(); 31 | auto* __restrict__ o_ptr=offset->ptr(); 32 | int e_num=indices->shape[0]; 33 | int v_num=offset->shape[0]-1; 34 | int feature_dim=x->shape[1]; 35 | int node; 36 | if(flag==0){ 37 | for(int vtx=0;vtx 59 | __global__ void gather_msg_to_dst( T_v* dst_feature, T_v* message, 60 | const T_l *row_indices,const T_l *column_offset, 61 | T_l batch_size_, T_l feature_size_){ 62 | int threadId = blockIdx.x *blockDim.x + threadIdx.x; 63 | for(long i=threadId;iptr(); 79 | auto* __restrict__ x_ptr = x->ptr(); 80 | auto* __restrict__ i_ptr = indices->ptr(); 81 | auto* __restrict__ o_ptr = offset->ptr(); 82 | Tint v_num=outputVar->shape[0]; 83 | Tint feature_dim=x->shape[1]; 84 | // std::cout<<<>>( 93 | out_ptr, x_ptr, i_ptr, o_ptr, 94 | v_num, feature_dim); 95 | } 96 | #endif //cuda 97 | #endif // JIT 98 | 99 | } // jittor -------------------------------------------------------------------------------- /jittor_geometric/ops/cpp/edgetovertex_op.h: -------------------------------------------------------------------------------- 1 | /* 2 | * @Description: 3 | * @Author: lusz 4 | * @Date: 2024-06-28 17:08:33 5 | */ 6 | 7 | #pragma once 8 | #include "op.h" 9 | #include 10 | #include 11 | #include 12 | namespace jittor { 13 | 14 | struct EdgetovertexOp : Op { 15 | Var* x; 16 | Var* outputVar; 17 | Var* indices; 18 | Var* offset; 19 | Var* output; 20 | int flag; 21 | EdgetovertexOp(Var* outputVar_, Var* x_, Var* indices_,Var* offset_, int flag_); 22 | const char* name() const override { return "edgetovertex"; } 23 | DECLARE_jit_run; 24 | }; 25 | 26 | } // jittor -------------------------------------------------------------------------------- /jittor_geometric/ops/cpp/scattertoedge_op.h: -------------------------------------------------------------------------------- 1 | /* 2 | * @Description: 3 | * @Author: lusz 4 | * @Date: 2024-06-28 17:08:33 5 | */ 6 | 7 | #pragma once 8 | #include "op.h" 9 | #include 10 | #include 11 | #include 12 | namespace jittor { 13 | 14 | struct ScattertoedgeOp : Op { 15 | Var* x; 16 | Var* outputVar; 17 | Var* edge_weight; 18 | Var* indices; 19 | Var* offset; 20 | bool with_weight; 21 | Var* output; 22 | int flag; 23 | ScattertoedgeOp(Var* outputVar_, Var* x_, Var* indices_,Var* offset_,Var* edge_weight_,bool with_weight_,int flag_); 24 | ScattertoedgeOp(Var* outputVar_, Var* x_, Var* indices_,Var* offset_, bool with_weight_,int flag_); 25 | const char* name() const override { return "scattertoedge"; } 26 | DECLARE_jit_run; 27 | }; 28 | 29 | } // jittor -------------------------------------------------------------------------------- /jittor_geometric/ops/cpp/spmmcoo_op.h: -------------------------------------------------------------------------------- 1 | /* 2 | * @Description: 3 | * @Author: lusz 4 | * @Date: 2024-11-10 21:15:59 5 | */ 6 | #pragma once 7 | #include "op.h" 8 | #include "cusparse.h" 9 | #include 10 | #include 11 | #include "helper_cuda.h" 12 | namespace jittor { 13 | 14 | struct SpmmcooOp : Op { 15 | Var* x; 16 | Var* outputVar; 17 | Var* row_indices; 18 | Var* col_indices; 19 | Var* value; 20 | Var* output; 21 | int A_row; 22 | int A_col; 23 | SpmmcooOp(Var* outputVar_, Var* x_, Var* row_indices_,Var* col_indices_,Var* value_,int A_row,int A_col); 24 | const char* name() const override { return "spmmcoo"; } 25 | DECLARE_jit_run; 26 | }; 27 | 28 | } // jittor -------------------------------------------------------------------------------- /jittor_geometric/ops/cpp/spmmcsr_op.h: -------------------------------------------------------------------------------- 1 | /* 2 | * @Description: 3 | * @Author: lusz 4 | * @Date: 2024-11-03 15:03:08 5 | */ 6 | #pragma once 7 | #include "op.h" 8 | 9 | #include "cusparse.h" 10 | namespace jittor { 11 | 12 | struct SpmmcsrOp : Op { 13 | Var* x; 14 | Var* outputVar; 15 | Var* col_indices; 16 | Var* row_offset; 17 | Var* value; 18 | NanoString dtype; 19 | Var* output; 20 | int A_row; 21 | int A_col; 22 | SpmmcsrOp(Var* outputVar_, Var* x_, Var* col_indices_,Var* value_,Var* row_offset_,int A_row,int A_col); 23 | const char* name() const override { return "spmmcsr"; } 24 | DECLARE_jit_run; 25 | }; 26 | 27 | } // jittor -------------------------------------------------------------------------------- /jittor_geometric/ops/cpp/toundirected_op.cc: -------------------------------------------------------------------------------- 1 | /* 2 | * @Description: 3 | * @Author: lusz 4 | * @Date: 2024-06-23 16:06:20 5 | */ 6 | #include "var.h" 7 | #include "toundirected_op.h" 8 | 9 | 10 | namespace jittor { 11 | #ifndef JIT 12 | ToundirectedOp::ToundirectedOp(Var* edge_index_,Var* edge_attr_,int num_edges_,int num_nodes_,Var* new_edge_index_,Var* new_edge_attr_,NanoString dtype_): 13 | edge_index(edge_index_),edge_attr(edge_attr_),num_edges(num_edges_),num_nodes(num_nodes_),new_edge_index(new_edge_index_),new_edge_attr(new_edge_attr_),dtype(dtype_){ 14 | flags.set(NodeFlags::_cpu, 1); 15 | output = create_output(nullptr,dtype); 16 | } 17 | 18 | void ToundirectedOp::jit_prepare(JK& jk) { 19 | add_jit_define(jk, "T", dtype); 20 | 21 | } 22 | #else // JIT 23 | struct Edge { 24 | int row; 25 | int col; 26 | T data; 27 | }; 28 | bool edge_less(const Edge& e1, const Edge& e2) { 29 | if (e1.row != e2.row) 30 | return e1.row < e2.row; 31 | return e1.col < e2.col; 32 | } 33 | void ToundirectedOp::jit_run() { 34 | auto* __restrict__ e_x = edge_index->ptr(); 35 | auto* __restrict__ e_a = edge_attr->ptr(); 36 | std::vector edges; 37 | for (int i = 0; i < num_edges; ++i) { 38 | edges.push_back({ e_x[i], e_x[i+num_edges], e_a[i] }); 39 | edges.push_back({ e_x[i+num_edges], e_x[i], e_a[i] }); 40 | } 41 | std::sort(edges.begin(), edges.end(), edge_less); 42 | edges.erase(std::unique(edges.begin(), edges.end(), [](const Edge& e1, const Edge& e2) { 43 | return e1.row == e2.row && e1.col == e2.col; 44 | }), edges.end()); 45 | NanoVector index_shape; 46 | NanoVector attr_shape; 47 | int length=edges.size(); 48 | index_shape.push_back(2); 49 | index_shape.push_back(length); 50 | attr_shape.push_back(1); 51 | attr_shape.push_back(length); 52 | new_edge_index->set_shape(index_shape); 53 | new_edge_attr->set_shape(attr_shape); 54 | auto* __restrict__ n_e_x = new_edge_index->ptr(); 55 | auto* __restrict__ n_e_a = new_edge_attr->ptr(); 56 | for(int i=0 ; i 10 | #include 11 | #include 12 | #include 13 | #include 14 | namespace jittor { 15 | 16 | struct ToundirectedOp : Op { 17 | Var* output; 18 | Var* edge_index; 19 | Var* edge_attr; 20 | Var* new_edge_index; 21 | Var* new_edge_attr; 22 | int num_edges; 23 | int num_nodes; 24 | NanoString dtype; 25 | ToundirectedOp(Var* edge_index_,Var* edge_attr_,int num_edges_,int num_nodes_,Var* new_edge_index_,Var* new_edge_attr_,NanoString dtype_=ns_float32); 26 | const char* name() const override { return "toundirected"; } 27 | DECLARE_jit_run; 28 | }; 29 | 30 | } // jittor -------------------------------------------------------------------------------- /jittor_geometric/ops/edgesoftmax.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Description: 3 | Author: lusz 4 | Date: 2024-07-03 13:50:35 5 | ''' 6 | import jittor as jt 7 | import os 8 | import sys 9 | from jittor import nn 10 | from jittor import Function 11 | sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) 12 | from jittor_geometric.data import CSC, CSR 13 | 14 | 15 | # class EdgeSoftMaxFunc(Function): 16 | # def execute(self,x,csc): 17 | # v_num = jt.size(csc.column_offset, 0) - 1 18 | # out = None 19 | # for vtx in range(v_num): 20 | # start = csc.column_offset[vtx] 21 | # end = csc.column_offset[vtx + 1] 22 | # slice = x[start:end] 23 | # softmax_slice = jt.nn.softmax(slice) 24 | # if out is None: 25 | # out = softmax_slice 26 | # else: 27 | # out = jt.contrib.concat([out, softmax_slice], dim=0) 28 | # self.y = out 29 | # self.csc = csc 30 | # return out 31 | 32 | # def grad(self, grad_output): 33 | # # print(grad_output) 34 | # v_num = jt.size(self.csc.column_offset, 0) - 1 35 | # output_grad = None 36 | # for vtx in range(v_num): 37 | # start = self.csc.column_offset[vtx] 38 | # end = self.csc.column_offset[vtx + 1] 39 | # slice = grad_output[start:end] 40 | # imr = self.y[start:end] 41 | # d_o = imr * slice - imr * (slice.sum() * imr) 42 | # if output_grad is None: 43 | # output_grad = d_o 44 | # else: 45 | # output_grad = jt.contrib.concat([output_grad, d_o], dim=0) 46 | # return output_grad, None 47 | 48 | # def EdgeSoftmax(x,csc): 49 | # out = EdgeSoftMaxFunc.apply(x,csc) 50 | # return out 51 | 52 | module_path = os.path.dirname(__file__) 53 | src = os.path.join(module_path, "cpp/edgesoftmax_op.cc") 54 | header = os.path.join(module_path, "cpp/edgesoftmax_op.h") 55 | edge_softmax_op = jt.compile_custom_ops((src, header)) 56 | 57 | src_b = os.path.join(module_path, "cpp/edgesoftmaxbackward_op.cc") 58 | header_b = os.path.join(module_path, "cpp/edgesoftmaxbackward_op.h") 59 | edge_softmax_backward_op = jt.compile_custom_ops((src_b, header_b)) 60 | 61 | 62 | class EdgeSoftmaxFunc(Function): 63 | def execute(self,x,csc): 64 | self.x=x 65 | self.csc=csc 66 | output=jt.zeros_like(x) 67 | edge_softmax_op.edgesoftmax(output,x,csc.row_indices,csc.column_offset) 68 | self.y=output 69 | return output 70 | 71 | def grad(self, grad_output): 72 | output_grad=jt.zeros_like(grad_output) 73 | edge_softmax_backward_op.edgesoftmaxbackward(output_grad,grad_output,self.y,self.csc.row_indices,self.csc.column_offset) 74 | return output_grad,None 75 | 76 | 77 | 78 | def EdgeSoftmax(x,csc): 79 | out = EdgeSoftmaxFunc.apply(x,csc) 80 | return out -------------------------------------------------------------------------------- /jittor_geometric/ops/repeat_interleave.py: -------------------------------------------------------------------------------- 1 | import jittor as jt 2 | def repeat_interleave(x, repeats, dim=None): 3 | if isinstance(repeats, int): 4 | if dim is None: 5 | x = x.reshape(-1) 6 | dim = 0 7 | if dim < 0: 8 | dim += x.ndim 9 | 10 | # 计算目标形状 11 | tar_shape = list(x.shape) 12 | tar_shape[dim] = tar_shape[dim] * repeats 13 | dims = [] 14 | for i in range(len(tar_shape)): 15 | if dim == i: 16 | dims.append(f"i{i}/{repeats}") 17 | else: 18 | dims.append(f"i{i}") 19 | return x.reindex(tar_shape, dims) 20 | 21 | elif isinstance(repeats, jt.Var): 22 | # 检查 repeats 在指定维度上的大小是否与输入张量一致 23 | if dim is None: 24 | raise ValueError("When repeats is a jt.Var, dim must be specified.") 25 | if dim < 0: 26 | dim += x.ndim 27 | if repeats.shape[0] != x.shape[dim]: 28 | raise ValueError(f"repeats must have the same size as input along dimension {dim}.") 29 | 30 | result = [] 31 | # 对指定维度进行逐个元素重复 32 | for i in range(x.shape[dim]): 33 | # 提取切片,获取第 i 个元素 34 | slice_obj = [slice(None)] * x.ndim 35 | slice_obj[dim] = slice(i, i + 1) 36 | sliced_x = x[tuple(slice_obj)] 37 | 38 | expanded_x = sliced_x 39 | for _ in range(repeats[i].item() - 1): # 重复 repeats[i] - 1 次 40 | expanded_x = jt.concat([expanded_x, sliced_x], dim=dim) # 沿指定维度拼接 41 | result.append(expanded_x) # 将扩展后的元素加入结果列表 42 | result = jt.concat(result, dim=dim) 43 | 44 | return result 45 | else: 46 | raise ValueError("repeats should be either int or jt.Var") 47 | 48 | 49 | -------------------------------------------------------------------------------- /jittor_geometric/ops/saparse_ops.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Description: 3 | Author: lusz 4 | Date: 2025-01-11 13:40:29 5 | ''' 6 | import jittor as jt 7 | from jittor_geometric.data import CSC, CSR 8 | 9 | def from_nodes(csc, nodes): 10 | """ 11 | Given a CSC structure and a set of input nodes, find all the neighbor nodes. 12 | 13 | Parameters: 14 | csc (CSC): Compressed Sparse Column structure. 15 | nodes (Var): Input node IDs (Var type). 16 | 17 | Returns: 18 | Var: A Var containing all neighbor nodes corresponding to the input nodes. 19 | """ 20 | if csc is not None: 21 | # Use CSC (column-based) to find neighbors 22 | column_offset = csc.column_offset 23 | row_indices = csc.row_indices 24 | 25 | # Extract neighbors for each input node 26 | neighbors = [] 27 | for i in range(nodes.shape[0]): 28 | node = nodes[i] 29 | start = column_offset[node] 30 | end = column_offset[node + 1] 31 | neighbors.append(row_indices[start:end]) 32 | 33 | else: 34 | raise ValueError("CSC structure must be provided.") 35 | 36 | # Flatten the list of neighbors and remove duplicates 37 | neighbors = jt.Var(jt.contrib.concat([n for n in neighbors])) 38 | # unique_neighbors = jt.unique(neighbors) 39 | 40 | return neighbors 41 | 42 | 43 | def to_nodes(csr, nodes): 44 | """ 45 | Given a CSR structure and a set of input nodes, find all the neighbor nodes. 46 | 47 | Parameters: 48 | csr (CSR): Compressed Sparse Row structure. 49 | nodes (Var): Input node IDs (Var type). 50 | 51 | Returns: 52 | Var: A Var containing all neighbor nodes corresponding to the input nodes. 53 | """ 54 | if csr is not None: 55 | # Use CSR to find neighbors 56 | row_offset = csr.row_offset 57 | column_indices = csr.column_indices 58 | neighbors = [] 59 | for i in range(nodes.shape[0]): 60 | node = nodes[i] 61 | start = row_offset[node] 62 | end = row_offset[node + 1] 63 | neighbors.append(column_indices[start:end]) 64 | 65 | else: 66 | raise ValueError("CSR structure must be provided.") 67 | 68 | # Flatten the list of neighbors and remove duplicates 69 | neighbors = jt.Var(jt.contrib.concat([n for n in neighbors])) 70 | # unique_neighbors = jt.unique(neighbors) 71 | 72 | return neighbors 73 | 74 | -------------------------------------------------------------------------------- /jittor_geometric/ops/scatterToEdge.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Description: COO version of scatter to edge for GAT 3 | Author: lusz 4 | Date: 2024-06-28 17:10:12 5 | ''' 6 | 7 | import jittor as jt 8 | import os 9 | import sys 10 | from jittor import nn 11 | from jittor import Function 12 | sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) 13 | from jittor_geometric.data import CSC, CSR 14 | module_path = os.path.dirname(__file__) 15 | src = os.path.join(module_path, "cpp/scattertoedge_op.cc") 16 | header = os.path.join(module_path, "cpp/scattertoedge_op.h") 17 | srcb = os.path.join(module_path, "cpp/edgetovertex_op.cc") 18 | headerb = os.path.join(module_path, "cpp/edgetovertex_op.h") 19 | scatter_op = jt.compile_custom_ops((src, header)) 20 | scatter_backward_op = jt.compile_custom_ops((srcb, headerb)) 21 | 22 | class ScatterToEdgeFunc(Function): 23 | def execute(self,x,csc,flow): 24 | self.flow=flow 25 | self.csc=csc 26 | # output dim 27 | e_num=jt.size(csc.row_indices,0) 28 | feature_dim=jt.size(x,1) 29 | v_num=jt.size(x,0) 30 | self.e_num=e_num 31 | self.v_num=v_num 32 | self.feature_dim=feature_dim 33 | output=jt.zeros(e_num,feature_dim) 34 | flag=1 35 | if flow=="src": 36 | flag=0 37 | self.flag=flag 38 | scatter_op.scattertoedge(output,x,csc.row_indices,csc.column_offset,False,flag).fetch_sync() 39 | 40 | return output 41 | 42 | def grad(self, grad_output): 43 | output_grad=jt.zeros(self.v_num,self.feature_dim) 44 | csc=self.csc 45 | scatter_backward_op.edgetovertex(output_grad,grad_output,csc.row_indices,csc.column_offset,self.flag).fetch_sync() 46 | return output_grad,None,None 47 | 48 | 49 | def ScatterToEdge(x,csc,flow): 50 | out = ScatterToEdgeFunc.apply(x,csc,flow) 51 | return out -------------------------------------------------------------------------------- /jittor_geometric/ops/scatterToVertex.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Description: 3 | Author: lusz 4 | Date: 2024-07-05 17:22:55 5 | ''' 6 | 7 | 8 | import jittor as jt 9 | import os 10 | import sys 11 | from jittor import nn 12 | from jittor import Function 13 | sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) 14 | from jittor_geometric.data import CSC, CSR 15 | module_path = os.path.dirname(__file__) 16 | src = os.path.join(module_path, "cpp/edgetovertex_op.cc") 17 | header = os.path.join(module_path, "cpp/edgetovertex_op.h") 18 | srcb = os.path.join(module_path, "cpp/scattertoedge_op.cc") 19 | headerb = os.path.join(module_path, "cpp/scattertoedge_op.h") 20 | scatter_op = jt.compile_custom_ops((src, header)) 21 | scatter_backward_op = jt.compile_custom_ops((srcb, headerb)) 22 | # Run the test 23 | class ScatterToVertexFunc(Function): 24 | def execute(self,x,csc,flow): 25 | self.flow=flow 26 | self.csc=csc 27 | # output dim 28 | # print(x.shape) 29 | # print(csc.row_indices.shape) 30 | # print(csc.column_offset.shape) 31 | e_num=jt.size(csc.row_indices,0) 32 | feature_dim=jt.size(x,1) 33 | v_num=jt.size(csc.column_offset,0)-1 34 | self.e_num=e_num 35 | self.v_num=v_num 36 | self.feature_dim=feature_dim 37 | output=jt.zeros(v_num,feature_dim) 38 | scatter_op.edgetovertex(output,x,csc.row_indices,csc.column_offset,1).fetch_sync() 39 | return output 40 | 41 | def grad(self, grad_output): 42 | output_grad=jt.zeros(self.e_num,self.feature_dim) 43 | csc=self.csc 44 | scatter_backward_op.scattertoedge(output_grad,grad_output,csc.row_indices,csc.column_offset,False,1).fetch_sync() 45 | return output_grad,None,None 46 | 47 | 48 | def ScatterToVertex(x,csc,flow): 49 | out = ScatterToVertexFunc.apply(x,csc,flow) 50 | return out -------------------------------------------------------------------------------- /jittor_geometric/ops/spmmcoo.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Description: 3 | Author: lusz 4 | Date: 2024-11-11 14:10:31 5 | ''' 6 | import jittor as jt 7 | import os 8 | import sys 9 | from jittor import nn 10 | from jittor import Function 11 | sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) 12 | module_path = os.path.dirname(__file__) 13 | # src = os.path.join(module_path, "cpp/spmmcoo_op.cc") 14 | # header = os.path.join(module_path, "cpp/spmmcoo_op.h") 15 | # spmmcoo_op = jt.compile_custom_ops((src, header)) 16 | from jittor.compile_extern import cusparse_ops # latest jittor 17 | # Run the test 18 | jt.flags.use_cuda=1 19 | class SpmmCooFunc(Function): 20 | def execute(self,x,edge_index,edge_weight,trans_A,trans_B): 21 | self.edge_index=edge_index 22 | row_indices=edge_index[0,:] 23 | col_indices=edge_index[1,:] 24 | self.row_indices=row_indices 25 | self.col_indices=col_indices 26 | self.edge_weight=edge_weight 27 | feature_dim=jt.size(x,1) 28 | v_num=jt.size(x,0) 29 | self.v_num=v_num 30 | self.feature_dim=feature_dim 31 | self.trans_A=trans_A 32 | self.trans_B=trans_B 33 | output=jt.zeros(v_num,feature_dim) 34 | cusparse_ops.cusparse_spmmcoo(output,x,row_indices,col_indices,edge_weight,v_num,v_num,trans_A,trans_B).fetch_sync() 35 | return output 36 | 37 | def grad(self, grad_output): 38 | output_grad=jt.zeros(self.v_num,self.feature_dim) 39 | cusparse_ops.cusparse_spmmcoo(output_grad,grad_output,self.row_indices,self.col_indices,self.edge_weight,self.v_num,self.v_num).fetch_sync() 40 | return output_grad,None,None 41 | 42 | 43 | def SpmmCoo(x,edge_index,edge_weight,trans_A=True,trans_B=False): 44 | out = SpmmCooFunc.apply(x,edge_index,edge_weight,trans_A,trans_B) 45 | return out -------------------------------------------------------------------------------- /jittor_geometric/ops/spmmcsr.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Description: 3 | Author: lusz 4 | Date: 2024-11-06 19:05:55 5 | ''' 6 | import jittor as jt 7 | import os 8 | import sys 9 | from jittor import nn 10 | from jittor import Function 11 | sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) 12 | from jittor_geometric.data import CSR 13 | module_path = os.path.dirname(__file__) 14 | from jittor.compile_extern import cusparse_ops 15 | # src = os.path.join(module_path, "cpp/spmmcsr_op.cc") 16 | # header = os.path.join(module_path, "cpp/spmmcsr_op.h") 17 | # spmmcsr_op = jt.compile_custom_ops((src, header)) 18 | # latest jittor 19 | # Run the test 20 | jt.flags.use_cuda=1 21 | class SpmmCsrFunc(Function): 22 | def execute(self,x,csr,trans_A,trans_B): 23 | self.csr=csr 24 | feature_dim=jt.size(x,1) 25 | v_num=jt.size(csr.row_offset,0)-1 26 | self.v_num=v_num 27 | self.feature_dim=feature_dim 28 | output=jt.zeros(v_num,feature_dim) 29 | self.trans_A=trans_A 30 | self.trans_B=trans_B 31 | cusparse_ops.cusparse_spmmcsr(output,x,csr.column_indices,csr.edge_weight,csr.row_offset,v_num,v_num,trans_A,trans_B).fetch_sync() 32 | # spmmcsr_op.spmmcsr(output,x,csr.column_indices,csr.edge_weight,csr.row_offset,v_num,v_num).fetch_sync() 33 | return output 34 | 35 | def grad(self, grad_output): 36 | output_grad=jt.zeros(self.v_num,self.feature_dim) 37 | cusparse_ops.cusparse_spmmcsr(output_grad,grad_output,self.csr.column_indices,self.csr.edge_weight,self.csr.row_offset,self.v_num,self.v_num,self.trans_A,self.trans_B).fetch_sync() 38 | # spmmcsr_op.spmmcsr(output_grad,grad_output,self.csr.column_indices,self.csr.edge_weight,self.csr.row_offset,self.v_num,self.v_num).fetch_sync() 39 | return output_grad,None 40 | 41 | 42 | def SpmmCsr(x,csr,trans_A=True,trans_B=False): 43 | out = SpmmCsrFunc.apply(x,csr,trans_A,trans_B) 44 | return out -------------------------------------------------------------------------------- /jittor_geometric/ops/toundirected.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Description: Converts the graph to an undirected graph 3 | Author: lusz 4 | Date: 2024-06-23 14:45:47 5 | ''' 6 | import jittor as jt 7 | import os 8 | import sys 9 | sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) 10 | from jittor_geometric.utils.num_nodes import maybe_num_nodes 11 | module_path = os.path.dirname(__file__) 12 | src = os.path.join(module_path, "cpp/toundirected_op.cc") 13 | header = os.path.join(module_path, "cpp/toundirected_op.h") 14 | toundirected_op = jt.compile_custom_ops((src, header)) 15 | 16 | def toUndirected(edge_index, edge_attr,num_nodes): 17 | num_edges=jt.size(edge_index,1) 18 | num_nodes = maybe_num_nodes(edge_index, num_nodes) 19 | new_edge_index=jt.zeros_like(edge_index) 20 | new_edge_attr=jt.zeros_like(edge_attr) 21 | toundirected_op.toundirected(edge_index,edge_attr,num_edges,num_nodes,new_edge_index,new_edge_attr,edge_attr.dtype) 22 | return new_edge_index,new_edge_attr 23 | -------------------------------------------------------------------------------- /jittor_geometric/partition/__init__.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Description: 3 | Author: lusz 4 | Date: 2024-11-15 14:13:33 5 | ''' 6 | from .chunk_manager import ChunkManager 7 | from .partition_graph import partition_graph 8 | 9 | __all__ = [ 10 | 'ChunkManager', 11 | ] 12 | 13 | classes = __all__ 14 | -------------------------------------------------------------------------------- /jittor_geometric/partition/partition_graph.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Author: lusz 3 | Date: 2024-11-15 15:28:04 4 | Description: Load and preprocess GNN datasets with relative paths. 5 | ''' 6 | import os.path as osp 7 | import argparse 8 | import pickle 9 | import os 10 | import sys 11 | import jittor as jt 12 | sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) 13 | #from jittor_geometric.data import GraphChunk,CSR 14 | from jittor import Var 15 | from jittor_geometric.partition.chunk_manager import ChunkManager 16 | from jittor_geometric.datasets import Planetoid, Amazon, WikipediaNetwork, OGBNodePropPredDataset, HeteroDataset, Reddit 17 | import jittor_geometric.transforms as T 18 | from pymetis import part_graph 19 | import numpy as np 20 | 21 | 22 | def partition_graph(dataset_name, num_parts, use_gdc=False): 23 | """Partition a graph dataset using METIS.""" 24 | script_dir = osp.dirname(osp.realpath(__file__)) 25 | path = osp.join(script_dir, '..', '..', 'data') 26 | 27 | # Load dataset 28 | if dataset_name in ['computers', 'photo']: 29 | dataset = Amazon(path, dataset_name, transform=T.NormalizeFeatures()) 30 | elif dataset_name in ['cora', 'citeseer', 'pubmed']: 31 | dataset = Planetoid(path, dataset_name, transform=T.NormalizeFeatures()) 32 | elif dataset_name in ['chameleon', 'squirrel']: 33 | dataset = WikipediaNetwork(path, dataset_name, geom_gcn_preprocess=False) 34 | elif dataset_name in ['ogbn-arxiv', 'ogbn-products', 'ogbn-papers100M']: 35 | dataset = OGBNodePropPredDataset(name=dataset_name, root=path) 36 | elif dataset_name in ['roman_empire', 'amazon_ratings', 'minesweeper', 'questions', 'tolokers']: 37 | dataset = HeteroDataset(path, dataset_name) 38 | elif dataset_name in ['reddit']: 39 | dataset = Reddit(os.path.join(path, 'Reddit')) 40 | data = dataset[0] 41 | edge_index = data.edge_index.numpy() 42 | num_nodes = data.x.shape[0] 43 | 44 | # METIS partition 45 | reorder_dir = osp.join(path, "reorder", f"{dataset_name}_{num_parts}part") 46 | chunk_manager = ChunkManager(output_dir=reorder_dir) 47 | partition = chunk_manager.metis_partition(edge_index, num_nodes, num_parts) 48 | partition = np.array(partition) 49 | os.makedirs(reorder_dir, exist_ok=True) 50 | binary_file_path = osp.join(reorder_dir, f"{dataset_name}_partition_{num_parts}.bin") 51 | if not osp.exists(binary_file_path): 52 | with open(binary_file_path, 'wb') as f: 53 | pickle.dump(partition, f) 54 | print("Partition file saved.") 55 | else: 56 | print("Partition file already exists. Skipping.") 57 | 58 | if __name__ == "__main__": 59 | parser = argparse.ArgumentParser() 60 | parser.add_argument('--use_gdc', action='store_true', help='Use GDC preprocessing.') 61 | parser.add_argument('--dataset', type=str, required=True, help='Name of the GNN dataset to load.') 62 | parser.add_argument('--num_parts', type=int, required=True, help='Partition number.') 63 | args = parser.parse_args() 64 | partition_graph(args.dataset, args.num_parts, args.use_gdc) -------------------------------------------------------------------------------- /jittor_geometric/tests/test_aggregate.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Author: lusz 3 | Date: 2024-06-21 10:21:33 4 | Description: 5 | ''' 6 | import jittor as jt 7 | import os 8 | import sys 9 | from jittor import nn 10 | from jittor import Function 11 | from jittor_geometric.data import CSC, CSR 12 | current_file_path = os.path.abspath(__file__) 13 | test_path = os.path.dirname(current_file_path) 14 | module_path = os.path.dirname(test_path) 15 | # print(module_path) 16 | src = os.path.join(module_path, "ops/cpp/aggregate_op.cc") 17 | header = os.path.join(module_path, "ops/cpp/aggregate_op.h") 18 | 19 | aggregate_op = jt.compile_custom_ops((src, header)) 20 | # Run the test 21 | class MyFunc(Function): 22 | def execute(self,x,csc,csr): 23 | self.csc=csc 24 | self.csr=csr 25 | edge_weight=csc.edge_weight 26 | indices=csc.row_indices 27 | offset=csc.column_offset 28 | dtype=edge_weight.dtype 29 | output=x 30 | aggregate_op.aggregate(output,x,indices,offset,edge_weight,True).fetch_sync() 31 | return output 32 | 33 | def grad(self, grad_output): 34 | edge_weight=self.csr.edge_weight 35 | indices=self.csr.column_indices 36 | offset=self.csr.row_offset 37 | dtype=edge_weight.dtype 38 | output_grad=grad_output 39 | aggregate_op.aggregate(output_grad,grad_output,indices,offset,edge_weight,False).fetch_sync() 40 | return output_grad,None,None 41 | 42 | jt.flags.lazy_execution = 0 43 | x=jt.array([[3.0, 2.0, 1.0],[3.0, 2.0, 1.0],[3.0, 2.0, 1.0],[3.0, 2.0, 1.0]]) 44 | y=jt.array([[1.0, 1.0, 1.0],[1.0, 1.0, 1.0],[1.0, 1.0, 1.0],[1.0, 1.0, 1.0]]) 45 | # csc 46 | row_indices=jt.array([0,0,1,2]) 47 | col_offset=jt.array([0,1,3,4]) 48 | csc_weight=jt.array([1.0,2.0,3,0,4.0]) 49 | csc=CSC(row_indices, col_offset, csc_weight) 50 | # csr 51 | col_indices=jt.array([0,1,1,2]) 52 | row_offset=jt.array([0,2,3,4]) 53 | csr_weight=jt.array([3.0,1.0,4,0,2.0]) 54 | csr=CSR(col_indices, row_offset, csr_weight) 55 | 56 | func = MyFunc() 57 | print("x") 58 | abs_x=x.abs().sum() 59 | print(abs_x) 60 | output=func(x,csc,csr) 61 | print("out") 62 | abs_out=output.abs().sum() 63 | print(abs_out) 64 | 65 | # 计算损失并进行反向传播 66 | print(output.shape) 67 | print(y.shape) 68 | loss = nn.BCELoss() 69 | loss_var = loss(output, y) 70 | di = jt.grad(loss_var, [x]) 71 | 72 | print("Input Variable Gradient:", di) -------------------------------------------------------------------------------- /jittor_geometric/tests/test_csr.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Author: lusz 3 | Date: 2024-06-20 22:10:53 4 | Description: convert COO to CSR 5 | ''' 6 | 7 | import jittor as jt 8 | import os 9 | from jittor_geometric.data import CSR 10 | from jittor_geometric.ops import cootocsr 11 | 12 | def test_coo_to_csr(): 13 | jt.flags.use_cuda = 0 14 | jt.flags.lazy_execution = 0 15 | 16 | edge_index = jt.array([[0, 0, 1, 1, 2], [1, 2, 2, 3, 3]]) 17 | edge_weight = jt.array([1.0, 2.0, 3.0, 4.0, 5.0]) 18 | v_num = 4 19 | csr=cootocsr(edge_index, edge_weight ,v_num) 20 | 21 | print("CSR Edge Weight:", csr.edge_weight) 22 | print("Column Indices:", csr.column_indices) 23 | print("Row Offset:", csr.row_offset) 24 | 25 | test_coo_to_csr() -------------------------------------------------------------------------------- /jittor_geometric/tests/test_edgesoftmax.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Description: 3 | Author: lusz 4 | Date: 2024-07-04 12:01:14 5 | ''' 6 | 7 | import jittor as jt 8 | import os 9 | import sys 10 | from jittor import nn 11 | from jittor import Function 12 | sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) 13 | from jittor_geometric.data import CSC, CSR 14 | from jittor_geometric.ops import EdgeSoftmax 15 | 16 | 17 | jt.flags.use_cuda=0 18 | jt.flags.lazy_execution = 0 19 | x=jt.array([1.0,2.0,3.0,4.0]) 20 | y=jt.array([1.0,1.0,1.0,1.0]) 21 | # csc 22 | row_indices=jt.array([0,0,1,2]) 23 | col_offset=jt.array([0,1,3,4]) 24 | csc_weight=jt.array([1.0,2.0,3.0,4.0]) 25 | csc=CSC(row_indices, col_offset, csc_weight) 26 | 27 | output=EdgeSoftmax(x,csc) 28 | print(x) 29 | print(output) 30 | # print(y.shape) 31 | loss = nn.BCELoss() 32 | loss_var = loss(output, y) 33 | di = jt.grad(loss_var, [x]) 34 | 35 | print("Input Variable Gradient:", di) -------------------------------------------------------------------------------- /jittor_geometric/tests/test_edgetovertex.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Author: lusz 3 | Date: 2024-06-21 10:21:33 4 | Description: 5 | ''' 6 | import jittor as jt 7 | import os 8 | import sys 9 | from jittor import nn 10 | from jittor import Function 11 | sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) 12 | from jittor_geometric.data import CSC, CSR 13 | from jittor_geometric.ops import ScatterToVertex 14 | 15 | 16 | 17 | 18 | jt.flags.use_cuda=1 19 | jt.flags.lazy_execution = 0 20 | x=jt.array([[1.0, 1.0, 1.0],[2.0, 2.0, 2.0],[3.0, 3.0, 3.0],[3.0, 2.0, 1.0]]) 21 | y=jt.array([[1.0, 1.0, 1.0],[1.0, 1.0, 1.0],[1.0, 1.0, 1.0]]) 22 | # csc 23 | row_indices=jt.array([0,0,1,2]) 24 | col_offset=jt.array([0,1,3,4]) 25 | csc_weight=jt.array([1.0,2.0,3.0,4.0]) 26 | csc=CSC(row_indices, col_offset, csc_weight) 27 | output=ScatterToVertex(x,csc,"src") 28 | print(x) 29 | print(output) 30 | # print(y.shape) 31 | loss = nn.BCELoss() 32 | print(output.shape) 33 | print(y.shape) 34 | loss_var = loss(output, y) 35 | di = jt.grad(loss_var, [x]) 36 | 37 | print("Input Variable Gradient:", di) -------------------------------------------------------------------------------- /jittor_geometric/tests/test_fromtonodes.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Description: 3 | Author: lusz 4 | Date: 2025-01-11 13:51:38 5 | ''' 6 | import jittor as jt 7 | import sys,os 8 | sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) 9 | from jittor_geometric.data import CSC, CSR 10 | from jittor_geometric.ops import from_nodes,to_nodes 11 | 12 | col_offset=jt.array([0,2,4,5,6,7]) 13 | row_indices=jt.array([0,3,1,4,0,4,2]) 14 | csc_weight=None 15 | csc=CSC(row_indices, col_offset, csc_weight) 16 | 17 | row_offset=jt.array([0,2,3,4,5,7]) 18 | col_indices=jt.array([0,2,1,4,0,1,3]) 19 | csr_weight=None 20 | csr=CSR(col_indices, row_offset, csr_weight) 21 | 22 | nodes = jt.array([1, 2, 2, 4]) 23 | result1 = from_nodes(csc=csc, nodes=nodes) 24 | result2 = to_nodes(csr=csr ,nodes=nodes) 25 | print(result1) # 0 1 2 4 26 | print(result2) # 1 3 4 27 | 28 | -------------------------------------------------------------------------------- /jittor_geometric/tests/test_mp_spmm.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Description: 3 | Author: lusz 4 | Date: 2024-12-28 19:35:50 5 | ''' 6 | import jittor as jt 7 | import os,sys 8 | sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) 9 | from jittor_geometric.ops import SpmmCsr,aggregateWithWeight,cootocsc,cootocsr,SpmmCoo 10 | from jittor_geometric.data import CSC,CSR 11 | def test_spmm_csr(): 12 | jt.flags.use_cuda = 1 13 | jt.flags.lazy_execution = 0 14 | x = jt.array([[3.0, 2.0, 1.0, 0.0, 5.0, 0.0, 1.0, 0.0, 2.0, 0.0], 15 | [1.0, 0.0, 2.0, 3.0, 0.0, 4.0, 0.0, 5.0, 1.0, 0.0], 16 | [0.0, 6.0, 0.0, 0.0, 7.0, 0.0, 8.0, 0.0, 9.0, 0.0], 17 | [4.0, 0.0, 0.0, 1.0, 0.0, 5.0, 0.0, 0.0, 0.0, 6.0], 18 | [0.0, 0.0, 3.0, 0.0, 4.0, 0.0, 2.0, 1.0, 0.0, 0.0], 19 | [7.0, 0.0, 2.0, 0.0, 3.0, 0.0, 4.0, 5.0, 0.0, 6.0], 20 | [1.0, 3.0, 0.0, 2.0, 0.0, 0.0, 4.0, 5.0, 6.0, 0.0], 21 | [0.0, 8.0, 0.0, 0.0, 7.0, 1.0, 0.0, 9.0, 2.0, 0.0], 22 | [9.0, 0.0, 6.0, 0.0, 0.0, 2.0, 0.0, 0.0, 1.0, 3.0], 23 | [0.0, 1.0, 0.0, 4.0, 0.0, 0.0, 6.0, 0.0, 5.0, 7.0]], dtype="float32") 24 | edge_index = jt.array([[0, 0, 1, 1, 2 ,2, 3, 4, 4,5,5,7,8], [1, 2, 2, 3, 3, 4, 5,8,9,5,8,9,9]]) 25 | edge_weight = jt.array([1.0, 2.0, 3.0, 4.0, 5.0,3.0,5.0,1.0,2.0,3.0,2.0,3.0,2.0], dtype="float32") 26 | csr=cootocsr(edge_index,edge_weight,10) 27 | csc=cootocsc(edge_index,edge_weight,10) 28 | output_msg = aggregateWithWeight(x,csc,csr) 29 | print(output_msg) 30 | output_spmm= SpmmCsr(x,csr) 31 | print(output_spmm) 32 | output_coo=SpmmCoo(x,edge_index,edge_weight) 33 | print(output_coo) 34 | jt.flags.use_cuda = 0 35 | output_cpu = aggregateWithWeight(x,csc,csr) 36 | print(output_cpu) 37 | 38 | test_spmm_csr() -------------------------------------------------------------------------------- /jittor_geometric/tests/test_repeat_interleave.py: -------------------------------------------------------------------------------- 1 | import jittor as jt 2 | import os.path as osp 3 | import sys,os 4 | current_dir = osp.dirname(osp.abspath(__file__)) 5 | root = osp.dirname(osp.dirname(current_dir)) 6 | sys.path.append(root) 7 | 8 | from jittor_geometric.ops.repeat_interleave import repeat_interleave 9 | 10 | y1= jt.array([1, 2, 3]) 11 | result1= repeat_interleave(y1, 2) 12 | print(result1) 13 | 14 | y2= jt.array([[1, 2], [3, 4]]) 15 | result2= repeat_interleave(y2, 2) 16 | print(result2) 17 | 18 | result3= repeat_interleave(y2, 3,1) 19 | print(result3) 20 | 21 | result4= repeat_interleave(y2, jt.array([1, 2]),0) 22 | print(result4) 23 | 24 | result4= repeat_interleave(y2, jt.array([1, 2]),1) 25 | print(result4) 26 | 27 | result4= repeat_interleave(y2, jt.array([1, 2]),-2) 28 | print(result4) 29 | 30 | y3= jt.array([[[1, 2, 3], [4, 5, 6]], 31 | [[7, 8, 9], [10, 11, 12]]]) 32 | 33 | # Test dim=1, repeats = [1, 2] 34 | result5 = repeat_interleave(y3, jt.array([1, 2]), dim=1) 35 | print("dim=1:", result5) 36 | 37 | # Test dim=2, repeats = [2, 1, 3] 38 | result6 = repeat_interleave(y3, jt.array([2, 1, 3]), dim=2) 39 | print("dim=2:", result6) 40 | -------------------------------------------------------------------------------- /jittor_geometric/tests/test_scatter_to_edge.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Author: lusz 3 | Date: 2024-06-21 10:21:33 4 | Description: 5 | ''' 6 | import jittor as jt 7 | import os 8 | import sys 9 | from jittor import nn 10 | sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) 11 | from jittor_geometric.data import CSC, CSR 12 | from jittor_geometric.ops import ScatterToEdge 13 | jt.flags.use_cuda=1 14 | x=jt.array([[1.0, 1.0, 1.0],[2.0, 2.0, 2.0],[3.0, 3.0, 3.0],[3.0, 2.0, 1.0]]) 15 | y=jt.array([[1.0, 1.0, 1.0],[1.0, 1.0, 1.0],[1.0, 1.0, 1.0],[1.0, 1.0, 1.0]]) 16 | # csc 17 | row_indices=jt.array([0,0,1,2]) 18 | col_offset=jt.array([0,1,3,4]) 19 | csc_weight=jt.array([1.0,2.0,3.0,4.0]) 20 | csc=CSC(row_indices, col_offset, csc_weight) 21 | # csr 22 | col_indices=jt.array([0,1,1,2]) 23 | row_offset=jt.array([0,2,3,4]) 24 | csr_weight=jt.array([1.0,2.0,3.0,4.0]) 25 | csr=CSR(col_indices, row_offset, csr_weight) 26 | 27 | # output=ScatterToEdge(x,csc,"src") 28 | output=ScatterToEdge(x,csc,"dst") 29 | print(x) 30 | print(output) 31 | # print(y.shape) 32 | loss = nn.BCELoss() 33 | loss_var = loss(output, y) 34 | di = jt.grad(loss_var, [x]) 35 | 36 | print("Input Variable Gradient:", di) -------------------------------------------------------------------------------- /jittor_geometric/tests/test_spmmcoo.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Description: 3 | Author: lusz 4 | Date: 2024-11-11 15:13:39 5 | ''' 6 | import jittor as jt 7 | import os 8 | import sys 9 | from jittor import nn 10 | from jittor_geometric.ops import SpmmCoo 11 | 12 | def test_spmm_coo(): 13 | jt.flags.use_cuda = 1 14 | jt.flags.lazy_execution = 0 15 | x=jt.array([[3.0, 2.0, 1.0],[4.0, 2.0, 2.0],[1.0, 2.0, 3.0]]) 16 | edge_index=jt.array([[0,0,1,2],[1,2,2,1]]) 17 | edge_weight=jt.array([1.0,1.0,1.0,1.0]) 18 | output=SpmmCoo(x,edge_index,edge_weight) 19 | print(output) 20 | 21 | test_spmm_coo() -------------------------------------------------------------------------------- /jittor_geometric/tests/test_spmmcsr.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Description: 3 | Author: lusz 4 | Date: 2024-11-05 16:25:39 5 | ''' 6 | import jittor as jt 7 | import os 8 | from jittor_geometric.data import CSR 9 | from jittor_geometric.ops import SpmmCsr 10 | def test_spmm_csr(): 11 | jt.flags.use_cuda = 1 12 | jt.flags.lazy_execution = 0 13 | x=jt.array([[3.0, 2.0, 1.0],[3.0, 2.0, 1.0],[3.0, 2.0, 1.0]]) 14 | col_indices=jt.array([0,1,1,2],dtype='int64') 15 | row_offset=jt.array([0,2,3,4],dtype='int64') 16 | csr_weight=jt.array([3.0,1.0,4.0,2.0], dtype='float32') 17 | csr=CSR(column_indices=col_indices,row_offset=row_offset,edge_weight=csr_weight) 18 | output=SpmmCsr(x,csr) 19 | print(output) 20 | 21 | test_spmm_csr() -------------------------------------------------------------------------------- /jittor_geometric/tests/test_undirected.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Description: 3 | Author: lusz 4 | Date: 2024-06-23 14:53:18 5 | ''' 6 | import jittor as jt 7 | import os 8 | import sys 9 | from jittor_geometric.ops import toUndirected 10 | 11 | jt.flags.lazy_execution = 0 12 | edge_index = jt.array([[0, 1, 1], 13 | [2, 0, 2]]) 14 | edge_attr = jt.array([1., 3., 2.]) 15 | num_edges=3 16 | num_nodes=3 17 | new_edge_index,new_edge_attr=toUndirected(edge_index,edge_attr,num_nodes) 18 | print(new_edge_index) 19 | print(new_edge_attr) -------------------------------------------------------------------------------- /jittor_geometric/transforms/__init__.py: -------------------------------------------------------------------------------- 1 | from .normalize_features import NormalizeFeatures 2 | 3 | __all__ = [ 4 | 'NormalizeFeatures', 5 | ] 6 | 7 | classes = __all__ 8 | -------------------------------------------------------------------------------- /jittor_geometric/transforms/normalize_features.py: -------------------------------------------------------------------------------- 1 | class NormalizeFeatures(object): 2 | r"""Row-normalizes node features to sum-up to one.""" 3 | 4 | def __call__(self, data): 5 | data.x = data.x / data.x.sum(1, keepdims=True).clamp(min_v=1) 6 | return data 7 | 8 | def __repr__(self): 9 | return '{}()'.format(self.__class__.__name__) 10 | -------------------------------------------------------------------------------- /jittor_geometric/typing.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, Optional, Union 2 | 3 | from jittor import Var 4 | 5 | Adj = Optional[Var] 6 | OptVar = Optional[Var] 7 | PairVar = Tuple[Var, Var] 8 | OptPairVar = Tuple[Var, Optional[Var]] 9 | PairOptVar = Tuple[Optional[Var], Optional[Var]] 10 | Size = Optional[Tuple[int, int]] 11 | NoneType = Optional[Var] 12 | 13 | EdgeType = Tuple[str, str, str] 14 | NodeType = str 15 | SparseVar = Optional[Var] 16 | jt_lib = object 17 | jt_scatter = object -------------------------------------------------------------------------------- /jittor_geometric/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .coalesce import coalesce 2 | from .degree import degree 3 | from .loop import (contains_self_loops, remove_self_loops, 4 | segregate_self_loops, add_self_loops, 5 | add_remaining_self_loops) 6 | from .isolated import contains_isolated_nodes, remove_isolated_nodes 7 | from .get_laplacian import get_laplacian 8 | from .undirected import to_undirected 9 | from .sort import index_sort, unique 10 | from .sparse import is_jittor_sparse_tensor 11 | from .scatter import scatter 12 | from .induced_graph import induced_graph 13 | from .neighbor_sampler import neighbor_sampler, randomwalk_sampler 14 | from .one_hot import one_hot 15 | from .num_nodes import maybe_num_nodes 16 | from .smiles import from_rdmol, to_rdmol, from_smiles, to_smiles 17 | 18 | __all__ = [ 19 | 'coalesce', 20 | 'degree', 21 | 'contains_self_loops', 22 | 'remove_self_loops', 23 | 'segregate_self_loops', 24 | 'add_self_loops', 25 | 'add_remaining_self_loops', 26 | 'contains_isolated_nodes', 27 | 'remove_isolated_nodes', 28 | 'get_laplacian', 29 | 'undirected' 30 | 'index_sort', 31 | 'is_jittor_sparse_tensor', 32 | 'scatter', 33 | 'induced_graph', 34 | 'unique', 35 | 'neighbor_sampler', 36 | 'randomwalk_sampler', 37 | 'one_hot', 38 | 'maybe_num_nodes', 39 | 'from_rdmol', 40 | 'to_rdmol', 41 | 'from_smiles', 42 | 'to_smiles', 43 | ] 44 | 45 | classes = __all__ 46 | -------------------------------------------------------------------------------- /jittor_geometric/utils/coalesce.py: -------------------------------------------------------------------------------- 1 | import jittor as jt 2 | from jittor import Var 3 | from jittor_geometric.utils.num_nodes import maybe_num_nodes 4 | from typing import Optional 5 | 6 | def coalesce(edge_index, 7 | edge_weight: Optional[Var] = None, 8 | num_nodes: Optional[int] = None, 9 | reduce: str = 'sum', 10 | is_sorted: bool = False, 11 | sort_by_row: bool = True): 12 | 13 | num_edges = edge_index.shape[1] 14 | num_nodes = maybe_num_nodes(edge_index, num_nodes) 15 | 16 | idx = jt.zeros(num_edges + 1).astype(edge_index.dtype) 17 | idx[0] = -1 18 | idx[1:] = edge_index[1 - int(sort_by_row)] 19 | idx[1:] = idx[1:].mul(num_nodes).add(edge_index[int(sort_by_row)]) 20 | 21 | if not is_sorted: 22 | idx[1:], perm = jt.sort(idx[1:], dim=0) 23 | edge_index = edge_index[:, perm] 24 | if edge_weight is not None: 25 | edge_weight = edge_weight[perm] 26 | 27 | mask = idx[1:] > idx[:-1] 28 | 29 | if jt.all(mask): 30 | return edge_index, edge_weight 31 | 32 | edge_index = edge_index[:, mask] 33 | 34 | if edge_weight is None: 35 | return edge_index, None 36 | else: 37 | num_edges = edge_index.shape[1] 38 | idx = jt.arange(0, num_edges, dtype=jt.int32) 39 | idx = idx - (~mask).cumsum(0) 40 | edge_weight = jt.zeros((num_edges,)).scatter_(0, idx, edge_weight, reduce=reduce) 41 | return edge_index, edge_weight -------------------------------------------------------------------------------- /jittor_geometric/utils/degree.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import jittor as jt 4 | 5 | from .num_nodes import maybe_num_nodes 6 | 7 | 8 | def degree(index, num_nodes: Optional[int] = None, 9 | dtype: Optional[int] = None): 10 | r"""Computes the (unweighted) degree of a given one-dimensional index 11 | Var. 12 | """ 13 | N = maybe_num_nodes(index, num_nodes) 14 | out = jt.zeros((N, ), dtype=dtype) 15 | one = jt.ones((index.size(0), ), dtype=out.dtype) 16 | return jt.scatter(out, 0, index, one, reduce = 'sum') 17 | -------------------------------------------------------------------------------- /jittor_geometric/utils/get_laplacian.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import jittor as jt 4 | from jittor import Var 5 | from jittor_geometric.utils import add_self_loops, remove_self_loops 6 | 7 | from .num_nodes import maybe_num_nodes 8 | 9 | 10 | def get_laplacian(edge_index, edge_weight: Optional[Var] = None, 11 | normalization: Optional[str] = None, 12 | dtype: Optional[int] = None, 13 | num_nodes: Optional[int] = None): 14 | r""" Computes the graph Laplacian of the graph given by :obj:`edge_index` 15 | and optional :obj:`edge_weight`. 16 | 17 | Args: 18 | edge_index (Var int32): The edge indices. 19 | edge_weight (Var, optional): One-dimensional edge weights. 20 | (default: :obj:`None`) 21 | normalization (str, optional): The normalization scheme for the graph 22 | Laplacian (default: :obj:`None`): 23 | 24 | 1. :obj:`None`: No normalization 25 | :math:`\mathbf{L} = \mathbf{D} - \mathbf{A}` 26 | 27 | 2. :obj:`"sym"`: Symmetric normalization 28 | :math:`\mathbf{L} = \mathbf{I} - \mathbf{D}^{-1/2} \mathbf{A} 29 | \mathbf{D}^{-1/2}` 30 | 31 | 3. :obj:`"rw"`: Random-walk normalization 32 | :math:`\mathbf{L} = \mathbf{I} - \mathbf{D}^{-1} \mathbf{A}` 33 | dtype (Var.dtype, optional): The desired data type of returned Var 34 | in case :obj:`edge_weight=None`. (default: :obj:`None`) 35 | num_nodes (int, optional): The number of nodes, *i.e.* 36 | :obj:`max_val + 1` of :attr:`edge_index`. (default: :obj:`None`) 37 | """ 38 | 39 | if normalization is not None: 40 | assert normalization in ['sym', 'rw'] # 'Invalid normalization' 41 | 42 | edge_index, edge_weight = remove_self_loops(edge_index, edge_weight) 43 | 44 | if edge_weight is None: 45 | edge_weight = jt.ones((edge_index.size(1)), dtype=dtype) 46 | 47 | num_nodes = maybe_num_nodes(edge_index, num_nodes) 48 | 49 | row, col = edge_index[0], edge_index[1] 50 | shape = list(edge_weight.shape) 51 | shape[0] = num_nodes 52 | deg = jt.zeros(shape) 53 | deg = jt.scatter(deg, 0, row, src=edge_weight, reduce='add') 54 | if normalization is None: 55 | # L = D - A. 56 | edge_index, _ = add_self_loops(edge_index, num_nodes=num_nodes) 57 | edge_weight = jt.concat([-edge_weight, deg], dim=0) 58 | elif normalization == 'sym': 59 | # Compute A_norm = -D^{-1/2} A D^{-1/2}. 60 | deg_inv_sqrt = deg.pow(-0.5) 61 | # deg_inv_sqrt.masked_fill(deg_inv_sqrt == float('inf'), 0) 62 | 63 | for i in range(deg_inv_sqrt.shape[0]): 64 | if deg_inv_sqrt[i] == float('inf'): 65 | deg_inv_sqrt[i] = 0 66 | edge_weight = deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col] 67 | 68 | # L = I - A_norm. 69 | edge_index, tmp = add_self_loops(edge_index, -edge_weight, 70 | fill_value=1., num_nodes=num_nodes) 71 | assert tmp is not None 72 | edge_weight = tmp 73 | else: 74 | # Compute A_norm = -D^{-1} A. 75 | deg_inv = 1.0 / deg 76 | deg_inv.masked_fill(deg_inv == float('inf'), 0) 77 | edge_weight = deg_inv[row] * edge_weight 78 | 79 | # L = I - A_norm. 80 | edge_index, tmp = add_self_loops(edge_index, -edge_weight, 81 | fill_value=1., num_nodes=num_nodes) 82 | assert tmp is not None 83 | edge_weight = tmp 84 | 85 | return edge_index, edge_weight 86 | -------------------------------------------------------------------------------- /jittor_geometric/utils/induced_graph.py: -------------------------------------------------------------------------------- 1 | import jittor 2 | from typing import Tuple 3 | 4 | def induced_graph( 5 | edge_index: jittor.Var, 6 | node_selected: jittor.Var, 7 | max_nodes: int 8 | ) -> Tuple[jittor.Var, jittor.Var]: 9 | r"""Generate the node-induced graph of the original graph described by 10 | 'edge_index'. It is expected that max_nodes is larger than, or equal to 11 | the real max node index in edge_index, or may cause errors. 12 | 13 | Args: 14 | edge_index (jittor.Var): The edge list describing the whole graph. 15 | node_selected (jittor.Var): The node list of the expected induced 16 | graph. 17 | max_nodes (int): The maximum node index in edge_index. 18 | """ 19 | node_mask = jittor.zeros(max_nodes, dtype = "bool") 20 | node_mask[node_selected] = True 21 | node_map = jittor.zeros(max_nodes, dtype = "int") 22 | node_map[node_selected] = jittor.arange(0, node_selected.size(0)) 23 | 24 | edge_mask = node_mask[edge_index[0]] & node_mask[edge_index[1]] 25 | edge_selected = jittor.nonzero(edge_mask).view(-1) 26 | 27 | return node_map, edge_selected -------------------------------------------------------------------------------- /jittor_geometric/utils/isolated.py: -------------------------------------------------------------------------------- 1 | import jittor as jt 2 | from jittor import Var 3 | from jittor_geometric.utils import remove_self_loops, segregate_self_loops 4 | 5 | from .num_nodes import maybe_num_nodes 6 | 7 | 8 | def contains_isolated_nodes(edge_index, num_nodes=None): 9 | r"""Returns :obj:`True` if the graph given by :attr:`edge_index` contains 10 | isolated nodes. 11 | 12 | Args: 13 | edge_index (Var int32): The edge indices. 14 | num_nodes (int, optional): The number of nodes, *i.e.* 15 | :obj:`max_val + 1` of :attr:`edge_index`. (default: :obj:`None`) 16 | 17 | :rtype: bool 18 | """ 19 | num_nodes = maybe_num_nodes(edge_index, num_nodes) 20 | (row, col), _ = remove_self_loops(edge_index) 21 | 22 | return jt.unique(jt.concat((row, col))).size(0) < num_nodes 23 | 24 | 25 | def remove_isolated_nodes(edge_index, edge_attr=None, num_nodes=None): 26 | r"""Removes the isolated nodes from the graph given by :attr:`edge_index` 27 | with optional edge attributes :attr:`edge_attr`. 28 | In addition, returns a mask of shape :obj:`[num_nodes]` to manually filter 29 | out isolated node features later on. 30 | Self-loops are preserved for non-isolated nodes. 31 | 32 | Args: 33 | edge_index (Var int32): The edge indices. 34 | edge_attr (Var, optional): Edge weights or multi-dimensional 35 | edge features. (default: :obj:`None`) 36 | num_nodes (int, optional): The number of nodes, *i.e.* 37 | :obj:`max_val + 1` of :attr:`edge_index`. (default: :obj:`None`) 38 | 39 | :rtype: (Var int32, Var, Var bool) 40 | """ 41 | num_nodes = maybe_num_nodes(edge_index, num_nodes) 42 | 43 | out = segregate_self_loops(edge_index, edge_attr) 44 | edge_index, edge_attr, loop_edge_index, loop_edge_attr = out 45 | 46 | mask = jt.zeros((num_nodes), dtype=Var.bool) 47 | mask[edge_index.view(-1)] = 1 48 | 49 | assoc = jt.full((num_nodes, ), -1, dtype=Var.int32) 50 | assoc[mask] = jt.arange(mask.sum()) 51 | edge_index = assoc[edge_index] 52 | 53 | loop_mask = jt.zeros_like(mask) 54 | loop_mask[loop_edge_index[0]] = 1 55 | loop_mask = loop_mask & mask 56 | loop_assoc = jt.full_like(assoc, -1) 57 | loop_assoc[loop_edge_index[0]] = jt.arange(loop_edge_index.size(1)) 58 | loop_idx = loop_assoc[loop_mask] 59 | loop_edge_index = assoc[loop_edge_index[:, loop_idx]] 60 | 61 | edge_index = jt.concat([edge_index, loop_edge_index], dim=1) 62 | 63 | if edge_attr is not None: 64 | loop_edge_attr = loop_edge_attr[loop_idx] 65 | edge_attr = jt.concat([edge_attr, loop_edge_attr], dim=0) 66 | 67 | return edge_index, edge_attr, mask 68 | -------------------------------------------------------------------------------- /jittor_geometric/utils/neighbor_sampler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from typing import Optional, Tuple 3 | import jittor as jt 4 | import copy 5 | 6 | def neighbor_sampler( 7 | neighbor_list: jt.Var, 8 | neighbor_offs: jt.Var, 9 | neighbor_nums: jt.Var, 10 | source_node: jt.Var, 11 | max_nodes : Optional[int] = None, 12 | max_edges : Optional[int] = None 13 | ) -> jt.Var: 14 | r"""Samples the neighbor of all the source nodes, in the given graph 15 | described by neighbor_list, neighbor_hook, and neighbor_offs. 16 | 17 | Args: 18 | neighbor_list (jittor.Var): An ordered list of neighbors. 19 | neighbor_offs (jittor.Var): For each node i, the neighbors of i 20 | are neighbor_list[neighbor_offs[i], neighbor_offs[i+1]]. There is 21 | neighbor_offs[max_nodes], which should equals to negihbor_list.size(0). 22 | neighbor_nums (jittor.Var): For each node i, the number of its 23 | neighbors is neighbor_nums[i] == neighbor_offs[i+1] - 24 | neighbor_offs[i].For neighbor_nums[i] == 0, should be preprocessed 25 | (usually assigned max_edges). 26 | source_node (jittor,Var): The source node of sampling. 27 | max_nodes (int, optional): The total number of nodes, assert to be (the 28 | length of neighbor_offs) - 1 and (the length of neighbor_nums) . 29 | max_edges (int, optional): The total number of edges, assert to be the 30 | length of neighbor_list. 31 | """ 32 | 33 | if max_nodes is None: 34 | max_nodes = neighbor_nums.size(0) 35 | if max_edges is None: 36 | max_edges = neighbor_list.size(0) 37 | idx = jt.randint_like(source_node, 0, max_edges) 38 | idx = (idx % neighbor_nums[source_node] + neighbor_offs[source_node]) % max_edges 39 | dst = neighbor_list[1,idx] 40 | 41 | return dst 42 | 43 | 44 | def randomwalk_sampler( 45 | neighbor_list: jt.Var, 46 | neighbor_offs: jt.Var, 47 | neighbor_nums: jt.Var, 48 | source_node: jt.Var, 49 | walk_length: int, 50 | max_nodes : Optional[int] = None, 51 | max_edges : Optional[int] = None 52 | ) -> jt.Var: 53 | r"""Samples the random_walk of all the source nodes with length 'walk_length', 54 | in the given graph described by neighbor_list, neighbor_hook, and neighbor_offs. 55 | 56 | Args: 57 | neighbor_list (jittor.Var): An ordered list of neighbors. 58 | neighbor_offs (jittor.Var): For each node i, the neighbors of i 59 | are neighbor_list[neighbor_offs[i], neighbor_offs[i+1]]. There is 60 | neighbor_offs[max_nodes], which should equals to negihbor_list.size(0). 61 | neighbor_nums (jittor.Var): For each node i, the number of its 62 | neighbors is neighbor_nums[i] == neighbor_offs[i+1] - 63 | neighbor_offs[i].For neighbor_nums[i] == 0, should be preprocessed 64 | (usually assigned max_edges). 65 | source_node (jittor,Var): The source node of sampling. 66 | walk_length (int): The length of random walk. 67 | max_nodes (int, optional): The total number of nodes, assert to be (the 68 | length of neighbor_offs) - 1 and (the length of neighbor_nums) . 69 | max_edges (int, optional): The total number of edges, assert to be the 70 | length of neighbor_list. 71 | """ 72 | if max_nodes is None: 73 | max_nodes = neighbor_nums.size(0) 74 | if max_edges is None: 75 | max_edges = neighbor_list.size(0) 76 | dst = jt.zeros([source_node.size(0), walk_length + 1], dtype = 'int') 77 | dst[:, 0] = source_node 78 | source = copy.copy(source_node) 79 | for i in range(walk_length): 80 | target = neighbor_sampler(neighbor_list, neighbor_offs, neighbor_nums, source, max_nodes, max_edges) 81 | dst[:, i+1] = target 82 | source = copy.copy(target) 83 | return dst 84 | -------------------------------------------------------------------------------- /jittor_geometric/utils/num_nodes.py: -------------------------------------------------------------------------------- 1 | from jittor import Var 2 | 3 | 4 | def maybe_num_nodes(edge_index, num_nodes=None): 5 | if num_nodes is not None: 6 | return num_nodes 7 | elif isinstance(edge_index, Var): 8 | return int(edge_index.max()) + 1 9 | else: 10 | return max(edge_index.size(0), edge_index.size(1)) 11 | -------------------------------------------------------------------------------- /jittor_geometric/utils/one_hot.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import jittor as jt 4 | from jittor import Var 5 | 6 | 7 | def one_hot( 8 | index: Var, 9 | num_classes: Optional[int] = None, 10 | ) -> Var: 11 | r"""Taskes a one-dimensional :obj:`index` var and returns a one-hot 12 | encoded representation of it with shape :obj:`[*, num_classes]` that has 13 | zeros everywhere except where the index of last dimension matches the 14 | corresponding value of the input var, in which case it will be :obj:`1`. 15 | 16 | Args: 17 | index (jittor.Var): The one-dimensional input var. 18 | num_classes (int, optional): The total number of classes. If set to 19 | :obj:`None`, the number of classes will be inferred as one greater 20 | than the largest class value in the input var. 21 | (default: :obj:`None`) 22 | dtype (jittor.dtype, optional): The :obj:`dtype` of the output var. 23 | """ 24 | if index.dim() != 1: 25 | raise ValueError("'index' var needs to be one-dimensional") 26 | 27 | if num_classes is None: 28 | num_classes = int(index.max()) + 1 29 | 30 | out = jt.zeros((index.size(0), num_classes)) 31 | return out.scatter_(1, index.unsqueeze(1), jt.Var([1])) -------------------------------------------------------------------------------- /jittor_geometric/utils/sort.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple 2 | 3 | import jittor 4 | 5 | import jittor_geometric.typing 6 | from jittor_geometric.typing import jt_lib 7 | 8 | 9 | def index_sort( 10 | inputs: jittor.Var, 11 | max_value: Optional[int] = None, 12 | ) -> Tuple[jittor.Var, jittor.Var]: 13 | r"""Sorts the elements of the :obj:`inputs` tensor in ascending order. 14 | It is expected that :obj:`inputs` is one-dimensional and that it only 15 | contains positive integer values. If :obj:`max_value` is given, it can 16 | be used by the underlying algorithm for better performance. 17 | 18 | Args: 19 | inputs (jittor.Var): A vector with positive integer values. 20 | max_value (int, optional): The maximum value stored inside 21 | :obj:`inputs`. This value can be an estimation, but needs to be 22 | greater than or equal to the real maximum. 23 | (default: :obj:`None`) 24 | """ 25 | if not jittor_geometric.typing.WITH_INDEX_SORT: # pragma: no cover 26 | return inputs.sort() 27 | return jt_lib.ops.index_sort(inputs, max_value=max_value) 28 | 29 | def unique( 30 | inputs: jittor.Var, 31 | max_value: Optional[int] = None, 32 | ) -> jittor.Var: 33 | r"""Sorts the elements of the :obj:`inputs` tensor in ascending order, 34 | and then delete the duplicated values. 35 | It is expected that :obj:`inputs` is one-dimensional and that it only 36 | contains positive integer values. If :obj:`max_value` is given, it can 37 | be used by the underlying algorithm for better performance (well this 38 | is fake because the index_sort() above cannot be directly called). 39 | 40 | Args: 41 | inputs (jittor.Var): A vector with positive integer values. 42 | max_value (int, optional): The maximum value stored inside 43 | :obj:`inputs`. This value can be an estimation, but needs to be 44 | greater than or equal to the real maximum. 45 | (default: :obj:`None`) 46 | """ 47 | res = inputs.sort()[0] 48 | #print(res) 49 | selection = jittor.zeros_like(res, dtype = "bool") 50 | selection[0] = True 51 | selection[1:] = res[1:] != res[:-1] 52 | res = res[selection] 53 | return res 54 | -------------------------------------------------------------------------------- /jittor_geometric/utils/sort_edge_index.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Tuple, Union 2 | 3 | import jittor 4 | from jittor import Var 5 | 6 | from jittor_geometric.typing import OptVar 7 | from jittor_geometric.utils import index_sort 8 | from jittor_geometric.utils.num_nodes import maybe_num_nodes 9 | 10 | MISSING = '???' 11 | 12 | 13 | # @torch.jit._overload 14 | # def sort_edge_index(edge_index, edge_attr, num_nodes, sort_by_row): 15 | # # type: (Tensor, str, Optional[int], bool) -> Tensor # noqa 16 | # pass 17 | 18 | 19 | # @torch.jit._overload 20 | # def sort_edge_index(edge_index, edge_attr, num_nodes, sort_by_row): 21 | # # type: (Tensor, Optional[Tensor], Optional[int], bool) -> Tuple[Tensor, Optional[Tensor]] # noqa 22 | # pass 23 | 24 | 25 | # @torch.jit._overload 26 | # def sort_edge_index(edge_index, edge_attr, num_nodes, sort_by_row): 27 | # # type: (Tensor, List[Tensor], Optional[int], bool) -> Tuple[Tensor, List[Tensor]] # noqa 28 | # pass 29 | 30 | 31 | def sort_edge_index( 32 | edge_index: Var, 33 | edge_attr: Union[OptVar, List[Var], str] = MISSING, 34 | num_nodes: Optional[int] = None, 35 | sort_by_row: bool = True, 36 | ) -> Union[Var, Tuple[Var, OptVar], Tuple[Var, List[Var]]]: 37 | """Row-wise sorts :obj:`edge_index`. 38 | 39 | Args: 40 | edge_index (LongTensor): The edge indices. 41 | edge_attr (Tensor or List[Tensor], optional): Edge weights or multi- 42 | dimensional edge features. 43 | If given as a list, will re-shuffle and remove duplicates for all 44 | its entries. (default: :obj:`None`) 45 | num_nodes (int, optional): The number of nodes, *i.e.* 46 | :obj:`max_val + 1` of :attr:`edge_index`. (default: :obj:`None`) 47 | sort_by_row (bool, optional): If set to :obj:`False`, will sort 48 | :obj:`edge_index` column-wise. 49 | 50 | :rtype: :class:`LongTensor` if :attr:`edge_attr` is not passed, else 51 | (:class:`LongTensor`, :obj:`Optional[Tensor]` or :obj:`List[Tensor]]`) 52 | 53 | .. warning:: 54 | 55 | From :pyg:`PyG >= 2.3.0` onwards, this function will always return a 56 | tuple whenever :obj:`edge_attr` is passed as an argument (even in case 57 | it is set to :obj:`None`). 58 | 59 | Examples: 60 | 61 | >>> edge_index = torch.tensor([[2, 1, 1, 0], 62 | [1, 2, 0, 1]]) 63 | >>> edge_attr = torch.tensor([[1], [2], [3], [4]]) 64 | >>> sort_edge_index(edge_index) 65 | tensor([[0, 1, 1, 2], 66 | [1, 0, 2, 1]]) 67 | 68 | >>> sort_edge_index(edge_index, edge_attr) 69 | (tensor([[0, 1, 1, 2], 70 | [1, 0, 2, 1]]), 71 | tensor([[4], 72 | [3], 73 | [2], 74 | [1]])) 75 | """ 76 | num_nodes = maybe_num_nodes(edge_index, num_nodes) 77 | 78 | idx = edge_index[1 - int(sort_by_row)] * num_nodes 79 | idx += edge_index[int(sort_by_row)] 80 | 81 | _, perm = index_sort(idx, max_value=num_nodes * num_nodes) 82 | 83 | edge_index = edge_index[:, perm] 84 | 85 | if edge_attr is None: 86 | return edge_index, None 87 | if isinstance(edge_attr, Var): 88 | return edge_index, edge_attr[perm] 89 | if isinstance(edge_attr, (list, tuple)): 90 | return edge_index, [e[perm] for e in edge_attr] 91 | 92 | return edge_index 93 | -------------------------------------------------------------------------------- /jittor_geometric/utils/sparse.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Optional, Tuple, Union 2 | 3 | import jittor 4 | from jittor import Var 5 | 6 | from jittor_geometric.typing import SparseVar 7 | 8 | 9 | def is_jittor_sparse_tensor(src: Any) -> bool: 10 | r"""Returns :obj:`True` if the input :obj:`src` is a 11 | :class:`jittor.sparse.Tensor` (in any sparse layout). 12 | 13 | Args: 14 | src (Any): The input object to be checked. 15 | """ 16 | if isinstance(src, Var): 17 | if src.layout == jittor.sparse_coo: 18 | return True 19 | if src.layout == jittor.sparse_csr: 20 | return True 21 | if src.layout == jittor.sparse_csc: 22 | return True 23 | return False -------------------------------------------------------------------------------- /jittor_geometric/utils/undirected.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import jittor as jt 4 | from jittor import Var 5 | from jittor_geometric.utils.coalesce import coalesce 6 | 7 | 8 | def to_undirected(edge_index, 9 | edge_weight: Optional[Var] = None, 10 | num_nodes: Optional[int] = None, 11 | reduce: str = 'add'): 12 | 13 | row, col = edge_index[0], edge_index[1] 14 | row, col = jt.concat([row, col], dim=0), jt.concat([col, row], dim=0) 15 | edge_index = jt.stack([row, col], dim=0) 16 | edge_weight = jt.concat([edge_weight, edge_weight], dim=0) if edge_weight is not None else None 17 | 18 | return coalesce(edge_index, edge_weight, num_nodes, reduce) -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | ase==3.24.0 2 | astunparse==1.6.3 3 | autograd==1.7.0 4 | cupy==13.3.0 5 | Flask==3.1.0 6 | huggingface_hub==0.27.1 7 | jittor==1.3.9.14 8 | numpy==1.24.0 9 | pandas==2.2.3 10 | Pillow==11.1.0 11 | PyMetis==2023.1.1 12 | pyparsing==3.2.1 13 | pywebio==1.8.3 14 | recommonmark==0.7.1 15 | schnetpack==2.0.0 16 | scikit_learn==1.6.1 17 | scipy==1.15.1 18 | setuptools==69.5.1 19 | six==1.16.0 20 | sphinx_rtd_theme==3.0.2 21 | sympy==1.13.3 22 | tqdm==4.66.4 23 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages, Extension 2 | 3 | 4 | setup( 5 | name='jittor_geometric', 6 | version='0.1', 7 | # author='Your Name', 8 | # author_email='your.email@example.com', 9 | packages=find_packages(), 10 | package_data={ 11 | 'jittor_geometric': ['ops/cpp/*.cc', 'ops/cpp/*.h'], 12 | }, 13 | # description='A brief description of the library', 14 | long_description=open('README.md').read(), 15 | long_description_content_type='text/markdown', 16 | classifiers=[ 17 | ], 18 | ) 19 | --------------------------------------------------------------------------------